From 9149607484c91d92484daa9d49a57ca6c4bacf15 Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Mon, 18 Mar 2024 13:11:43 +0100 Subject: [PATCH] disable name hack by default again, added JCP case for auto-applying the expression encoder without spark-connect --- buildSrc/src/main/kotlin/Versions.kt | 7 +- .../jetbrains/kotlinx/spark/api/Encoding.kt | 49 +++--- .../kotlinx/spark/api/EncodingTest.kt | 141 +++++++++++++++++- 3 files changed, 173 insertions(+), 24 deletions(-) diff --git a/buildSrc/src/main/kotlin/Versions.kt b/buildSrc/src/main/kotlin/Versions.kt index 02de97c6..dd10547a 100644 --- a/buildSrc/src/main/kotlin/Versions.kt +++ b/buildSrc/src/main/kotlin/Versions.kt @@ -9,6 +9,10 @@ object Versions { inline val scala get() = System.getProperty("scala") as String inline val sparkMinor get() = spark.substringBeforeLast('.') inline val scalaCompat get() = scala.substringBeforeLast('.') + + // TODO + const val sparkConnect = false + const val jupyter = "0.12.0-32-1" const val kotest = "5.5.4" @@ -25,7 +29,7 @@ object Versions { const val jacksonDatabind = "2.13.4.2" const val kotlinxDateTime = "0.6.0-RC.2" - inline val versionMap + inline val versionMap: Map get() = mapOf( "kotlin" to kotlin, "scala" to scala, @@ -33,6 +37,7 @@ object Versions { "spark" to spark, "sparkMinor" to sparkMinor, "version" to project, + "sparkConnect" to sparkConnect.toString(), ) } diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt index a123c0c0..dd28c18f 100644 --- a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.SerializerBuildHelper import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.types.DataType @@ -69,14 +68,15 @@ fun kotlinEncoderFor( arguments: List = emptyList(), nullable: Boolean = false, annotations: List = emptyList() -): Encoder = ExpressionEncoder.apply( - KotlinTypeInference.encoderFor( - kClass = kClass, - arguments = arguments, - nullable = nullable, - annotations = annotations, +): Encoder = + applyEncoder( + KotlinTypeInference.encoderFor( + kClass = kClass, + arguments = arguments, + nullable = nullable, + annotations = annotations, + ) ) -) /** * Main method of API, which gives you seamless integration with Spark: @@ -88,15 +88,26 @@ fun kotlinEncoderFor( * @return generated encoder */ inline fun kotlinEncoderFor(): Encoder = - ExpressionEncoder.apply( - KotlinTypeInference.encoderFor() + kotlinEncoderFor( + typeOf() ) fun kotlinEncoderFor(kType: KType): Encoder = - ExpressionEncoder.apply( + applyEncoder( KotlinTypeInference.encoderFor(kType) ) +/** + * For spark-connect, no ExpressionEncoder is needed, so we can just return the AgnosticEncoder. + */ +private fun applyEncoder(agnosticEncoder: AgnosticEncoder): Encoder { + //#if sparkConnect == false + return org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.apply(agnosticEncoder) + //#else + //$return agnosticEncoder + //#endif +} + @Deprecated("Use kotlinEncoderFor instead", ReplaceWith("kotlinEncoderFor()")) inline fun encoder(): Encoder = kotlinEncoderFor(typeOf()) @@ -112,7 +123,7 @@ object KotlinTypeInference { // TODO this hack is a WIP and can give errors // TODO it's to make data classes get column names like "age" with functions like "getAge" // TODO instead of column names like "getAge" - var DO_NAME_HACK = true + var DO_NAME_HACK = false /** * @param kClass the class for which to infer the encoder. @@ -151,7 +162,6 @@ object KotlinTypeInference { currentType = kType, seenTypeSet = emptySet(), typeVariables = emptyMap(), - isTopLevel = true, ) as AgnosticEncoder @@ -218,7 +228,6 @@ object KotlinTypeInference { // how the generic types of the data class (like T, S) are filled in for this instance of the class typeVariables: Map, - isTopLevel: Boolean = false, ): AgnosticEncoder<*> { val kClass = currentType.classifier as? KClass<*> ?: throw IllegalArgumentException("Unsupported type $currentType") @@ -328,7 +337,7 @@ object KotlinTypeInference { AgnosticEncoders.UDTEncoder(udt, udt.javaClass) } - currentType.isSubtypeOf>() -> { + currentType.isSubtypeOf?>() -> { val elementEncoder = encoderFor( currentType = tArguments.first().type!!, seenTypeSet = seenTypeSet, @@ -506,7 +515,6 @@ object KotlinTypeInference { DirtyProductEncoderField( doNameHack = DO_NAME_HACK, - isTopLevel = isTopLevel, columnName = paramName, readMethodName = readMethodName, writeMethodName = writeMethodName, @@ -525,7 +533,7 @@ object KotlinTypeInference { if (currentType in seenTypeSet) throw IllegalStateException("Circular reference detected for type $currentType") val constructorParams = currentType.getScalaConstructorParameters(typeVariables, kClass) - val params: List = constructorParams.map { (paramName, paramType) -> + val params = constructorParams.map { (paramName, paramType) -> val encoder = encoderFor( currentType = paramType, seenTypeSet = seenTypeSet + currentType, @@ -564,7 +572,6 @@ internal open class DirtyProductEncoderField( private val readMethodName: String, // the name of the method used to read the value private val writeMethodName: String?, private val doNameHack: Boolean, - private val isTopLevel: Boolean, encoder: AgnosticEncoder<*>, nullable: Boolean, metadata: Metadata = Metadata.empty(), @@ -577,7 +584,7 @@ internal open class DirtyProductEncoderField( /* writeMethod = */ writeMethodName.toOption(), ), Serializable { - private var isFirstNameCall = true + private var noNameCalls = 0 /** * This dirty trick only works because in [SerializerBuildHelper], [ProductEncoder] @@ -585,10 +592,10 @@ internal open class DirtyProductEncoderField( * the name of the column. This way, we can alternate between the two names. */ override fun name(): String = - if (doNameHack && !isFirstNameCall) { + if (doNameHack && noNameCalls > 0) { columnName } else { - isFirstNameCall = false + noNameCalls++ readMethodName } diff --git a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt index 151bca14..ad390a1a 100644 --- a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt +++ b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt @@ -22,9 +22,9 @@ package org.jetbrains.kotlinx.spark.api import ch.tutteli.atrium.api.fluent.en_GB.* import ch.tutteli.atrium.api.verbs.expect import io.kotest.core.spec.style.ShouldSpec -import io.kotest.matchers.collections.shouldContain import io.kotest.matchers.collections.shouldContainExactly import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.Decimal import org.apache.spark.unsafe.types.CalendarInterval @@ -210,7 +210,7 @@ class EncodingTest : ShouldSpec({ context("schema") { withSpark(props = mapOf("spark.sql.codegen.comments" to true)) { - context("Give proper names to columns of data classe") { + context("Give proper names to columns of data classes") { val old = KotlinTypeInference.DO_NAME_HACK KotlinTypeInference.DO_NAME_HACK = true @@ -240,6 +240,142 @@ class EncodingTest : ShouldSpec({ dataset.collectAsList() shouldBe pairs } + should("Be able to serialize pairs of pairs of pairs") { + val pairs = listOf( + 1 to (1 to (1 to "1")), + 2 to (2 to (2 to "2")), + 3 to (3 to (3 to "3")), + ) + val dataset = pairs.toDS() + dataset.show() + dataset.printSchema() + dataset.columns().shouldContainExactly("first", "second") + dataset.select("second.*").columns().shouldContainExactly("first", "second") + dataset.select("second.second.*").columns().shouldContainExactly("first", "second") + dataset.collectAsList() shouldBe pairs + } + + should("Be able to serialize lists of pairs") { + val pairs = listOf( + listOf(1 to "1", 2 to "2"), + listOf(3 to "3", 4 to "4"), + ) + val dataset = pairs.toDS() + dataset.show() + dataset.printSchema() + dataset.schema().toString().let { + it shouldContain "first" + it shouldContain "second" + } + dataset.collectAsList() shouldBe pairs + } + + should("Be able to serialize lists of lists of pairs") { + val pairs = listOf( + listOf( + listOf(1 to "1", 2 to "2"), + listOf(3 to "3", 4 to "4") + ) + ) + val dataset = pairs.toDS() + dataset.show() + dataset.printSchema() + dataset.schema().toString().let { + it shouldContain "first" + it shouldContain "second" + } + dataset.collectAsList() shouldBe pairs + } + + should("Be able to serialize lists of lists of lists of pairs") { + val pairs = listOf( + listOf( + listOf( + listOf(1 to "1", 2 to "2"), + listOf(3 to "3", 4 to "4"), + ) + ) + ) + val dataset = pairs.toDS() + dataset.show() + dataset.printSchema() + dataset.schema().toString().let { + it shouldContain "first" + it shouldContain "second" + } + dataset.collectAsList() shouldBe pairs + } + + should("Be able to serialize lists of lists of lists of pairs of pairs") { + val pairs = listOf( + listOf( + listOf( + listOf(1 to ("1" to 3.0), 2 to ("2" to 3.0)), + listOf(3 to ("3" to 3.0), 4 to ("4" to 3.0)), + ) + ) + ) + val dataset = pairs.toDS() + dataset.show() + dataset.printSchema() + dataset.schema().toString().let { + it shouldContain "first" + it shouldContain "second" + } + dataset.collectAsList() shouldBe pairs + } + + should("Be able to serialize arrays of pairs") { + val pairs = arrayOf( + arrayOf(1 to "1", 2 to "2"), + arrayOf(3 to "3", 4 to "4"), + ) + val dataset = pairs.toDS() + dataset.show() + dataset.printSchema() + dataset.schema().toString().let { + it shouldContain "first" + it shouldContain "second" + } + dataset.collectAsList() shouldBe pairs + } + + should("Be able to serialize arrays of arrays of pairs") { + val pairs = arrayOf( + arrayOf( + arrayOf(1 to "1", 2 to "2"), + arrayOf(3 to "3", 4 to "4") + ) + ) + val dataset = pairs.toDS() + dataset.show() + dataset.printSchema() + dataset.schema().toString().let { + it shouldContain "first" + it shouldContain "second" + } + dataset.collectAsList() shouldBe pairs + } + + should("Be able to serialize arrays of arrays of arrays of pairs") { + val pairs = arrayOf( + arrayOf( + arrayOf( + arrayOf(1 to "1", 2 to "2"), + arrayOf(3 to "3", 4 to "4"), + ) + ) + ) + val dataset = pairs.toDS() + dataset.show() + dataset.printSchema() + dataset.schema().toString().let { + it shouldContain "first" + it shouldContain "second" + } + dataset.collectAsList() shouldBe pairs + } + KotlinTypeInference.DO_NAME_HACK = old } @@ -351,6 +487,7 @@ class EncodingTest : ShouldSpec({ listOf(SomeClass(intArrayOf(1, 2, 3), 4)), listOf(SomeClass(intArrayOf(3, 2, 1), 0)), ) + dataset.printSchema() val (first, second) = dataset.collectAsList()