diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 3dfafd20c176f..4dc0a84e1e589 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.python.{FlatMapGroupsInPandasWithStateExec, TransformWithStateInPandasExec} import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter, StateSchemaCompatibilityChecker, StateSchemaMetadata} +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaMetadataKey, StateSchemaMetadataValue} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -261,6 +261,7 @@ class IncrementalExecution( case tws: TransformWithStateExec => val stateSchemaMetadata = createStateSchemaMetadata(stateSchemaMapping.head) val ssmBc = sparkSession.sparkContext.broadcast(stateSchemaMetadata) + logStateSchemaMetadata(stateSchemaMetadata) tws.copy(stateSchemaMetadata = Some(ssmBc)) case _ => ssw } @@ -269,6 +270,14 @@ class IncrementalExecution( } } + def logStateSchemaMetadata(metadata: StateSchemaMetadata): Unit = { + metadata.schemas.foreach { case (schemaId, keyValueMap) => + keyValueMap.foreach { case (key, value) => + logError(s"### Key: $key, Value: $value") + } + } + } + private def createStateSchemaMetadata( stateSchemaMapping: Map[Int, String] ): StateSchemaMetadata = { @@ -277,7 +286,10 @@ class IncrementalExecution( val inStream = fm.open(new Path(stateSchemaPath)) val colFamilySchemas = StateSchemaCompatibilityChecker.readSchemaFile(inStream).map { schema => - schema.colFamilyName -> SchemaConverters.toAvroType(schema.valueSchema) + StateSchemaMetadataKey( + stateSchemaId, schema.colFamilyName, runId.toString) -> + StateSchemaMetadataValue( + SchemaConverters.toAvroType(schema.valueSchema), schema.valueSchema) }.toMap stateSchemaId -> colFamilySchemas } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 4cb2c7dd02f4e..88a31c883ab3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -482,12 +482,14 @@ case class TransformWithStateExec( } case None => None } + List(StateSchemaCompatibilityChecker. validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newSchemas.values.toList, session.sessionState, stateSchemaVersion, storeName = StateStoreId.DEFAULT_STORE_NAME, oldSchemaFilePath = oldStateSchemaFilePath, - newSchemaFilePath = Some(newStateSchemaFilePath))) + newSchemaFilePath = Some(newStateSchemaFilePath), + usingAvro = session.sessionState.conf.stateStoreEncodingFormat == "avro")) } override def stateSchemaMapping( @@ -612,7 +614,8 @@ case class TransformWithStateExec( NoPrefixKeyStateEncoderSpec(keyEncoder.schema), session.sessionState, Some(session.streams.stateStoreCoordinator), - useColumnFamilies = true + useColumnFamilies = true, + stateSchemaMetadata = stateSchemaMetadata ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) => processData(store, singleIterator) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 2fdd8e51bbd6a..7f0d1fd712549 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.execution.streaming.StateStoreColumnFamilySchemaUtils -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION, VIRTUAL_COL_FAMILY_PREFIX_BYTES} +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION, STATE_ROW_SCHEMA_ID_PREFIX_BYTES, VIRTUAL_COL_FAMILY_PREFIX_BYTES} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -51,16 +51,24 @@ sealed trait RocksDBValueStateEncoder { def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow] } -abstract class RocksDBValueStateEncoderWithProvider( +abstract class RocksDBStateSchema( provider: RocksDBStateStoreProvider, - colFamilyName: String -) extends RocksDBValueStateEncoder { - def getSchemaFromId(schemaId: Short): Schema = { - null + colFamilyName: String) { + def getSchemaFromId(schemaId: Int): StateSchemaMetadataValue = { + provider.getSchemaFromId(colFamilyName, schemaId) } - def getCurrentSchemaId: Short = { - 0 + def getCurrentSchemaId: Int = { + provider.getCurrentSchemaId + } + + protected def encodeSchemaId(numBytes: Int, schemaId: Int): (Array[Byte], Int) = { + val encodedBytes = new Array[Byte](numBytes + STATE_ROW_SCHEMA_ID_PREFIX_BYTES) + var offset = Platform.BYTE_ARRAY_OFFSET + Platform.putInt(encodedBytes, Platform.BYTE_ARRAY_OFFSET, schemaId) + offset = Platform.BYTE_ARRAY_OFFSET + STATE_ROW_SCHEMA_ID_PREFIX_BYTES + + (encodedBytes, offset) } } @@ -220,6 +228,27 @@ object RocksDBStateEncoder extends Logging { valueProj.apply(internalRow) } + /** + * This method takes a byte array written using Avro encoding, and + * deserializes to an UnsafeRow using the Avro deserializer + */ + def decodeFromAvroToUnsafeRow( + valueBytes: Array[Byte], + avroDeserializer: AvroDeserializer, + currentAvroType: Schema, + valueAvroType: Schema, + valueProj: UnsafeProjection): UnsafeRow = { + val reader = new GenericDatumReader[Any](valueAvroType, currentAvroType) + val decoder = DecoderFactory.get().binaryDecoder(valueBytes, 0, valueBytes.length, null) + // bytes -> Avro.GenericDataRecord + val genericData = reader.read(null, decoder) + // Avro.GenericDataRecord -> InternalRow + val internalRow = avroDeserializer.deserialize( + genericData).orNull.asInstanceOf[InternalRow] + // InternalRow -> UnsafeRow + valueProj.apply(internalRow) + } + def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow = { if (bytes != null) { // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform. @@ -1161,7 +1190,8 @@ class MultiValuedStateEncoder( colFamilyName: String, valueSchema: StructType, avroEnc: Option[AvroEncoder] = None) - extends RocksDBValueStateEncoderWithProvider(provider, colFamilyName) with Logging { + extends RocksDBStateSchema(provider, colFamilyName) with RocksDBValueStateEncoder + with Logging { import RocksDBStateEncoder._ @@ -1259,7 +1289,8 @@ class SingleValueStateEncoder( colFamilyName: String, valueSchema: StructType, avroEnc: Option[AvroEncoder] = None) - extends RocksDBValueStateEncoderWithProvider(provider, colFamilyName) with Logging { + extends RocksDBStateSchema(provider, colFamilyName) with RocksDBValueStateEncoder + with Logging { import RocksDBStateEncoder._ @@ -1271,11 +1302,22 @@ class SingleValueStateEncoder( private val valueProj = UnsafeProjection.create(valueSchema) override def encodeValue(row: UnsafeRow): Array[Byte] = { - if (usingAvroEncoding) { + val valueBytes = if (usingAvroEncoding) { encodeUnsafeRowToAvro(row, avroEnc.get.valueSerializer, valueAvroType, out) } else { encodeUnsafeRow(row) } + + // Create new array with space for schema ID + val (schemaVersionedBytes, offset) = encodeSchemaId(valueBytes.length, getCurrentSchemaId) + + // Copy value bytes after schema ID + Platform.copyMemory( + valueBytes, Platform.BYTE_ARRAY_OFFSET, + schemaVersionedBytes, offset, + valueBytes.length) + + schemaVersionedBytes } /** @@ -1285,14 +1327,28 @@ class SingleValueStateEncoder( * the given byte array. */ override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = { - if (valueBytes == null) { - return null - } + if (valueBytes == null) return null + + // Get schema ID from first 4 bytes + val schemaId = Platform.getInt(valueBytes, Platform.BYTE_ARRAY_OFFSET) + val schemaMetadataValue = getSchemaFromId(schemaId) + val projection = UnsafeProjection.create(schemaMetadataValue.valueSchema) + // Get actual value bytes after schema ID + val actualValueBytes = new Array[Byte](valueBytes.length - STATE_ROW_SCHEMA_ID_PREFIX_BYTES) + Platform.copyMemory( + valueBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ROW_SCHEMA_ID_PREFIX_BYTES, + actualValueBytes, Platform.BYTE_ARRAY_OFFSET, + actualValueBytes.length) + logError(s"### schemaId: ${schemaId}") + logError(s"### schemaMetadataValue: ${schemaMetadataValue}") + if (usingAvroEncoding) { decodeFromAvroToUnsafeRow( - valueBytes, avroEnc.get.valueDeserializer, valueAvroType, valueProj) + actualValueBytes, avroEnc.get.valueDeserializer, + valueAvroType, + schemaMetadataValue.avroSchema, valueProj) } else { - decodeToUnsafeRow(valueBytes, valueRow) + decodeToUnsafeRow(actualValueBytes, valueRow) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 3c3f4b1274b6c..bb801e22b9186 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -378,6 +378,7 @@ private[sql] class RocksDBStateStoreProvider this.hadoopConf = hadoopConf this.useColumnFamilies = useColumnFamilies this.stateStoreEncoding = storeConf.stateStoreEncodingFormat + this.stateSchemaBroadcast = stateSchemaMetadata if (useMultipleValuesPerKey) { require(useColumnFamilies, "Multiple values per key support requires column families to be" + @@ -482,6 +483,30 @@ private[sql] class RocksDBStateStoreProvider @volatile private var hadoopConf: Configuration = _ @volatile private var useColumnFamilies: Boolean = _ @volatile private var stateStoreEncoding: String = _ + @volatile private var stateSchemaBroadcast: Option[Broadcast[StateSchemaMetadata]] = _ + + def getSchemaFromId(colFamilyName: String, schemaId: Int): StateSchemaMetadataValue = { + if (stateSchemaBroadcast.isDefined) { + val metadata = stateSchemaBroadcast.get.value + metadata.schemas(schemaId)( + StateSchemaMetadataKey( + schemaId, + colFamilyName, + hadoopConf.get(StreamExecution.RUN_ID_KEY) + ) + ) + } else { + null + } + } + + def getCurrentSchemaId: Int = { + if (stateSchemaBroadcast.isDefined) { + stateSchemaBroadcast.get.value.currentSchemaId + } else { + -1 + } + } private[sql] lazy val rocksDB = { val dfsRootDir = stateStoreId.storeCheckpointLocation().toString @@ -616,6 +641,7 @@ object RocksDBStateStoreProvider { val STATE_ENCODING_NUM_VERSION_BYTES = 1 val STATE_ENCODING_VERSION: Byte = 0 val VIRTUAL_COL_FAMILY_PREFIX_BYTES = 2 + val STATE_ROW_SCHEMA_ID_PREFIX_BYTES = 4 private val MAX_AVRO_ENCODERS_IN_CACHE = 1000 // Add the cache at companion object level so it persists across provider instances diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 7c765c1c29261..739ee9504222f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql.execution.streaming.state +import scala.jdk.CollectionConverters.IterableHasAsJava import scala.util.Try +import org.apache.avro.SchemaValidatorBuilder import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FSDataInputStream, Path} import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.internal.{Logging, LogKeys, MDC} -import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer} +import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StatefulOperatorStateInfo} import org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaReader, SchemaWriter} @@ -38,21 +40,8 @@ case class StateSchemaValidationResult( schemaPath: String ) -/** - * An Avro-based encoder used for serializing between UnsafeRow and Avro - * byte arrays in RocksDB state stores. - * - * This encoder is primarily utilized by [[RocksDBStateStoreProvider]] and [[RocksDBStateEncoder]] - * to handle serialization and deserialization of state store data. - * - * @param keySerializer Serializer for converting state store keys to Avro format - * @param keyDeserializer Deserializer for converting Avro-encoded keys back to UnsafeRow - * @param valueSerializer Serializer for converting state store values to Avro format - * @param valueDeserializer Deserializer for converting Avro-encoded values back to UnsafeRow - * @param suffixKeySerializer Optional serializer for handling suffix keys in Avro format - * @param suffixKeyDeserializer Optional deserializer for converting Avro-encoded suffix - * keys back to UnsafeRow - */ +// Avro encoder that is used by the RocksDBStateStoreProvider and RocksDBStateEncoder +// in order to serialize from UnsafeRow to a byte array of Avro encoding. case class AvroEncoder( keySerializer: AvroSerializer, keyDeserializer: AvroDeserializer, @@ -153,7 +142,8 @@ class StateSchemaCompatibilityChecker( private def check( oldSchema: StateStoreColFamilySchema, newSchema: StateStoreColFamilySchema, - ignoreValueSchema: Boolean) : Unit = { + ignoreValueSchema: Boolean, + usingAvro: Boolean) : Boolean = { val (storedKeySchema, storedValueSchema) = (oldSchema.keySchema, oldSchema.valueSchema) val (keySchema, valueSchema) = (newSchema.keySchema, newSchema.valueSchema) @@ -161,14 +151,28 @@ class StateSchemaCompatibilityChecker( if (storedKeySchema.equals(keySchema) && (ignoreValueSchema || storedValueSchema.equals(valueSchema))) { // schema is exactly same + false } else if (!schemasCompatible(storedKeySchema, keySchema)) { throw StateStoreErrors.stateStoreKeySchemaNotCompatible(storedKeySchema.toString, keySchema.toString) + } else if (!ignoreValueSchema && usingAvro) { + // By this point, we know that old value schema is not equal to new value schema + val oldAvroSchema = SchemaConverters.toAvroType(storedValueSchema) + val newAvroSchema = SchemaConverters.toAvroType(valueSchema) + + val validator = new SchemaValidatorBuilder().canReadStrategy.validateAll() + // This will throw a SchemaValidation exception if the schema has evolved in an + // unacceptable way. + validator.validate(newAvroSchema, Iterable(oldAvroSchema).asJava) + // If no exception is thrown, then we know that the schema evolved in an + // acceptable way + true } else if (!ignoreValueSchema && !schemasCompatible(storedValueSchema, valueSchema)) { throw StateStoreErrors.stateStoreValueSchemaNotCompatible(storedValueSchema.toString, valueSchema.toString) } else { logInfo("Detected schema change which is compatible. Allowing to put rows.") + true } } @@ -182,7 +186,8 @@ class StateSchemaCompatibilityChecker( def validateAndMaybeEvolveStateSchema( newStateSchema: List[StateStoreColFamilySchema], ignoreValueSchema: Boolean, - stateSchemaVersion: Int): Boolean = { + stateSchemaVersion: Int, + usingAvro: Boolean): Boolean = { val existingStateSchemaList = getExistingKeyAndValueSchema() val newStateSchemaList = newStateSchema @@ -197,18 +202,18 @@ class StateSchemaCompatibilityChecker( }.toMap // For each new state variable, we want to compare it to the old state variable // schema with the same name - newStateSchemaList.foreach { newSchema => - existingSchemaMap.get(newSchema.colFamilyName).foreach { existingStateSchema => - check(existingStateSchema, newSchema, ignoreValueSchema) - } + val hasEvolvedSchema = newStateSchemaList.exists { newSchema => + existingSchemaMap.get(newSchema.colFamilyName) + .exists(existingSchema => check(existingSchema, newSchema, ignoreValueSchema, usingAvro)) } val colFamiliesAddedOrRemoved = (newStateSchemaList.map(_.colFamilyName).toSet != existingSchemaMap.keySet) - if (stateSchemaVersion == SCHEMA_FORMAT_V3 && colFamiliesAddedOrRemoved) { + val newSchemaFileWritten = hasEvolvedSchema || colFamiliesAddedOrRemoved + if (stateSchemaVersion == SCHEMA_FORMAT_V3 && newSchemaFileWritten) { createSchemaFile(newStateSchemaList.sortBy(_.colFamilyName), stateSchemaVersion) } // TODO: [SPARK-49535] Write Schema files after schema has changed for StateSchemaV3 - colFamiliesAddedOrRemoved + newSchemaFileWritten } } @@ -270,7 +275,8 @@ object StateSchemaCompatibilityChecker extends Logging { extraOptions: Map[String, String] = Map.empty, storeName: String = StateStoreId.DEFAULT_STORE_NAME, oldSchemaFilePath: Option[Path] = None, - newSchemaFilePath: Option[Path] = None): StateSchemaValidationResult = { + newSchemaFilePath: Option[Path] = None, + usingAvro: Boolean = false): StateSchemaValidationResult = { // SPARK-47776: collation introduces the concept of binary (in)equality, which means // in some collation we no longer be able to just compare the binary format of two // UnsafeRows to determine equality. For example, 'aaa' and 'AAA' can be "semantically" @@ -301,7 +307,7 @@ object StateSchemaCompatibilityChecker extends Logging { val result = Try( checker.validateAndMaybeEvolveStateSchema(newStateSchema, ignoreValueSchema = !storeConf.formatValidationCheckValue, - stateSchemaVersion = stateSchemaVersion) + stateSchemaVersion = stateSchemaVersion, usingAvro) ).toEither.fold(Some(_), hasEvolvedSchema => { evolvedSchema = hasEvolvedSchema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index a8197b2f179ec..8dff1b6f2a63d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -55,9 +55,20 @@ object StateStoreEncoding { case object Avro extends StateStoreEncoding } +case class StateSchemaMetadataKey( + schemaId: Int, + colFamilyName: String, + queryRunId: String +) + +case class StateSchemaMetadataValue( + avroSchema: Schema, + valueSchema: StructType +) + case class StateSchemaMetadata( currentSchemaId: Int, - schemas: Map[Int, Map[String, Schema]] + schemas: Map[Int, Map[StateSchemaMetadataKey, StateSchemaMetadataValue]] ) /** @@ -821,7 +832,8 @@ object StateStore extends Logging { storeProviderId, StateStoreProvider.createAndInit( storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, - useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey) + useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey, + stateSchemaMetadata) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index b85093bc1fd23..1e7eef0932567 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -133,7 +133,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion, uniqueId.map(_.apply(partition.index).head), useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, - useMultipleValuesPerKey) + useMultipleValuesPerKey, stateSchemaMetadata) storeUpdateFunction(store, inputIter) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 6c950cd234963..55680acca4368 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -89,7 +89,8 @@ package object state { storeCoordinator, useColumnFamilies, extraOptions, - useMultipleValuesPerKey) + useMultipleValuesPerKey, + stateSchemaMetadata) } // scalastyle:on diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 91a47645f4179..bcab0bcbcf54d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -571,6 +571,62 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + test("transformWithState - upcasting should succeed") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { chkptDir => + val dirPath = chkptDir.getCanonicalPath + val inputData = MemoryStream[String] + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorInt(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + Execute { q => + assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) + assert(q.lastProgress.stateOperators(0).customMetrics.get("numRegisteredTimers") == 0) + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + }, + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + StopStream, + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + StopStream, + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) + }, + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckNewAnswer(("a", "1"), ("c", "1")) + ) + + logError(s"### starting query 2") + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "2")), + StopStream + ) + } + } + } + test("transformWithState - streaming with rocksdb and processing time timer " + "should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->