Skip to content

Commit

Permalink
returning pdf for initial state processing
Browse files Browse the repository at this point in the history
  • Loading branch information
jingz-db committed Sep 10, 2024
1 parent a6f48f3 commit fe5a9eb
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 24 deletions.
4 changes: 1 addition & 3 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,11 +509,9 @@ def transformWithStateUDF(

# only process initial state if first batch
batch_id = statefulProcessorApiClient.get_batch_id()
"""
if batch_id == 0 and initialState is not None:
if batch_id == 0:
initial_state = statefulProcessorApiClient.get_initial_state(key)
statefulProcessor.handleInitialState(key, initial_state)
"""

statefulProcessorApiClient.set_implicit_key(key)
result = statefulProcessor.handleInputRows(key, inputRows)
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/sql/streaming/stateful_processor_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def get_initial_state(self, key: Tuple) -> "PandasDataFrameLike":
from pandas import DataFrame
import pyspark.sql.streaming.StateMessage_pb2 as stateMessage

bytes = self._stateful_processor_api_client._serialize_to_bytes(self.key_schema, key)
bytes = self._serialize_to_bytes(self.key_schema, key)

get_initial_state = stateMessage.GetInitialState(value=bytes)
request = stateMessage.UtilsCallCommand(getInitialState=get_initial_state)
Expand All @@ -159,12 +159,12 @@ def get_initial_state(self, key: Tuple) -> "PandasDataFrameLike":
status = response_message[0]
if status == 1:
DataFrame()
if status == 0:
iterator = self._stateful_processor_api_client._read_arrow_state()
elif status == 0:
iterator = self._read_arrow_state()
batch = next(iterator)
return batch.to_pandas()
else:
raise StopIteration()
raise PySparkRuntimeError(f"Error getting initial state: " f"{response_message[1]}")

def _send_proto_message(self, message: bytes) -> None:
# Writing zero here to indicate message version. This allows us to evolve the message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
yield pd.DataFrame({"id": key, "countAsString": str(count)})

def handleInitialState(self, key, initialState) -> None:
pass
raise Exception(f"I am inside handleInitialState, init state: {initialState.get('sth')}")

def close(self) -> None:
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,10 +511,17 @@ class RelationalGroupedDataset protected[sql](
val leftChild = df.logicalPlan
val rightChild = initialState.df.logicalPlan

/*
val left = df.sparkSession.sessionState.executePlan(
Project(groupingAttrs ++ leftChild.output, leftChild)).analyzed
val right = initialState.df.sparkSession.sessionState.executePlan(
Project(initGroupingAttrs ++ rightChild.output, rightChild)).analyzed
Project(initGroupingAttrs ++ rightChild.output, rightChild)).analyzed */

val left = df.sparkSession.sessionState.executePlan(
Project(leftChild.output, leftChild)).analyzed
val right = initialState.df.sparkSession.sessionState.executePlan(
Project(rightChild.output, rightChild)).analyzed


TransformWithStateInPandas(
func.expr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ case class TransformWithStateInPandasExec(
jobArtifactUUID,
groupingKeySchema,
hasInitialState,
initialStateGroupingAttrs,
initialStateSchema,
initStateIterator
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import org.apache.spark.TaskContext
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.TransformWithStateInPandasPythonRunner.{InType, OutType}
import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleImpl
Expand All @@ -53,7 +52,6 @@ class TransformWithStateInPandasPythonRunner(
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
hasInitialState: Boolean,
initialStateGroupingAttrs: Seq[Attribute],
initialStateSchema: StructType,
initialStateDataIterator: Iterator[(InternalRow, Iterator[InternalRow])])
extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, argOffsets, jobArtifactUUID)
Expand Down Expand Up @@ -111,7 +109,6 @@ class TransformWithStateInPandasPythonRunner(
groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes,
arrowMaxRecordsPerBatch,
hasInitialState = hasInitialState,
initialStateGroupingAttrs = initialStateGroupingAttrs,
initialStateSchema = initialStateSchema,
initialStateDataIterator = initialStateDataIterator))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.apache.spark.sql.{Encoders, Row}
import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState}
import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, UtilsCallCommand, ValueStateCall}
import org.apache.spark.sql.execution.streaming.state.StateStoreErrors
Expand Down Expand Up @@ -61,7 +60,6 @@ class TransformWithStateInPandasStateServer(
valueStateMapForTest: mutable.HashMap[String,
(ValueState[Row], StructType, ExpressionEncoder.Deserializer[Row])] = null,
hasInitialState: Boolean,
initialStateGroupingAttrs: Seq[Attribute],
initialStateSchema: StructType,
initialStateDataIterator: Iterator[(InternalRow, Iterator[InternalRow])])
extends Runnable with Logging {
Expand Down Expand Up @@ -187,14 +185,11 @@ class TransformWithStateInPandasStateServer(
sendResponse(0, null, ByteString.copyFromUtf8(valueStr))

case UtilsCallCommand.MethodCase.GETINITIALSTATE =>
val keyBytes = message.getGetInitialState.getValue.toByteArray
// The key row is serialized as a byte array, we need to convert it back to a Row
val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, keyRowDeserializer)
if (!initialStateDataIterator.isEmpty || !initialStateDataIterator.hasNext) {
if (!hasInitialState || initialStateKeyToRowMap.isEmpty) {
sendResponse(1)
} else {
sendResponse(0)
// TODO check if has initial state

outputStream.flush()
val arrowStreamWriter = {
val outputSchema = initialStateSchema
Expand All @@ -207,21 +202,22 @@ class TransformWithStateInPandasStateServer(
arrowTransformWithStateInPandasMaxRecordsPerBatch)
}

val keyBytes = message.getGetInitialState.getValue.toByteArray
// The key row is serialized as a byte array, we need to convert it back to a Row
val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, keyRowDeserializer)
val groupingKeyToInternalRow =
ExpressionEncoder(groupingKeySchema).createSerializer().apply(keyRow)

throw new Exception(s"I am inside initial state processing, grouping key row" +
s"received: ${keyRow}")

val iter = initialStateKeyToRowMap
.get(groupingKeyToInternalRow).getOrElse(Iterator.empty)

var seenInitStateOnKey = false
while (iter.hasNext) {
if (seenInitStateOnKey) {
throw StateStoreErrors.cannotReInitializeStateOnKey(
keyRowDeserializer.apply(groupingKeyToInternalRow).toString)
} else {
val initialStateRow = iter.next()
seenInitStateOnKey = true
arrowStreamWriter.writeRow(initialStateRow)
}
}
Expand Down

0 comments on commit fe5a9eb

Please sign in to comment.