Skip to content

Commit

Permalink
checking if is batch before appyling rule
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Jan 24, 2024
1 parent e825bfa commit a1fc108
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ object QueryExecution {
Nil
} else {
Seq(ReuseExchangeAndSubquery)
}) ++ () // add new rule here
}) ++ Seq(StateOpIdBatchRule)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,44 +16,42 @@
*/
package org.apache.spark.sql.execution

import java.util.UUID
import java.util.concurrent.atomic.AtomicInteger

import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, TransformWithStateExec}
import org.apache.spark.sql.internal.SQLConf

import java.util.concurrent.atomic.AtomicInteger

// Create batch equivalent of StateOpIdRule for streaming queries

object StateOpIdBatchRule extends Rule[SparkPlan] {

private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
.map(SQLConf.SHUFFLE_PARTITIONS.valueConverter)
.getOrElse(sparkSession.sessionState.conf.numShufflePartitions)

/**
* Records the current id for a given stateful operator in the query plan as the `state`
* preparation walks the query plan.
*/
private val statefulOperatorId = new AtomicInteger(0)
private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = {
StatefulOperatorStateInfo(
checkpointLocation,
runId,
statefulOperatorId.getAndIncrement(),
currentBatchId,
numStateStores)
}
override def apply(plan: SparkPlan): SparkPlan = {
case t: TransformWithStateExec =>
// get numShufflePartitions from SQLConf

val numShufflePartitions = plan.session.sessionState.conf.numShufflePartitions

t.copy(
stateInfo = Some(nextStatefulOperationStateInfo()),
batchTimestampMs = t.batchTimestampMs,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None
)
// fill out fields in TransformWithStateExec
if (plan.logicalLink.isDefined && !plan.logicalLink.get.isStreaming) {
logError(s"## IN BATCH RULE ##")
plan transform {
case t: TransformWithStateExec =>
t.copy(
stateInfo = Some(StatefulOperatorStateInfo(
"",
queryRunId = UUID.randomUUID(),
operatorId = statefulOperatorId.getAndIncrement(),
storeVersion = 0,
numPartitions = 0)
),
batchTimestampMs = t.batchTimestampMs,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None
)
}
} else {
plan
}
}
}

0 comments on commit a1fc108

Please sign in to comment.