diff --git a/build.gradle b/build.gradle index 2bfc5e4..dbdd048 100644 --- a/build.gradle +++ b/build.gradle @@ -1,5 +1,5 @@ plugins { - id 'org.jetbrains.kotlin.jvm' version '1.3.61' apply false + id 'org.jetbrains.kotlin.jvm' version '1.3.72' apply false } subprojects { @@ -16,7 +16,7 @@ subprojects { testImplementation 'org.jetbrains.kotlin:kotlin-test' testImplementation 'org.jetbrains.kotlin:kotlin-test-junit5' testImplementation 'io.kotlintest:kotlintest-runner-junit5:3.4.2' - testCompile 'org.testcontainers:postgresql:1.12.5' + testCompile 'org.testcontainers:postgresql:1.14.3' testCompile 'postgresql:postgresql:9.1-901-1.jdbc4' } @@ -34,10 +34,10 @@ project(":codegen") { dependencies { implementation project(":runtime") - implementation 'com.squareup:kotlinpoet:1.5.0' + implementation 'com.squareup:kotlinpoet:1.6.0' implementation 'org.atteo:evo-inflector:1.2.2' implementation 'org.apache.commons:commons-text:1.8' - implementation 'org.postgresql:postgresql:42.2.9' + implementation 'org.postgresql:postgresql:42.2.14' testImplementation("io.kotlintest:kotlintest-runner-junit5:3.3.0") } diff --git a/codegen/src/main/kotlin/norm/CodeGenerator.kt b/codegen/src/main/kotlin/norm/CodeGenerator.kt index d761c9d..71d89ce 100644 --- a/codegen/src/main/kotlin/norm/CodeGenerator.kt +++ b/codegen/src/main/kotlin/norm/CodeGenerator.kt @@ -15,20 +15,15 @@ class CodeGenerator(private val typeMapper: DbToKtDefaultTypeMapper = DbToKtDefa fileBuilder.addType( TypeSpec.classBuilder(ClassName(packageName, paramsClassName)) - .addModifiers(KModifier.DATA) + .also { if (params.isNotEmpty()) it.addModifiers(KModifier.DATA) } .primaryConstructor( FunSpec.constructorBuilder() .addParameters(params.distinctBy { it.name }.map { - ParameterSpec.builder( - it.name, - getTypeName(it) - ).build() + ParameterSpec.builder(it.name, getTypeName(it)).build() }).build() ) .addProperties(params.distinctBy { it.name }.map { - PropertySpec.builder(it.name, - getTypeName(it) - ) + PropertySpec.builder(it.name, getTypeName(it)) .initializer(it.name) .build() }) @@ -88,28 +83,24 @@ class CodeGenerator(private val typeMapper: DbToKtDefaultTypeMapper = DbToKtDefa FunSpec.constructorBuilder() .addParameters(cols.map { ParameterSpec.builder(it.fieldName, - if(it.colType.startsWith("_")) ARRAY.parameterizedBy(typeMapper.getType(it.colType).asTypeName()) - else typeMapper.getType(it.colType).asTypeName() + getTypeName(it) ).build() }).build() ) .addProperties(cols.map { - PropertySpec.builder(it.fieldName, - if(it.colType.startsWith("_")) ARRAY.parameterizedBy(typeMapper.getType(it.colType).asTypeName()) - else typeMapper.getType(it.colType).asTypeName() - ) + PropertySpec.builder(it.fieldName, getTypeName(it)) .initializer(it.fieldName) .build() }) .build() ) - val constructArgs = "\n" + cols.map { - if(it.colType.startsWith("_")) - "${it.fieldName} = rs.getArray(\"${it.colName}\").array as Array<${typeMapper.getType(it.colType).asClassName().simpleName}>" + val constructArgs = "\n" + cols.joinToString(",\n ") { + if (it.colType.startsWith("_")) + "${it.fieldName} = rs.getArray(\"${it.colName}\").array as ${getTypeName(it)}>" else - "${it.fieldName} = rs.getObject(\"${it.colName}\") as ${typeMapper.getType(it.colType).asClassName().canonicalName}" - }.joinToString(",\n ") + "${it.fieldName} = rs.getObject(\"${it.colName}\") as ${getTypeName(it)}" + } fileBuilder.addType( TypeSpec.classBuilder(ClassName(packageName, rowMapperClassName)) @@ -164,16 +155,18 @@ class CodeGenerator(private val typeMapper: DbToKtDefaultTypeMapper = DbToKtDefa return fileBuilder.build().toString() } - private fun getTypeName(it: ParamModel): TypeName { - return if (it.dbType.startsWith("_")) - ARRAY.parameterizedBy(typeMapper.getType(it.dbType).asTypeName()).copy(nullable = it.isNullable) - else typeMapper.getType(it.dbType).asTypeName().copy(nullable = it.isNullable) - } + private fun getTypeName(it: ColumnModel) = + if (it.colType.startsWith("_")) ARRAY.parameterizedBy(typeMapper.getType(it.colType, false)).copy(nullable = it.isNullable) + else typeMapper.getType(it.colType, it.isNullable) + + private fun getTypeName(it: ParamModel) = + if (it.dbType.startsWith("_")) ARRAY.parameterizedBy(typeMapper.getType(it.dbType, false)).copy(nullable = it.isNullable) + else typeMapper.getType(it.dbType, it.isNullable) private fun addStatementsForParams(fb: FunSpec.Builder, params: List) = params.forEachIndexed { i, pm -> - when (pm.paramClassName) { - "java.sql.Array" -> fb.addStatement("ps.setArray(${i + 1}, ps.connection.createArrayOf(\"${pm.dbType.removePrefix("_")}\", params.${pm.name}))") + when { + pm.dbType.startsWith("_") -> fb.addStatement("ps.setArray(${i + 1}, ps.connection.createArrayOf(\"${pm.dbType.removePrefix("_")}\", params.${pm.name}))") else -> fb.addStatement("ps.setObject(${i + 1}, params.${pm.name})") } diff --git a/codegen/src/main/kotlin/norm/DbToKtDefaultTypeMapper.kt b/codegen/src/main/kotlin/norm/DbToKtDefaultTypeMapper.kt index 1eee314..5989dc7 100644 --- a/codegen/src/main/kotlin/norm/DbToKtDefaultTypeMapper.kt +++ b/codegen/src/main/kotlin/norm/DbToKtDefaultTypeMapper.kt @@ -1,15 +1,13 @@ package norm -import com.squareup.kotlinpoet.ARRAY -import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy +import com.squareup.kotlinpoet.ClassName +import com.squareup.kotlinpoet.TypeName +import com.squareup.kotlinpoet.asTypeName import org.postgresql.util.PGobject import java.math.BigDecimal -import kotlin.reflect.KClass - -typealias typeMapper = (String) -> KClass<*> class DbToKtDefaultTypeMapper { - fun getType(colType: String): KClass<*> { + fun getType(colType: String, nullable: Boolean): TypeName { return when (colType.toLowerCase()) { "int4" -> Int::class "int" -> Int::class @@ -34,6 +32,6 @@ class DbToKtDefaultTypeMapper { "_varchar" -> String::class "_int4" -> Int::class else -> String::class - } + }.asTypeName().copy(nullable = nullable) } } diff --git a/codegen/src/main/kotlin/norm/SqlAnalyzer.kt b/codegen/src/main/kotlin/norm/SqlAnalyzer.kt index 8493b3e..3eb60e8 100644 --- a/codegen/src/main/kotlin/norm/SqlAnalyzer.kt +++ b/codegen/src/main/kotlin/norm/SqlAnalyzer.kt @@ -19,8 +19,7 @@ data class ColumnModel( data class ParamModel( val name: String, val dbType: String, - val isNullable: Boolean, - val paramClassName:String + val isNullable: Boolean ) data class SqlModel( @@ -46,8 +45,7 @@ class SqlAnalyzer(private val connection: Connection) { ParamModel( paramNames[it - 1].substring(1), parameterMetaData.getParameterTypeName(it), // db type - parameterMetaData.isNullable(it) != ParameterMetaData.parameterNoNulls, - parameterMetaData.getParameterClassName(it) + parameterMetaData.isNullable(it) != ParameterMetaData.parameterNoNulls ) } diff --git a/codegen/src/test/kotlin/norm/CodeGeneratorTest.kt b/codegen/src/test/kotlin/norm/CodeGeneratorTest.kt index c251bfa..51fac5a 100644 --- a/codegen/src/test/kotlin/norm/CodeGeneratorTest.kt +++ b/codegen/src/test/kotlin/norm/CodeGeneratorTest.kt @@ -1,7 +1,10 @@ package norm import io.kotlintest.matchers.string.shouldContain +import io.kotlintest.matchers.string.shouldNotContain +import io.kotlintest.shouldBe import io.kotlintest.specs.StringSpec +import org.apache.commons.io.FileUtils import org.junit.ClassRule import org.postgresql.ds.PGSimpleDataSource @@ -21,15 +24,9 @@ class CodeGeneratorTest : StringSpec() { "Query class generator" { dataSource.connection.use { - val generatedFileContent = codegen(it, "select * from employees where first_name = :name order by :field", "com.foo", "Foo") - - generatedFileContent shouldContain "data class FooResult(" - generatedFileContent shouldContain "data class FooParams(" - generatedFileContent shouldContain "class FooParamSetter : ParamSetter {" - generatedFileContent shouldContain "class FooRowMapper : RowMapper {" - generatedFileContent shouldContain "class FooQuery : Query {" - - println(generatedFileContent) + val expectedFileContent = FileUtils.getFile( "src", "test", "resources", "generated/employee-query").readText().trimIndent() + val generatedFileContent = codegen(it, "select * from employees where first_name = :name order by :field", "com.foo", "Foo").trimIndent() + generatedFileContent shouldBe expectedFileContent } } @@ -77,5 +74,32 @@ class CodeGeneratorTest : StringSpec() { } + "should support jsonb type along with array"{ + + dataSource.connection.use { + val generatedFileContent = codegen(it, "insert into owners(colors,details) VALUES(:colors,:details)", "com.foo", "Foo") + generatedFileContent shouldContain "data class FooParams(" + generatedFileContent shouldContain " val colors: Array?" + generatedFileContent shouldContain " val details: PGobject?" + + generatedFileContent shouldContain "class FooParamSetter : ParamSetter {" + generatedFileContent shouldContain " override fun map(ps: PreparedStatement, params: FooParams) {" + generatedFileContent shouldContain " ps.setArray(1, ps.connection.createArrayOf(\"varchar\", params.colors))" + generatedFileContent shouldContain " ps.setObject(2, params.details)" + generatedFileContent shouldContain " }" + + println(generatedFileContent) + } + } + + "should generate empty params class if inputs params are not present" { + dataSource.connection.use { + val generatedFileContent = codegen(it, "select * from employees", "com.foo", "Foo") + generatedFileContent shouldNotContain "data class FooParams" + generatedFileContent shouldContain "class FooParams" + + println(generatedFileContent) + } + } } } diff --git a/codegen/src/test/resources/generated/employee-query b/codegen/src/test/resources/generated/employee-query new file mode 100644 index 0000000..b703c2b --- /dev/null +++ b/codegen/src/test/resources/generated/employee-query @@ -0,0 +1,42 @@ +package com.foo + +import java.sql.PreparedStatement +import java.sql.ResultSet +import kotlin.Int +import kotlin.String +import norm.ParamSetter +import norm.Query +import norm.RowMapper + +data class FooParams( + val name: String?, + val field: String? +) + +class FooParamSetter : ParamSetter { + override fun map(ps: PreparedStatement, params: FooParams) { + ps.setObject(1, params.name) + ps.setObject(2, params.field) + } +} + +data class FooResult( + val id: Int, + val firstName: String?, + val lastName: String? +) + +class FooRowMapper : RowMapper { + override fun map(rs: ResultSet): FooResult = FooResult( + id = rs.getObject("id") as kotlin.Int, + firstName = rs.getObject("first_name") as kotlin.String?, + lastName = rs.getObject("last_name") as kotlin.String?) +} + +class FooQuery : Query { + override val sql: String = "select * from employees where first_name = ? order by ?" + + override val mapper: RowMapper = FooRowMapper() + + override val paramSetter: ParamSetter = FooParamSetter() +} diff --git a/codegen/src/test/resources/init_postgres.sql b/codegen/src/test/resources/init_postgres.sql index 40ce520..f3a186a 100644 --- a/codegen/src/test/resources/init_postgres.sql +++ b/codegen/src/test/resources/init_postgres.sql @@ -12,3 +12,9 @@ CREATE TABLE departments ( CREATE TABLE combinations( id serial PRIMARY KEY, colors varchar[]); + +CREATE TABLE owners( +id serial PRIMARY KEY, +colors varchar[], +details jsonb +) diff --git a/runtime/src/main/kotlin/norm/SqlExtensions.kt b/runtime/src/main/kotlin/norm/SqlExtensions.kt index 03d19a5..0be07b6 100644 --- a/runtime/src/main/kotlin/norm/SqlExtensions.kt +++ b/runtime/src/main/kotlin/norm/SqlExtensions.kt @@ -35,12 +35,12 @@ fun ResultSet.toTable(columnNames: List = getColumnNames()): List = listOf()): PreparedStatement = +fun PreparedStatement.withParams(params: List = listOf()): PreparedStatement = this.also { self -> params.forEachIndexed { index, param -> self.setObject(index + 1, param) } } -fun PreparedStatement.withBatches(batchedParams: List> = listOf()) = +fun PreparedStatement.withBatches(batchedParams: List> = listOf()) = this.also { ps -> batchedParams.forEach { params -> ps.withParams(params).addBatch() @@ -52,7 +52,7 @@ fun Connection.executeCommand(sql: String, params: List = listOf()): Int = .withParams(params) .use { it.executeUpdate() } // auto-close ps -fun Connection.batchExecuteCommand(sql: String, batchedParams: List> = listOf()): List = +fun Connection.batchExecuteCommand(sql: String, batchedParams: List> = listOf()): List = this.prepareStatement(sql) .withBatches(batchedParams) .use { it.executeBatch() }