Skip to content

Commit

Permalink
[SPARK-46865][SS] Add Batch Support for TransformWithState Operator
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Jan 25, 2024
1 parent 617014c commit 3d56a4a
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ object UnsupportedOperationChecker extends Logging {
throwError("dropDuplicatesWithinWatermark is not supported with batch " +
"DataFrames/DataSets")(d)

case t: TransformWithState =>
throwError("transformWithState is not supported with batch DataFrames/Datasets")(t)

case _ =>
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
batchTimestampMs = None,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
isStreaming = true,
planLater(child))
execPlan :: Nil
case _ =>
Expand Down Expand Up @@ -891,6 +892,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
hasInitialState, planLater(initialState), planLater(child)
) :: Nil
case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes,
dataAttributes, statefulProcessor, timeoutMode, outputMode, outputObjAttr, child) =>
TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, valueDeserializer,
groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode,
outputObjAttr, planLater(child)) :: Nil

case _: FlatMapGroupsInPandasWithState =>
// TODO(SPARK-40443): support applyInPandasWithState in batch query
throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3176")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,28 +68,34 @@ class QueryInfoImpl(
* track of valid transitions as various functions are invoked to track object lifecycle.
* @param store - instance of state store
*/
class StatefulProcessorHandleImpl(store: StateStore, runId: UUID)
class StatefulProcessorHandleImpl(store: StateStore, runId: UUID, isStreaming: Boolean = true)
extends StatefulProcessorHandle with Logging {
import StatefulProcessorHandleState._

private def buildQueryInfo(): QueryInfo = {
val taskCtxOpt = Option(TaskContext.get())
// Task context is not available in tests, so we generate a random query id and batch id here
val queryId = if (taskCtxOpt.isDefined) {
taskCtxOpt.get.getLocalProperty(StreamExecution.QUERY_ID_KEY)
} else {
assert(Utils.isTesting, "Failed to find query id in task context")
UUID.randomUUID().toString
}

val batchId = if (taskCtxOpt.isDefined) {
taskCtxOpt.get.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY).toLong
if (!isStreaming) {
val queryId = "00000000-0000-0000-0000-000000000000"
val batchId = 0L
new QueryInfoImpl(UUID.fromString(queryId), runId, batchId)
} else {
assert(Utils.isTesting, "Failed to find batch id in task context")
0
val taskCtxOpt = Option(TaskContext.get())
// Task context is not available in tests, so we generate a random query id and batch id here
val queryId = if (taskCtxOpt.isDefined) {
taskCtxOpt.get.getLocalProperty(StreamExecution.QUERY_ID_KEY)
} else {
assert(Utils.isTesting, "Failed to find query id in task context")
UUID.randomUUID().toString
}

val batchId = if (taskCtxOpt.isDefined) {
taskCtxOpt.get.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY).toLong
} else {
assert(Utils.isTesting, "Failed to find batch id in task context")
0
}
new QueryInfoImpl(UUID.fromString(queryId), runId, batchId)
}

new QueryInfoImpl(UUID.fromString(queryId), runId, batchId)
}

private lazy val currQueryInfo: QueryInfo = buildQueryInfo()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.execution.streaming

import java.util.UUID
import java.util.concurrent.TimeUnit.NANOSECONDS

import org.apache.spark.rdd.RDD
Expand All @@ -26,7 +27,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor, TimeoutMode}
import org.apache.spark.sql.types._
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.{CompletionIterator, Utils}

/**
* Physical operator for executing `TransformWithState`
Expand Down Expand Up @@ -57,6 +58,7 @@ case class TransformWithStateExec(
batchTimestampMs: Option[Long],
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
isStreaming: Boolean = true,
child: SparkPlan)
extends UnaryExecNode with StateStoreWriter with WatermarkSupport with ObjectProducerExec {

Expand Down Expand Up @@ -152,6 +154,7 @@ case class TransformWithStateExec(
override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver

// populate stateInfo if this is a streaming query
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateInfo,
schemaForKeyRow,
Expand All @@ -162,7 +165,8 @@ case class TransformWithStateExec(
useColumnFamilies = true
) {
case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId)
val processorHandle = new StatefulProcessorHandleImpl(
store, getStateInfo.queryRunId, isStreaming)
assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED)
statefulProcessor.init(processorHandle, outputMode)
processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
Expand All @@ -171,3 +175,44 @@ case class TransformWithStateExec(
}
}
}


object TransformWithStateExec {

// Plan logical transformWithState for batch queries
def generateSparkPlanForBatchQueries(
keyDeserializer: Expression,
valueDeserializer: Expression,
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
statefulProcessor: StatefulProcessor[Any, Any, Any],
timeoutMode: TimeoutMode,
outputMode: OutputMode,
outputObjAttr: Attribute,
child: SparkPlan): SparkPlan = {
val shufflePartitions = child.session.sessionState.conf.numShufflePartitions
val statefulOperatorStateInfo = StatefulOperatorStateInfo(
Utils.createTempDir().getAbsolutePath,
queryRunId = UUID.randomUUID(),
operatorId = 0,
storeVersion = 0,
numPartitions = shufflePartitions
)

new TransformWithStateExec(
keyDeserializer,
valueDeserializer,
groupingAttributes,
dataAttributes,
statefulProcessor,
timeoutMode,
outputMode,
outputObjAttr,
Some(statefulOperatorStateInfo),
Some(System.currentTimeMillis),
None,
None,
isStreaming = false,
child)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,27 +131,23 @@ class TransformWithStateSuite extends StateStoreMetricsTest
)
}
}

test("transformWithState - batch should succeed") {
val inputData = Seq("a", "a", "b")
val result = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessor(),
TimeoutMode.NoTimeouts(),
OutputMode.Append())

val df = result.toDF()
checkAnswer(df, Seq(("a", "1"), ("b", "1")).toDF())
}
}

class TransformWithStateValidationSuite extends StateStoreMetricsTest {
import testImplicits._

test("transformWithState - batch should fail") {
val ex = intercept[Exception] {
val df = Seq("a", "a", "b").toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessor,
TimeoutMode.NoTimeouts(),
OutputMode.Append())
.write
.format("noop")
.mode(SaveMode.Append)
.save()
}
assert(ex.isInstanceOf[AnalysisException])
assert(ex.getMessage.contains("not supported"))
}

test("transformWithState - streaming with hdfsStateStoreProvider should fail") {
val inputData = MemoryStream[String]
val result = inputData.toDS()
Expand Down

0 comments on commit 3d56a4a

Please sign in to comment.