diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 02f0d9ae41fef..91e420c1b14b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -53,6 +53,11 @@ case class ReorderedLongs( value1: Long ) +case class RenamedFields( + value4: Long, + value2: Long +) + // Initial state with basic fields case class BasicState( id: Int, @@ -150,6 +155,33 @@ class RunningCountStatefulProcessorInitialOrder } } +// Evolved processor with renamed field +class RenameEvolvedProcessor extends StatefulProcessor[String, String, (String, String)] { + @transient protected var _countState: ValueState[RenamedFields] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + _countState = getHandle.getValueState[RenamedFields]( + "countState", + Encoders.product[RenamedFields], + TTLConfig.NONE) + } + + override def handleInputRows( + key: String, + rows: Iterator[String], + timerValues: TimerValues): Iterator[(String, String)] = { + val count = _countState.getOption().getOrElse(RenamedFields(0L, -1L)).value4 + 1 + if (count == 3) { + _countState.clear() + Iterator.empty + } else { + _countState.update(RenamedFields(count, -1L)) + Iterator((key, count.toString)) + } + } +} + + class RunningCountStatefulProcessorReorderedFields extends StatefulProcessor[String, String, (String, String)] with Logging { @transient protected var _countState: ValueState[ReorderedLongs] = _ @@ -975,6 +1007,48 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + test("transformWithState - rename field") { + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + withTempDir { dir => + val inputData = MemoryStream[String] + + // First run with original field names + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorInitialOrder(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = dir.getCanonicalPath), + AddData(inputData, "test1"), + CheckNewAnswer(("test1", "1")), + StopStream + ) + + // Second run with renamed field + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RenameEvolvedProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = dir.getCanonicalPath), + // Uses default value, does not factor previous value1 into this + AddData(inputData, "test1"), + CheckNewAnswer(("test1", "1")), + // Verify we can write state with new field name + AddData(inputData, "test2"), + CheckNewAnswer(("test2", "1")), + StopStream + ) + } + } + } + test("transformWithState - verify default values during schema evolution") { withSQLConf( SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName,