diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 1d8a6d004ef97..22cd10bef45a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -148,6 +148,7 @@ trait DataEncoder { */ def decodeValue(bytes: Array[Byte]): UnsafeRow } + abstract class RocksDBDataEncoder( keyStateEncoderSpec: KeyStateEncoderSpec, valueSchema: StructType) extends DataEncoder { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 44751ddc45e60..fb0bf84d7aabc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -75,8 +75,12 @@ private[sql] class RocksDBStateStoreProvider isInternal: Boolean = false): Unit = { verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal) val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName) - val dataEncoderCacheKey = s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" + - s"${stateStoreId.partitionId}_${colFamilyName}" + val dataEncoderCacheKey = StateRowEncoderCacheKey( + queryRunId = getRunId(hadoopConf), + operatorId = stateStoreId.operatorId, + partitionId = stateStoreId.partitionId, + stateStoreName = stateStoreId.storeName, + colFamilyName = colFamilyName) val dataEncoder = getDataEncoder( stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec, valueSchema) @@ -393,8 +397,12 @@ private[sql] class RocksDBStateStoreProvider defaultColFamilyId = Some(rocksDB.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME)) } - val dataEncoderCacheKey = s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" + - s"${stateStoreId.partitionId}_${StateStore.DEFAULT_COL_FAMILY_NAME}" + val dataEncoderCacheKey = StateRowEncoderCacheKey( + queryRunId = getRunId(hadoopConf), + operatorId = stateStoreId.operatorId, + partitionId = stateStoreId.partitionId, + stateStoreName = stateStoreId.storeName, + colFamilyName = StateStore.DEFAULT_COL_FAMILY_NAME) val dataEncoder = getDataEncoder( stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec, valueSchema) @@ -618,6 +626,15 @@ private[sql] class RocksDBStateStoreProvider } } + +case class StateRowEncoderCacheKey( + queryRunId: String, + operatorId: Long, + partitionId: Int, + stateStoreName: String, + colFamilyName: String +) + object RocksDBStateStoreProvider { // Version as a single byte that specifies the encoding of the row data in RocksDB val STATE_ENCODING_NUM_VERSION_BYTES = 1 @@ -628,7 +645,7 @@ object RocksDBStateStoreProvider { private val AVRO_ENCODER_LIFETIME_HOURS = 1L // Add the cache at companion object level so it persists across provider instances - private val dataEncoderCache: NonFateSharingCache[String, RocksDBDataEncoder] = + private val dataEncoderCache: NonFateSharingCache[StateRowEncoderCacheKey, RocksDBDataEncoder] = NonFateSharingCache( maximumSize = MAX_AVRO_ENCODERS_IN_CACHE, expireAfterAccessTime = AVRO_ENCODER_LIFETIME_HOURS, @@ -654,7 +671,7 @@ object RocksDBStateStoreProvider { */ def getDataEncoder( stateStoreEncoding: String, - encoderCacheKey: String, + encoderCacheKey: StateRowEncoderCacheKey, keyStateEncoderSpec: KeyStateEncoderSpec, valueSchema: StructType): RocksDBDataEncoder = { assert(Set("avro", "unsaferow").contains(stateStoreEncoding))