Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,8 @@ def __hash__(self):
"pyspark.sql.tests.test_session",
"pyspark.sql.tests.test_subquery",
"pyspark.sql.tests.test_types",
"pyspark.sql.tests.test_coercion",
"pyspark.sql.tests.test_arrow_udf_coercion",
"pyspark.sql.tests.test_geographytype",
"pyspark.sql.tests.test_geometrytype",
"pyspark.sql.tests.test_udf",
Expand Down
22 changes: 19 additions & 3 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,8 @@ class ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer):
This has performance penalties.
binary_as_bytes : bool
If True, binary type will be deserialized as bytes, otherwise as bytearray.
coercion_policy : CoercionPolicy
The coercion policy for type conversion (PERMISSIVE, WARN, or STRICT).
"""

def __init__(
Expand All @@ -814,6 +816,7 @@ def __init__(
input_types,
int_to_decimal_coercion_enabled,
binary_as_bytes,
coercion_policy,
):
super().__init__(
timezone=timezone,
Expand All @@ -824,6 +827,7 @@ def __init__(
self._input_types = input_types
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
self._binary_as_bytes = binary_as_bytes
self._coercion_policy = coercion_policy

def load_stream(self, stream):
"""
Expand Down Expand Up @@ -874,18 +878,30 @@ def dump_stream(self, iterator, stream):
Result of writing the Arrow stream via ArrowStreamArrowUDFSerializer dump_stream
"""
import pyarrow as pa
from pyspark.sql.types import CoercionPolicy

coercion_policy = self._coercion_policy

def create_array(results, arrow_type, spark_type):
conv = LocalDataToArrowConversion._create_converter(
spark_type,
none_on_identity=True,
int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled,
)
converted = [conv(res) for res in results] if conv is not None else results

if (
coercion_policy == CoercionPolicy.PERMISSIVE
or coercion_policy == CoercionPolicy.WARN
) and spark_type.needsCoercion:
results = [spark_type.coerce(v, coercion_policy) for v in results]

if conv is not None:
results = [conv(v) for v in results]

try:
return pa.array(converted, type=arrow_type)
return pa.array(results, type=arrow_type)
except pa.lib.ArrowInvalid:
return pa.array(converted).cast(target_type=arrow_type, safe=self._safecheck)
return pa.array(results).cast(target_type=arrow_type, safe=self._safecheck)

def py_to_batch():
for packed in iterator:
Expand Down
81 changes: 45 additions & 36 deletions python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,45 +160,54 @@ def test_nested_array_input(self):
)

def test_type_coercion_string_to_numeric(self):
df_int_value = self.spark.createDataFrame(["1", "2"], schema="string")
df_floating_value = self.spark.createDataFrame(["1.1", "2.2"], schema="string")

int_ddl_types = ["tinyint", "smallint", "int", "bigint"]
floating_ddl_types = ["double", "float"]

for ddl_type in int_ddl_types:
# df_int_value
res = df_int_value.select(udf(lambda x: x, ddl_type)("value").alias("res"))
self.assertEqual(res.collect(), [Row(res=1), Row(res=2)])
self.assertEqual(res.dtypes[0][1], ddl_type)

floating_results = [
[Row(res=1.1), Row(res=2.2)],
[Row(res=1.100000023841858), Row(res=2.200000047683716)],
]
for ddl_type, floating_res in zip(floating_ddl_types, floating_results):
# df_int_value
res = df_int_value.select(udf(lambda x: x, ddl_type)("value").alias("res"))
self.assertEqual(res.collect(), [Row(res=1.0), Row(res=2.0)])
self.assertEqual(res.dtypes[0][1], ddl_type)
# df_floating_value
res = df_floating_value.select(udf(lambda x: x, ddl_type)("value").alias("res"))
self.assertEqual(res.collect(), floating_res)
self.assertEqual(res.dtypes[0][1], ddl_type)

# invalid
with self.assertRaises(PythonException):
df_floating_value.select(udf(lambda x: x, "int")("value").alias("res")).collect()

with self.assertRaises(PythonException):
df_int_value.select(udf(lambda x: x, "decimal")("value").alias("res")).collect()

with self.assertRaises(PythonException):
df_floating_value.select(udf(lambda x: x, "decimal")("value").alias("res")).collect()
# Use STRICT policy to preserve original Arrow coercion behavior
with self.sql_conf({"spark.sql.execution.pythonUDF.coercion.policy": "strict"}):
df_int_value = self.spark.createDataFrame(["1", "2"], schema="string")
df_floating_value = self.spark.createDataFrame(["1.1", "2.2"], schema="string")

int_ddl_types = ["tinyint", "smallint", "int", "bigint"]
floating_ddl_types = ["double", "float"]

for ddl_type in int_ddl_types:
# df_int_value
res = df_int_value.select(udf(lambda x: x, ddl_type)("value").alias("res"))
self.assertEqual(res.collect(), [Row(res=1), Row(res=2)])
self.assertEqual(res.dtypes[0][1], ddl_type)

floating_results = [
[Row(res=1.1), Row(res=2.2)],
[Row(res=1.100000023841858), Row(res=2.200000047683716)],
]
for ddl_type, floating_res in zip(floating_ddl_types, floating_results):
# df_int_value
res = df_int_value.select(udf(lambda x: x, ddl_type)("value").alias("res"))
self.assertEqual(res.collect(), [Row(res=1.0), Row(res=2.0)])
self.assertEqual(res.dtypes[0][1], ddl_type)
# df_floating_value
res = df_floating_value.select(udf(lambda x: x, ddl_type)("value").alias("res"))
self.assertEqual(res.collect(), floating_res)
self.assertEqual(res.dtypes[0][1], ddl_type)

# invalid
with self.assertRaises(PythonException):
df_floating_value.select(udf(lambda x: x, "int")("value").alias("res")).collect()

with self.assertRaises(PythonException):
df_int_value.select(udf(lambda x: x, "decimal")("value").alias("res")).collect()

with self.assertRaises(PythonException):
df_floating_value.select(
udf(lambda x: x, "decimal")("value").alias("res")
).collect()

def test_arrow_udf_int_to_decimal_coercion(self):
# Use STRICT policy to let Arrow handle type coercion natively,
# so the intToDecimalCoercionEnabled flag can control the behavior
with self.sql_conf(
{"spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled": False}
{
"spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled": False,
"spark.sql.execution.pythonUDF.coercion.policy": "strict",
}
):
df = self.spark.range(0, 3)

Expand Down
Loading