Skip to content

Commit

Permalink
tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Nov 22, 2024
1 parent 3b753d1 commit f22bcbf
Showing 1 changed file with 88 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,84 +78,18 @@ private[sql] class RocksDBStateStoreProvider
verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal)
val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName)
// Create cache key using store ID to avoid collisions
val avroEncCacheKey = s"${getRunId}_${stateStoreId.operatorId}_" +
val avroEncCacheKey = s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" +
s"${stateStoreId.partitionId}_$colFamilyName"

def avroEnc = stateStoreEncoding match {
case "avro" => Some(
RocksDBStateStoreProvider.avroEncoderMap.get(
avroEncCacheKey,
new java.util.concurrent.Callable[AvroEncoder] {
override def call(): AvroEncoder = getAvroEnc(keyStateEncoderSpec, valueSchema)
}
)
)
case "unsaferow" => None
}
val avroEnc = getAvroEnc(
stateStoreEncoding, avroEncCacheKey, keyStateEncoderSpec, valueSchema)

keyValueEncoderMap.putIfAbsent(colFamilyName,
(RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, useColumnFamilies,
Some(newColFamilyId), avroEnc), RocksDBStateEncoder.getValueEncoder(valueSchema,
useMultipleValuesPerKey, avroEnc)))
}

private def getRunId: String = {
val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY)
if (runId != null) {
runId
} else {
assert(Utils.isTesting, "Failed to find query id/batch Id in task context")
UUID.randomUUID().toString
}
}

private def getAvroSerializer(schema: StructType): AvroSerializer = {
val avroType = SchemaConverters.toAvroType(schema)
new AvroSerializer(schema, avroType, nullable = false)
}

private def getAvroDeserializer(schema: StructType): AvroDeserializer = {
val avroType = SchemaConverters.toAvroType(schema)
val avroOptions = AvroOptions(Map.empty)
new AvroDeserializer(avroType, schema,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
}

private def getAvroEnc(
keyStateEncoderSpec: KeyStateEncoderSpec,
valueSchema: StructType
): AvroEncoder = {
val valueSerializer = getAvroSerializer(valueSchema)
val valueDeserializer = getAvroDeserializer(valueSchema)
val keySchema = keyStateEncoderSpec match {
case NoPrefixKeyStateEncoderSpec(schema) =>
schema
case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
StructType(schema.take(numColsPrefixKey))
case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) =>
val remainingSchema = {
0.until(schema.length).diff(orderingOrdinals).map { ordinal =>
schema(ordinal)
}
}
StructType(remainingSchema)
}
val suffixKeySchema = keyStateEncoderSpec match {
case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
Some(StructType(schema.drop(numColsPrefixKey)))
case _ => None
}
AvroEncoder(
getAvroSerializer(keySchema),
getAvroDeserializer(keySchema),
valueSerializer,
valueDeserializer,
suffixKeySchema.map(getAvroSerializer),
suffixKeySchema.map(getAvroDeserializer)
)
}

override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = {
verify(key != null, "Key cannot be null")
verifyColFamilyOperations("get", colFamilyName)
Expand Down Expand Up @@ -454,10 +388,17 @@ private[sql] class RocksDBStateStoreProvider
defaultColFamilyId = Some(rocksDB.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME))
}

val colFamilyName = StateStore.DEFAULT_COL_FAMILY_NAME
// Create cache key using store ID to avoid collisions
val avroEncCacheKey = s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" +
s"${stateStoreId.partitionId}_$colFamilyName"
val avroEnc = getAvroEnc(
stateStoreEncoding, avroEncCacheKey, keyStateEncoderSpec, valueSchema)

keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME,
(RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec,
useColumnFamilies, defaultColFamilyId),
RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey)))
useColumnFamilies, defaultColFamilyId, avroEnc),
RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey, avroEnc)))
}

override def stateStoreId: StateStoreId = stateStoreId_
Expand Down Expand Up @@ -682,6 +623,82 @@ object RocksDBStateStoreProvider {
new NonFateSharingCache(guavaCache)
}

def getAvroEnc(
stateStoreEncoding: String,
avroEncCacheKey: String,
keyStateEncoderSpec: KeyStateEncoderSpec,
valueSchema: StructType): Option[AvroEncoder] = {

stateStoreEncoding match {
case "avro" => Some(
RocksDBStateStoreProvider.avroEncoderMap.get(
avroEncCacheKey,
new java.util.concurrent.Callable[AvroEncoder] {
override def call(): AvroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema)
}
)
)
case "unsaferow" => None
}
}

private def getRunId(hadoopConf: Configuration): String = {
val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY)
if (runId != null) {
runId
} else {
assert(Utils.isTesting, "Failed to find query id/batch Id in task context")
UUID.randomUUID().toString
}
}

private def getAvroSerializer(schema: StructType): AvroSerializer = {
val avroType = SchemaConverters.toAvroType(schema)
new AvroSerializer(schema, avroType, nullable = false)
}

private def getAvroDeserializer(schema: StructType): AvroDeserializer = {
val avroType = SchemaConverters.toAvroType(schema)
val avroOptions = AvroOptions(Map.empty)
new AvroDeserializer(avroType, schema,
avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType,
avroOptions.stableIdPrefixForUnionType, avroOptions.recursiveFieldMaxDepth)
}

private def createAvroEnc(
keyStateEncoderSpec: KeyStateEncoderSpec,
valueSchema: StructType
): AvroEncoder = {
val valueSerializer = getAvroSerializer(valueSchema)
val valueDeserializer = getAvroDeserializer(valueSchema)
val keySchema = keyStateEncoderSpec match {
case NoPrefixKeyStateEncoderSpec(schema) =>
schema
case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
StructType(schema.take(numColsPrefixKey))
case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) =>
val remainingSchema = {
0.until(schema.length).diff(orderingOrdinals).map { ordinal =>
schema(ordinal)
}
}
StructType(remainingSchema)
}
val suffixKeySchema = keyStateEncoderSpec match {
case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
Some(StructType(schema.drop(numColsPrefixKey)))
case _ => None
}
AvroEncoder(
getAvroSerializer(keySchema),
getAvroDeserializer(keySchema),
valueSerializer,
valueDeserializer,
suffixKeySchema.map(getAvroSerializer),
suffixKeySchema.map(getAvroDeserializer)
)
}

// Native operation latencies report as latency in microseconds
// as SQLMetrics support millis. Convert the value to millis
val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric(
Expand Down

0 comments on commit f22bcbf

Please sign in to comment.