68 lines
1.9 KiB
Kotlin

package roomescape.util
import io.kotest.core.listeners.AfterSpecListener
import io.kotest.core.listeners.AfterTestListener
import io.kotest.core.spec.Spec
import io.kotest.core.test.TestCase
import io.kotest.core.test.TestResult
import io.kotest.extensions.spring.testContextManager
import jakarta.persistence.EntityManager
import org.springframework.jdbc.core.JdbcTemplate
import org.springframework.stereotype.Component
@Component
class DatabaseCleaner(
val entityManager: EntityManager,
val jdbcTemplate: JdbcTemplate,
) {
val tables: List<String> by lazy {
jdbcTemplate.query("SHOW TABLES") { rs, _ ->
rs.getString(1).lowercase()
}
}
fun clear() {
entityManager.clear()
jdbcTemplate.execute("SET REFERENTIAL_INTEGRITY FALSE")
tables.forEach {
if (it == "region") {
return@forEach
}
jdbcTemplate.execute("TRUNCATE TABLE $it RESTART IDENTITY")
}
jdbcTemplate.execute("SET REFERENTIAL_INTEGRITY TRUE")
}
}
enum class CleanerMode {
AFTER_EACH_TEST,
AFTER_SPEC
}
class DatabaseCleanerExtension(
private val mode: CleanerMode
) : AfterTestListener, AfterSpecListener {
override suspend fun afterTest(testCase: TestCase, result: TestResult) {
super.afterTest(testCase, result)
when (mode) {
CleanerMode.AFTER_EACH_TEST -> getCleaner().clear()
CleanerMode.AFTER_SPEC -> Unit
}
}
override suspend fun afterSpec(spec: Spec) {
super.afterSpec(spec)
when (mode) {
CleanerMode.AFTER_EACH_TEST -> Unit
CleanerMode.AFTER_SPEC -> getCleaner().clear()
}
}
private suspend fun getCleaner(): DatabaseCleaner {
return testContextManager().testContext
.applicationContext
.getBean(DatabaseCleaner::class.java)
}
}