Skip to content

Commit 85c9fd1

Browse files
committed
[SPARK-53867][PYTHON] Limit Arrow batch sizes in SQL_GROUPED_AGG_ARROW_UDF
### What changes were proposed in this pull request? Limit Arrow batch sizes in SQL_GROUPED_AGG_ARROW_UDF ### Why are the changes needed? to avoid OOM in the JVM side ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? added UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #52605 from zhengruifeng/limit_grouped_agg_arrow. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 37564db commit 85c9fd1

File tree

7 files changed

+90
-30
lines changed

7 files changed

+90
-30
lines changed

python/pyspark/sql/pandas/serializers.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,49 @@ def __repr__(self):
11431143
return "GroupArrowUDFSerializer"
11441144

11451145

1146+
class AggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
1147+
def __init__(
1148+
self,
1149+
timezone,
1150+
safecheck,
1151+
assign_cols_by_name,
1152+
arrow_cast,
1153+
):
1154+
super().__init__(
1155+
timezone=timezone,
1156+
safecheck=safecheck,
1157+
assign_cols_by_name=False,
1158+
arrow_cast=True,
1159+
)
1160+
self._timezone = timezone
1161+
self._safecheck = safecheck
1162+
self._assign_cols_by_name = assign_cols_by_name
1163+
self._arrow_cast = arrow_cast
1164+
1165+
def load_stream(self, stream):
1166+
"""
1167+
Flatten the struct into Arrow's record batches.
1168+
"""
1169+
import pyarrow as pa
1170+
1171+
dataframes_in_group = None
1172+
1173+
while dataframes_in_group is None or dataframes_in_group > 0:
1174+
dataframes_in_group = read_int(stream)
1175+
1176+
if dataframes_in_group == 1:
1177+
yield pa.concat_batches(ArrowStreamSerializer.load_stream(self, stream))
1178+
1179+
elif dataframes_in_group != 0:
1180+
raise PySparkValueError(
1181+
errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
1182+
messageParameters={"dataframes_in_group": str(dataframes_in_group)},
1183+
)
1184+
1185+
def __repr__(self):
1186+
return "AggArrowUDFSerializer"
1187+
1188+
11461189
class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
11471190
def __init__(
11481191
self,

python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def test_arrow_batch_slicing(self):
360360
df = df.withColumns(cols)
361361

362362
def min_max_v(table):
363+
assert len(table) == 10000000 / 2, len(table)
363364
return pa.Table.from_pydict(
364365
{
365366
"key": [table.column("key")[0].as_py()],
@@ -372,8 +373,7 @@ def min_max_v(table):
372373
df.groupby("key").agg(sf.min("v").alias("min"), sf.max("v").alias("max")).sort("key")
373374
).collect()
374375

375-
int_max = 2147483647
376-
for maxRecords, maxBytes in [(1000, int_max), (0, 1048576), (1000, 1048576)]:
376+
for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
377377
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
378378
with self.sql_conf(
379379
{

python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,34 @@ def arrow_lit_1() -> int:
986986
)
987987
self.assertEqual(expected2.collect(), result2.collect())
988988

989+
def test_arrow_batch_slicing(self):
990+
import pyarrow as pa
991+
992+
df = self.spark.range(10000000).select(
993+
(sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
994+
)
995+
996+
@arrow_udf("long", ArrowUDFType.GROUPED_AGG)
997+
def arrow_max(v):
998+
assert len(v) == 10000000 / 2, len(v)
999+
return pa.compute.max(v)
1000+
1001+
expected = (df.groupby("key").agg(sf.max("v").alias("res")).sort("key")).collect()
1002+
1003+
for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
1004+
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
1005+
with self.sql_conf(
1006+
{
1007+
"spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords,
1008+
"spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
1009+
}
1010+
):
1011+
result = (
1012+
df.groupBy("key").agg(arrow_max("v").alias("res")).sort("key")
1013+
).collect()
1014+
1015+
self.assertEqual(expected, result)
1016+
9891017

9901018
class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase):
9911019
pass

python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,7 @@ def test_arrow_batch_slicing(self):
954954
df = df.withColumns(cols)
955955

956956
def min_max_v(pdf):
957+
assert len(pdf) == 10000000 / 2, len(pdf)
957958
return pd.DataFrame(
958959
{
959960
"key": [pdf.key.iloc[0]],
@@ -966,8 +967,7 @@ def min_max_v(pdf):
966967
df.groupby("key").agg(sf.min("v").alias("min"), sf.max("v").alias("max")).sort("key")
967968
).collect()
968969

969-
int_max = 2147483647
970-
for maxRecords, maxBytes in [(1000, int_max), (0, 1048576), (1000, 1048576)]:
970+
for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
971971
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
972972
with self.sql_conf(
973973
{

python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,7 @@ def test_arrow_batch_slicing(self):
801801

802802
@pandas_udf("long", PandasUDFType.GROUPED_AGG)
803803
def pandas_max(v):
804+
assert len(v) == 10000000 / 2, len(v)
804805
return v.max()
805806

806807
expected = (df.groupby("key").agg(sf.max("v").alias("res")).sort("key")).collect()

python/pyspark/worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from pyspark.sql.conversion import LocalDataToArrowConversion, ArrowTableToRowsConversion
5353
from pyspark.sql.functions import SkipRestOfInputTableException
5454
from pyspark.sql.pandas.serializers import (
55+
AggArrowUDFSerializer,
5556
ArrowStreamPandasUDFSerializer,
5657
ArrowStreamPandasUDTFSerializer,
5758
GroupPandasUDFSerializer,
@@ -2611,6 +2612,8 @@ def read_udfs(pickleSer, infile, eval_type):
26112612
or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
26122613
):
26132614
ser = GroupArrowUDFSerializer(_assign_cols_by_name)
2615+
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF:
2616+
ser = AggArrowUDFSerializer(timezone, True, _assign_cols_by_name, True)
26142617
elif eval_type in (
26152618
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
26162619
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
@@ -2700,7 +2703,6 @@ def read_udfs(pickleSer, infile, eval_type):
27002703
elif eval_type in (
27012704
PythonEvalType.SQL_SCALAR_ARROW_UDF,
27022705
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
2703-
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
27042706
PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
27052707
):
27062708
# Arrow cast and safe check are always enabled

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -180,31 +180,17 @@ case class ArrowAggregatePythonExec(
180180
rows
181181
}
182182

183-
val runner = if (evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) {
184-
new ArrowPythonWithNamedArgumentRunner(
185-
pyFuncs,
186-
evalType,
187-
argMetas,
188-
aggInputSchema,
189-
sessionLocalTimeZone,
190-
largeVarTypes,
191-
pythonRunnerConf,
192-
pythonMetrics,
193-
jobArtifactUUID,
194-
conf.pythonUDFProfiler) with GroupedPythonArrowInput
195-
} else {
196-
new ArrowPythonWithNamedArgumentRunner(
197-
pyFuncs,
198-
evalType,
199-
argMetas,
200-
aggInputSchema,
201-
sessionLocalTimeZone,
202-
largeVarTypes,
203-
pythonRunnerConf,
204-
pythonMetrics,
205-
jobArtifactUUID,
206-
conf.pythonUDFProfiler)
207-
}
183+
val runner = new ArrowPythonWithNamedArgumentRunner(
184+
pyFuncs,
185+
evalType,
186+
argMetas,
187+
aggInputSchema,
188+
sessionLocalTimeZone,
189+
largeVarTypes,
190+
pythonRunnerConf,
191+
pythonMetrics,
192+
jobArtifactUUID,
193+
conf.pythonUDFProfiler) with GroupedPythonArrowInput
208194

209195
val columnarBatchIter = runner.compute(projectedRowIter, context.partitionId(), context)
210196

0 commit comments

Comments
 (0)