Skip to content

Commit

Permalink
stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Jan 2, 2025
1 parent e66598c commit 8fdce94
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProj
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore}
import org.apache.spark.sql.types._

object ListStateMetricsImpl {
def getRowCounterCFName(stateName: String): String = "$rowCounter_" + stateName
}

/**
* Trait that provides helper methods to maintain metrics for a list state.
* For list state, we keep track of the count of entries in the list in a separate column family
Expand All @@ -43,18 +47,18 @@ trait ListStateMetricsImpl {

private val updatedCountRow = new GenericInternalRow(1)

private def getRowCounterCFName(stateName: String) = "$rowCounter_" + stateName

stateStore.createColFamilyIfAbsent(getRowCounterCFName(baseStateName), exprEncSchema,
counterCFValueSchema, NoPrefixKeyStateEncoderSpec(exprEncSchema), isInternal = true)
stateStore.createColFamilyIfAbsent(ListStateMetricsImpl.getRowCounterCFName(baseStateName),
exprEncSchema, counterCFValueSchema, NoPrefixKeyStateEncoderSpec(exprEncSchema),
isInternal = true)

/**
* Function to get the number of entries in the list state for a given grouping key
* @param encodedKey - encoded grouping key
* @return - number of entries in the list state
*/
def getEntryCount(encodedKey: UnsafeRow): Long = {
val countRow = stateStore.get(encodedKey, getRowCounterCFName(baseStateName))
val countRow = stateStore.get(encodedKey,
ListStateMetricsImpl.getRowCounterCFName(baseStateName))
if (countRow != null) {
countRow.getLong(0)
} else {
Expand All @@ -73,14 +77,15 @@ trait ListStateMetricsImpl {
updatedCountRow.setLong(0, updatedCount)
stateStore.put(encodedKey,
counterCFProjection(updatedCountRow.asInstanceOf[InternalRow]),
getRowCounterCFName(baseStateName))
ListStateMetricsImpl.getRowCounterCFName(baseStateName))
}

/**
* Function to remove the number of entries in the list state for a given grouping key
* @param encodedKey - encoded grouping key
*/
def removeEntryCount(encodedKey: UnsafeRow): Unit = {
stateStore.remove(encodedKey, getRowCounterCFName(baseStateName))
stateStore.remove(encodedKey,
ListStateMetricsImpl.getRowCounterCFName(baseStateName))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.execution.streaming

import scala.collection.mutable

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
Expand Down Expand Up @@ -67,39 +69,108 @@ object StateStoreColumnFamilySchemaUtils {
stateName: String,
keyEncoder: ExpressionEncoder[Any],
valEncoder: Encoder[T],
hasTtl: Boolean): StateStoreColFamilySchema = {
StateStoreColFamilySchema(
hasTtl: Boolean): Map[String, StateStoreColFamilySchema] = {
val schemas = mutable.Map[String, StateStoreColFamilySchema]()

// Add main value state schema
schemas.put(stateName, StateStoreColFamilySchema(
stateName, 0,
keyEncoder.schema, 0,
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema)))
Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema))))

// Add TTL index if needed
if (hasTtl) {
val ttlIndexSchema = StateStoreColFamilySchema(
getTtlColFamilyName(stateName), 0,
getTTLRowKeySchema(keyEncoder.schema), 0,
StructType(Array(StructField("__empty__", NullType))),
Some(RangeKeyScanStateEncoderSpec(getTTLRowKeySchema(keyEncoder.schema), Seq(0))))
schemas.put(ttlIndexSchema.colFamilyName, ttlIndexSchema)
}

schemas.toMap
}

def getListStateSchema[T](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
valEncoder: Encoder[T],
hasTtl: Boolean): StateStoreColFamilySchema = {
StateStoreColFamilySchema(
hasTtl: Boolean): Map[String, StateStoreColFamilySchema] = {
val schemas = mutable.Map[String, StateStoreColFamilySchema]()

// Add main list state schema
schemas.put(stateName, StateStoreColFamilySchema(
stateName, 0,
keyEncoder.schema, 0,
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema))))
// Add row counter schema
val counterSchema = StateStoreColFamilySchema(
ListStateMetricsImpl.getRowCounterCFName(stateName), 0,
keyEncoder.schema, 0,
StructType(Seq(StructField("count", LongType, nullable = false))),
Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema)))
schemas.put(counterSchema.colFamilyName, counterSchema)

// Add TTL-related schemas if needed
if (hasTtl) {
// TTL index
val ttlIndexSchema = StateStoreColFamilySchema(
getTtlColFamilyName(stateName), 0,
getTTLRowKeySchema(keyEncoder.schema), 0,
StructType(Array(StructField("__empty__", NullType))),
Some(RangeKeyScanStateEncoderSpec(getTTLRowKeySchema(keyEncoder.schema), Seq(0))))
schemas.put(ttlIndexSchema.colFamilyName, ttlIndexSchema)

// Min expiry index
val minIndexSchema = StateStoreColFamilySchema(
s"$$min_$stateName", 0,
keyEncoder.schema, 0,
getExpirationMsRowSchema(),
Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema)))
schemas.put(minIndexSchema.colFamilyName, minIndexSchema)

// Count index
val countSchema = StateStoreColFamilySchema(
s"$$count_$stateName", 0,
keyEncoder.schema, 0,
StructType(Seq(StructField("count", LongType, nullable = false))),
Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema)))
schemas.put(countSchema.colFamilyName, countSchema)
}

schemas.toMap
}

def getMapStateSchema[K, V](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
userKeyEnc: Encoder[K],
valEncoder: Encoder[V],
hasTtl: Boolean): StateStoreColFamilySchema = {
hasTtl: Boolean): Map[String, StateStoreColFamilySchema] = {
val schemas = mutable.Map[String, StateStoreColFamilySchema]()
val compositeKeySchema = getCompositeKeySchema(keyEncoder.schema, userKeyEnc.schema)
StateStoreColFamilySchema(

// Add main map state schema
schemas.put(stateName, StateStoreColFamilySchema(
stateName, 0,
compositeKeySchema, 0,
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)),
Some(userKeyEnc.schema))
Some(userKeyEnc.schema)))

// Add TTL index if needed
if (hasTtl) {
val ttlIndexSchema = StateStoreColFamilySchema(
getTtlColFamilyName(stateName), 0,
getTTLRowKeySchema(compositeKeySchema), 0,
StructType(Array(StructField("__empty__", NullType))),
Some(RangeKeyScanStateEncoderSpec(getTTLRowKeySchema(compositeKeySchema), Seq(0))))
schemas.put(ttlIndexSchema.colFamilyName, ttlIndexSchema)
}

schemas.toMap
}

def getTimerStateSchema(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ class StatefulProcessorHandleImpl(
* the StatefulProcessor is initialized.
*/
class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any])
extends StatefulProcessorHandleImplBase(timeMode, keyExprEnc) {
extends StatefulProcessorHandleImplBase(timeMode, keyExprEnc) with Logging {

// Because this is only happening on the driver side, there is only
// one task modifying and accessing these maps at a time
Expand Down Expand Up @@ -410,7 +410,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
val colFamilySchema = StateStoreColumnFamilySchemaUtils.
getValueStateSchema(stateName, keyExprEnc, stateEncoder, ttlEnabled)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
columnFamilySchemas ++= colFamilySchema
val stateVariableInfo = TransformWithStateVariableUtils.
getValueState(stateName, ttlEnabled = ttlEnabled)
stateVariableInfos.put(stateName, stateVariableInfo)
Expand Down Expand Up @@ -444,9 +444,10 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
val colFamilySchema = StateStoreColumnFamilySchemaUtils.
getListStateSchema(stateName, keyExprEnc, stateEncoder, ttlEnabled)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
columnFamilySchemas ++= colFamilySchema
val stateVariableInfo = TransformWithStateVariableUtils.
getListState(stateName, ttlEnabled = ttlEnabled)
logError(s"### colFamilySchema: $colFamilySchema")
stateVariableInfos.put(stateName, stateVariableInfo)
addTTLSchemas(
columnFamilySchemas,
Expand Down Expand Up @@ -480,16 +481,10 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
val valEncoder = encoderFor[V]
val colFamilySchema = StateStoreColumnFamilySchemaUtils.
getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, ttlEnabled)
columnFamilySchemas.put(stateName, colFamilySchema)
columnFamilySchemas ++= colFamilySchema
val stateVariableInfo = TransformWithStateVariableUtils.
getMapState(stateName, ttlEnabled = ttlEnabled)
stateVariableInfos.put(stateName, stateVariableInfo)
addTTLSchemas(
columnFamilySchemas,
stateVariableInfo,
stateName,
keyExprEnc.schema
)
null.asInstanceOf[MapState[K, V]]
}

Expand Down

0 comments on commit 8fdce94

Please sign in to comment.