Skip to content

Commit

Permalink
feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Dec 13, 2024
1 parent ca1353c commit 85aa8da
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,94 @@ sealed trait RocksDBValueStateEncoder {
* by the callers. The metadata in each row does not need to be written as Avro or UnsafeRow,
* but the actual data provided by the caller does.
*/
/** Interface for encoding and decoding state store data between UnsafeRow and raw bytes.
*
* @note All encode methods expect non-null input rows. Handling of null values is left to the
* implementing classes.
*/
trait DataEncoder {
/** Encodes a complete key row into bytes. Used as the primary key for state lookups.
*
* @param row An UnsafeRow containing all key columns as defined in the keySchema
* @return Serialized byte array representation of the key
*/
def encodeKey(row: UnsafeRow): Array[Byte]

/** Encodes the non-prefix portion of a key row. Used with prefix scan and
* range scan state lookups where the key is split into prefix and remaining portions.
*
* For prefix scans: Encodes columns after the prefix columns
* For range scans: Encodes columns not included in the ordering columns
*
* @param row An UnsafeRow containing only the remaining key columns
* @return Serialized byte array of the remaining key portion
* @throws UnsupportedOperationException if called on an encoder that doesn't support split keys
*/
def encodeRemainingKey(row: UnsafeRow): Array[Byte]

/** Encodes key columns used for range scanning, ensuring proper sort order in RocksDB.
*
* This method handles special encoding for numeric types to maintain correct sort order:
* - Adds sign byte markers for numeric types
* - Flips bits for negative floating point values
* - Preserves null ordering
*
* @param row An UnsafeRow containing the columns needed for range scan
* (specified by orderingOrdinals)
* @return Serialized bytes that will maintain correct sort order in RocksDB
* @throws UnsupportedOperationException if called on an encoder that doesn't support range scans
*/
def encodePrefixKeyForRangeScan(row: UnsafeRow): Array[Byte]

/** Encodes a value row into bytes.
*
* @param row An UnsafeRow containing the value columns as defined in the valueSchema
* @return Serialized byte array representation of the value
*/
def encodeValue(row: UnsafeRow): Array[Byte]

/** Decodes a complete key from its serialized byte form.
*
* For NoPrefixKeyStateEncoder: Decodes the entire key
* For PrefixKeyScanStateEncoder: Decodes only the prefix portion
*
* @param bytes Serialized byte array containing the encoded key
* @return UnsafeRow containing the decoded key columns
* @throws UnsupportedOperationException for unsupported encoder types
*/
def decodeKey(bytes: Array[Byte]): UnsafeRow

/** Decodes the remaining portion of a split key from its serialized form.
*
* For PrefixKeyScanStateEncoder: Decodes columns after the prefix
* For RangeKeyScanStateEncoder: Decodes non-ordering columns
*
* @param bytes Serialized byte array containing the encoded remaining key portion
* @return UnsafeRow containing the decoded remaining key columns
* @throws UnsupportedOperationException if called on an encoder that doesn't support split keys
*/
def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow

/** Decodes range scan key bytes back into an UnsafeRow, preserving proper ordering.
*
* This method reverses the special encoding done by encodePrefixKeyForRangeScan:
* - Interprets sign byte markers
* - Reverses bit flipping for negative floating point values
* - Handles null values
*
* @param bytes Serialized byte array containing the encoded range scan key
* @return UnsafeRow containing the decoded range scan columns
* @throws UnsupportedOperationException if called on an encoder that doesn't support range scans
*/
def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow

/** Decodes a value from its serialized byte form.
*
* @param bytes Serialized byte array containing the encoded value
* @return UnsafeRow containing the decoded value columns
*/
def decodeValue(bytes: Array[Byte]): UnsafeRow
}

abstract class RocksDBDataEncoder(
keyStateEncoderSpec: KeyStateEncoderSpec,
valueSchema: StructType) extends DataEncoder {
Expand Down Expand Up @@ -789,44 +865,58 @@ abstract class RocksDBKeyStateEncoderBase(
}
}

/** Factory object for creating state encoders used by RocksDB state store.
*
* The encoders created by this object handle serialization and deserialization of state data,
* supporting both key and value encoding with various access patterns
* (e.g., prefix scan, range scan).
*/
object RocksDBStateEncoder extends Logging {

/** Creates a key encoder based on the specified encoding strategy and configuration.
*
* @param dataEncoder The underlying encoder that handles the actual data encoding/decoding
* @param keyStateEncoderSpec Specification defining the key encoding strategy
* (no prefix, prefix scan, or range scan)
* @param useColumnFamilies Whether to use RocksDB column families for storage
* @param virtualColFamilyId Optional column family identifier when column families are enabled
* @return A configured RocksDBKeyStateEncoder instance
*/
def getKeyEncoder(
dataEncoder: RocksDBDataEncoder,
keyStateEncoderSpec: KeyStateEncoderSpec,
useColumnFamilies: Boolean,
virtualColFamilyId: Option[Short] = None,
avroEnc: Option[AvroEncoder] = None): RocksDBKeyStateEncoder = {
// Return the key state encoder based on the requested type
keyStateEncoderSpec match {
case NoPrefixKeyStateEncoderSpec(keySchema) =>
new NoPrefixKeyStateEncoder(dataEncoder, keySchema, useColumnFamilies, virtualColFamilyId)

case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
new PrefixKeyScanStateEncoder(dataEncoder, keySchema, numColsPrefixKey,
useColumnFamilies, virtualColFamilyId)

case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
new RangeKeyScanStateEncoder(dataEncoder, keySchema, orderingOrdinals,
useColumnFamilies, virtualColFamilyId)

case _ =>
throw new IllegalArgumentException(s"Unsupported key state encoder spec: " +
s"$keyStateEncoderSpec")
}
virtualColFamilyId: Option[Short] = None): RocksDBKeyStateEncoder = {
keyStateEncoderSpec.toEncoder(dataEncoder, useColumnFamilies, virtualColFamilyId)
}

/** Creates a value encoder that supports either single or multiple values per key.
*
* @param dataEncoder The underlying encoder that handles the actual data encoding/decoding
* @param valueSchema Schema defining the structure of values to be encoded
* @param useMultipleValuesPerKey If true, creates an encoder that can handle multiple values
* per key; if false, creates an encoder for single values
* @return A configured RocksDBValueStateEncoder instance
*/
def getValueEncoder(
dataEncoder: RocksDBDataEncoder,
valueSchema: StructType,
useMultipleValuesPerKey: Boolean,
avroEnc: Option[AvroEncoder] = None): RocksDBValueStateEncoder = {
useMultipleValuesPerKey: Boolean): RocksDBValueStateEncoder = {
if (useMultipleValuesPerKey) {
new MultiValuedStateEncoder(dataEncoder, valueSchema)
} else {
new SingleValueStateEncoder(dataEncoder, valueSchema)
}
}

/** Encodes a virtual column family ID into a byte array suitable for RocksDB.
*
* This method creates a fixed-size byte array prefixed with the virtual column family ID,
* which is used to partition data within RocksDB.
*
* @param virtualColFamilyId The column family identifier to encode
* @return A byte array containing the encoded column family ID
*/
def getColumnFamilyIdBytes(virtualColFamilyId: Short): Array[Byte] = {
val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET, virtualColFamilyId)
Expand Down Expand Up @@ -871,18 +961,6 @@ class PrefixKeyScanStateEncoder(
UnsafeProjection.create(refs)
}

// Prefix Key schema and projection definitions used by the Avro Serializers
// and Deserializers
private val prefixKeySchema = StructType(keySchema.take(numColsPrefixKey))
private lazy val prefixKeyAvroType = SchemaConverters.toAvroType(prefixKeySchema)
private val prefixKeyProj = UnsafeProjection.create(prefixKeySchema)

// Remaining Key schema and projection definitions used by the Avro Serializers
// and Deserializers
private val remainingKeySchema = StructType(keySchema.drop(numColsPrefixKey))
private lazy val remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema)
private val remainingKeyProj = UnsafeProjection.create(remainingKeySchema)

// This is quite simple to do - just bind sequentially, as we don't change the order.
private val restoreKeyProjection: UnsafeProjection = UnsafeProjection.create(keySchema)

Expand Down Expand Up @@ -1056,22 +1134,6 @@ class RangeKeyScanStateEncoder(
UnsafeProjection.create(refs)
}

private val rangeScanAvroSchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan(
StructType(rangeScanKeyFieldsWithOrdinal.map(_._1).toArray))

private lazy val rangeScanAvroType = SchemaConverters.toAvroType(rangeScanAvroSchema)

private val rangeScanAvroProjection = UnsafeProjection.create(rangeScanAvroSchema)

// Existing remainder key schema stuff
private val remainingKeySchema = StructType(
0.to(keySchema.length - 1).diff(orderingOrdinals).map(keySchema(_))
)

private lazy val remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema)

private val remainingKeyAvroProjection = UnsafeProjection.create(remainingKeySchema)

// Reusable objects
private val joinedRowOnKey = new JoinedRow()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,18 @@ private[sql] class RocksDBStateStoreProvider
val dataEncoder = getDataEncoder(
stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec, valueSchema)

keyValueEncoderMap.putIfAbsent(colFamilyName,
(RocksDBStateEncoder.getKeyEncoder(dataEncoder, keyStateEncoderSpec, useColumnFamilies,
Some(newColFamilyId)), RocksDBStateEncoder.getValueEncoder(dataEncoder, valueSchema,
useMultipleValuesPerKey)))
val keyEncoder = RocksDBStateEncoder.getKeyEncoder(
dataEncoder,
keyStateEncoderSpec,
useColumnFamilies,
Some(newColFamilyId)
)
val valueEncoder = RocksDBStateEncoder.getValueEncoder(
dataEncoder,
valueSchema,
useMultipleValuesPerKey
)
keyValueEncoderMap.putIfAbsent(colFamilyName, (keyEncoder, valueEncoder))
}

override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = {
Expand Down Expand Up @@ -392,10 +400,18 @@ private[sql] class RocksDBStateStoreProvider
val dataEncoder = getDataEncoder(
stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec, valueSchema)

keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME,
(RocksDBStateEncoder.getKeyEncoder(dataEncoder, keyStateEncoderSpec,
useColumnFamilies, defaultColFamilyId),
RocksDBStateEncoder.getValueEncoder(dataEncoder, valueSchema, useMultipleValuesPerKey)))
val keyEncoder = RocksDBStateEncoder.getKeyEncoder(
dataEncoder,
keyStateEncoderSpec,
useColumnFamilies,
defaultColFamilyId
)
val valueEncoder = RocksDBStateEncoder.getValueEncoder(
dataEncoder,
valueSchema,
useMultipleValuesPerKey
)
keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME, (keyEncoder, valueEncoder))
}

override def stateStoreId: StateStoreId = stateStoreId_
Expand Down Expand Up @@ -642,28 +658,20 @@ object RocksDBStateStoreProvider {
encoderCacheKey: String,
keyStateEncoderSpec: KeyStateEncoderSpec,
valueSchema: StructType): RocksDBDataEncoder = {

stateStoreEncoding match {
case "avro" =>
RocksDBStateStoreProvider.dataEncoderCache.get(
encoderCacheKey,
new java.util.concurrent.Callable[AvroStateEncoder] {
override def call(): AvroStateEncoder = {
val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema)
new AvroStateEncoder(keyStateEncoderSpec, valueSchema, avroEncoder)
}
}
)
case "unsaferow" =>
RocksDBStateStoreProvider.dataEncoderCache.get(
encoderCacheKey,
new java.util.concurrent.Callable[UnsafeRowDataEncoder] {
override def call(): UnsafeRowDataEncoder = {
new UnsafeRowDataEncoder(keyStateEncoderSpec, valueSchema)
}
assert(Set("avro", "unsaferow").contains(stateStoreEncoding))
RocksDBStateStoreProvider.dataEncoderCache.get(
encoderCacheKey,
new java.util.concurrent.Callable[RocksDBDataEncoder] {
override def call(): RocksDBDataEncoder = {
if (stateStoreEncoding == "avro") {
val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema)
new AvroStateEncoder(keyStateEncoderSpec, valueSchema, avroEncoder)
} else {
new UnsafeRowDataEncoder(keyStateEncoderSpec, valueSchema)
}
)
}
}
}
)
}

private def getRunId(hadoopConf: Configuration): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,18 @@ sealed trait KeyStateEncoderSpec {
def keySchema: StructType
def jsonValue: JValue
def json: String = compact(render(jsonValue))

/** Creates a RocksDBKeyStateEncoder for this specification.
*
* @param dataEncoder The encoder to handle the actual data encoding/decoding
* @param useColumnFamilies Whether to use RocksDB column families
* @param virtualColFamilyId Optional column family ID when column families are used
* @return A RocksDBKeyStateEncoder configured for this spec
*/
def toEncoder(
dataEncoder: RocksDBDataEncoder,
useColumnFamilies: Boolean,
virtualColFamilyId: Option[Short]): RocksDBKeyStateEncoder
}

object KeyStateEncoderSpec {
Expand All @@ -348,6 +360,14 @@ case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEn
override def jsonValue: JValue = {
("keyStateEncoderType" -> JString("NoPrefixKeyStateEncoderSpec"))
}

override def toEncoder(
dataEncoder: RocksDBDataEncoder,
useColumnFamilies: Boolean,
virtualColFamilyId: Option[Short]): RocksDBKeyStateEncoder = {
new NoPrefixKeyStateEncoder(
dataEncoder, keySchema, useColumnFamilies, virtualColFamilyId)
}
}

case class PrefixKeyScanStateEncoderSpec(
Expand All @@ -356,6 +376,14 @@ case class PrefixKeyScanStateEncoderSpec(
if (numColsPrefixKey == 0 || numColsPrefixKey >= keySchema.length) {
throw StateStoreErrors.incorrectNumOrderingColsForPrefixScan(numColsPrefixKey.toString)
}
override def toEncoder(
dataEncoder: RocksDBDataEncoder,
useColumnFamilies: Boolean,
virtualColFamilyId: Option[Short]): RocksDBKeyStateEncoder = {
new PrefixKeyScanStateEncoder(
dataEncoder, keySchema, numColsPrefixKey, useColumnFamilies, virtualColFamilyId)
}


override def jsonValue: JValue = {
("keyStateEncoderType" -> JString("PrefixKeyScanStateEncoderSpec")) ~
Expand All @@ -371,6 +399,14 @@ case class RangeKeyScanStateEncoderSpec(
throw StateStoreErrors.incorrectNumOrderingColsForRangeScan(orderingOrdinals.length.toString)
}

override def toEncoder(
dataEncoder: RocksDBDataEncoder,
useColumnFamilies: Boolean,
virtualColFamilyId: Option[Short]): RocksDBKeyStateEncoder = {
new RangeKeyScanStateEncoder(
dataEncoder, keySchema, orderingOrdinals, useColumnFamilies, virtualColFamilyId)
}

override def jsonValue: JValue = {
("keyStateEncoderType" -> JString("RangeKeyScanStateEncoderSpec")) ~
("orderingOrdinals" -> orderingOrdinals.map(JInt(_)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.sql.Timestamp
import java.time.Duration

import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithEncodingTypes, RocksDBStateStoreProvider}
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBStateStoreProvider}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.StreamManualClock

Expand All @@ -41,7 +41,7 @@ case class OutputEvent(
* Test suite base for TransformWithState with TTL support.
*/
abstract class TransformWithStateTTLTest
extends StreamTest
extends StreamTest with AlsoTestWithChangelogCheckpointingEnabled
with AlsoTestWithEncodingTypes {
import testImplicits._

Expand Down

0 comments on commit 85aa8da

Please sign in to comment.