Skip to content

Commit

Permalink
need to return iterator & unit tests, mostly done
Browse files Browse the repository at this point in the history
  • Loading branch information
jingz-db committed Sep 10, 2024
1 parent fe5a9eb commit 47067f8
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 209 deletions.
10 changes: 6 additions & 4 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,12 +508,14 @@ def transformWithStateUDF(
)

# only process initial state if first batch
batch_id = statefulProcessorApiClient.get_batch_id()
if batch_id == 0:
is_first_batch = statefulProcessorApiClient.is_first_batch()
statefulProcessorApiClient.set_implicit_key(key)
if is_first_batch:
initial_state = statefulProcessorApiClient.get_initial_state(key)
statefulProcessor.handleInitialState(key, initial_state)
# if user did not provide initial state df, initial_state will be None
if initial_state is not None:
statefulProcessor.handleInitialState(key, initial_state)

statefulProcessorApiClient.set_implicit_key(key)
result = statefulProcessor.handleInputRows(key, inputRows)

return result
Expand Down
16 changes: 8 additions & 8 deletions python/pyspark/sql/streaming/StateMessage_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions python/pyspark/sql/streaming/StateMessage_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ class Get(_message.Message):
__slots__ = []
def __init__(self) -> None: ...

class GetBatchId(_message.Message):
__slots__ = []
def __init__(self) -> None: ...

class GetInitialState(_message.Message):
__slots__ = ["value"]
VALUE_FIELD_NUMBER: ClassVar[int]
Expand All @@ -39,6 +35,10 @@ class ImplicitGroupingKeyRequest(_message.Message):
setImplicitKey: SetImplicitKey
def __init__(self, setImplicitKey: Optional[Union[SetImplicitKey, Mapping]] = ..., removeImplicitKey: Optional[Union[RemoveImplicitKey, Mapping]] = ...) -> None: ...

class IsFirstBatch(_message.Message):
__slots__ = []
def __init__(self) -> None: ...

class RemoveImplicitKey(_message.Message):
__slots__ = []
def __init__(self) -> None: ...
Expand Down Expand Up @@ -106,12 +106,12 @@ class StatefulProcessorCall(_message.Message):
def __init__(self, setHandleState: Optional[Union[SetHandleState, Mapping]] = ..., getValueState: Optional[Union[StateCallCommand, Mapping]] = ..., getListState: Optional[Union[StateCallCommand, Mapping]] = ..., getMapState: Optional[Union[StateCallCommand, Mapping]] = ..., utilsCall: Optional[Union[UtilsCallCommand, Mapping]] = ...) -> None: ...

class UtilsCallCommand(_message.Message):
__slots__ = ["getBatchId", "getInitialState"]
GETBATCHID_FIELD_NUMBER: ClassVar[int]
__slots__ = ["getInitialState", "isFirstBatch"]
GETINITIALSTATE_FIELD_NUMBER: ClassVar[int]
getBatchId: GetBatchId
ISFIRSTBATCH_FIELD_NUMBER: ClassVar[int]
getInitialState: GetInitialState
def __init__(self, getBatchId: Optional[Union[GetBatchId, Mapping]] = ..., getInitialState: Optional[Union[GetInitialState, Mapping]] = ...) -> None: ...
isFirstBatch: IsFirstBatch
def __init__(self, isFirstBatch: Optional[Union[IsFirstBatch, Mapping]] = ..., getInitialState: Optional[Union[GetInitialState, Mapping]] = ...) -> None: ...

class ValueStateCall(_message.Message):
__slots__ = ["clear", "exists", "get", "stateName", "valueStateUpdate"]
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/streaming/stateful_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def handleInitialState(
self, key: Any, initialState: "PandasDataFrameLike"
) -> None:
"""
Optional to implement
Optional to implement. Will act as no-op if not defined or no initial state input. Function
invoked only once at the first batch. Allow for users to perform initial state processing.
"""
pass
18 changes: 8 additions & 10 deletions python/pyspark/sql/streaming/stateful_processor_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,26 +122,24 @@ def get_value_state(self, state_name: str, schema: Union[StructType, str]) -> No
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error initializing value state: " f"{response_message[1]}")

def get_batch_id(self) -> int:
def is_first_batch(self) -> bool:
import pyspark.sql.streaming.StateMessage_pb2 as stateMessage

get_batch_id = stateMessage.GetBatchId()
request = stateMessage.UtilsCallCommand(getBatchId=get_batch_id)
is_first_batch = stateMessage.IsFirstBatch()
request = stateMessage.UtilsCallCommand(isFirstBatch=is_first_batch)
stateful_processor_call = stateMessage.StatefulProcessorCall(utilsCall=request)
message = stateMessage.StateRequest(statefulProcessorCall=stateful_processor_call)

self._send_proto_message(message.SerializeToString())
response_message = self._receive_proto_message()
status = response_message[0]
if status != 0:
if status == 0:
return True
elif status == 1:
return False
else:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error getting batch id: " f"{response_message[1]}")
else:
if len(response_message[2]) == 0:
return -1
# TODO: can we simply parse from utf8 string here?
batch_id = int(response_message[2])
return batch_id

def get_initial_state(self, key: Tuple) -> "PandasDataFrameLike":
from pandas import DataFrame
Expand Down
65 changes: 52 additions & 13 deletions python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def _prepare_test_resource2(self, input_path):
fw.write("1, 246\n")
fw.write("1, 6\n")

def _prepare_test_resource3(self, input_path):
with open(input_path + "/text-test2.txt", "w") as fw:
fw.write("3, 12\n")
fw.write("0, 67\n")

def _build_test_df(self, input_path):
df = self.spark.readStream.format("text").option("maxFilesPerTrigger", 1).load(input_path)
df_split = df.withColumn("split_values", split(df["value"], ","))
Expand Down Expand Up @@ -213,12 +218,11 @@ def test_transform_with_state_in_pandas_query_restarts(self):
"""

def _test_transform_with_state_in_pandas_basic(
self, stateful_processor, check_results, single_batch=False
self, stateful_processor, check_results
):
input_path = tempfile.mkdtemp()
self._prepare_test_resource1(input_path)
if not single_batch:
self._prepare_test_resource2(input_path)
self._prepare_test_resource3(input_path)

df = self._build_test_df(input_path)

Expand All @@ -229,15 +233,15 @@ def _test_transform_with_state_in_pandas_basic(
output_schema = StructType(
[
StructField("id", StringType(), True),
StructField("countAsString", StringType(), True),
StructField("value", StringType(), True),
]
)

from pyspark.sql import GroupedData

data = [("0", 789), ("1", 987)]
data = [("0", 789), ("3", 987)]
initial_state =\
self.spark.createDataFrame(data, "id string, sth int")\
self.spark.createDataFrame(data, "id string, initVal int")\
.groupBy("id")

q = (
Expand All @@ -264,20 +268,58 @@ def _test_transform_with_state_in_pandas_basic(
def test_transform_with_state_in_pandas_basic(self):
def check_results(batch_df, batch_id):
if batch_id == 0:
# for key 0, initial state was processed and it was only processed once;
# for key 1, it did not appear in the initial state df;
# for key 3, it did not appear in the first batch of input keys
# so it won't be emitted
assert set(batch_df.sort("id").collect()) == {
Row(id="0", countAsString="2"),
Row(id="1", countAsString="2"),
Row(id="0", value=str(789 + 123 + 46)),
Row(id="1", value=str(146 + 346)),
}
else:
# for key 0, verify initial state was only processed once in the first batch;
# for key 3, verify init state was now processed
assert set(batch_df.sort("id").collect()) == {
Row(id="0", countAsString="3"),
Row(id="1", countAsString="2"),
Row(id="0", value=str(789 + 123 + 46 + 67)),
# Row(id="3", value=str(987 + 12)),
Row(id="3", value=str(12)),
}

self._test_transform_with_state_in_pandas_basic(SimpleStatefulProcessorWithInitialState(), check_results)


class SimpleStatefulProcessorWithInitialState(StatefulProcessor):

def init(self, handle: StatefulProcessorHandle) -> None:
state_schema = StructType([StructField("value", IntegerType(), True)])
self.value_state = handle.getValueState("value_state", state_schema)

def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
exists = self.value_state.exists()
if exists:
value_row = self.value_state.get()
existing_value = value_row[0]
else:
existing_value = 0

accumulated_value = existing_value

for pdf in rows:
value = pdf['temperature'].astype(int).sum()
accumulated_value += value

self.value_state.update((accumulated_value,))

yield pd.DataFrame({"id": key, "value": str(accumulated_value)})

def handleInitialState(self, key, initialState) -> None:
initVal = initialState.at[0, "initVal"]
self.value_state.update((initVal,))

def close(self) -> None:
pass

class SimpleStatefulProcessor(StatefulProcessor):
dict = {0: {"0": 1, "1": 2}, 1: {"0": 4, "1": 3}}
batch_id = 0

Expand Down Expand Up @@ -307,9 +349,6 @@ def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
self.num_violations_state.update((updated_violations,))
yield pd.DataFrame({"id": key, "countAsString": str(count)})

def handleInitialState(self, key, initialState) -> None:
raise Exception(f"I am inside handleInitialState, init state: {initialState.get('sth')}")

def close(self) -> None:
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ message SetHandleState {

message UtilsCallCommand {
oneof method {
GetBatchId getBatchId = 1;
IsFirstBatch isFirstBatch = 1;
GetInitialState getInitialState = 2;
}
}

message GetBatchId {
message IsFirstBatch {
}

message GetInitialState {
Expand Down
Loading

0 comments on commit 47067f8

Please sign in to comment.