diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 15a856b273edf..32aa55301a17f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -42,10 +42,6 @@ object UnsupportedOperationChecker extends Logging { case d: DeduplicateWithinWatermark => throwError("dropDuplicatesWithinWatermark is not supported with batch " + "DataFrames/DataSets")(d) - - case t: TransformWithState => - throwError("transformWithState is not supported with batch DataFrames/Datasets")(t) - case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index a2cfad800e006..7218380e9a3b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -476,7 +476,7 @@ object QueryExecution { Nil } else { Seq(ReuseExchangeAndSubquery) - }) + }) ++ () // add new rule here } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala index 8f4ce0f49a89a..eace04f127598 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.internal.SQLConf */ object RemoveRedundantProjects extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { + // try calling SparkPartialRule here, add flag to enable/disable if (!conf.getConf(SQLConf.REMOVE_REDUNDANT_PROJECTS_ENABLED)) { plan } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f0ea40c104f52..881e205828f7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -891,6 +891,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout, hasInitialState, planLater(initialState), planLater(child) ) :: Nil + case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes, + dataAttributes, statefulProcessor, timeoutMode, outputMode, outputObjAttr, child) => + TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, valueDeserializer, + groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode, + outputObjAttr, None, planLater(child)) :: Nil + case _: FlatMapGroupsInPandasWithState => // TODO(SPARK-40443): support applyInPandasWithState in batch query throw new UnsupportedOperationException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/StateOpIdBatchRule.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/StateOpIdBatchRule.scala new file mode 100644 index 0000000000000..9a19a85857d67 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/StateOpIdBatchRule.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, TransformWithStateExec} +import org.apache.spark.sql.internal.SQLConf + +import java.util.concurrent.atomic.AtomicInteger + +// Create batch equivalent of StateOpIdRule for streaming queries + +object StateOpIdBatchRule extends Rule[SparkPlan] { + + private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) + .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) + .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) + + /** + * Records the current id for a given stateful operator in the query plan as the `state` + * preparation walks the query plan. + */ + private val statefulOperatorId = new AtomicInteger(0) + private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = { + StatefulOperatorStateInfo( + checkpointLocation, + runId, + statefulOperatorId.getAndIncrement(), + currentBatchId, + numStateStores) + } + override def apply(plan: SparkPlan): SparkPlan = { + case t: TransformWithStateExec => + // get numShufflePartitions from SQLConf + + val numShufflePartitions = plan.session.sessionState.conf.numShufflePartitions + + t.copy( + stateInfo = Some(nextStatefulOperationStateInfo()), + batchTimestampMs = t.batchTimestampMs, + eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None + ) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index ce651d959afc6..aeabf3ad7d969 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -152,6 +152,8 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver + // populate stateInfo if this is a streaming query + val stateInfo = getStateInfo child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, schemaForKeyRow, @@ -171,3 +173,36 @@ case class TransformWithStateExec( } } } + +object TransformWithStateExec { + + // Plan logical transformWithState for batch queries + def generateSparkPlanForBatchQueries( + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + statefulProcessor: StatefulProcessor[Any, Any, Any], + timeoutMode: TimeoutMode, + outputMode: OutputMode, + outputObjAttr: Attribute, + stateInfo: Option[StatefulOperatorStateInfo], + child: SparkPlan): SparkPlan = { + + new TransformWithStateExec( + keyDeserializer, + valueDeserializer, + groupingAttributes, + dataAttributes, + statefulProcessor, + timeoutMode, + outputMode, + outputObjAttr, + stateInfo, + Some(System.currentTimeMillis), + None, + None, + child) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index b85353c5d676f..1561a7632080b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} import org.apache.spark.sql.internal.SQLConf @@ -132,25 +132,34 @@ class TransformWithStateSuite extends StateStoreMetricsTest ) } } + + + test("transformWithState - batch should succeed") { + val inputData = Seq("a", "a", "b") + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeoutMode.NoTimeouts(), + OutputMode.Append()) + + val df = result.toDF() + checkAnswer(df, Seq(("a", "1"), ("b", "1")).toDF()) + } } class TransformWithStateValidationSuite extends StateStoreMetricsTest { import testImplicits._ - test("transformWithState - batch should fail") { - val ex = intercept[Exception] { - val df = Seq("a", "a", "b").toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor, - TimeoutMode.NoTimeouts(), - OutputMode.Append()) - .write - .format("noop") - .mode(SaveMode.Append) - .save() - } - assert(ex.isInstanceOf[AnalysisException]) - assert(ex.getMessage.contains("not supported")) + test("transformWithState - batch should not fail") { + val _ = Seq("a", "a", "b").toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor, + TimeoutMode.NoTimeouts(), + OutputMode.Append()) + .write + .format("noop") + .mode(SaveMode.Append) + .save() } test("transformWithState - streaming with hdfsStateStoreProvider should fail") {