Skip to content

Commit

Permalink
conforming to api change
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Jan 5, 2024
1 parent dc8380c commit 6438379
Showing 1 changed file with 22 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.streaming

import java.sql.Timestamp

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SaveMode}
Expand Down Expand Up @@ -140,7 +142,7 @@ class RunningCountStatefulProcessorWithAddRemoveProcTimeTimer

// Class to verify stateful processor usage with adding event time timers
class RunningCountStatefulProcessorWithEventTimeTimer
extends StatefulProcessor[String, (String, java.sql.Timestamp), (String, String)] {
extends StatefulProcessor[String, (String, Timestamp), (String, String)] {

@transient var _countState: ValueState[Long] = _
@transient var _processorHandle: StatefulProcessorHandle = _
Expand All @@ -155,19 +157,17 @@ class RunningCountStatefulProcessorWithEventTimeTimer
_countState = _processorHandle.getValueState[Long]("countState")
}

override def close(): Unit = {}

override def handleInputRows(
key: String,
inputRows: Iterator[(String, java.sql.Timestamp)],
timerValues: TimerValues): Iterator[(String, String)] = {
override def handleInputRow(
key: String,
inputRow: (String, Timestamp),
timerValues: TimerValues): Iterator[(String, String)] = {
val currCount = _countState.getOption().getOrElse(0L)
if (currCount == 0 && (key == "a" || key == "c")) {
_processorHandle.registerEventTimeTimer(timerValues.getCurrentWatermarkInMs()
+ 5000)
}

val count = currCount + inputRows.size
val count = currCount + 1
if (count == 3) {
_countState.remove()
Iterator.empty
Expand All @@ -177,6 +177,7 @@ class RunningCountStatefulProcessorWithEventTimeTimer
}
}

override def close(): Unit = {}
override def handleEventTimeTimers(
key: String,
expiryTimestampMs: Long,
Expand All @@ -198,12 +199,20 @@ class RunningCountStatefulProcessorWithAddRemoveEventTimeTimer
_timerState = _processorHandle.getValueState[Long]("timerState")
}

override def handleInputRows(
key: String,
inputRows: Iterator[(String, java.sql.Timestamp)],
timerValues: TimerValues): Iterator[(String, String)] = {
override def handleEventTimeTimers(
key: String,
expiryTimestampMs: Long,
timerValues: TimerValues): Iterator[(String, String)] = {
_timerState.remove()
Iterator((key, "-1"))
}

override def handleInputRow(
key: String,
inputRow: (String, Timestamp),
timerValues: TimerValues): Iterator[(String, String)] = {
val currCount = _countState.getOption().getOrElse(0L)
val count = currCount + inputRows.size
val count = currCount + 1
_countState.update(count)
if (key == "a") {
var nextTimerTs: Long = 0L
Expand All @@ -220,14 +229,6 @@ class RunningCountStatefulProcessorWithAddRemoveEventTimeTimer
}
Iterator((key, count.toString))
}

override def handleEventTimeTimers(
key: String,
expiryTimestampMs: Long,
timerValues: TimerValues): Iterator[(String, String)] = {
_timerState.remove()
Iterator((key, "-1"))
}
}

// Class to verify incorrect usage of stateful processor
Expand Down Expand Up @@ -383,7 +384,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

import java.sql.Timestamp
val inputData = MemoryStream[(String, Timestamp)]
val result = inputData.toDS()
.select($"_1".as("value"), $"_2".as("eventTime"))
Expand Down Expand Up @@ -422,7 +422,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest
classOf[RocksDBStateStoreProvider].getName) {
val clock = new StreamManualClock

import java.sql.Timestamp
val inputData = MemoryStream[(String, Timestamp)]
val result = inputData.toDS()
.select($"_1".as("value"), $"_2".as("eventTime"))
Expand Down

0 comments on commit 6438379

Please sign in to comment.