From 64383796aa5c253468ad085c36d35c903477f25c Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 5 Jan 2024 15:15:26 -0800 Subject: [PATCH] conforming to api change --- .../streaming/TransformWithStateSuite.scala | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 03649488f630c..f777d1daae019 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.sql.Timestamp + import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SaveMode} @@ -140,7 +142,7 @@ class RunningCountStatefulProcessorWithAddRemoveProcTimeTimer // Class to verify stateful processor usage with adding event time timers class RunningCountStatefulProcessorWithEventTimeTimer - extends StatefulProcessor[String, (String, java.sql.Timestamp), (String, String)] { + extends StatefulProcessor[String, (String, Timestamp), (String, String)] { @transient var _countState: ValueState[Long] = _ @transient var _processorHandle: StatefulProcessorHandle = _ @@ -155,19 +157,17 @@ class RunningCountStatefulProcessorWithEventTimeTimer _countState = _processorHandle.getValueState[Long]("countState") } - override def close(): Unit = {} - - override def handleInputRows( - key: String, - inputRows: Iterator[(String, java.sql.Timestamp)], - timerValues: TimerValues): Iterator[(String, String)] = { + override def handleInputRow( + key: String, + inputRow: (String, Timestamp), + timerValues: TimerValues): Iterator[(String, String)] = { val currCount = _countState.getOption().getOrElse(0L) if (currCount == 0 && (key == "a" || key == "c")) { _processorHandle.registerEventTimeTimer(timerValues.getCurrentWatermarkInMs() + 5000) } - val count = currCount + inputRows.size + val count = currCount + 1 if (count == 3) { _countState.remove() Iterator.empty @@ -177,6 +177,7 @@ class RunningCountStatefulProcessorWithEventTimeTimer } } + override def close(): Unit = {} override def handleEventTimeTimers( key: String, expiryTimestampMs: Long, @@ -198,12 +199,20 @@ class RunningCountStatefulProcessorWithAddRemoveEventTimeTimer _timerState = _processorHandle.getValueState[Long]("timerState") } - override def handleInputRows( - key: String, - inputRows: Iterator[(String, java.sql.Timestamp)], - timerValues: TimerValues): Iterator[(String, String)] = { + override def handleEventTimeTimers( + key: String, + expiryTimestampMs: Long, + timerValues: TimerValues): Iterator[(String, String)] = { + _timerState.remove() + Iterator((key, "-1")) + } + + override def handleInputRow( + key: String, + inputRow: (String, Timestamp), + timerValues: TimerValues): Iterator[(String, String)] = { val currCount = _countState.getOption().getOrElse(0L) - val count = currCount + inputRows.size + val count = currCount + 1 _countState.update(count) if (key == "a") { var nextTimerTs: Long = 0L @@ -220,14 +229,6 @@ class RunningCountStatefulProcessorWithAddRemoveEventTimeTimer } Iterator((key, count.toString)) } - - override def handleEventTimeTimers( - key: String, - expiryTimestampMs: Long, - timerValues: TimerValues): Iterator[(String, String)] = { - _timerState.remove() - Iterator((key, "-1")) - } } // Class to verify incorrect usage of stateful processor @@ -383,7 +384,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { - import java.sql.Timestamp val inputData = MemoryStream[(String, Timestamp)] val result = inputData.toDS() .select($"_1".as("value"), $"_2".as("eventTime")) @@ -422,7 +422,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest classOf[RocksDBStateStoreProvider].getName) { val clock = new StreamManualClock - import java.sql.Timestamp val inputData = MemoryStream[(String, Timestamp)] val result = inputData.toDS() .select($"_1".as("value"), $"_2".as("eventTime"))