diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index d30b475863551..d1aee2f1e4363 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -217,7 +217,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging // Read the schema file path from operator metadata version v2 onwards // for the transformWithState operator - val oldSchemaFilePath = if (storeMetadata.length > 0 && storeMetadata.head.version == 2 + val oldSchemaFilePaths = if (storeMetadata.length > 0 && storeMetadata.head.version == 2 && twsShortNameSeq.exists(storeMetadata.head.operatorName.contains)) { val storeMetadataEntry = storeMetadata.head val operatorProperties = TransformWithStateOperatorProperties.fromJson( @@ -241,12 +241,10 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging schemaFilePaths ) stateSchemaProvider = Some(new InMemoryStateSchemaProvider(stateSchemaMetadata)) - schemaFilePaths.lastOption.map { schemaFilePath => - new Path(schemaFilePath) - } + schemaFilePaths.map(new Path(_)) } else { None - } + }.toList try { // Read the actual state schema from the provided path for v2 or from the dedicated path @@ -257,7 +255,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging partitionId, sourceOptions.storeName) val providerId = new StateStoreProviderId(storeId, UUID.randomUUID()) val manager = new StateSchemaCompatibilityChecker(providerId, hadoopConf, - oldSchemaFilePath = oldSchemaFilePath) + oldSchemaFilePaths = oldSchemaFilePaths) val stateSchema = manager.readSchemaFile() // Based on the version and read schema, populate the keyStateEncoderSpec used for diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala index 291bbd8e1263f..f99093a7d0a5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala @@ -222,24 +222,21 @@ trait TransformWithStateMetadataUtils extends Logging { None } - val oldStateSchemaFilePath: Option[Path] = operatorStateMetadata match { + val oldStateSchemaFilePaths: List[Path] = operatorStateMetadata match { case Some(metadata) => metadata match { case v2: OperatorStateMetadataV2 => - // We pick the last entry in the schema list because it contains the most recent - // StateStoreColFamilySchemas - val schemaPath = v2.stateStoreInfo.head.stateSchemaFilePaths.last - Some(new Path(schemaPath)) - case _ => None + v2.stateStoreInfo.head.stateSchemaFilePaths.map(new Path(_)) + case _ => List.empty } - case None => None + case None => List.empty } // state schema file written here, writing the new schema list we passed here List(StateSchemaCompatibilityChecker. validateAndMaybeEvolveStateSchema(info, hadoopConf, newSchemas.values.toList, session.sessionState, stateSchemaVersion, storeName = StateStoreId.DEFAULT_STORE_NAME, - oldSchemaFilePath = oldStateSchemaFilePath, + oldSchemaFilePaths = oldStateSchemaFilePaths, newSchemaFilePath = Some(newStateSchemaFilePath), schemaEvolutionEnabled = stateStoreEncodingFormat == "avro")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index be4f49331de54..5533f749d32dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, Stateful import org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaReader, SchemaWriter} import org.apache.spark.sql.execution.streaming.state.StateSchemaCompatibilityChecker.SCHEMA_FORMAT_V3 import org.apache.spark.sql.internal.SessionState -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types._ // Result returned after validating the schema of the state store for schema changes case class StateSchemaValidationResult( @@ -78,37 +78,39 @@ case class StateStoreColFamilySchema( class StateSchemaCompatibilityChecker( providerId: StateStoreProviderId, hadoopConf: Configuration, - oldSchemaFilePath: Option[Path] = None, + oldSchemaFilePaths: List[Path] = List.empty, newSchemaFilePath: Option[Path] = None) extends Logging { - private val schemaFileLocation = if (oldSchemaFilePath.isEmpty) { + // For OperatorStateMetadataV1: Only one schema file present per operator + // per query + // For OperatorStateMetadataV2: Multiple schema files present per operator + // per query. This variable is the latest one + private val schemaFileLocation = if (oldSchemaFilePaths.isEmpty) { val storeCpLocation = providerId.storeId.storeCheckpointLocation() schemaFile(storeCpLocation) } else { - oldSchemaFilePath.get + oldSchemaFilePaths.last } private val fm = CheckpointFileManager.create(schemaFileLocation, hadoopConf) fm.mkdirs(schemaFileLocation.getParent) + // Read most recent schema file def readSchemaFile(): List[StateStoreColFamilySchema] = { val inStream = fm.open(schemaFileLocation) StateSchemaCompatibilityChecker.readSchemaFile(inStream) } - /** - * Function to read and return the list of existing state store column family schemas from the - * schema file, if it exists - * @return - List of state store column family schemas if the schema file exists and empty l - * otherwise - */ - private def getExistingKeyAndValueSchema(): List[StateStoreColFamilySchema] = { - if (fm.exists(schemaFileLocation)) { - readSchemaFile() - } else { - List.empty - } + // Read all old schema files, group by column family name + // This method is used for OperatorStateMetadataV2 when schema evolution + // is supported, to read all active schemas in the StateStore for this operator + def readSchemaFiles(): Map[String, List[StateStoreColFamilySchema]] = { + oldSchemaFilePaths.flatMap { schemaFile => + val inStream = fm.open(schemaFile) + StateSchemaCompatibilityChecker.readSchemaFile(inStream) + } + .groupBy(_.colFamilyName) } private def createSchemaFile( @@ -155,40 +157,45 @@ class StateSchemaCompatibilityChecker( * @param ignoreValueSchema - whether to ignore value schema or not */ private def check( - oldSchema: StateStoreColFamilySchema, + oldSchemas: List[StateStoreColFamilySchema], newSchema: StateStoreColFamilySchema, ignoreValueSchema: Boolean, - schemaEvolutionEnabled: Boolean): StateStoreColFamilySchema = { + schemaEvolutionEnabled: Boolean): (StateStoreColFamilySchema, Boolean) = { def incrementSchemaId(id: Short): Short = (id + 1).toShort + val mostRecentSchema = oldSchemas.last // Initialize with old schema IDs - var resultSchema = newSchema.copy( - keySchemaId = oldSchema.keySchemaId, - valueSchemaId = oldSchema.valueSchemaId + val resultSchema = newSchema.copy( + keySchemaId = mostRecentSchema.keySchemaId, + valueSchemaId = mostRecentSchema.valueSchemaId ) - val (storedKeySchema, storedValueSchema) = (oldSchema.keySchema, - oldSchema.valueSchema) + val (storedKeySchema, storedValueSchema) = (mostRecentSchema.keySchema, + mostRecentSchema.valueSchema) val (keySchema, valueSchema) = (newSchema.keySchema, newSchema.valueSchema) if (storedKeySchema.equals(keySchema) && (ignoreValueSchema || storedValueSchema.equals(valueSchema))) { // schema is exactly same - oldSchema + (mostRecentSchema, false) } else if (!schemasCompatible(storedKeySchema, keySchema)) { throw StateStoreErrors.stateStoreKeySchemaNotCompatible(storedKeySchema.toString, keySchema.toString) } else if (!ignoreValueSchema && schemaEvolutionEnabled) { + // Check value schema evolution - val oldAvroSchema = SchemaConverters.toAvroTypeWithDefaults(storedValueSchema) + // Sort schemas by most recent to least recent + val oldAvroSchemas = oldSchemas.sortBy(_.valueSchemaId).reverse.map { oldSchema => + SchemaConverters.toAvroTypeWithDefaults(oldSchema.valueSchema) + }.asJava val newAvroSchema = SchemaConverters.toAvroTypeWithDefaults(valueSchema) val validator = new SchemaValidatorBuilder().canReadStrategy.validateAll() try { - validator.validate(newAvroSchema, Iterable(oldAvroSchema).asJava) + validator.validate(newAvroSchema, oldAvroSchemas) } catch { case _: SchemaValidationException => - StateStoreErrors.stateStoreInvalidValueSchemaEvolution( + throw StateStoreErrors.stateStoreInvalidValueSchemaEvolution( storedValueSchema.toString, valueSchema.toString) case e: Throwable => @@ -196,13 +203,13 @@ class StateSchemaCompatibilityChecker( } // Schema evolved - increment value schema ID - resultSchema.copy(valueSchemaId = incrementSchemaId(oldSchema.valueSchemaId)) + (resultSchema.copy(valueSchemaId = incrementSchemaId(mostRecentSchema.valueSchemaId)), true) } else if (!ignoreValueSchema && !schemasCompatible(storedValueSchema, valueSchema)) { throw StateStoreErrors.stateStoreValueSchemaNotCompatible(storedValueSchema.toString, valueSchema.toString) } else { logInfo("Detected schema change which is compatible. Allowing to put rows.") - oldSchema + (mostRecentSchema, true) } } @@ -218,9 +225,9 @@ class StateSchemaCompatibilityChecker( ignoreValueSchema: Boolean, stateSchemaVersion: Int, schemaEvolutionEnabled: Boolean): Boolean = { - val existingStateSchemaList = getExistingKeyAndValueSchema() + val existingStateSchemaMap = readSchemaFiles() - if (existingStateSchemaList.isEmpty) { + if (existingStateSchemaMap.isEmpty) { // Initialize schemas with ID 0 when no existing schema val initializedSchemas = newStateSchema.map(schema => schema.copy(keySchemaId = 0, valueSchemaId = 0) @@ -228,19 +235,15 @@ class StateSchemaCompatibilityChecker( createSchemaFile(initializedSchemas.sortBy(_.colFamilyName), stateSchemaVersion) true } else { - val existingSchemaMap = existingStateSchemaList.map(schema => - schema.colFamilyName -> schema - ).toMap // Process each new schema and track if any have evolved val (evolvedSchemas, hasEvolutions) = newStateSchema.foldLeft( (List.empty[StateStoreColFamilySchema], false)) { case ((schemas, evolved), newSchema) => - existingSchemaMap.get(newSchema.colFamilyName) match { - case Some(existingSchema) => - val updatedSchema = check( - existingSchema, newSchema, ignoreValueSchema, schemaEvolutionEnabled) - val hasEvolved = !updatedSchema.equals(existingSchema) + existingStateSchemaMap.get(newSchema.colFamilyName) match { + case Some(existingSchemas) => + val (updatedSchema, hasEvolved) = check( + existingSchemas, newSchema, ignoreValueSchema, schemaEvolutionEnabled) (updatedSchema :: schemas, evolved || hasEvolved) case None => // New column family - initialize with schema ID 0 @@ -250,7 +253,7 @@ class StateSchemaCompatibilityChecker( } val colFamiliesAddedOrRemoved = - (newStateSchema.map(_.colFamilyName).toSet != existingSchemaMap.keySet) + (newStateSchema.map(_.colFamilyName).toSet != existingStateSchemaMap.keySet) val newSchemaFileWritten = hasEvolutions || colFamiliesAddedOrRemoved if (stateSchemaVersion == SCHEMA_FORMAT_V3 && newSchemaFileWritten) { @@ -319,7 +322,7 @@ object StateSchemaCompatibilityChecker extends Logging { stateSchemaVersion: Int, extraOptions: Map[String, String] = Map.empty, storeName: String = StateStoreId.DEFAULT_STORE_NAME, - oldSchemaFilePath: Option[Path] = None, + oldSchemaFilePaths: List[Path] = List.empty, newSchemaFilePath: Option[Path] = None, schemaEvolutionEnabled: Boolean = false): StateSchemaValidationResult = { // SPARK-47776: collation introduces the concept of binary (in)equality, which means @@ -339,7 +342,7 @@ object StateSchemaCompatibilityChecker extends Logging { val providerId = StateStoreProviderId(StateStoreId(stateInfo.checkpointLocation, stateInfo.operatorId, 0, storeName), stateInfo.queryRunId) val checker = new StateSchemaCompatibilityChecker(providerId, hadoopConf, - oldSchemaFilePath = oldSchemaFilePath, newSchemaFilePath = newSchemaFilePath) + oldSchemaFilePaths = oldSchemaFilePaths, newSchemaFilePath = newSchemaFilePath) // regardless of configuration, we check compatibility to at least write schema file // if necessary // if the format validation for value schema is disabled, we also disable the schema @@ -380,7 +383,7 @@ object StateSchemaCompatibilityChecker extends Logging { // so we would just populate the next run's metadata file with this // file path if (stateSchemaVersion == SCHEMA_FORMAT_V3) { - oldSchemaFilePath.get.toString + oldSchemaFilePaths.last.toString } else { // if we are using any version less than v3, we have written // the schema to this static location, which we will return diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index a35188b689717..80ac28b439b78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -320,8 +320,6 @@ class StateStoreKeySchemaNotCompatible( "storedKeySchema" -> storedKeySchema, "newKeySchema" -> newKeySchema)) -trait StateStoreInvalidValueSchema extends Throwable - class StateStoreValueSchemaNotCompatible( storedValueSchema: String, newValueSchema: String) @@ -329,7 +327,7 @@ class StateStoreValueSchemaNotCompatible( errorClass = "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE", messageParameters = Map( "storedValueSchema" -> storedValueSchema, - "newValueSchema" -> newValueSchema)) with StateStoreInvalidValueSchema + "newValueSchema" -> newValueSchema)) class StateStoreInvalidValueSchemaEvolution( storedValueSchema: String, @@ -338,7 +336,7 @@ class StateStoreInvalidValueSchemaEvolution( errorClass = "STATE_STORE_INVALID_VALUE_SCHEMA_EVOLUTION", messageParameters = Map( "storedValueSchema" -> storedValueSchema, - "newValueSchema" -> newValueSchema)) with StateStoreInvalidValueSchema + "newValueSchema" -> newValueSchema)) class StateStoreSnapshotFileNotFound(fileToRead: String, clazz: String) extends SparkRuntimeException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 3cbee419f0e5d..f15d16e179444 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -123,6 +123,20 @@ trait AlsoTestWithEncodingTypes extends SQLTestUtils { } } } + + // New method for Avro-only tests + protected def testWithAvroOnly(testName: String, testTags: Tag*)(testBody: => Any) + (implicit pos: Position): Unit = { + super.test(s"$testName (encoding = avro)", testTags: _*) { + withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> "avro") { + testBody + } + } + } + + protected def getCurrentEncoding(): String = { + spark.conf.get(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key) + } } trait AlsoTestWithRocksDBFeatures diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala index 36d6888ff850b..000c1b7a225ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala @@ -275,7 +275,7 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val schemaFilePath = Some(new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")) val checker = new StateSchemaCompatibilityChecker(providerId, hadoopConf, - oldSchemaFilePath = schemaFilePath, + oldSchemaFilePaths = schemaFilePath.toList, newSchemaFilePath = schemaFilePath) checker.createSchemaFile(storeColFamilySchema, SchemaHelper.SchemaWriter.createSchemaWriter(stateSchemaVersion)) @@ -397,7 +397,7 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val result = Try( StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, oldStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, - oldSchemaFilePath = schemaFilePath, + oldSchemaFilePaths = schemaFilePath.toList, newSchemaFilePath = newSchemaFilePath, extraOptions = extraOptions) ).toEither.fold(Some(_), _ => None) @@ -412,9 +412,9 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, newStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, extraOptions = extraOptions, - oldSchemaFilePath = stateSchemaVersion match { - case 3 => newSchemaFilePath - case _ => None + oldSchemaFilePaths = stateSchemaVersion match { + case 3 => newSchemaFilePath.toList + case _ => List.empty }, newSchemaFilePath = getNewSchemaPath(stateSchemaDir, stateSchemaVersion)) } @@ -463,7 +463,7 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, oldKeySchema))) StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, oldStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, - oldSchemaFilePath = schemaFilePath, + oldSchemaFilePaths = schemaFilePath.toList, newSchemaFilePath = getNewSchemaPath(stateSchemaDir, stateSchemaVersion), extraOptions = extraOptions) @@ -472,7 +472,7 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, newKeySchema))) StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, newStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, - oldSchemaFilePath = schemaFilePath, + oldSchemaFilePaths = schemaFilePath.toList, newSchemaFilePath = getNewSchemaPath(stateSchemaDir, stateSchemaVersion), extraOptions = extraOptions) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 91e420c1b14b9..cac686eccfb91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -33,11 +33,11 @@ import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.streaming.state.{StateStoreInvalidValueSchemaEvolution, _} import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.types._ object TransformWithStateSuiteUtils { val NUM_SHUFFLE_PARTITIONS = 5 @@ -854,7 +854,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - upcasting should succeed") { + testWithAvroOnly("transformWithState - upcasting should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -911,7 +911,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - reordering fields should succeed") { + testWithAvroOnly("transformWithState - reordering fields should succeed") { withSQLConf( SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -952,7 +952,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - adding field should succeed") { + testWithAvroOnly("transformWithState - adding field should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1007,7 +1007,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - rename field") { + testWithAvroOnly("transformWithState - add and remove field between runs") { withSQLConf( SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -1049,7 +1049,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - verify default values during schema evolution") { + testWithAvroOnly("transformWithState - verify default values during schema evolution") { withSQLConf( SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -1109,7 +1109,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - removing field should succeed") { + testWithAvroOnly("transformWithState - removing field should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -1140,7 +1140,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest testStream(result1, OutputMode.Update())( StartStream(checkpointLocation = dirPath), AddData(inputData, "a"), - CheckNewAnswer(("a", "2")), + CheckNewAnswer(("a", "1")), StopStream ) } @@ -1614,6 +1614,22 @@ class TransformWithStateSuite extends StateStoreMetricsTest Option(userKeySchema) ) + val schema3 = StateStoreColFamilySchema( + "$rowCounter_listState", 0, + keySchema, 0, + new StructType().add("count", LongType, false), + Some(NoPrefixKeyStateEncoderSpec(keySchema)), + None + ) + + val schema4 = StateStoreColFamilySchema( + "default", 0, + keySchema, 0, + new StructType().add("value", BinaryType), + Some(NoPrefixKeyStateEncoderSpec(keySchema)), + None + ) + val inputData = MemoryStream[String] val result = inputData.toDS() .groupByKey(x => x) @@ -1631,7 +1647,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest val providerId = StateStoreProviderId(StateStoreId( checkpointDir.getCanonicalPath, 0, 0), q.lastProgress.runId) val checker = new StateSchemaCompatibilityChecker(providerId, - hadoopConf, Some(schemaFilePath)) + hadoopConf, List(schemaFilePath)) val colFamilySeq = checker.readSchemaFile() assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == @@ -1641,9 +1657,9 @@ class TransformWithStateSuite extends StateStoreMetricsTest assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == q.lastProgress.stateOperators.head.customMetrics.get("numMapStateVars").toInt) - assert(colFamilySeq.length == 3) + assert(colFamilySeq.length == 5) assert(colFamilySeq.map(_.toString).toSet == Set( - schema0, schema1, schema2 + schema0, schema1, schema2, schema3, schema4 ).map(_.toString)) }, StopStream @@ -1721,10 +1737,11 @@ class TransformWithStateSuite extends StateStoreMetricsTest .transformWithState(new RunningCountStatefulProcessorInt(), TimeMode.None(), OutputMode.Update()) + testStream(result2, OutputMode.Update())( StartStream(checkpointLocation = checkpointDir.getCanonicalPath), AddData(inputData, "a"), - ExpectFailure[StateStoreInvalidValueSchema] { + ExpectFailure[StateStoreInvalidValueSchemaEvolution] { (t: Throwable) => { assert(t.getMessage.contains("Please check number and type of fields.")) } @@ -2137,7 +2154,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest // and we only need to keep metadata files for batches 2, 3, and the since schema // hasn't changed between batches 2, 3, we only keep the schema file for batch 2 assert(getFiles(metadataPath).length == 2) - assert(getFiles(stateSchemaPath).length == 1) + assert(getFiles(stateSchemaPath).length == 2) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 4d3cdadbe8356..57dfbe6c8d4fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -360,7 +360,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { val providerId = StateStoreProviderId(StateStoreId( checkpointDir.getCanonicalPath, 0, 0), q.lastProgress.runId) val checker = new StateSchemaCompatibilityChecker(providerId, - hadoopConf, Some(schemaFilePath)) + hadoopConf, List(schemaFilePath)) val colFamilySeq = checker.readSchemaFile() assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS ==