Skip to content

Commit

Permalink
reordering test case
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Jan 3, 2025
1 parent fc56474 commit 47f624d
Showing 1 changed file with 106 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,70 @@ object TransformWithStateSuiteUtils {
val NUM_SHUFFLE_PARTITIONS = 5
}

case class TwoLongs(
value1: Long,
value2: Long
)


case class ReorderedLongs(
value2: Long,
value1: Long
)

class RunningCountStatefulProcessorInitialOrder
extends StatefulProcessor[String, String, (String, String)] with Logging {
@transient protected var _countState: ValueState[TwoLongs] = _

override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_countState = getHandle.getValueState[TwoLongs]("countState",
Encoders.product[TwoLongs], TTLConfig.NONE)
}

override def handleInputRows(
key: String,
inputRows: Iterator[String],
timerValues: TimerValues): Iterator[(String, String)] = {
val count = _countState.getOption().getOrElse(TwoLongs(0L, -1L)).value1 + 1
if (count == 3) {
_countState.clear()
Iterator.empty
} else {
_countState.update(TwoLongs(count, -1L))
Iterator((key, count.toString))
}
}
}

class RunningCountStatefulProcessorReorderedFields
extends StatefulProcessor[String, String, (String, String)] with Logging {
@transient protected var _countState: ValueState[ReorderedLongs] = _

override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_countState = getHandle.getValueState[ReorderedLongs]("countState",
Encoders.product[ReorderedLongs], TTLConfig.NONE)
}

override def handleInputRows(
key: String,
inputRows: Iterator[String],
timerValues: TimerValues): Iterator[(String, String)] = {
val count = _countState.getOption().getOrElse(ReorderedLongs(-1L, 0L)).value1 + 1
if (count == 3) {
_countState.clear()
Iterator.empty
} else {
// And update value1 here
_countState.update(ReorderedLongs(-1L, count))
Iterator((key, count.toString))
}
}
}

class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (String, String)]
with Logging {
import implicits._
Expand All @@ -69,11 +133,6 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S
}
}

case class TwoLongs(
value: Long,
value2: Long
)

case class NestedLongs(
value: Long,
value2: TwoLongs
Expand All @@ -94,7 +153,7 @@ class RunningCountStatefulProcessorTwoLongs
key: String,
inputRows: Iterator[String],
timerValues: TimerValues): Iterator[(String, String)] = {
val count = _countState.getOption().getOrElse(TwoLongs(0L, 0L)).value + 1
val count = _countState.getOption().getOrElse(TwoLongs(0L, 0L)).value1 + 1
if (count == 3) {
_countState.clear()
Iterator.empty
Expand Down Expand Up @@ -750,6 +809,47 @@ class TransformWithStateSuite extends StateStoreMetricsTest
}
}

test("transformWithState - reordering fields should succeed") {
withSQLConf(
SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
withTempDir { chkptDir =>
val dirPath = chkptDir.getCanonicalPath
val inputData = MemoryStream[String]

// First run with initial field order
val result1 = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessorInitialOrder(),
TimeMode.None(),
OutputMode.Update())

testStream(result1, OutputMode.Update())(
StartStream(checkpointLocation = dirPath),
AddData(inputData, "a"),
CheckNewAnswer(("a", "1")),
StopStream
)

// Second run with reordered fields
val result2 = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessorReorderedFields(),
TimeMode.None(),
OutputMode.Update())

testStream(result2, OutputMode.Update())(
StartStream(checkpointLocation = dirPath),
AddData(inputData, "a"),
CheckNewAnswer(("a", "2")), // Should continue counting from previous state
StopStream
)
}
}
}

test("transformWithState - adding field should succeed") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
Expand Down

0 comments on commit 47f624d

Please sign in to comment.