From 4826d126f90d45fa316e05e0c30d0893aba7815c Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 30 Dec 2024 16:35:00 -0800 Subject: [PATCH] fix --- .../spark/sql/avro/SchemaConverters.scala | 27 ++++++++++++++----- .../streaming/TransformWithStateExec.scala | 8 +++++- .../streaming/state/RocksDBStateEncoder.scala | 4 +-- .../state/RocksDBStateStoreProvider.scala | 2 +- 4 files changed, 30 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 8b0a15403c54c..f7a66967ca186 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -300,7 +300,11 @@ object SchemaConverters extends Logging { } } - def getDefaultValue(dataType: DataType): Any = { + /** + * Creates default values for Spark SQL data types when converting to Avro. + * This ensures fields have appropriate defaults during schema evolution. + */ + private def getDefaultValue(dataType: DataType): Any = { def createNestedDefault(st: StructType): java.util.HashMap[String, Any] = { val defaultMap = new java.util.HashMap[String, Any]() st.fields.foreach { field => @@ -310,6 +314,7 @@ object SchemaConverters extends Logging { } dataType match { + // Basic types case BooleanType => false case ByteType | ShortType | IntegerType => 0 case LongType => 0L @@ -317,15 +322,19 @@ object SchemaConverters extends Logging { case DoubleType => 0.0 case StringType => "" case BinaryType => java.nio.ByteBuffer.allocate(0) + + // Complex types case ArrayType(elementType, _) => val defaultArray = new java.util.ArrayList[Any]() - defaultArray.add(getDefaultValue(elementType)) // Add one default element + defaultArray.add(getDefaultValue(elementType)) defaultArray case MapType(StringType, valueType, _) => val defaultMap = new java.util.HashMap[String, Any]() - defaultMap.put("defaultKey", getDefaultValue(valueType)) // Add one default entry + defaultMap.put("defaultKey", getDefaultValue(valueType)) defaultMap - case st: StructType => createNestedDefault(st) // Handle nested structs recursively + case st: StructType => createNestedDefault(st) + + // Special types case _: DecimalType => java.nio.ByteBuffer.allocate(0) case DateType => 0 case TimestampType => 0L @@ -335,6 +344,10 @@ object SchemaConverters extends Logging { } } + /** + * Converts a Spark SQL schema to a corresponding Avro schema. + * Handles nested types and adds support for schema evolution. + */ def toAvroType( catalystType: DataType, nullable: Boolean = false, @@ -377,7 +390,7 @@ object SchemaConverters extends Logging { } val schema = catalystType match { - // Basic types remain the same + // Basic types case BooleanType => builder.booleanType() case ByteType | ShortType | IntegerType => builder.intType() case LongType => builder.longType() @@ -386,7 +399,7 @@ object SchemaConverters extends Logging { case StringType => builder.stringType() case NullType => builder.nullType() - // Date and Timestamp types + // Date/Timestamp types case DateType => LogicalTypes.date().addToSchema(builder.intType()) case TimestampType => @@ -406,7 +419,7 @@ object SchemaConverters extends Logging { case BinaryType => builder.bytesType() - // Complex types with improved nesting handling + // Complex types case ArrayType(elementType, containsNull) => builder.array() .items(toAvroType(elementType, containsNull, recordName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index fc74278fdfce4..c794ec8f52f56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -140,7 +140,13 @@ case class TransformWithStateExec( * after init is called. */ override def getColFamilySchemas(): Map[String, StateStoreColFamilySchema] = { - val columnFamilySchemas = getDriverProcessorHandle().getColumnFamilySchemas + val keySchema = keyExpressions.toStructType + val columnFamilySchemas = getDriverProcessorHandle().getColumnFamilySchemas ++ + Map( + StateStore.DEFAULT_COL_FAMILY_NAME -> + StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, + 0, keyExpressions.toStructType, 0, DUMMY_VALUE_ROW_SCHEMA, + Some(NoPrefixKeyStateEncoderSpec(keySchema)))) closeProcessorHandle() columnFamilySchemas } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index e7121260c8d2a..c6c14ba182789 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -608,12 +608,12 @@ class AvroStateEncoder( // schema information - private val currentKeySchemaId: Short = getStateSchemaBroadcast.getCurrentStateSchemaId( + private lazy val currentKeySchemaId: Short = getStateSchemaBroadcast.getCurrentStateSchemaId( getColFamilyName, isKey = true ) - private val currentValSchemaId: Short = getStateSchemaBroadcast.getCurrentStateSchemaId( + private lazy val currentValSchemaId: Short = getStateSchemaBroadcast.getCurrentStateSchemaId( getColFamilyName, isKey = false ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 1e477a6fbbcb2..409486b93c3ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -418,7 +418,7 @@ private[sql] class RocksDBStateStoreProvider } val dataEncoder = getDataEncoder( - "unsaferow", + stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec, valueSchema,