Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Jan 24, 2024
1 parent 1acd719 commit e825bfa
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ object UnsupportedOperationChecker extends Logging {
case d: DeduplicateWithinWatermark =>
throwError("dropDuplicatesWithinWatermark is not supported with batch " +
"DataFrames/DataSets")(d)

case t: TransformWithState =>
throwError("transformWithState is not supported with batch DataFrames/Datasets")(t)

case _ =>
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ object QueryExecution {
Nil
} else {
Seq(ReuseExchangeAndSubquery)
})
}) ++ () // add new rule here
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.sql.internal.SQLConf
*/
object RemoveRedundantProjects extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
// try calling SparkPartialRule here, add flag to enable/disable
if (!conf.getConf(SQLConf.REMOVE_REDUNDANT_PROJECTS_ENABLED)) {
plan
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
hasInitialState, planLater(initialState), planLater(child)
) :: Nil
case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes,
dataAttributes, statefulProcessor, timeoutMode, outputMode, outputObjAttr, child) =>
TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, valueDeserializer,
groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode,
outputObjAttr, None, planLater(child)) :: Nil

case _: FlatMapGroupsInPandasWithState =>
// TODO(SPARK-40443): support applyInPandasWithState in batch query
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, TransformWithStateExec}
import org.apache.spark.sql.internal.SQLConf

import java.util.concurrent.atomic.AtomicInteger

// Create batch equivalent of StateOpIdRule for streaming queries

object StateOpIdBatchRule extends Rule[SparkPlan] {

private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
.map(SQLConf.SHUFFLE_PARTITIONS.valueConverter)
.getOrElse(sparkSession.sessionState.conf.numShufflePartitions)

/**
* Records the current id for a given stateful operator in the query plan as the `state`
* preparation walks the query plan.
*/
private val statefulOperatorId = new AtomicInteger(0)
private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = {
StatefulOperatorStateInfo(
checkpointLocation,
runId,
statefulOperatorId.getAndIncrement(),
currentBatchId,
numStateStores)
}
override def apply(plan: SparkPlan): SparkPlan = {
case t: TransformWithStateExec =>
// get numShufflePartitions from SQLConf

val numShufflePartitions = plan.session.sessionState.conf.numShufflePartitions

t.copy(
stateInfo = Some(nextStatefulOperationStateInfo()),
batchTimestampMs = t.batchTimestampMs,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ case class TransformWithStateExec(
override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver

// populate stateInfo if this is a streaming query
val stateInfo = getStateInfo
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateInfo,
schemaForKeyRow,
Expand All @@ -171,3 +173,36 @@ case class TransformWithStateExec(
}
}
}

object TransformWithStateExec {

// Plan logical transformWithState for batch queries
def generateSparkPlanForBatchQueries(
keyDeserializer: Expression,
valueDeserializer: Expression,
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
statefulProcessor: StatefulProcessor[Any, Any, Any],
timeoutMode: TimeoutMode,
outputMode: OutputMode,
outputObjAttr: Attribute,
stateInfo: Option[StatefulOperatorStateInfo],
child: SparkPlan): SparkPlan = {

new TransformWithStateExec(
keyDeserializer,
valueDeserializer,
groupingAttributes,
dataAttributes,
statefulProcessor,
timeoutMode,
outputMode,
outputObjAttr,
stateInfo,
Some(System.currentTimeMillis),
None,
None,
child)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -132,25 +132,34 @@ class TransformWithStateSuite extends StateStoreMetricsTest
)
}
}


test("transformWithState - batch should succeed") {
val inputData = Seq("a", "a", "b")
val result = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessor(),
TimeoutMode.NoTimeouts(),
OutputMode.Append())

val df = result.toDF()
checkAnswer(df, Seq(("a", "1"), ("b", "1")).toDF())
}
}

class TransformWithStateValidationSuite extends StateStoreMetricsTest {
import testImplicits._

test("transformWithState - batch should fail") {
val ex = intercept[Exception] {
val df = Seq("a", "a", "b").toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessor,
TimeoutMode.NoTimeouts(),
OutputMode.Append())
.write
.format("noop")
.mode(SaveMode.Append)
.save()
}
assert(ex.isInstanceOf[AnalysisException])
assert(ex.getMessage.contains("not supported"))
test("transformWithState - batch should not fail") {
val _ = Seq("a", "a", "b").toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessor,
TimeoutMode.NoTimeouts(),
OutputMode.Append())
.write
.format("noop")
.mode(SaveMode.Append)
.save()
}

test("transformWithState - streaming with hdfsStateStoreProvider should fail") {
Expand Down

0 comments on commit e825bfa

Please sign in to comment.