From a1fc10892833891c2118b3ec4c496cc48c653cfd Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 24 Jan 2024 11:52:25 -0800 Subject: [PATCH] checking if is batch before appyling rule --- .../spark/sql/execution/QueryExecution.scala | 2 +- .../sql/execution/StateOpIdBatchRule.scala | 50 +++++++++---------- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 7218380e9a3b7..034ff29c3ce22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -476,7 +476,7 @@ object QueryExecution { Nil } else { Seq(ReuseExchangeAndSubquery) - }) ++ () // add new rule here + }) ++ Seq(StateOpIdBatchRule) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/StateOpIdBatchRule.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/StateOpIdBatchRule.scala index 9a19a85857d67..91d3006dc2085 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/StateOpIdBatchRule.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/StateOpIdBatchRule.scala @@ -16,44 +16,42 @@ */ package org.apache.spark.sql.execution +import java.util.UUID +import java.util.concurrent.atomic.AtomicInteger + import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, TransformWithStateExec} -import org.apache.spark.sql.internal.SQLConf - -import java.util.concurrent.atomic.AtomicInteger // Create batch equivalent of StateOpIdRule for streaming queries object StateOpIdBatchRule extends Rule[SparkPlan] { - private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) - .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) - .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) - /** * Records the current id for a given stateful operator in the query plan as the `state` * preparation walks the query plan. */ private val statefulOperatorId = new AtomicInteger(0) - private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = { - StatefulOperatorStateInfo( - checkpointLocation, - runId, - statefulOperatorId.getAndIncrement(), - currentBatchId, - numStateStores) - } override def apply(plan: SparkPlan): SparkPlan = { - case t: TransformWithStateExec => - // get numShufflePartitions from SQLConf - - val numShufflePartitions = plan.session.sessionState.conf.numShufflePartitions - - t.copy( - stateInfo = Some(nextStatefulOperationStateInfo()), - batchTimestampMs = t.batchTimestampMs, - eventTimeWatermarkForLateEvents = None, - eventTimeWatermarkForEviction = None - ) + // fill out fields in TransformWithStateExec + if (plan.logicalLink.isDefined && !plan.logicalLink.get.isStreaming) { + logError(s"## IN BATCH RULE ##") + plan transform { + case t: TransformWithStateExec => + t.copy( + stateInfo = Some(StatefulOperatorStateInfo( + "", + queryRunId = UUID.randomUUID(), + operatorId = statefulOperatorId.getAndIncrement(), + storeVersion = 0, + numPartitions = 0) + ), + batchTimestampMs = t.batchTimestampMs, + eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None + ) + } + } else { + plan + } } }