Skip to content

Commit

Permalink
moving avro enc creation into RocksDBStateENcoder
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Dec 13, 2024
1 parent bc373b0 commit ff1e0f4
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWri
import org.apache.avro.io.{DecoderFactory, EncoderFactory}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer, SchemaConverters}
import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
Expand Down Expand Up @@ -423,10 +423,10 @@ class UnsafeRowDataEncoder(

class AvroStateEncoder(
keyStateEncoderSpec: KeyStateEncoderSpec,
valueSchema: StructType,
avroEncoder: AvroEncoder) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema)
valueSchema: StructType) extends RocksDBDataEncoder(keyStateEncoderSpec, valueSchema)
with Logging {

private val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema)
// Avro schema used by the avro encoders
private lazy val keyAvroType: Schema = SchemaConverters.toAvroType(keySchema)
private lazy val keyProj = UnsafeProjection.create(keySchema)
Expand Down Expand Up @@ -478,6 +478,80 @@ class AvroStateEncoder(

private lazy val remainingKeyAvroProjection = UnsafeProjection.create(remainingKeySchema)



private def getAvroSerializer(schema: StructType): AvroSerializer = {
val avroType = SchemaConverters.toAvroType(schema)
new AvroSerializer(schema, avroType, nullable = false)
}

private def getAvroDeserializer(schema: StructType): AvroDeserializer = {
val avroType = SchemaConverters.toAvroType(schema)
val avroOptions = AvroOptions(Map.empty)
new AvroDeserializer(avroType, schema,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
}

/**
* Creates an AvroEncoder that handles both key and value serialization/deserialization.
* This method sets up the complete encoding infrastructure needed for state store operations.
*
* The encoder handles different key encoding specifications:
* - NoPrefixKeyStateEncoderSpec: Simple key encoding without prefix
* - PrefixKeyScanStateEncoderSpec: Keys with prefix for efficient scanning
* - RangeKeyScanStateEncoderSpec: Keys with ordering requirements for range scans
*
* For prefix scan cases, it also creates separate encoders for the suffix portion of keys.
*
* @param keyStateEncoderSpec Specification for how to encode keys
* @param valueSchema Schema for the values to be encoded
* @return An AvroEncoder containing all necessary serializers and deserializers
*/
private def createAvroEnc(
keyStateEncoderSpec: KeyStateEncoderSpec,
valueSchema: StructType
): AvroEncoder = {
val valueSerializer = getAvroSerializer(valueSchema)
val valueDeserializer = getAvroDeserializer(valueSchema)

// Get key schema based on encoder spec type
val keySchema = keyStateEncoderSpec match {
case NoPrefixKeyStateEncoderSpec(schema) =>
schema
case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
StructType(schema.take(numColsPrefixKey))
case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) =>
val remainingSchema = {
0.until(schema.length).diff(orderingOrdinals).map { ordinal =>
schema(ordinal)
}
}
StructType(remainingSchema)
}

// Handle suffix key schema for prefix scan case
val suffixKeySchema = keyStateEncoderSpec match {
case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
Some(StructType(schema.drop(numColsPrefixKey)))
case _ =>
None
}

val keySerializer = getAvroSerializer(keySchema)
val keyDeserializer = getAvroDeserializer(keySchema)

// Create the AvroEncoder with all components
AvroEncoder(
keySerializer,
keyDeserializer,
valueSerializer,
valueDeserializer,
suffixKeySchema.map(getAvroSerializer),
suffixKeySchema.map(getAvroDeserializer)
)
}

/**
* This method takes an UnsafeRow, and serializes to a byte array using Avro encoding.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.spark.{SparkConf, SparkEnv, SparkException}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StreamExecution}
Expand Down Expand Up @@ -664,8 +663,7 @@ object RocksDBStateStoreProvider {
new java.util.concurrent.Callable[RocksDBDataEncoder] {
override def call(): RocksDBDataEncoder = {
if (stateStoreEncoding == "avro") {
val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema)
new AvroStateEncoder(keyStateEncoderSpec, valueSchema, avroEncoder)
new AvroStateEncoder(keyStateEncoderSpec, valueSchema)
} else {
new UnsafeRowDataEncoder(keyStateEncoderSpec, valueSchema)
}
Expand All @@ -684,78 +682,6 @@ object RocksDBStateStoreProvider {
}
}

private def getAvroSerializer(schema: StructType): AvroSerializer = {
val avroType = SchemaConverters.toAvroType(schema)
new AvroSerializer(schema, avroType, nullable = false)
}

private def getAvroDeserializer(schema: StructType): AvroDeserializer = {
val avroType = SchemaConverters.toAvroType(schema)
val avroOptions = AvroOptions(Map.empty)
new AvroDeserializer(avroType, schema,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
}

/**
* Creates an AvroEncoder that handles both key and value serialization/deserialization.
* This method sets up the complete encoding infrastructure needed for state store operations.
*
* The encoder handles different key encoding specifications:
* - NoPrefixKeyStateEncoderSpec: Simple key encoding without prefix
* - PrefixKeyScanStateEncoderSpec: Keys with prefix for efficient scanning
* - RangeKeyScanStateEncoderSpec: Keys with ordering requirements for range scans
*
* For prefix scan cases, it also creates separate encoders for the suffix portion of keys.
*
* @param keyStateEncoderSpec Specification for how to encode keys
* @param valueSchema Schema for the values to be encoded
* @return An AvroEncoder containing all necessary serializers and deserializers
*/
private def createAvroEnc(
keyStateEncoderSpec: KeyStateEncoderSpec,
valueSchema: StructType
): AvroEncoder = {
val valueSerializer = getAvroSerializer(valueSchema)
val valueDeserializer = getAvroDeserializer(valueSchema)

// Get key schema based on encoder spec type
val keySchema = keyStateEncoderSpec match {
case NoPrefixKeyStateEncoderSpec(schema) =>
schema
case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
StructType(schema.take(numColsPrefixKey))
case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) =>
val remainingSchema = {
0.until(schema.length).diff(orderingOrdinals).map { ordinal =>
schema(ordinal)
}
}
StructType(remainingSchema)
}

// Handle suffix key schema for prefix scan case
val suffixKeySchema = keyStateEncoderSpec match {
case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
Some(StructType(schema.drop(numColsPrefixKey)))
case _ =>
None
}

val keySerializer = getAvroSerializer(keySchema)
val keyDeserializer = getAvroDeserializer(keySchema)

// Create the AvroEncoder with all components
AvroEncoder(
keySerializer,
keyDeserializer,
valueSerializer,
valueDeserializer,
suffixKeySchema.map(getAvroSerializer),
suffixKeySchema.map(getAvroDeserializer)
)
}

// Native operation latencies report as latency in microseconds
// as SQLMetrics support millis. Convert the value to millis
val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric(
Expand Down

0 comments on commit ff1e0f4

Please sign in to comment.