Skip to content

Commit

Permalink
RocksDBStateStoreSuite passes
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Dec 31, 2024
1 parent 4826d12 commit 3e1d01e
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
storeConf: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false,
stateSchemaBroadcast: Option[StateSchemaBroadcast] = None): Unit = {
stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = {
assert(
!storeConf.enableStateStoreCheckpointIds,
"HDFS State Store Provider doesn't support checkpointFormatVersion >= 2 " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,44 @@ sealed trait RocksDBValueStateEncoder {
def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow]
}

trait StateSchemaProvider {
def getSchemaMetadataValue(key: StateSchemaMetadataKey): StateSchemaMetadataValue

def getCurrentStateSchemaId(colFamilyName: String, isKey: Boolean): Short
}


// Test implementation that can be dynamically updated
class TestStateSchemaProvider extends StateSchemaProvider {
private var schemas = Map.empty[StateSchemaMetadataKey, StateSchemaMetadataValue]

def addSchema(
colFamilyName: String,
keySchema: StructType,
valueSchema: StructType,
keySchemaId: Short = 0,
valueSchemaId: Short = 0): Unit = {
schemas ++= Map(
StateSchemaMetadataKey(colFamilyName, keySchemaId, isKey = true) ->
StateSchemaMetadataValue(keySchema, SchemaConverters.toAvroType(keySchema)),
StateSchemaMetadataKey(colFamilyName, valueSchemaId, isKey = false) ->
StateSchemaMetadataValue(valueSchema, SchemaConverters.toAvroType(valueSchema))
)
}

override def getSchemaMetadataValue(key: StateSchemaMetadataKey): StateSchemaMetadataValue = {
schemas(key)
}

override def getCurrentStateSchemaId(colFamilyName: String, isKey: Boolean): Short = {
schemas.keys
.filter(key =>
key.colFamilyName == colFamilyName &&
key.isKey == isKey)
.map(_.schemaId).max
}
}

/**
* Broadcasts schema metadata information for stateful operators in a streaming query.
*
Expand All @@ -63,19 +101,19 @@ sealed trait RocksDBValueStateEncoder {
*/
case class StateSchemaBroadcast(
broadcast: Broadcast[StateSchemaMetadata]
) extends Logging {
) extends Logging with StateSchemaProvider {

/**
* Retrieves the schema information for a given column family and schema version
*
* @param key A combination of column family name and schema ID
* @return The corresponding schema metadata value containing both SQL and Avro schemas
*/
def getSchemaMetadataValue(key: StateSchemaMetadataKey): StateSchemaMetadataValue = {
override def getSchemaMetadataValue(key: StateSchemaMetadataKey): StateSchemaMetadataValue = {
broadcast.value.activeSchemas(key)
}

def getCurrentStateSchemaId(colFamilyName: String, isKey: Boolean): Short = {
override def getCurrentStateSchemaId(colFamilyName: String, isKey: Boolean): Short = {
broadcast.value.activeSchemas
.keys
.filter(key =>
Expand Down Expand Up @@ -370,7 +408,7 @@ abstract class RocksDBDataEncoder(
class UnsafeRowDataEncoder(
keyStateEncoderSpec: KeyStateEncoderSpec,
valueSchema: StructType,
stateSchemaBroadcast: Option[StateSchemaBroadcast],
stateSchemaProvider: Option[StateSchemaProvider],
columnFamilyInfo: Option[ColumnFamilyInfo]
) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema) {

Expand Down Expand Up @@ -602,18 +640,18 @@ class UnsafeRowDataEncoder(
class AvroStateEncoder(
keyStateEncoderSpec: KeyStateEncoderSpec,
valueSchema: StructType,
stateSchemaBroadcast: Option[StateSchemaBroadcast],
stateSchemaProvider: Option[StateSchemaProvider],
columnFamilyInfo: Option[ColumnFamilyInfo]
) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema) with Logging {


// schema information
private lazy val currentKeySchemaId: Short = getStateSchemaBroadcast.getCurrentStateSchemaId(
private lazy val currentKeySchemaId: Short = getStateSchemaProvider.getCurrentStateSchemaId(
getColFamilyName,
isKey = true
)

private lazy val currentValSchemaId: Short = getStateSchemaBroadcast.getCurrentStateSchemaId(
private lazy val currentValSchemaId: Short = getStateSchemaProvider.getCurrentStateSchemaId(
getColFamilyName,
isKey = false
)
Expand Down Expand Up @@ -748,8 +786,8 @@ class AvroStateEncoder(
columnFamilyInfo.get.colFamilyName
}

private def getStateSchemaBroadcast: StateSchemaBroadcast = {
stateSchemaBroadcast.get
private def getStateSchemaProvider: StateSchemaProvider = {
stateSchemaProvider.get
}

/**
Expand Down Expand Up @@ -1129,7 +1167,7 @@ class AvroStateEncoder(

override def decodeValue(bytes: Array[Byte]): UnsafeRow = {
val schemaIdRow = decodeStateSchemaIdRow(bytes)
val writerSchema = stateSchemaBroadcast.get.getSchemaMetadataValue(
val writerSchema = stateSchemaProvider.get.getSchemaMetadataValue(
StateSchemaMetadataKey(
getColFamilyName,
schemaIdRow.schemaId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,18 @@ private[sql] class RocksDBStateStoreProvider

val columnFamilyInfo = Some(ColumnFamilyInfo(colFamilyName, newColFamilyId))

stateSchemaProvider match {
case Some(t: TestStateSchemaProvider) =>
t.addSchema(colFamilyName, keySchema, valueSchema)
case _ =>
}

val dataEncoder = getDataEncoder(
stateStoreEncoding,
dataEncoderCacheKey,
keyStateEncoderSpec,
valueSchema,
stateSchemaBroadcast,
stateSchemaProvider,
columnFamilyInfo
)

Expand Down Expand Up @@ -385,15 +391,15 @@ private[sql] class RocksDBStateStoreProvider
storeConf: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false,
stateSchemaBroadcast: Option[StateSchemaBroadcast]): Unit = {
stateSchemaProvider: Option[StateSchemaProvider]): Unit = {
this.stateStoreId_ = stateStoreId
this.keySchema = keySchema
this.valueSchema = valueSchema
this.storeConf = storeConf
this.hadoopConf = hadoopConf
this.useColumnFamilies = useColumnFamilies
this.stateStoreEncoding = storeConf.stateStoreEncodingFormat
this.stateSchemaBroadcast = stateSchemaBroadcast
this.stateSchemaProvider = stateSchemaProvider

if (useMultipleValuesPerKey) {
require(useColumnFamilies, "Multiple values per key support requires column families to be" +
Expand All @@ -410,19 +416,22 @@ private[sql] class RocksDBStateStoreProvider
stateStoreName = stateStoreId.storeName,
colFamilyName = StateStore.DEFAULT_COL_FAMILY_NAME)

val columnFamilyInfo = if (useColumnFamilies) {
defaultColFamilyId = Some(rocksDB.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME))
Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, defaultColFamilyId.get))
} else {
None
stateSchemaProvider match {
case Some(t: TestStateSchemaProvider) =>
t.addSchema(StateStore.DEFAULT_COL_FAMILY_NAME, keySchema, valueSchema)
case _ =>
}

defaultColFamilyId = Some(rocksDB.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME))
val columnFamilyInfo =
Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, defaultColFamilyId.get))

val dataEncoder = getDataEncoder(
stateStoreEncoding,
dataEncoderCacheKey,
keyStateEncoderSpec,
valueSchema,
stateSchemaBroadcast,
stateSchemaProvider,
columnFamilyInfo
)

Expand Down Expand Up @@ -516,7 +525,7 @@ private[sql] class RocksDBStateStoreProvider
@volatile private var hadoopConf: Configuration = _
@volatile private var useColumnFamilies: Boolean = _
@volatile private var stateStoreEncoding: String = _
@volatile private var stateSchemaBroadcast: Option[StateSchemaBroadcast] = _
@volatile private var stateSchemaProvider: Option[StateSchemaProvider] = _

private[sql] lazy val rocksDB = {
val dfsRootDir = stateStoreId.storeCheckpointLocation().toString
Expand Down Expand Up @@ -696,7 +705,7 @@ object RocksDBStateStoreProvider {
encoderCacheKey: StateRowEncoderCacheKey,
keyStateEncoderSpec: KeyStateEncoderSpec,
valueSchema: StructType,
stateSchemaBroadcast: Option[StateSchemaBroadcast],
stateSchemaProvider: Option[StateSchemaProvider],
columnFamilyInfo: Option[ColumnFamilyInfo] = None): RocksDBDataEncoder = {
assert(Set("avro", "unsaferow").contains(stateStoreEncoding))
RocksDBStateStoreProvider.dataEncoderCache.get(
Expand All @@ -707,14 +716,14 @@ object RocksDBStateStoreProvider {
new AvroStateEncoder(
keyStateEncoderSpec,
valueSchema,
stateSchemaBroadcast,
stateSchemaProvider,
columnFamilyInfo
)
} else {
new UnsafeRowDataEncoder(
keyStateEncoderSpec,
valueSchema,
stateSchemaBroadcast,
stateSchemaProvider,
columnFamilyInfo
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ trait StateStoreProvider {
storeConfs: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false,
stateSchemaBroadcast: Option[StateSchemaBroadcast] = None): Unit
stateSchemaProvider: Option[StateSchemaProvider] = None): Unit

/**
* Return the id of the StateStores this provider will generate.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class CkptIdCollectingStateStoreProviderWrapper extends StateStoreProvider {
storeConfs: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false,
stateSchemaBroadcast: Option[StateSchemaBroadcast] = None): Unit = {
stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = {
innerProvider.init(
stateStoreId,
keySchema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
useColumnFamilies: Boolean = false,
useMultipleValuesPerKey: Boolean = false): RocksDBStateStoreProvider = {
val provider = new RocksDBStateStoreProvider()
val testStateSchemaProvider = new TestStateSchemaProvider
provider.init(
storeId,
keySchema,
Expand All @@ -1331,7 +1332,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
useColumnFamilies,
new StateStoreConf(sqlConf.getOrElse(SQLConf.get)),
conf,
useMultipleValuesPerKey)
useMultipleValuesPerKey,
stateSchemaProvider = Some(testStateSchemaProvider))
provider
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class MaintenanceErrorOnCertainPartitionsProvider extends HDFSBackedStateStorePr
storeConfs: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false,
stateSchemaBroadcast: Option[StateSchemaBroadcast] = None): Unit = {
stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = {
id = stateStoreId

super.init(
Expand Down Expand Up @@ -102,7 +102,7 @@ class FakeStateStoreProviderWithMaintenanceError extends StateStoreProvider {
storeConfs: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false,
stateSchemaBroadcast: Option[StateSchemaBroadcast] = None): Unit = {
stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = {
id = stateStoreId
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan}
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.{ContinuousMemoryStream, ForeachBatchUserFuncException, MemorySink}
import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, StateSchemaBroadcast, StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, StateSchemaProvider, StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1480,7 +1480,7 @@ class TestStateStoreProvider extends StateStoreProvider {
storeConfs: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false,
stateSchemaBroadcast: Option[StateSchemaBroadcast] = None): Unit = {
stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = {
throw new Exception("Successfully instantiated")
}

Expand Down

0 comments on commit 3e1d01e

Please sign in to comment.