Skip to content

Commit

Permalink
[SPARK-49467][SS] Add support for state data source reader and list s…
Browse files Browse the repository at this point in the history
…tate

### What changes were proposed in this pull request?
Add support for state data source reader and list state

### Why are the changes needed?
This change adds support for reading state written using list state used primarily within the stateful processor used with the `transformWithState` operator

### Does this PR introduce _any_ user-facing change?
Yes

Users can read state and `explode` entries using the following query:
```
        val stateReaderDf = spark.read
          .format("statestore")
          .option(StateSourceOptions.PATH, <checkpoint_location>)
          .option(StateSourceOptions.STATE_VAR_NAME, <state_var_name>)
          .load()

        val listStateDf = stateReaderDf
          .selectExpr(
            "key.value AS groupingKey",
            "list_value AS valueList",
            "partition_id")
          .select($"groupingKey",
            explode($"valueList").as("valueList"))
```

### How was this patch tested?
Added unit tests

```
[info] Run completed in 1 minute, 3 seconds.
[info] Total number of tests run: 8
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 8, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
```

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #47978 from anishshri-db/task/SPARK-49467.

Authored-by: Anish Shrigondekar <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
anishshri-db authored and HeartSaVioR committed Sep 6, 2024
1 parent fdeb288 commit 5d1d44f
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.datasources.v2.state
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{NullType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{NextIterator, SerializableConfiguration}

Expand Down Expand Up @@ -68,10 +69,20 @@ abstract class StatePartitionReaderBase(
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
extends PartitionReader[InternalRow] with Logging {
// Used primarily as a placeholder for the value schema in the context of
// state variables used within the transformWithState operator.
private val schemaForValueRow: StructType =
StructType(Array(StructField("__dummy__", NullType)))

protected val keySchema = SchemaUtil.getSchemaAsDataType(
schema, "key").asInstanceOf[StructType]
protected val valueSchema = SchemaUtil.getSchemaAsDataType(
schema, "value").asInstanceOf[StructType]

protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
schemaForValueRow
} else {
SchemaUtil.getSchemaAsDataType(
schema, "value").asInstanceOf[StructType]
}

protected lazy val provider: StateStoreProvider = {
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
Expand All @@ -84,10 +95,17 @@ abstract class StatePartitionReaderBase(
false
}

val useMultipleValuesPerKey = if (stateVariableInfoOpt.isDefined &&
stateVariableInfoOpt.get.stateVariableType == StateVariableType.ListState) {
true
} else {
false
}

val provider = StateStoreProvider.createAndInit(
stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
useColumnFamilies = useColFamilies, storeConf, hadoopConf.value,
useMultipleValuesPerKey = false)
useMultipleValuesPerKey = useMultipleValuesPerKey)

if (useColFamilies) {
val store = provider.getStore(partition.sourceOptions.batchId + 1)
Expand All @@ -99,7 +117,7 @@ abstract class StatePartitionReaderBase(
stateStoreColFamilySchema.keySchema,
stateStoreColFamilySchema.valueSchema,
stateStoreColFamilySchema.keyStateEncoderSpec.get,
useMultipleValuesPerKey = false)
useMultipleValuesPerKey = useMultipleValuesPerKey)
}
provider
}
Expand Down Expand Up @@ -166,16 +184,22 @@ class StatePartitionReader(
stateVariableInfoOpt match {
case Some(stateVarInfo) =>
val stateVarType = stateVarInfo.stateVariableType
val hasTTLEnabled = stateVarInfo.ttlEnabled

stateVarType match {
case StateVariableType.ValueState =>
if (hasTTLEnabled) {
SchemaUtil.unifyStateRowPairWithTTL((pair.key, pair.value), valueSchema,
partition.partition)
} else {
SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)
SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)

case StateVariableType.ListState =>
val key = pair.key
val result = store.valuesIterator(key, stateVarName)
var unsafeRowArr: Seq[UnsafeRow] = Seq.empty
result.foreach { entry =>
unsafeRowArr = unsafeRowArr :+ entry.copy()
}
// convert the list of values to array type
val arrData = new GenericArrayData(unsafeRowArr.toArray)
SchemaUtil.unifyStateRowPairWithMultipleValues((pair.key, arrData),
partition.partition)

case _ =>
throw new IllegalStateException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.datasources.v2.state.utils
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors, StateSourceOptions}
import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.state.StateStoreColFamilySchema
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, StringType, StructType}
import org.apache.spark.util.ArrayImplicits._

object SchemaUtil {
Expand Down Expand Up @@ -70,15 +71,13 @@ object SchemaUtil {
row
}

def unifyStateRowPairWithTTL(
pair: (UnsafeRow, UnsafeRow),
valueSchema: StructType,
def unifyStateRowPairWithMultipleValues(
pair: (UnsafeRow, GenericArrayData),
partition: Int): InternalRow = {
val row = new GenericInternalRow(4)
val row = new GenericInternalRow(3)
row.update(0, pair._1)
row.update(1, pair._2.get(0, valueSchema))
row.update(2, pair._2.get(1, LongType))
row.update(3, partition)
row.update(1, pair._2)
row.update(2, partition)
row
}

Expand All @@ -91,23 +90,22 @@ object SchemaUtil {
"change_type" -> classOf[StringType],
"key" -> classOf[StructType],
"value" -> classOf[StructType],
"partition_id" -> classOf[IntegerType],
"expiration_timestamp" -> classOf[LongType])
"single_value" -> classOf[StructType],
"list_value" -> classOf[ArrayType],
"partition_id" -> classOf[IntegerType])

val expectedFieldNames = if (sourceOptions.readChangeFeed) {
Seq("batch_id", "change_type", "key", "value", "partition_id")
} else if (transformWithStateVariableInfoOpt.isDefined) {
val stateVarInfo = transformWithStateVariableInfoOpt.get
val hasTTLEnabled = stateVarInfo.ttlEnabled
val stateVarType = stateVarInfo.stateVariableType

stateVarType match {
case StateVariableType.ValueState =>
if (hasTTLEnabled) {
Seq("key", "value", "expiration_timestamp", "partition_id")
} else {
Seq("key", "value", "partition_id")
}
Seq("key", "single_value", "partition_id")

case StateVariableType.ListState =>
Seq("key", "list_value", "partition_id")

case _ =>
throw StateDataSourceErrors
Expand All @@ -131,24 +129,19 @@ object SchemaUtil {
stateVarInfo: TransformWithStateVariableInfo,
stateStoreColFamilySchema: StateStoreColFamilySchema): StructType = {
val stateVarType = stateVarInfo.stateVariableType
val hasTTLEnabled = stateVarInfo.ttlEnabled

stateVarType match {
case StateVariableType.ValueState =>
if (hasTTLEnabled) {
val ttlValueSchema = SchemaUtil.getSchemaAsDataType(
stateStoreColFamilySchema.valueSchema, "value").asInstanceOf[StructType]
new StructType()
.add("key", stateStoreColFamilySchema.keySchema)
.add("value", ttlValueSchema)
.add("expiration_timestamp", LongType)
.add("partition_id", IntegerType)
} else {
new StructType()
.add("key", stateStoreColFamilySchema.keySchema)
.add("value", stateStoreColFamilySchema.valueSchema)
.add("partition_id", IntegerType)
}
new StructType()
.add("key", stateStoreColFamilySchema.keySchema)
.add("single_value", stateStoreColFamilySchema.valueSchema)
.add("partition_id", IntegerType)

case StateVariableType.ListState =>
new StructType()
.add("key", stateStoreColFamilySchema.keySchema)
.add("list_value", ArrayType(stateStoreColFamilySchema.valueSchema))
.add("partition_id", IntegerType)

case _ =>
throw StateDataSourceErrors.internalError(s"Unsupported state variable type $stateVarType")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ import java.time.Duration
import org.apache.spark.sql.{Encoders, Row}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, TestClass}
import org.apache.spark.sql.functions.explode
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{ExpiredTimerInfo, OutputMode, RunningCountStatefulProcessor, StatefulProcessor, StateStoreMetricsTest, TimeMode, TimerValues, TransformWithStateSuiteUtils, TTLConfig, ValueState}
import org.apache.spark.sql.streaming.{ExpiredTimerInfo, ListState, OutputMode, RunningCountStatefulProcessor, StatefulProcessor, StateStoreMetricsTest, TimeMode, TimerValues, TransformWithStateSuiteUtils, TTLConfig, ValueState}

/** Stateful processor of single value state var with non-primitive type */
class StatefulProcessorWithSingleValueVar extends RunningCountStatefulProcessor {
Expand Down Expand Up @@ -73,6 +74,52 @@ class StatefulProcessorWithTTL
}
}

/** Stateful processor tracking groups belonging to sessions with/without TTL */
class SessionGroupsStatefulProcessor extends
StatefulProcessor[String, (String, String), String] {
@transient private var _groupsList: ListState[String] = _

override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_groupsList = getHandle.getListState("groupsList", Encoders.STRING)
}

override def handleInputRows(
key: String,
inputRows: Iterator[(String, String)],
timerValues: TimerValues,
expiredTimerInfo: ExpiredTimerInfo): Iterator[String] = {
inputRows.foreach { inputRow =>
_groupsList.appendValue(inputRow._2)
}
Iterator.empty
}
}

class SessionGroupsStatefulProcessorWithTTL extends
StatefulProcessor[String, (String, String), String] {
@transient private var _groupsListWithTTL: ListState[String] = _

override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_groupsListWithTTL = getHandle.getListState("groupsListWithTTL", Encoders.STRING,
TTLConfig(Duration.ofMillis(30000)))
}

override def handleInputRows(
key: String,
inputRows: Iterator[(String, String)],
timerValues: TimerValues,
expiredTimerInfo: ExpiredTimerInfo): Iterator[String] = {
inputRows.foreach { inputRow =>
_groupsListWithTTL.appendValue(inputRow._2)
}
Iterator.empty
}
}

/**
* Test suite to verify integration of state data source reader with the transformWithState operator
*/
Expand Down Expand Up @@ -111,7 +158,7 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest

val resultDf = stateReaderDf.selectExpr(
"key.value AS groupingKey",
"value.id AS valueId", "value.name AS valueName",
"single_value.id AS valueId", "single_value.name AS valueName",
"partition_id")

checkAnswer(resultDf,
Expand Down Expand Up @@ -174,7 +221,7 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest
.load()

val resultDf = stateReaderDf.selectExpr(
"key.value", "value.value", "expiration_timestamp", "partition_id")
"key.value", "single_value.value", "single_value.ttlExpirationMs", "partition_id")

var count = 0L
resultDf.collect().foreach { row =>
Expand All @@ -187,7 +234,7 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest

val answerDf = stateReaderDf.selectExpr(
"key.value AS groupingKey",
"value.value AS valueId", "partition_id")
"single_value.value.value AS valueId", "partition_id")
checkAnswer(answerDf,
Seq(Row("a", 1L, 0), Row("b", 1L, 1)))

Expand Down Expand Up @@ -217,4 +264,110 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest
}
}
}

test("state data source integration - list state") {
withTempDir { tempDir =>
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

val inputData = MemoryStream[(String, String)]
val result = inputData.toDS()
.groupByKey(x => x._1)
.transformWithState(new SessionGroupsStatefulProcessor(),
TimeMode.None(),
OutputMode.Update())

testStream(result, OutputMode.Update())(
StartStream(checkpointLocation = tempDir.getAbsolutePath),
AddData(inputData, ("session1", "group2")),
AddData(inputData, ("session1", "group1")),
AddData(inputData, ("session2", "group1")),
CheckNewAnswer(),
AddData(inputData, ("session3", "group7")),
AddData(inputData, ("session1", "group4")),
CheckNewAnswer(),
StopStream
)

val stateReaderDf = spark.read
.format("statestore")
.option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
.option(StateSourceOptions.STATE_VAR_NAME, "groupsList")
.load()

val listStateDf = stateReaderDf
.selectExpr(
"key.value AS groupingKey",
"list_value.value AS valueList",
"partition_id")
.select($"groupingKey",
explode($"valueList"))

checkAnswer(listStateDf,
Seq(Row("session1", "group1"), Row("session1", "group2"), Row("session1", "group4"),
Row("session2", "group1"), Row("session3", "group7")))
}
}
}

test("state data source integration - list state and TTL") {
withTempDir { tempDir =>
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
val inputData = MemoryStream[(String, String)]
val result = inputData.toDS()
.groupByKey(x => x._1)
.transformWithState(new SessionGroupsStatefulProcessorWithTTL(),
TimeMode.ProcessingTime(),
OutputMode.Update())

testStream(result, OutputMode.Update())(
StartStream(checkpointLocation = tempDir.getAbsolutePath),
AddData(inputData, ("session1", "group2")),
AddData(inputData, ("session1", "group1")),
AddData(inputData, ("session2", "group1")),
AddData(inputData, ("session3", "group7")),
AddData(inputData, ("session1", "group4")),
Execute { _ =>
// wait for the batch to run since we are using processing time
Thread.sleep(5000)
},
StopStream
)

val stateReaderDf = spark.read
.format("statestore")
.option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
.option(StateSourceOptions.STATE_VAR_NAME, "groupsListWithTTL")
.load()

val listStateDf = stateReaderDf
.selectExpr(
"key.value AS groupingKey",
"list_value AS valueList",
"partition_id")
.select($"groupingKey",
explode($"valueList").as("valueList"))

val resultDf = listStateDf.selectExpr("valueList.ttlExpirationMs")
var count = 0L
resultDf.collect().foreach { row =>
count = count + 1
assert(row.getLong(0) > 0)
}

// verify that 5 state rows are present
assert(count === 5)

val valuesDf = listStateDf.selectExpr("groupingKey",
"valueList.value.value AS groupId")

checkAnswer(valuesDf,
Seq(Row("session1", "group1"), Row("session1", "group2"), Row("session1", "group4"),
Row("session2", "group1"), Row("session3", "group7")))
}
}
}
}

0 comments on commit 5d1d44f

Please sign in to comment.