Skip to content

Commit

Permalink
disable name hack by default again, added JCP case for auto-applying …
Browse files Browse the repository at this point in the history
…the expression encoder without spark-connect
  • Loading branch information
Jolanrensen committed Mar 18, 2024
1 parent 0c8f4b1 commit 9149607
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 24 deletions.
7 changes: 6 additions & 1 deletion buildSrc/src/main/kotlin/Versions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ object Versions {
inline val scala get() = System.getProperty("scala") as String
inline val sparkMinor get() = spark.substringBeforeLast('.')
inline val scalaCompat get() = scala.substringBeforeLast('.')

// TODO
const val sparkConnect = false

const val jupyter = "0.12.0-32-1"

const val kotest = "5.5.4"
Expand All @@ -25,14 +29,15 @@ object Versions {
const val jacksonDatabind = "2.13.4.2"
const val kotlinxDateTime = "0.6.0-RC.2"

inline val versionMap
inline val versionMap: Map<String, String>
get() = mapOf(
"kotlin" to kotlin,
"scala" to scala,
"scalaCompat" to scalaCompat,
"spark" to spark,
"sparkMinor" to sparkMinor,
"version" to project,
"sparkConnect" to sparkConnect.toString(),
)

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.SerializerBuildHelper
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.types.DataType
Expand Down Expand Up @@ -69,14 +68,15 @@ fun <T : Any> kotlinEncoderFor(
arguments: List<KTypeProjection> = emptyList(),
nullable: Boolean = false,
annotations: List<Annotation> = emptyList()
): Encoder<T> = ExpressionEncoder.apply(
KotlinTypeInference.encoderFor(
kClass = kClass,
arguments = arguments,
nullable = nullable,
annotations = annotations,
): Encoder<T> =
applyEncoder(
KotlinTypeInference.encoderFor(
kClass = kClass,
arguments = arguments,
nullable = nullable,
annotations = annotations,
)
)
)

/**
* Main method of API, which gives you seamless integration with Spark:
Expand All @@ -88,15 +88,26 @@ fun <T : Any> kotlinEncoderFor(
* @return generated encoder
*/
inline fun <reified T> kotlinEncoderFor(): Encoder<T> =
ExpressionEncoder.apply(
KotlinTypeInference.encoderFor<T>()
kotlinEncoderFor(
typeOf<T>()
)

fun <T> kotlinEncoderFor(kType: KType): Encoder<T> =
ExpressionEncoder.apply(
applyEncoder(
KotlinTypeInference.encoderFor(kType)
)

/**
* For spark-connect, no ExpressionEncoder is needed, so we can just return the AgnosticEncoder.
*/
private fun <T> applyEncoder(agnosticEncoder: AgnosticEncoder<T>): Encoder<T> {
//#if sparkConnect == false
return org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.apply(agnosticEncoder)
//#else
//$return agnosticEncoder
//#endif
}


@Deprecated("Use kotlinEncoderFor instead", ReplaceWith("kotlinEncoderFor<T>()"))
inline fun <reified T> encoder(): Encoder<T> = kotlinEncoderFor(typeOf<T>())
Expand All @@ -112,7 +123,7 @@ object KotlinTypeInference {
// TODO this hack is a WIP and can give errors
// TODO it's to make data classes get column names like "age" with functions like "getAge"
// TODO instead of column names like "getAge"
var DO_NAME_HACK = true
var DO_NAME_HACK = false

/**
* @param kClass the class for which to infer the encoder.
Expand Down Expand Up @@ -151,7 +162,6 @@ object KotlinTypeInference {
currentType = kType,
seenTypeSet = emptySet(),
typeVariables = emptyMap(),
isTopLevel = true,
) as AgnosticEncoder<T>


Expand Down Expand Up @@ -218,7 +228,6 @@ object KotlinTypeInference {

// how the generic types of the data class (like T, S) are filled in for this instance of the class
typeVariables: Map<String, KType>,
isTopLevel: Boolean = false,
): AgnosticEncoder<*> {
val kClass =
currentType.classifier as? KClass<*> ?: throw IllegalArgumentException("Unsupported type $currentType")
Expand Down Expand Up @@ -328,7 +337,7 @@ object KotlinTypeInference {
AgnosticEncoders.UDTEncoder(udt, udt.javaClass)
}

currentType.isSubtypeOf<scala.Option<*>>() -> {
currentType.isSubtypeOf<scala.Option<*>?>() -> {
val elementEncoder = encoderFor(
currentType = tArguments.first().type!!,
seenTypeSet = seenTypeSet,
Expand Down Expand Up @@ -506,7 +515,6 @@ object KotlinTypeInference {

DirtyProductEncoderField(
doNameHack = DO_NAME_HACK,
isTopLevel = isTopLevel,
columnName = paramName,
readMethodName = readMethodName,
writeMethodName = writeMethodName,
Expand All @@ -525,7 +533,7 @@ object KotlinTypeInference {
if (currentType in seenTypeSet) throw IllegalStateException("Circular reference detected for type $currentType")
val constructorParams = currentType.getScalaConstructorParameters(typeVariables, kClass)

val params: List<AgnosticEncoders.EncoderField> = constructorParams.map { (paramName, paramType) ->
val params = constructorParams.map { (paramName, paramType) ->
val encoder = encoderFor(
currentType = paramType,
seenTypeSet = seenTypeSet + currentType,
Expand Down Expand Up @@ -564,7 +572,6 @@ internal open class DirtyProductEncoderField(
private val readMethodName: String, // the name of the method used to read the value
private val writeMethodName: String?,
private val doNameHack: Boolean,
private val isTopLevel: Boolean,
encoder: AgnosticEncoder<*>,
nullable: Boolean,
metadata: Metadata = Metadata.empty(),
Expand All @@ -577,18 +584,18 @@ internal open class DirtyProductEncoderField(
/* writeMethod = */ writeMethodName.toOption(),
), Serializable {

private var isFirstNameCall = true
private var noNameCalls = 0

/**
* This dirty trick only works because in [SerializerBuildHelper], [ProductEncoder]
* creates an [Invoke] using [name] first and then calls [name] again to retrieve
* the name of the column. This way, we can alternate between the two names.
*/
override fun name(): String =
if (doNameHack && !isFirstNameCall) {
if (doNameHack && noNameCalls > 0) {
columnName
} else {
isFirstNameCall = false
noNameCalls++
readMethodName
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ package org.jetbrains.kotlinx.spark.api
import ch.tutteli.atrium.api.fluent.en_GB.*
import ch.tutteli.atrium.api.verbs.expect
import io.kotest.core.spec.style.ShouldSpec
import io.kotest.matchers.collections.shouldContain
import io.kotest.matchers.collections.shouldContainExactly
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.CalendarInterval
Expand Down Expand Up @@ -210,7 +210,7 @@ class EncodingTest : ShouldSpec({
context("schema") {
withSpark(props = mapOf("spark.sql.codegen.comments" to true)) {

context("Give proper names to columns of data classe") {
context("Give proper names to columns of data classes") {
val old = KotlinTypeInference.DO_NAME_HACK
KotlinTypeInference.DO_NAME_HACK = true

Expand Down Expand Up @@ -240,6 +240,142 @@ class EncodingTest : ShouldSpec({
dataset.collectAsList() shouldBe pairs
}

should("Be able to serialize pairs of pairs of pairs") {
val pairs = listOf(
1 to (1 to (1 to "1")),
2 to (2 to (2 to "2")),
3 to (3 to (3 to "3")),
)
val dataset = pairs.toDS()
dataset.show()
dataset.printSchema()
dataset.columns().shouldContainExactly("first", "second")
dataset.select("second.*").columns().shouldContainExactly("first", "second")
dataset.select("second.second.*").columns().shouldContainExactly("first", "second")
dataset.collectAsList() shouldBe pairs
}

should("Be able to serialize lists of pairs") {
val pairs = listOf(
listOf(1 to "1", 2 to "2"),
listOf(3 to "3", 4 to "4"),
)
val dataset = pairs.toDS()
dataset.show()
dataset.printSchema()
dataset.schema().toString().let {
it shouldContain "first"
it shouldContain "second"
}
dataset.collectAsList() shouldBe pairs
}

should("Be able to serialize lists of lists of pairs") {
val pairs = listOf(
listOf(
listOf(1 to "1", 2 to "2"),
listOf(3 to "3", 4 to "4")
)
)
val dataset = pairs.toDS()
dataset.show()
dataset.printSchema()
dataset.schema().toString().let {
it shouldContain "first"
it shouldContain "second"
}
dataset.collectAsList() shouldBe pairs
}

should("Be able to serialize lists of lists of lists of pairs") {
val pairs = listOf(
listOf(
listOf(
listOf(1 to "1", 2 to "2"),
listOf(3 to "3", 4 to "4"),
)
)
)
val dataset = pairs.toDS()
dataset.show()
dataset.printSchema()
dataset.schema().toString().let {
it shouldContain "first"
it shouldContain "second"
}
dataset.collectAsList() shouldBe pairs
}

should("Be able to serialize lists of lists of lists of pairs of pairs") {
val pairs = listOf(
listOf(
listOf(
listOf(1 to ("1" to 3.0), 2 to ("2" to 3.0)),
listOf(3 to ("3" to 3.0), 4 to ("4" to 3.0)),
)
)
)
val dataset = pairs.toDS()
dataset.show()
dataset.printSchema()
dataset.schema().toString().let {
it shouldContain "first"
it shouldContain "second"
}
dataset.collectAsList() shouldBe pairs
}

should("Be able to serialize arrays of pairs") {
val pairs = arrayOf(
arrayOf(1 to "1", 2 to "2"),
arrayOf(3 to "3", 4 to "4"),
)
val dataset = pairs.toDS()
dataset.show()
dataset.printSchema()
dataset.schema().toString().let {
it shouldContain "first"
it shouldContain "second"
}
dataset.collectAsList() shouldBe pairs
}

should("Be able to serialize arrays of arrays of pairs") {
val pairs = arrayOf(
arrayOf(
arrayOf(1 to "1", 2 to "2"),
arrayOf(3 to "3", 4 to "4")
)
)
val dataset = pairs.toDS()
dataset.show()
dataset.printSchema()
dataset.schema().toString().let {
it shouldContain "first"
it shouldContain "second"
}
dataset.collectAsList() shouldBe pairs
}

should("Be able to serialize arrays of arrays of arrays of pairs") {
val pairs = arrayOf(
arrayOf(
arrayOf(
arrayOf(1 to "1", 2 to "2"),
arrayOf(3 to "3", 4 to "4"),
)
)
)
val dataset = pairs.toDS()
dataset.show()
dataset.printSchema()
dataset.schema().toString().let {
it shouldContain "first"
it shouldContain "second"
}
dataset.collectAsList() shouldBe pairs
}

KotlinTypeInference.DO_NAME_HACK = old
}

Expand Down Expand Up @@ -351,6 +487,7 @@ class EncodingTest : ShouldSpec({
listOf(SomeClass(intArrayOf(1, 2, 3), 4)),
listOf(SomeClass(intArrayOf(3, 2, 1), 0)),
)
dataset.printSchema()

val (first, second) = dataset.collectAsList()

Expand Down

0 comments on commit 9149607

Please sign in to comment.