Skip to content

Commit

Permalink
[SPARK-46852][SS] Remove use of explicit key encoder and pass it impl…
Browse files Browse the repository at this point in the history
…icitly to the operator for transformWithState operator

### What changes were proposed in this pull request?
Remove use of explicit key encoder and pass it implicitly to the operator for transformWithState operator

### Why are the changes needed?
Changes needed to avoid asking users to provide explicit key encoder and we also might need them for subsequent timer related changes

### Does this PR introduce _any_ user-facing change?
Yes

### How was this patch tested?
Existing unit tests

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #44974 from anishshri-db/task/SPARK-46852.

Authored-by: Anish Shrigondekar <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
anishshri-db authored and HeartSaVioR committed Feb 1, 2024
1 parent 1870de0 commit e610d1d
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming
import java.io.Serializable

import org.apache.spark.annotation.{Evolving, Experimental}
import org.apache.spark.sql.Encoder

/**
* Represents the operation handle provided to the stateful processor used in the
Expand All @@ -34,12 +33,10 @@ private[sql] trait StatefulProcessorHandle extends Serializable {
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
* @param stateName - name of the state variable
* @param keyEncoder - Spark SQL Encoder for key
* @tparam K - type of key
* @tparam T - type of state variable
* @return - instance of ValueState of type T that can be used to store state persistently
*/
def getValueState[K, T](stateName: String, keyEncoder: Encoder[K]): ValueState[T]
def getValueState[T](stateName: String): ValueState[T]

/** Function to return queryInfo for currently running task */
def getQueryInfo(): QueryInfo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ object TransformWithState {
timeoutMode: TimeoutMode,
outputMode: OutputMode,
child: LogicalPlan): LogicalPlan = {
val keyEncoder = encoderFor[K]
val mapped = new TransformWithState(
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
Expand All @@ -585,6 +586,7 @@ object TransformWithState {
statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]],
timeoutMode,
outputMode,
keyEncoder.asInstanceOf[ExpressionEncoder[Any]],
CatalystSerde.generateObjAttr[U],
child
)
Expand All @@ -600,6 +602,7 @@ case class TransformWithState(
statefulProcessor: StatefulProcessor[Any, Any, Any],
timeoutMode: TimeoutMode,
outputMode: OutputMode,
keyEncoder: ExpressionEncoder[Any],
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectProducer {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case TransformWithState(
keyDeserializer, valueDeserializer, groupingAttributes,
dataAttributes, statefulProcessor, timeoutMode, outputMode,
outputAttr, child) =>
keyEncoder, outputAttr, child) =>
val execPlan = TransformWithStateExec(
keyDeserializer,
valueDeserializer,
Expand All @@ -737,6 +737,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
statefulProcessor,
timeoutMode,
outputMode,
keyEncoder,
outputAttr,
stateInfo = None,
batchTimestampMs = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.util.UUID

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.streaming.{QueryInfo, StatefulProcessorHandle, ValueState}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -67,8 +67,13 @@ class QueryInfoImpl(
* Class that provides a concrete implementation of a StatefulProcessorHandle. Note that we keep
* track of valid transitions as various functions are invoked to track object lifecycle.
* @param store - instance of state store
* @param runId - unique id for the current run
* @param keyEncoder - encoder for the key
*/
class StatefulProcessorHandleImpl(store: StateStore, runId: UUID)
class StatefulProcessorHandleImpl(
store: StateStore,
runId: UUID,
keyEncoder: ExpressionEncoder[Any])
extends StatefulProcessorHandle with Logging {
import StatefulProcessorHandleState._

Expand Down Expand Up @@ -108,11 +113,11 @@ class StatefulProcessorHandleImpl(store: StateStore, runId: UUID)

def getHandleState: StatefulProcessorHandleState = currState

override def getValueState[K, T](stateName: String, keyEncoder: Encoder[K]): ValueState[T] = {
override def getValueState[T](stateName: String): ValueState[T] = {
verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " +
"initialization is complete")
store.createColFamilyIfAbsent(stateName)
val resultState = new ValueStateImpl[K, T](store, stateName, keyEncoder)
val resultState = new ValueStateImpl[T](store, stateName, keyEncoder)
resultState
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.execution._
Expand All @@ -38,6 +39,7 @@ import org.apache.spark.util.CompletionIterator
* @param statefulProcessor processor methods called on underlying data
* @param timeoutMode defines the timeout mode
* @param outputMode defines the output mode for the statefulProcessor
* @param keyEncoder expression encoder for the key type
* @param outputObjAttr Defines the output object
* @param batchTimestampMs processing timestamp of the current batch.
* @param eventTimeWatermarkForLateEvents event time watermark for filtering late events
Expand All @@ -52,6 +54,7 @@ case class TransformWithStateExec(
statefulProcessor: StatefulProcessor[Any, Any, Any],
timeoutMode: TimeoutMode,
outputMode: OutputMode,
keyEncoder: ExpressionEncoder[Any],
outputObjAttr: Attribute,
stateInfo: Option[StatefulOperatorStateInfo],
batchTimestampMs: Option[Long],
Expand Down Expand Up @@ -162,7 +165,8 @@ case class TransformWithStateExec(
useColumnFamilies = true
) {
case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId)
val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId,
keyEncoder)
assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED)
statefulProcessor.init(processorHandle, outputMode)
processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ import java.io.Serializable
import org.apache.commons.lang3.SerializationUtils

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.streaming.ValueState
Expand All @@ -38,10 +37,10 @@ import org.apache.spark.sql.types._
* @tparam K - data type of key
* @tparam S - data type of object that will be stored
*/
class ValueStateImpl[K, S](
class ValueStateImpl[S](
store: StateStore,
stateName: String,
keyEnc: Encoder[K]) extends ValueState[S] with Logging {
keyExprEnc: ExpressionEncoder[Any]) extends ValueState[S] with Logging {

// TODO: validate places that are trying to encode the key and check if we can eliminate/
// add caching for some of these calls.
Expand All @@ -52,10 +51,9 @@ class ValueStateImpl[K, S](
s"stateName=$stateName")
}

val exprEnc: ExpressionEncoder[K] = encoderFor(keyEnc)
val toRow = exprEnc.createSerializer()
val toRow = keyExprEnc.createSerializer()
val keyByteArr = toRow
.apply(keyOption.get.asInstanceOf[K]).asInstanceOf[UnsafeRow].getBytes()
.apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()

val schemaForKeyRow: StructType = new StructType().add("key", BinaryType)
val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration
import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.ValueState
Expand Down Expand Up @@ -87,10 +88,10 @@ class ValueStateSuite extends SharedSparkSession
test("Implicit key operations") {
tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID())
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])

val testState: ValueState[Long] = handle.getValueState[String, Long]("testState",
Encoders.STRING)
val testState: ValueState[Long] = handle.getValueState[Long]("testState")
assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isEmpty)
val ex = intercept[Exception] {
testState.update(123)
Expand Down Expand Up @@ -118,10 +119,10 @@ class ValueStateSuite extends SharedSparkSession
test("Value state operations for single instance") {
tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID())
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])

val testState: ValueState[Long] = handle.getValueState[String, Long]("testState",
Encoders.STRING)
val testState: ValueState[Long] = handle.getValueState[Long]("testState")
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
testState.update(123)
assert(testState.get() === 123)
Expand All @@ -144,12 +145,11 @@ class ValueStateSuite extends SharedSparkSession
test("Value state operations for multiple instances") {
tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID())
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])

val testState1: ValueState[Long] = handle.getValueState[String, Long]("testState1",
Encoders.STRING)
val testState2: ValueState[Long] = handle.getValueState[String, Long]("testState2",
Encoders.STRING)
val testState1: ValueState[Long] = handle.getValueState[Long]("testState1")
val testState2: ValueState[Long] = handle.getValueState[Long]("testState2")
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
testState1.update(123)
assert(testState1.get() === 123)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, Encoders, SaveMode}
import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider}
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -38,8 +38,7 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S
outputMode: OutputMode) : Unit = {
_processorHandle = handle
assert(handle.getQueryInfo().getBatchId >= 0)
_countState = _processorHandle.getValueState[String, Long]("countState",
Encoders.STRING)
_countState = _processorHandle.getValueState[Long]("countState")
}

override def handleInputRows(
Expand Down Expand Up @@ -67,8 +66,7 @@ class RunningCountStatefulProcessorWithError extends RunningCountStatefulProcess
inputRows: Iterator[String],
timerValues: TimerValues): Iterator[(String, String)] = {
// Trying to create value state here should fail
_tempState = _processorHandle.getValueState[String, Long]("tempState",
Encoders.STRING)
_tempState = _processorHandle.getValueState[Long]("tempState")
Iterator.empty
}
}
Expand Down

0 comments on commit e610d1d

Please sign in to comment.