Skip to content

Commit

Permalink
[SPARK-48536][PYTHON][CONNECT] Cache user specified schema in applyIn…
Browse files Browse the repository at this point in the history
…Pandas and applyInArrow

### What changes were proposed in this pull request?
Cache user specified schema in applyInPandas and applyInArrow

### Why are the changes needed?
to avoid extra RPCs

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #46877 from zhengruifeng/cache_schema_apply_in_x.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jun 5, 2024
1 parent 88b8dc2 commit 34ac7de
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 5 deletions.
20 changes: 16 additions & 4 deletions python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def applyInPandas(
evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
)

return DataFrame(
res = DataFrame(
plan.GroupMap(
child=self._df._plan,
grouping_cols=self._grouping_cols,
Expand All @@ -310,6 +310,9 @@ def applyInPandas(
),
session=self._df._session,
)
if isinstance(schema, StructType):
res._cached_schema = schema
return res

applyInPandas.__doc__ = PySparkGroupedData.applyInPandas.__doc__

Expand Down Expand Up @@ -370,7 +373,7 @@ def applyInArrow(
evalType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
)

return DataFrame(
res = DataFrame(
plan.GroupMap(
child=self._df._plan,
grouping_cols=self._grouping_cols,
Expand All @@ -379,6 +382,9 @@ def applyInArrow(
),
session=self._df._session,
)
if isinstance(schema, StructType):
res._cached_schema = schema
return res

applyInArrow.__doc__ = PySparkGroupedData.applyInArrow.__doc__

Expand Down Expand Up @@ -410,7 +416,7 @@ def applyInPandas(
evalType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
)

return DataFrame(
res = DataFrame(
plan.CoGroupMap(
input=self._gd1._df._plan,
input_grouping_cols=self._gd1._grouping_cols,
Expand All @@ -420,6 +426,9 @@ def applyInPandas(
),
session=self._gd1._df._session,
)
if isinstance(schema, StructType):
res._cached_schema = schema
return res

applyInPandas.__doc__ = PySparkPandasCogroupedOps.applyInPandas.__doc__

Expand All @@ -436,7 +445,7 @@ def applyInArrow(
evalType=PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
)

return DataFrame(
res = DataFrame(
plan.CoGroupMap(
input=self._gd1._df._plan,
input_grouping_cols=self._gd1._grouping_cols,
Expand All @@ -446,6 +455,9 @@ def applyInArrow(
),
session=self._gd1._df._session,
)
if isinstance(schema, StructType):
res._cached_schema = schema
return res

applyInArrow.__doc__ = PySparkPandasCogroupedOps.applyInArrow.__doc__

Expand Down
145 changes: 144 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import unittest

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DoubleType
from pyspark.sql.utils import is_remote

from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
Expand All @@ -30,6 +30,7 @@

if have_pyarrow:
import pyarrow as pa
import pyarrow.compute as pc

if have_pandas:
import pandas as pd
Expand Down Expand Up @@ -127,6 +128,148 @@ def func(iterator):
self.assertEqual(cdf1.schema, sdf1.schema)
self.assertEqual(cdf1.collect(), sdf1.collect())

@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
def test_cached_schema_group_apply_in_pandas(self):
data = [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)]
cdf = self.connect.createDataFrame(data, ("id", "v"))
sdf = self.spark.createDataFrame(data, ("id", "v"))

def normalize(pdf):
v = pdf.v
return pdf.assign(v=(v - v.mean()) / v.std())

schema = StructType(
[
StructField("id", LongType(), True),
StructField("v", DoubleType(), True),
]
)

with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
self.assertTrue(is_remote())
cdf1 = cdf.groupby("id").applyInPandas(normalize, schema)
self.assertEqual(cdf1._cached_schema, schema)

with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
self.assertFalse(is_remote())
sdf1 = sdf.groupby("id").applyInPandas(normalize, schema)

self.assertEqual(cdf1.schema, sdf1.schema)
self.assertEqual(cdf1.collect(), sdf1.collect())

@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
def test_cached_schema_group_apply_in_arrow(self):
data = [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)]
cdf = self.connect.createDataFrame(data, ("id", "v"))
sdf = self.spark.createDataFrame(data, ("id", "v"))

def normalize(table):
v = table.column("v")
norm = pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v, ddof=1))
return table.set_column(1, "v", norm)

schema = StructType(
[
StructField("id", LongType(), True),
StructField("v", DoubleType(), True),
]
)

with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
self.assertTrue(is_remote())
cdf1 = cdf.groupby("id").applyInArrow(normalize, schema)
self.assertEqual(cdf1._cached_schema, schema)

with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
self.assertFalse(is_remote())
sdf1 = sdf.groupby("id").applyInArrow(normalize, schema)

self.assertEqual(cdf1.schema, sdf1.schema)
self.assertEqual(cdf1.collect(), sdf1.collect())

@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
def test_cached_schema_cogroup_apply_in_pandas(self):
data1 = [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)]
data2 = [(20000101, 1, "x"), (20000101, 2, "y")]

cdf1 = self.connect.createDataFrame(data1, ("time", "id", "v1"))
sdf1 = self.spark.createDataFrame(data1, ("time", "id", "v1"))
cdf2 = self.connect.createDataFrame(data2, ("time", "id", "v2"))
sdf2 = self.spark.createDataFrame(data2, ("time", "id", "v2"))

def asof_join(left, right):
return pd.merge_asof(left, right, on="time", by="id")

schema = StructType(
[
StructField("time", IntegerType(), True),
StructField("id", IntegerType(), True),
StructField("v1", DoubleType(), True),
StructField("v2", StringType(), True),
]
)

with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
self.assertTrue(is_remote())
cdf3 = cdf1.groupby("id").cogroup(cdf2.groupby("id")).applyInPandas(asof_join, schema)
self.assertEqual(cdf3._cached_schema, schema)

with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
self.assertFalse(is_remote())
sdf3 = sdf1.groupby("id").cogroup(sdf2.groupby("id")).applyInPandas(asof_join, schema)

self.assertEqual(cdf3.schema, sdf3.schema)
self.assertEqual(cdf3.collect(), sdf3.collect())

@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
def test_cached_schema_cogroup_apply_in_arrow(self):
data1 = [(1, 1.0), (2, 2.0), (1, 3.0), (2, 4.0)]
data2 = [(1, "x"), (2, "y")]

cdf1 = self.connect.createDataFrame(data1, ("id", "v1"))
sdf1 = self.spark.createDataFrame(data1, ("id", "v1"))
cdf2 = self.connect.createDataFrame(data2, ("id", "v2"))
sdf2 = self.spark.createDataFrame(data2, ("id", "v2"))

def summarize(left, right):
return pa.Table.from_pydict(
{
"left": [left.num_rows],
"right": [right.num_rows],
}
)

schema = StructType(
[
StructField("left", LongType(), True),
StructField("right", LongType(), True),
]
)

with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
self.assertTrue(is_remote())
cdf3 = cdf1.groupby("id").cogroup(cdf2.groupby("id")).applyInArrow(summarize, schema)
self.assertEqual(cdf3._cached_schema, schema)

with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
self.assertFalse(is_remote())
sdf3 = sdf1.groupby("id").cogroup(sdf2.groupby("id")).applyInArrow(summarize, schema)

self.assertEqual(cdf3.schema, sdf3.schema)
self.assertEqual(cdf3.collect(), sdf3.collect())


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_dataframe_property import * # noqa: F401
Expand Down

0 comments on commit 34ac7de

Please sign in to comment.