Skip to content

Commit

Permalink
feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Jan 25, 2024
1 parent 579884f commit a54eb81
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import org.apache.spark.sql.internal.SQLConf
*/
object RemoveRedundantProjects extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
// try calling SparkPartialRule here, add flag to enable/disable
if (!conf.getConf(SQLConf.REMOVE_REDUNDANT_PROJECTS_ENABLED)) {
plan
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,16 @@ class QueryInfoImpl(
* track of valid transitions as various functions are invoked to track object lifecycle.
* @param store - instance of state store
*/
class StatefulProcessorHandleImpl(store: StateStore, runId: UUID, isStreaming: Boolean)
class StatefulProcessorHandleImpl(store: StateStore, runId: UUID, isStreaming: Boolean = true)
extends StatefulProcessorHandle with Logging {
import StatefulProcessorHandleState._

private def buildQueryInfo(): QueryInfo = {
val taskCtxOpt = Option(TaskContext.get())
// Task context is not available in tests, so we generate a random query id and batch id here
// For batch queries, we populate the queryId manually
val queryId = if (!isStreaming) {
UUID.randomUUID().toString
UUID.fromString("00000000-0000-0000-0000-000000000000")
} else if (taskCtxOpt.isDefined) {
taskCtxOpt.get.getLocalProperty(StreamExecution.QUERY_ID_KEY)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ case class TransformWithStateExec(
batchTimestampMs: Option[Long],
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
isStreaming: Boolean,
isStreaming: Boolean = true,
child: SparkPlan)
extends UnaryExecNode with StateStoreWriter with WatermarkSupport with ObjectProducerExec {

Expand Down Expand Up @@ -155,7 +155,6 @@ case class TransformWithStateExec(
metrics // force lazy init at driver

// populate stateInfo if this is a streaming query
val stateInfo = getStateInfo
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateInfo,
schemaForKeyRow,
Expand Down Expand Up @@ -199,7 +198,6 @@ object TransformWithStateExec {
numPartitions = shufflePartitions
)

// Rewrite physical operator to TransformWithStateForBatchExec
new TransformWithStateExec(
keyDeserializer,
valueDeserializer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class ValueStateSuite extends SharedSparkSession
test("Implicit key operations") {
tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), true)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID())
assert(handle.getQueryInfo().getPartitionId === 0)

val testState: ValueState[Long] = handle.getValueState[Long]("testState")
Expand Down Expand Up @@ -117,7 +117,7 @@ class ValueStateSuite extends SharedSparkSession
test("Value state operations for single instance") {
tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), true)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID())
assert(handle.getQueryInfo().getPartitionId === 0)

val testState: ValueState[Long] = handle.getValueState[Long]("testState")
Expand All @@ -143,7 +143,7 @@ class ValueStateSuite extends SharedSparkSession
test("Value state operations for multiple instances") {
tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), true)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID())
assert(handle.getQueryInfo().getPartitionId === 0)

val testState1: ValueState[Long] = handle.getValueState[Long]("testState1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -133,7 +132,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest
}
}


test("transformWithState - batch should succeed") {
val inputData = Seq("a", "a", "b")
val result = inputData.toDS()
Expand All @@ -150,18 +148,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest
class TransformWithStateValidationSuite extends StateStoreMetricsTest {
import testImplicits._

test("transformWithState - batch should not fail") {
val _ = Seq("a", "a", "b").toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessor,
TimeoutMode.NoTimeouts(),
OutputMode.Append())
.write
.format("noop")
.mode(SaveMode.Append)
.save()
}

test("transformWithState - streaming with hdfsStateStoreProvider should fail") {
val inputData = MemoryStream[String]
val result = inputData.toDS()
Expand Down

0 comments on commit a54eb81

Please sign in to comment.