Skip to content

Commit

Permalink
IT WORKS
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Dec 12, 2024
1 parent 94cb378 commit c7103a4
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.python.{FlatMapGroupsInPandasWithStateExec, TransformWithStateInPandasExec}
import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter, StateSchemaCompatibilityChecker, StateSchemaMetadata}
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaMetadataKey, StateSchemaMetadataValue}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand Down Expand Up @@ -261,6 +261,7 @@ class IncrementalExecution(
case tws: TransformWithStateExec =>
val stateSchemaMetadata = createStateSchemaMetadata(stateSchemaMapping.head)
val ssmBc = sparkSession.sparkContext.broadcast(stateSchemaMetadata)
logStateSchemaMetadata(stateSchemaMetadata)
tws.copy(stateSchemaMetadata = Some(ssmBc))
case _ => ssw
}
Expand All @@ -269,6 +270,14 @@ class IncrementalExecution(
}
}

def logStateSchemaMetadata(metadata: StateSchemaMetadata): Unit = {
metadata.schemas.foreach { case (schemaId, keyValueMap) =>
keyValueMap.foreach { case (key, value) =>
logError(s"### Key: $key, Value: $value")
}
}
}

private def createStateSchemaMetadata(
stateSchemaMapping: Map[Int, String]
): StateSchemaMetadata = {
Expand All @@ -277,7 +286,10 @@ class IncrementalExecution(
val inStream = fm.open(new Path(stateSchemaPath))
val colFamilySchemas =
StateSchemaCompatibilityChecker.readSchemaFile(inStream).map { schema =>
schema.colFamilyName -> SchemaConverters.toAvroType(schema.valueSchema)
StateSchemaMetadataKey(
stateSchemaId, schema.colFamilyName, runId.toString) ->
StateSchemaMetadataValue(
SchemaConverters.toAvroType(schema.valueSchema), schema.valueSchema)
}.toMap
stateSchemaId -> colFamilySchemas
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,12 +482,14 @@ case class TransformWithStateExec(
}
case None => None
}

List(StateSchemaCompatibilityChecker.
validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf,
newSchemas.values.toList, session.sessionState, stateSchemaVersion,
storeName = StateStoreId.DEFAULT_STORE_NAME,
oldSchemaFilePath = oldStateSchemaFilePath,
newSchemaFilePath = Some(newStateSchemaFilePath)))
newSchemaFilePath = Some(newStateSchemaFilePath),
usingAvro = session.sessionState.conf.stateStoreEncodingFormat == "avro"))
}

override def stateSchemaMapping(
Expand Down Expand Up @@ -612,7 +614,8 @@ case class TransformWithStateExec(
NoPrefixKeyStateEncoderSpec(keyEncoder.schema),
session.sessionState,
Some(session.streams.stateStoreCoordinator),
useColumnFamilies = true
useColumnFamilies = true,
stateSchemaMetadata = stateSchemaMetadata
) {
case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
processData(store, singleIterator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
import org.apache.spark.sql.execution.streaming.StateStoreColumnFamilySchemaUtils
import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION, VIRTUAL_COL_FAMILY_PREFIX_BYTES}
import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION, STATE_ROW_SCHEMA_ID_PREFIX_BYTES, VIRTUAL_COL_FAMILY_PREFIX_BYTES}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform

Expand All @@ -51,16 +51,24 @@ sealed trait RocksDBValueStateEncoder {
def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow]
}

abstract class RocksDBValueStateEncoderWithProvider(
abstract class RocksDBStateSchema(
provider: RocksDBStateStoreProvider,
colFamilyName: String
) extends RocksDBValueStateEncoder {
def getSchemaFromId(schemaId: Short): Schema = {
null
colFamilyName: String) {
def getSchemaFromId(schemaId: Int): StateSchemaMetadataValue = {
provider.getSchemaFromId(colFamilyName, schemaId)
}

def getCurrentSchemaId: Short = {
0
def getCurrentSchemaId: Int = {
provider.getCurrentSchemaId
}

protected def encodeSchemaId(numBytes: Int, schemaId: Int): (Array[Byte], Int) = {
val encodedBytes = new Array[Byte](numBytes + STATE_ROW_SCHEMA_ID_PREFIX_BYTES)
var offset = Platform.BYTE_ARRAY_OFFSET
Platform.putInt(encodedBytes, Platform.BYTE_ARRAY_OFFSET, schemaId)
offset = Platform.BYTE_ARRAY_OFFSET + STATE_ROW_SCHEMA_ID_PREFIX_BYTES

(encodedBytes, offset)
}
}

Expand Down Expand Up @@ -220,6 +228,27 @@ object RocksDBStateEncoder extends Logging {
valueProj.apply(internalRow)
}

/**
* This method takes a byte array written using Avro encoding, and
* deserializes to an UnsafeRow using the Avro deserializer
*/
def decodeFromAvroToUnsafeRow(
valueBytes: Array[Byte],
avroDeserializer: AvroDeserializer,
currentAvroType: Schema,
valueAvroType: Schema,
valueProj: UnsafeProjection): UnsafeRow = {
val reader = new GenericDatumReader[Any](valueAvroType, currentAvroType)
val decoder = DecoderFactory.get().binaryDecoder(valueBytes, 0, valueBytes.length, null)
// bytes -> Avro.GenericDataRecord
val genericData = reader.read(null, decoder)
// Avro.GenericDataRecord -> InternalRow
val internalRow = avroDeserializer.deserialize(
genericData).orNull.asInstanceOf[InternalRow]
// InternalRow -> UnsafeRow
valueProj.apply(internalRow)
}

def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow = {
if (bytes != null) {
// Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform.
Expand Down Expand Up @@ -1161,7 +1190,8 @@ class MultiValuedStateEncoder(
colFamilyName: String,
valueSchema: StructType,
avroEnc: Option[AvroEncoder] = None)
extends RocksDBValueStateEncoderWithProvider(provider, colFamilyName) with Logging {
extends RocksDBStateSchema(provider, colFamilyName) with RocksDBValueStateEncoder
with Logging {

import RocksDBStateEncoder._

Expand Down Expand Up @@ -1259,7 +1289,8 @@ class SingleValueStateEncoder(
colFamilyName: String,
valueSchema: StructType,
avroEnc: Option[AvroEncoder] = None)
extends RocksDBValueStateEncoderWithProvider(provider, colFamilyName) with Logging {
extends RocksDBStateSchema(provider, colFamilyName) with RocksDBValueStateEncoder
with Logging {

import RocksDBStateEncoder._

Expand All @@ -1271,11 +1302,22 @@ class SingleValueStateEncoder(
private val valueProj = UnsafeProjection.create(valueSchema)

override def encodeValue(row: UnsafeRow): Array[Byte] = {
if (usingAvroEncoding) {
val valueBytes = if (usingAvroEncoding) {
encodeUnsafeRowToAvro(row, avroEnc.get.valueSerializer, valueAvroType, out)
} else {
encodeUnsafeRow(row)
}

// Create new array with space for schema ID
val (schemaVersionedBytes, offset) = encodeSchemaId(valueBytes.length, getCurrentSchemaId)

// Copy value bytes after schema ID
Platform.copyMemory(
valueBytes, Platform.BYTE_ARRAY_OFFSET,
schemaVersionedBytes, offset,
valueBytes.length)

schemaVersionedBytes
}

/**
Expand All @@ -1285,14 +1327,28 @@ class SingleValueStateEncoder(
* the given byte array.
*/
override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
if (valueBytes == null) {
return null
}
if (valueBytes == null) return null

// Get schema ID from first 4 bytes
val schemaId = Platform.getInt(valueBytes, Platform.BYTE_ARRAY_OFFSET)
val schemaMetadataValue = getSchemaFromId(schemaId)
val projection = UnsafeProjection.create(schemaMetadataValue.valueSchema)
// Get actual value bytes after schema ID
val actualValueBytes = new Array[Byte](valueBytes.length - STATE_ROW_SCHEMA_ID_PREFIX_BYTES)
Platform.copyMemory(
valueBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ROW_SCHEMA_ID_PREFIX_BYTES,
actualValueBytes, Platform.BYTE_ARRAY_OFFSET,
actualValueBytes.length)
logError(s"### schemaId: ${schemaId}")
logError(s"### schemaMetadataValue: ${schemaMetadataValue}")

if (usingAvroEncoding) {
decodeFromAvroToUnsafeRow(
valueBytes, avroEnc.get.valueDeserializer, valueAvroType, valueProj)
actualValueBytes, avroEnc.get.valueDeserializer,
valueAvroType,
schemaMetadataValue.avroSchema, valueProj)
} else {
decodeToUnsafeRow(valueBytes, valueRow)
decodeToUnsafeRow(actualValueBytes, valueRow)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ private[sql] class RocksDBStateStoreProvider
this.hadoopConf = hadoopConf
this.useColumnFamilies = useColumnFamilies
this.stateStoreEncoding = storeConf.stateStoreEncodingFormat
this.stateSchemaBroadcast = stateSchemaMetadata

if (useMultipleValuesPerKey) {
require(useColumnFamilies, "Multiple values per key support requires column families to be" +
Expand Down Expand Up @@ -482,6 +483,30 @@ private[sql] class RocksDBStateStoreProvider
@volatile private var hadoopConf: Configuration = _
@volatile private var useColumnFamilies: Boolean = _
@volatile private var stateStoreEncoding: String = _
@volatile private var stateSchemaBroadcast: Option[Broadcast[StateSchemaMetadata]] = _

def getSchemaFromId(colFamilyName: String, schemaId: Int): StateSchemaMetadataValue = {
if (stateSchemaBroadcast.isDefined) {
val metadata = stateSchemaBroadcast.get.value
metadata.schemas(schemaId)(
StateSchemaMetadataKey(
schemaId,
colFamilyName,
hadoopConf.get(StreamExecution.RUN_ID_KEY)
)
)
} else {
null
}
}

def getCurrentSchemaId: Int = {
if (stateSchemaBroadcast.isDefined) {
stateSchemaBroadcast.get.value.currentSchemaId
} else {
-1
}
}

private[sql] lazy val rocksDB = {
val dfsRootDir = stateStoreId.storeCheckpointLocation().toString
Expand Down Expand Up @@ -616,6 +641,7 @@ object RocksDBStateStoreProvider {
val STATE_ENCODING_NUM_VERSION_BYTES = 1
val STATE_ENCODING_VERSION: Byte = 0
val VIRTUAL_COL_FAMILY_PREFIX_BYTES = 2
val STATE_ROW_SCHEMA_ID_PREFIX_BYTES = 4

private val MAX_AVRO_ENCODERS_IN_CACHE = 1000
// Add the cache at companion object level so it persists across provider instances
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@

package org.apache.spark.sql.execution.streaming.state

import scala.jdk.CollectionConverters.IterableHasAsJava
import scala.util.Try

import org.apache.avro.SchemaValidatorBuilder
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FSDataInputStream, Path}

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer}
import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StatefulOperatorStateInfo}
import org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaReader, SchemaWriter}
Expand All @@ -38,21 +40,8 @@ case class StateSchemaValidationResult(
schemaPath: String
)

/**
* An Avro-based encoder used for serializing between UnsafeRow and Avro
* byte arrays in RocksDB state stores.
*
* This encoder is primarily utilized by [[RocksDBStateStoreProvider]] and [[RocksDBStateEncoder]]
* to handle serialization and deserialization of state store data.
*
* @param keySerializer Serializer for converting state store keys to Avro format
* @param keyDeserializer Deserializer for converting Avro-encoded keys back to UnsafeRow
* @param valueSerializer Serializer for converting state store values to Avro format
* @param valueDeserializer Deserializer for converting Avro-encoded values back to UnsafeRow
* @param suffixKeySerializer Optional serializer for handling suffix keys in Avro format
* @param suffixKeyDeserializer Optional deserializer for converting Avro-encoded suffix
* keys back to UnsafeRow
*/
// Avro encoder that is used by the RocksDBStateStoreProvider and RocksDBStateEncoder
// in order to serialize from UnsafeRow to a byte array of Avro encoding.
case class AvroEncoder(
keySerializer: AvroSerializer,
keyDeserializer: AvroDeserializer,
Expand Down Expand Up @@ -153,22 +142,37 @@ class StateSchemaCompatibilityChecker(
private def check(
oldSchema: StateStoreColFamilySchema,
newSchema: StateStoreColFamilySchema,
ignoreValueSchema: Boolean) : Unit = {
ignoreValueSchema: Boolean,
usingAvro: Boolean) : Boolean = {
val (storedKeySchema, storedValueSchema) = (oldSchema.keySchema,
oldSchema.valueSchema)
val (keySchema, valueSchema) = (newSchema.keySchema, newSchema.valueSchema)

if (storedKeySchema.equals(keySchema) &&
(ignoreValueSchema || storedValueSchema.equals(valueSchema))) {
// schema is exactly same
false
} else if (!schemasCompatible(storedKeySchema, keySchema)) {
throw StateStoreErrors.stateStoreKeySchemaNotCompatible(storedKeySchema.toString,
keySchema.toString)
} else if (!ignoreValueSchema && usingAvro) {
// By this point, we know that old value schema is not equal to new value schema
val oldAvroSchema = SchemaConverters.toAvroType(storedValueSchema)
val newAvroSchema = SchemaConverters.toAvroType(valueSchema)

val validator = new SchemaValidatorBuilder().canReadStrategy.validateAll()
// This will throw a SchemaValidation exception if the schema has evolved in an
// unacceptable way.
validator.validate(newAvroSchema, Iterable(oldAvroSchema).asJava)
// If no exception is thrown, then we know that the schema evolved in an
// acceptable way
true
} else if (!ignoreValueSchema && !schemasCompatible(storedValueSchema, valueSchema)) {
throw StateStoreErrors.stateStoreValueSchemaNotCompatible(storedValueSchema.toString,
valueSchema.toString)
} else {
logInfo("Detected schema change which is compatible. Allowing to put rows.")
true
}
}

Expand All @@ -182,7 +186,8 @@ class StateSchemaCompatibilityChecker(
def validateAndMaybeEvolveStateSchema(
newStateSchema: List[StateStoreColFamilySchema],
ignoreValueSchema: Boolean,
stateSchemaVersion: Int): Boolean = {
stateSchemaVersion: Int,
usingAvro: Boolean): Boolean = {
val existingStateSchemaList = getExistingKeyAndValueSchema()
val newStateSchemaList = newStateSchema

Expand All @@ -197,18 +202,18 @@ class StateSchemaCompatibilityChecker(
}.toMap
// For each new state variable, we want to compare it to the old state variable
// schema with the same name
newStateSchemaList.foreach { newSchema =>
existingSchemaMap.get(newSchema.colFamilyName).foreach { existingStateSchema =>
check(existingStateSchema, newSchema, ignoreValueSchema)
}
val hasEvolvedSchema = newStateSchemaList.exists { newSchema =>
existingSchemaMap.get(newSchema.colFamilyName)
.exists(existingSchema => check(existingSchema, newSchema, ignoreValueSchema, usingAvro))
}
val colFamiliesAddedOrRemoved =
(newStateSchemaList.map(_.colFamilyName).toSet != existingSchemaMap.keySet)
if (stateSchemaVersion == SCHEMA_FORMAT_V3 && colFamiliesAddedOrRemoved) {
val newSchemaFileWritten = hasEvolvedSchema || colFamiliesAddedOrRemoved
if (stateSchemaVersion == SCHEMA_FORMAT_V3 && newSchemaFileWritten) {
createSchemaFile(newStateSchemaList.sortBy(_.colFamilyName), stateSchemaVersion)
}
// TODO: [SPARK-49535] Write Schema files after schema has changed for StateSchemaV3
colFamiliesAddedOrRemoved
newSchemaFileWritten
}
}

Expand Down Expand Up @@ -270,7 +275,8 @@ object StateSchemaCompatibilityChecker extends Logging {
extraOptions: Map[String, String] = Map.empty,
storeName: String = StateStoreId.DEFAULT_STORE_NAME,
oldSchemaFilePath: Option[Path] = None,
newSchemaFilePath: Option[Path] = None): StateSchemaValidationResult = {
newSchemaFilePath: Option[Path] = None,
usingAvro: Boolean = false): StateSchemaValidationResult = {
// SPARK-47776: collation introduces the concept of binary (in)equality, which means
// in some collation we no longer be able to just compare the binary format of two
// UnsafeRows to determine equality. For example, 'aaa' and 'AAA' can be "semantically"
Expand Down Expand Up @@ -301,7 +307,7 @@ object StateSchemaCompatibilityChecker extends Logging {
val result = Try(
checker.validateAndMaybeEvolveStateSchema(newStateSchema,
ignoreValueSchema = !storeConf.formatValidationCheckValue,
stateSchemaVersion = stateSchemaVersion)
stateSchemaVersion = stateSchemaVersion, usingAvro)
).toEither.fold(Some(_),
hasEvolvedSchema => {
evolvedSchema = hasEvolvedSchema
Expand Down
Loading

0 comments on commit c7103a4

Please sign in to comment.