diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 132d8e8d63b34..d55242def0cfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -893,7 +893,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { hasInitialState, planLater(initialState), planLater(child) ) :: Nil case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes, - dataAttributes, statefulProcessor, timeoutMode, outputMode, outputObjAttr, child) => + dataAttributes, statefulProcessor, timeoutMode, outputMode, outputObjAttr, child) => TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode, outputObjAttr, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 347a3a288e3bd..ae6455c0fc956 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -154,7 +154,6 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver - // populate stateInfo if this is a streaming query child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, schemaForKeyRow, @@ -181,15 +180,15 @@ object TransformWithStateExec { // Plan logical transformWithState for batch queries def generateSparkPlanForBatchQueries( - keyDeserializer: Expression, - valueDeserializer: Expression, - groupingAttributes: Seq[Attribute], - dataAttributes: Seq[Attribute], - statefulProcessor: StatefulProcessor[Any, Any, Any], - timeoutMode: TimeoutMode, - outputMode: OutputMode, - outputObjAttr: Attribute, - child: SparkPlan): SparkPlan = { + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + statefulProcessor: StatefulProcessor[Any, Any, Any], + timeoutMode: TimeoutMode, + outputMode: OutputMode, + outputObjAttr: Attribute, + child: SparkPlan): SparkPlan = { val shufflePartitions = child.session.sessionState.conf.numShufflePartitions val statefulOperatorStateInfo = StatefulOperatorStateInfo( Utils.createTempDir().getAbsolutePath,