diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index fd3fc8bdaa178..75307a60649a3 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -580,6 +580,8 @@ def __hash__(self): "pyspark.sql.tests.test_session", "pyspark.sql.tests.test_subquery", "pyspark.sql.tests.test_types", + "pyspark.sql.tests.test_coercion", + "pyspark.sql.tests.test_arrow_udf_coercion", "pyspark.sql.tests.test_geographytype", "pyspark.sql.tests.test_geometrytype", "pyspark.sql.tests.test_udf", diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 3b4e2677933f4..6b3ca3e4c52bb 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -805,6 +805,8 @@ class ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer): This has performance penalties. binary_as_bytes : bool If True, binary type will be deserialized as bytes, otherwise as bytearray. + coercion_policy : CoercionPolicy + The coercion policy for type conversion (PERMISSIVE, WARN, or STRICT). """ def __init__( @@ -814,6 +816,7 @@ def __init__( input_types, int_to_decimal_coercion_enabled, binary_as_bytes, + coercion_policy, ): super().__init__( timezone=timezone, @@ -824,6 +827,7 @@ def __init__( self._input_types = input_types self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled self._binary_as_bytes = binary_as_bytes + self._coercion_policy = coercion_policy def load_stream(self, stream): """ @@ -874,6 +878,9 @@ def dump_stream(self, iterator, stream): Result of writing the Arrow stream via ArrowStreamArrowUDFSerializer dump_stream """ import pyarrow as pa + from pyspark.sql.types import CoercionPolicy + + coercion_policy = self._coercion_policy def create_array(results, arrow_type, spark_type): conv = LocalDataToArrowConversion._create_converter( @@ -881,11 +888,20 @@ def create_array(results, arrow_type, spark_type): none_on_identity=True, int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, ) - converted = [conv(res) for res in results] if conv is not None else results + + if ( + coercion_policy == CoercionPolicy.PERMISSIVE + or coercion_policy == CoercionPolicy.WARN + ) and spark_type.needsCoercion: + results = [spark_type.coerce(v, coercion_policy) for v in results] + + if conv is not None: + results = [conv(v) for v in results] + try: - return pa.array(converted, type=arrow_type) + return pa.array(results, type=arrow_type) except pa.lib.ArrowInvalid: - return pa.array(converted).cast(target_type=arrow_type, safe=self._safecheck) + return pa.array(results).cast(target_type=arrow_type, safe=self._safecheck) def py_to_batch(): for packed in iterator: diff --git a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py index 84190f1b10749..ad5f1ca2e2b48 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py @@ -160,45 +160,54 @@ def test_nested_array_input(self): ) def test_type_coercion_string_to_numeric(self): - df_int_value = self.spark.createDataFrame(["1", "2"], schema="string") - df_floating_value = self.spark.createDataFrame(["1.1", "2.2"], schema="string") - - int_ddl_types = ["tinyint", "smallint", "int", "bigint"] - floating_ddl_types = ["double", "float"] - - for ddl_type in int_ddl_types: - # df_int_value - res = df_int_value.select(udf(lambda x: x, ddl_type)("value").alias("res")) - self.assertEqual(res.collect(), [Row(res=1), Row(res=2)]) - self.assertEqual(res.dtypes[0][1], ddl_type) - - floating_results = [ - [Row(res=1.1), Row(res=2.2)], - [Row(res=1.100000023841858), Row(res=2.200000047683716)], - ] - for ddl_type, floating_res in zip(floating_ddl_types, floating_results): - # df_int_value - res = df_int_value.select(udf(lambda x: x, ddl_type)("value").alias("res")) - self.assertEqual(res.collect(), [Row(res=1.0), Row(res=2.0)]) - self.assertEqual(res.dtypes[0][1], ddl_type) - # df_floating_value - res = df_floating_value.select(udf(lambda x: x, ddl_type)("value").alias("res")) - self.assertEqual(res.collect(), floating_res) - self.assertEqual(res.dtypes[0][1], ddl_type) - - # invalid - with self.assertRaises(PythonException): - df_floating_value.select(udf(lambda x: x, "int")("value").alias("res")).collect() - - with self.assertRaises(PythonException): - df_int_value.select(udf(lambda x: x, "decimal")("value").alias("res")).collect() - - with self.assertRaises(PythonException): - df_floating_value.select(udf(lambda x: x, "decimal")("value").alias("res")).collect() + # Use STRICT policy to preserve original Arrow coercion behavior + with self.sql_conf({"spark.sql.execution.pythonUDF.coercion.policy": "strict"}): + df_int_value = self.spark.createDataFrame(["1", "2"], schema="string") + df_floating_value = self.spark.createDataFrame(["1.1", "2.2"], schema="string") + + int_ddl_types = ["tinyint", "smallint", "int", "bigint"] + floating_ddl_types = ["double", "float"] + + for ddl_type in int_ddl_types: + # df_int_value + res = df_int_value.select(udf(lambda x: x, ddl_type)("value").alias("res")) + self.assertEqual(res.collect(), [Row(res=1), Row(res=2)]) + self.assertEqual(res.dtypes[0][1], ddl_type) + + floating_results = [ + [Row(res=1.1), Row(res=2.2)], + [Row(res=1.100000023841858), Row(res=2.200000047683716)], + ] + for ddl_type, floating_res in zip(floating_ddl_types, floating_results): + # df_int_value + res = df_int_value.select(udf(lambda x: x, ddl_type)("value").alias("res")) + self.assertEqual(res.collect(), [Row(res=1.0), Row(res=2.0)]) + self.assertEqual(res.dtypes[0][1], ddl_type) + # df_floating_value + res = df_floating_value.select(udf(lambda x: x, ddl_type)("value").alias("res")) + self.assertEqual(res.collect(), floating_res) + self.assertEqual(res.dtypes[0][1], ddl_type) + + # invalid + with self.assertRaises(PythonException): + df_floating_value.select(udf(lambda x: x, "int")("value").alias("res")).collect() + + with self.assertRaises(PythonException): + df_int_value.select(udf(lambda x: x, "decimal")("value").alias("res")).collect() + + with self.assertRaises(PythonException): + df_floating_value.select( + udf(lambda x: x, "decimal")("value").alias("res") + ).collect() def test_arrow_udf_int_to_decimal_coercion(self): + # Use STRICT policy to let Arrow handle type coercion natively, + # so the intToDecimalCoercionEnabled flag can control the behavior with self.sql_conf( - {"spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled": False} + { + "spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled": False, + "spark.sql.execution.pythonUDF.coercion.policy": "strict", + } ): df = self.spark.range(0, 3) diff --git a/python/pyspark/sql/tests/test_arrow_udf_coercion.py b/python/pyspark/sql/tests/test_arrow_udf_coercion.py new file mode 100644 index 0000000000000..9616003f7f153 --- /dev/null +++ b/python/pyspark/sql/tests/test_arrow_udf_coercion.py @@ -0,0 +1,356 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Integration tests for unified type coercion in Arrow-backed Python UDFs. + +These tests verify: +1. PERMISSIVE policy: Arrow-enabled UDFs produce the same results as pickle-based UDFs +2. WARN policy: Same as PERMISSIVE but with warnings (tested at unit level) +3. STRICT policy: Arrow handles type conversion natively (no coercion applied) + +The goal is to ensure backward compatibility when enabling Arrow optimization. +""" + +import array +import datetime +import re +import unittest +from decimal import Decimal + +from pyspark.sql import Row +from pyspark.sql.functions import udf +from pyspark.sql.types import ( + ArrayType, + BinaryType, + BooleanType, + ByteType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + MapType, + ShortType, + StringType, + StructField, + StructType, + TimestampType, +) +from pyspark.testing.utils import ( + have_pyarrow, + have_pandas, + pyarrow_requirement_message, + pandas_requirement_message, +) +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +def normalize_result(value): + """Normalize result for comparison, handling Java object hash codes.""" + result_str = repr(value) + # Normalize Java object hash codes to make tests deterministic + return re.sub(r"@[a-fA-F0-9]+", "@", result_str) + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, +) +class ArrowUDFCoercionTests(ReusedSQLTestCase): + """ + Integration tests comparing Arrow-enabled UDFs (with PERMISSIVE coercion) + against pickle-based UDFs to ensure identical behavior. + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def setUp(self): + super().setUp() + # Test values covering various Python types + self.test_data = [ + None, + True, + 1, + "a", + datetime.date(1970, 1, 1), + datetime.datetime(1970, 1, 1, 0, 0), + 1.0, + array.array("i", [1]), + [1], + (1,), + bytearray([65, 66, 67]), + Decimal(1), + {"a": 1}, + Row(kwargs=1), + Row("namedtuple")(1), + ] + + # SQL types to test coercion against + self.test_types = [ + BooleanType(), + ByteType(), + ShortType(), + IntegerType(), + LongType(), + StringType(), + DateType(), + TimestampType(), + FloatType(), + DoubleType(), + ArrayType(IntegerType()), + BinaryType(), + DecimalType(10, 0), + MapType(StringType(), IntegerType()), + StructType([StructField("_1", IntegerType())]), + ] + + def _run_udf(self, value, spark_type, use_arrow): + """Run a UDF that returns a specific value with a given return type.""" + try: + test_udf = udf(lambda _: value, spark_type, useArrow=use_arrow) + row = self.spark.range(1).select(test_udf("id")).first() + return ("success", normalize_result(row[0])) + except Exception as e: + return ("error", type(e).__name__) + + def _results_match(self, pickle_result, arrow_result, spark_type, value): + """ + Check if pickle and Arrow results match, with tolerance for known differences. + + Returns True if the results are equivalent for the purposes of coercion testing. + """ + # Exact match + if pickle_result == arrow_result: + return True + + # Both error - consider equivalent (error types may differ between Py4J and Python) + if pickle_result[0] == "error" and arrow_result[0] == "error": + return True + + # String type has known representation differences between Java and Python + # (Java's toString() vs Python's str()) + if isinstance(spark_type, StringType) and pickle_result[0] == "success": + # datetime objects: Java GregorianCalendar vs Python iso format + if isinstance(value, (datetime.date, datetime.datetime)): + return True + # Container types: Java array/object notation vs Python repr + if isinstance(value, (array.array, tuple, bytearray, dict)): + return True + + return False + + def test_arrow_with_permissive_matches_pickle(self): + """ + Test that Arrow-enabled UDFs with PERMISSIVE coercion produce + the same results as pickle-based UDFs. + """ + mismatches = [] + + for spark_type in self.test_types: + for value in self.test_data: + # Run with pickle (Arrow disabled) + with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": "false"}): + pickle_result = self._run_udf(value, spark_type, use_arrow=False) + + # Run with Arrow enabled (uses PERMISSIVE coercion by default) + with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": "true"}): + arrow_result = self._run_udf(value, spark_type, use_arrow=True) + + # Compare results + if not self._results_match(pickle_result, arrow_result, spark_type, value): + mismatches.append( + { + "type": spark_type.simpleString(), + "value": f"{value!r} ({type(value).__name__})", + "pickle": pickle_result, + "arrow": arrow_result, + } + ) + + if mismatches: + mismatch_report = "\n".join( + f" - {m['type']} <- {m['value']}: pickle={m['pickle']}, arrow={m['arrow']}" + for m in mismatches + ) + self.fail( + f"Arrow with PERMISSIVE coercion does not match pickle behavior:\n{mismatch_report}" + ) + + def test_specific_coercion_cases(self): + """Test specific coercion cases that are known to differ between Arrow and pickle.""" + test_cases = [ + # (value, spark_type, description) + (1, BooleanType(), "int -> boolean should return None"), + (1.0, BooleanType(), "float -> boolean should return None"), + (True, IntegerType(), "bool -> int should return None"), + (1.5, IntegerType(), "float -> int should return None"), + (Decimal(1), IntegerType(), "Decimal -> int should return None"), + (True, FloatType(), "bool -> float should return None"), + (1, FloatType(), "int -> float should return None"), + (Decimal(1), FloatType(), "Decimal -> float should return None"), + (1, DecimalType(10, 0), "int -> decimal should return None"), + ] + + for value, spark_type, description in test_cases: + with self.subTest(msg=description): + # Run with pickle (Arrow disabled) + with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": "false"}): + pickle_result = self._run_udf(value, spark_type, use_arrow=False) + + # Run with Arrow enabled + with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": "true"}): + arrow_result = self._run_udf(value, spark_type, use_arrow=True) + + self.assertEqual( + pickle_result, + arrow_result, + f"{description}: pickle={pickle_result}, arrow={arrow_result}", + ) + + def test_warn_policy_matches_permissive(self): + """ + Test that WARN policy produces the same results as PERMISSIVE. + WARN should behave identically to PERMISSIVE, just with warnings logged. + """ + mismatches = [] + + for spark_type in self.test_types: + for value in self.test_data: + # Run with PERMISSIVE policy + with self.sql_conf( + { + "spark.sql.execution.pythonUDF.arrow.enabled": "true", + "spark.sql.execution.pythonUDF.coercion.policy": "permissive", + } + ): + permissive_result = self._run_udf(value, spark_type, use_arrow=True) + + # Run with WARN policy + with self.sql_conf( + { + "spark.sql.execution.pythonUDF.arrow.enabled": "true", + "spark.sql.execution.pythonUDF.coercion.policy": "warn", + } + ): + warn_result = self._run_udf(value, spark_type, use_arrow=True) + + # Compare results (should be identical) + if permissive_result != warn_result: + mismatches.append( + { + "type": spark_type.simpleString(), + "value": f"{value!r} ({type(value).__name__})", + "permissive": permissive_result, + "warn": warn_result, + } + ) + + if mismatches: + mismatch_report = "\n".join( + f" - {m['type']} <- {m['value']}: permissive={m['permissive']}, warn={m['warn']}" + for m in mismatches + ) + self.fail(f"WARN policy does not match PERMISSIVE behavior:\n{mismatch_report}") + + def test_strict_policy_differs_from_permissive(self): + """ + Test that STRICT policy (no coercion) produces different results than PERMISSIVE + for cases where Arrow's native type conversion differs from pickle behavior. + + STRICT is a no-op - it lets Arrow handle type conversion natively, + which is more aggressive than pickle (PERMISSIVE). + """ + # Cases where STRICT (Arrow native) should produce different results than PERMISSIVE + # For each case: (value, spark_type, description) + test_cases = [ + # int -> boolean: PERMISSIVE returns None, STRICT (Arrow) converts to True + (1, BooleanType(), "int -> boolean"), + # float -> boolean: PERMISSIVE returns None, STRICT (Arrow) converts to True + (1.0, BooleanType(), "float -> boolean"), + # float -> int: PERMISSIVE returns None, STRICT (Arrow) truncates to 1 + (1.5, IntegerType(), "float -> int"), + # Decimal -> int: PERMISSIVE returns None, STRICT (Arrow) converts to 1 + (Decimal(1), IntegerType(), "Decimal -> int"), + # int -> float: PERMISSIVE returns None, STRICT (Arrow) converts to 1.0 + (1, FloatType(), "int -> float"), + # bool -> float: PERMISSIVE returns None, STRICT (Arrow) converts to 1.0 + (True, FloatType(), "bool -> float"), + ] + + for value, spark_type, description in test_cases: + with self.subTest(msg=description): + # Run with PERMISSIVE policy + with self.sql_conf( + { + "spark.sql.execution.pythonUDF.arrow.enabled": "true", + "spark.sql.execution.pythonUDF.coercion.policy": "permissive", + } + ): + permissive_result = self._run_udf(value, spark_type, use_arrow=True) + + # Run with STRICT policy (Arrow native behavior) + with self.sql_conf( + { + "spark.sql.execution.pythonUDF.arrow.enabled": "true", + "spark.sql.execution.pythonUDF.coercion.policy": "strict", + } + ): + strict_result = self._run_udf(value, spark_type, use_arrow=True) + + # Skip if either result is an error (environment issue, not test logic) + if permissive_result[0] == "error" or strict_result[0] == "error": + # If both error, that's unexpected - fail + if permissive_result[0] == "error" and strict_result[0] == "error": + self.skipTest( + f"{description}: Both policies errored - likely environment issue" + ) + continue + + # PERMISSIVE should return None for these cases + self.assertIn( + "None", + permissive_result[1], + f"{description}: PERMISSIVE should return None, got {permissive_result[1]}", + ) + + # STRICT should succeed with a converted value (not None) + self.assertNotIn( + "None", + strict_result[1], + f"{description}: STRICT should not return None (Arrow converts), got {strict_result[1]}", + ) + + # The results should be different + self.assertNotEqual( + permissive_result, + strict_result, + f"{description}: PERMISSIVE and STRICT should produce different results", + ) + + +if __name__ == "__main__": + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_coercion.py b/python/pyspark/sql/tests/test_coercion.py new file mode 100644 index 0000000000000..d8ca48a8e9163 --- /dev/null +++ b/python/pyspark/sql/tests/test_coercion.py @@ -0,0 +1,588 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for unified type coercion in PySpark. + +These tests verify that the CoercionPolicy enum and DataType.coerce() method +correctly handle type coercion with different policies: +- PERMISSIVE: matches legacy pickle behavior (returns None for most type mismatches) +- WARN: same as PERMISSIVE but logs warnings when Arrow would behave differently + +Note: STRICT policy skips coercion entirely in worker.py, so coerce() is never +called with STRICT. These unit tests only cover PERMISSIVE and WARN. + +The goal is to enable Arrow by default without breaking existing code. +""" + +import array +import datetime +import unittest +from decimal import Decimal + +from pyspark.errors import PySparkTypeError +from pyspark.sql import Row +from pyspark.sql.types import ( + ArrayType, + BinaryType, + BooleanType, + ByteType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + MapType, + ShortType, + StringType, + StructField, + StructType, + TimestampType, + CoercionPolicy, + CoercionWarning, + _reset_coercion_warnings, +) + + +class CoercionPolicyTests(unittest.TestCase): + """Tests for the CoercionPolicy enum.""" + + def test_policy_values(self): + """Test that all expected policy values exist.""" + # With StrEnum and auto(), values are lowercase names + self.assertEqual(CoercionPolicy.PERMISSIVE, "permissive") + self.assertEqual(CoercionPolicy.WARN, "warn") + self.assertEqual(CoercionPolicy.STRICT, "strict") + + def test_policy_from_string(self): + """Test creating policy from string value.""" + self.assertEqual(CoercionPolicy("permissive"), CoercionPolicy.PERMISSIVE) + self.assertEqual(CoercionPolicy("warn"), CoercionPolicy.WARN) + self.assertEqual(CoercionPolicy("strict"), CoercionPolicy.STRICT) + + +class BooleanCoercionTests(unittest.TestCase): + """Tests for BooleanType coercion.""" + + def setUp(self): + _reset_coercion_warnings() + self.boolean_type = BooleanType() + + def test_bool_to_boolean_permissive_and_warn(self): + """bool -> boolean should work for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.boolean_type.coerce(True, policy), True) + self.assertEqual(self.boolean_type.coerce(False, policy), False) + + def test_none_to_boolean_permissive_and_warn(self): + """None -> boolean should return None for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertIsNone(self.boolean_type.coerce(None, policy)) + + def test_int_to_boolean_permissive(self): + """int -> boolean: PERMISSIVE returns None (pickle behavior).""" + self.assertIsNone(self.boolean_type.coerce(1, CoercionPolicy.PERMISSIVE)) + self.assertIsNone(self.boolean_type.coerce(0, CoercionPolicy.PERMISSIVE)) + + def test_int_to_boolean_warn(self): + """int -> boolean: WARN returns None but logs warning.""" + with self.assertWarns(CoercionWarning): + result = self.boolean_type.coerce(1, CoercionPolicy.WARN) + self.assertIsNone(result) + + def test_float_to_boolean_permissive(self): + """float -> boolean: PERMISSIVE returns None (pickle behavior).""" + self.assertIsNone(self.boolean_type.coerce(1.0, CoercionPolicy.PERMISSIVE)) + self.assertIsNone(self.boolean_type.coerce(0.0, CoercionPolicy.PERMISSIVE)) + + def test_string_to_boolean_permissive(self): + """str -> boolean: PERMISSIVE returns None.""" + self.assertIsNone(self.boolean_type.coerce("true", CoercionPolicy.PERMISSIVE)) + + +class IntegerCoercionTests(unittest.TestCase): + """Tests for integer types (ByteType, ShortType, IntegerType, LongType).""" + + def setUp(self): + _reset_coercion_warnings() + self.int_types = [ByteType(), ShortType(), IntegerType(), LongType()] + + def test_int_to_int_permissive_and_warn(self): + """int -> int should work for PERMISSIVE and WARN.""" + for int_type in self.int_types: + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(int_type.coerce(1, policy), 1) + self.assertEqual(int_type.coerce(0, policy), 0) + self.assertEqual(int_type.coerce(-1, policy), -1) + + def test_none_to_int_permissive_and_warn(self): + """None -> int should return None for PERMISSIVE and WARN.""" + for int_type in self.int_types: + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertIsNone(int_type.coerce(None, policy)) + + def test_bool_to_int_permissive(self): + """bool -> int: PERMISSIVE returns None (pickle behavior).""" + for int_type in self.int_types: + self.assertIsNone(int_type.coerce(True, CoercionPolicy.PERMISSIVE)) + self.assertIsNone(int_type.coerce(False, CoercionPolicy.PERMISSIVE)) + + def test_float_to_int_permissive(self): + """float -> int: PERMISSIVE returns None (pickle behavior).""" + for int_type in self.int_types: + self.assertIsNone(int_type.coerce(1.0, CoercionPolicy.PERMISSIVE)) + self.assertIsNone(int_type.coerce(1.9, CoercionPolicy.PERMISSIVE)) + + def test_decimal_to_int_permissive(self): + """Decimal -> int: PERMISSIVE returns None (pickle behavior).""" + for int_type in self.int_types: + self.assertIsNone(int_type.coerce(Decimal(1), CoercionPolicy.PERMISSIVE)) + + def test_string_to_int_permissive(self): + """str -> int: PERMISSIVE returns None.""" + for int_type in self.int_types: + self.assertIsNone(int_type.coerce("1", CoercionPolicy.PERMISSIVE)) + + +class FloatCoercionTests(unittest.TestCase): + """Tests for FloatType and DoubleType coercion.""" + + def setUp(self): + _reset_coercion_warnings() + self.float_types = [FloatType(), DoubleType()] + + def test_float_to_float_permissive_and_warn(self): + """float -> float should work for PERMISSIVE and WARN.""" + for float_type in self.float_types: + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(float_type.coerce(1.0, policy), 1.0) + self.assertEqual(float_type.coerce(0.0, policy), 0.0) + + def test_none_to_float_permissive_and_warn(self): + """None -> float should return None for PERMISSIVE and WARN.""" + for float_type in self.float_types: + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertIsNone(float_type.coerce(None, policy)) + + def test_int_to_float_permissive(self): + """int -> float: PERMISSIVE returns None (pickle behavior).""" + for float_type in self.float_types: + self.assertIsNone(float_type.coerce(1, CoercionPolicy.PERMISSIVE)) + + def test_bool_to_float_permissive(self): + """bool -> float: PERMISSIVE returns None (pickle behavior).""" + for float_type in self.float_types: + self.assertIsNone(float_type.coerce(True, CoercionPolicy.PERMISSIVE)) + self.assertIsNone(float_type.coerce(False, CoercionPolicy.PERMISSIVE)) + + def test_decimal_to_float_permissive(self): + """Decimal -> float: PERMISSIVE returns None (pickle behavior).""" + for float_type in self.float_types: + self.assertIsNone(float_type.coerce(Decimal(1), CoercionPolicy.PERMISSIVE)) + + +class StringCoercionTests(unittest.TestCase): + """Tests for StringType coercion.""" + + def setUp(self): + _reset_coercion_warnings() + self.string_type = StringType() + + def test_str_to_string_permissive_and_warn(self): + """str -> string should work for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.string_type.coerce("hello", policy), "hello") + self.assertEqual(self.string_type.coerce("", policy), "") + + def test_none_to_string_permissive_and_warn(self): + """None -> string should return None for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertIsNone(self.string_type.coerce(None, policy)) + + def test_int_to_string_permissive_and_warn(self): + """int -> string: all paths convert via str().""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.string_type.coerce(1, policy), "1") + + def test_bool_to_string_permissive(self): + """bool -> string: PERMISSIVE returns 'true'/'false' (Java toString).""" + self.assertEqual(self.string_type.coerce(True, CoercionPolicy.PERMISSIVE), "true") + self.assertEqual(self.string_type.coerce(False, CoercionPolicy.PERMISSIVE), "false") + + def test_float_to_string_permissive_and_warn(self): + """float -> string: converted via str().""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.string_type.coerce(1.0, policy), "1.0") + + +class DateCoercionTests(unittest.TestCase): + """Tests for DateType coercion.""" + + def setUp(self): + _reset_coercion_warnings() + self.date_type = DateType() + + def test_date_to_date_permissive_and_warn(self): + """date -> date should work for PERMISSIVE and WARN.""" + date_val = datetime.date(1970, 1, 1) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.date_type.coerce(date_val, policy), date_val) + + def test_none_to_date_permissive_and_warn(self): + """None -> date should return None for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertIsNone(self.date_type.coerce(None, policy)) + + def test_datetime_to_date_permissive_and_warn(self): + """datetime -> date: PERMISSIVE/WARN extract date part.""" + dt_val = datetime.datetime(1970, 1, 1, 12, 30, 45) + expected = datetime.date(1970, 1, 1) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.date_type.coerce(dt_val, policy), expected) + + def test_int_to_date_permissive(self): + """int -> date: PERMISSIVE raises PySparkTypeError (pickle behavior).""" + with self.assertRaises(PySparkTypeError): + self.date_type.coerce(1, CoercionPolicy.PERMISSIVE) + + +class TimestampCoercionTests(unittest.TestCase): + """Tests for TimestampType coercion.""" + + def setUp(self): + _reset_coercion_warnings() + self.timestamp_type = TimestampType() + + def test_datetime_to_timestamp_permissive_and_warn(self): + """datetime -> timestamp should work for PERMISSIVE and WARN.""" + dt_val = datetime.datetime(1970, 1, 1, 0, 0, 0) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.timestamp_type.coerce(dt_val, policy), dt_val) + + def test_none_to_timestamp_permissive_and_warn(self): + """None -> timestamp should return None for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertIsNone(self.timestamp_type.coerce(None, policy)) + + def test_date_to_timestamp_permissive_and_warn(self): + """date -> timestamp: raises PySparkTypeError for PERMISSIVE and WARN.""" + date_val = datetime.date(1970, 1, 1) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + with self.assertRaises(PySparkTypeError): + self.timestamp_type.coerce(date_val, policy) + + def test_int_to_timestamp_permissive_and_warn(self): + """int -> timestamp: raises PySparkTypeError for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + with self.assertRaises(PySparkTypeError): + self.timestamp_type.coerce(1, policy) + + +class BinaryCoercionTests(unittest.TestCase): + """Tests for BinaryType coercion.""" + + def setUp(self): + _reset_coercion_warnings() + self.binary_type = BinaryType() + + def test_bytes_to_binary_permissive_and_warn(self): + """bytes -> binary should work for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.binary_type.coerce(b"ABC", policy), b"ABC") + + def test_bytearray_to_binary_permissive_and_warn(self): + """bytearray -> binary: PERMISSIVE/WARN convert to bytes.""" + ba = bytearray([65, 66, 67]) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.binary_type.coerce(ba, policy), b"ABC") + + def test_none_to_binary_permissive_and_warn(self): + """None -> binary should return None for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertIsNone(self.binary_type.coerce(None, policy)) + + def test_str_to_binary_permissive(self): + """str -> binary: PERMISSIVE encodes (pickle behavior).""" + self.assertEqual(self.binary_type.coerce("a", CoercionPolicy.PERMISSIVE), b"a") + + +class ArrayCoercionTests(unittest.TestCase): + """Tests for ArrayType coercion.""" + + def setUp(self): + _reset_coercion_warnings() + self.array_int_type = ArrayType(IntegerType()) + + def test_list_to_array_permissive_and_warn(self): + """list -> array should work for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.array_int_type.coerce([1, 2, 3], policy), [1, 2, 3]) + + def test_tuple_to_array_permissive_and_warn(self): + """tuple -> array: PERMISSIVE/WARN convert to list.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.array_int_type.coerce((1, 2, 3), policy), [1, 2, 3]) + + def test_none_to_array_permissive_and_warn(self): + """None -> array should return None for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertIsNone(self.array_int_type.coerce(None, policy)) + + def test_python_array_to_array_permissive_and_warn(self): + """array.array -> array: PERMISSIVE/WARN convert to list.""" + arr = array.array("i", [1, 2, 3]) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.array_int_type.coerce(arr, policy), [1, 2, 3]) + + def test_bytearray_to_int_array_permissive_and_warn(self): + """bytearray -> array: PERMISSIVE/WARN convert to list.""" + ba = bytearray([65, 66, 67]) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.array_int_type.coerce(ba, policy), [65, 66, 67]) + + +class StructCoercionTests(unittest.TestCase): + """Tests for StructType coercion.""" + + def setUp(self): + _reset_coercion_warnings() + self.struct_type = StructType([StructField("_1", IntegerType())]) + + def test_row_to_struct_permissive_and_warn(self): + """Row -> struct should work for PERMISSIVE and WARN.""" + row = Row(_1=1) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = self.struct_type.coerce(row, policy) + self.assertEqual(result._1, 1) + + def test_tuple_to_struct_permissive_and_warn(self): + """tuple -> struct: PERMISSIVE/WARN convert to Row.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = self.struct_type.coerce((1,), policy) + self.assertEqual(result._1, 1) + + def test_dict_to_struct_permissive_and_warn(self): + """dict -> struct: PERMISSIVE/WARN convert to Row.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = self.struct_type.coerce({"_1": 1}, policy) + self.assertEqual(result._1, 1) + + def test_none_to_struct_permissive_and_warn(self): + """None -> struct should return None for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertIsNone(self.struct_type.coerce(None, policy)) + + def test_list_to_struct_permissive(self): + """list -> struct: PERMISSIVE converts (pickle behavior).""" + result = self.struct_type.coerce([1], CoercionPolicy.PERMISSIVE) + self.assertEqual(result._1, 1) + + +class MapCoercionTests(unittest.TestCase): + """Tests for MapType coercion.""" + + def setUp(self): + _reset_coercion_warnings() + self.map_type = MapType(StringType(), IntegerType()) + + def test_dict_to_map_permissive_and_warn(self): + """dict -> map should work for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.map_type.coerce({"a": 1}, policy), {"a": 1}) + + def test_none_to_map_permissive_and_warn(self): + """None -> map should return None for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertIsNone(self.map_type.coerce(None, policy)) + + def test_other_to_map_permissive(self): + """other -> map: PERMISSIVE returns None.""" + self.assertIsNone(self.map_type.coerce([1, 2], CoercionPolicy.PERMISSIVE)) + + +class DecimalCoercionTests(unittest.TestCase): + """Tests for DecimalType coercion.""" + + def setUp(self): + _reset_coercion_warnings() + self.decimal_type = DecimalType(10, 0) + + def test_decimal_to_decimal_permissive_and_warn(self): + """Decimal -> decimal should work for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertEqual(self.decimal_type.coerce(Decimal("123"), policy), Decimal("123")) + + def test_none_to_decimal_permissive_and_warn(self): + """None -> decimal should return None for PERMISSIVE and WARN.""" + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + self.assertIsNone(self.decimal_type.coerce(None, policy)) + + def test_int_to_decimal_permissive(self): + """int -> decimal: PERMISSIVE returns None (pickle behavior).""" + self.assertIsNone(self.decimal_type.coerce(1, CoercionPolicy.PERMISSIVE)) + + +class NestedTypeCoercionTests(unittest.TestCase): + """Tests for recursive coercion in nested types (ArrayType, MapType, StructType).""" + + def setUp(self): + _reset_coercion_warnings() + + def test_array_element_coercion(self): + """ArrayType should recursively coerce elements.""" + array_int_type = ArrayType(IntegerType()) + # float elements should be coerced to None for IntegerType + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = array_int_type.coerce([1.5, 2.5, 3.5], policy) + self.assertEqual(result, [None, None, None]) + + def test_array_element_coercion_mixed(self): + """ArrayType should coerce each element independently.""" + array_int_type = ArrayType(IntegerType()) + # Mix of valid and invalid types + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = array_int_type.coerce([1, 2.5, 3], policy) + self.assertEqual(result, [1, None, 3]) + + def test_array_element_coercion_with_none(self): + """ArrayType should handle None elements correctly.""" + array_int_type = ArrayType(IntegerType()) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = array_int_type.coerce([1, None, 3], policy) + self.assertEqual(result, [1, None, 3]) + + def test_nested_array_coercion(self): + """Nested ArrayType should coerce recursively.""" + nested_array_type = ArrayType(ArrayType(IntegerType())) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = nested_array_type.coerce([[1, 2.5], [3.5, 4]], policy) + self.assertEqual(result, [[1, None], [None, 4]]) + + def test_map_value_coercion(self): + """MapType should recursively coerce values.""" + map_type = MapType(StringType(), IntegerType()) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = map_type.coerce({"a": 1.5, "b": 2}, policy) + self.assertEqual(result, {"a": None, "b": 2}) + + def test_map_key_coercion(self): + """MapType should recursively coerce keys.""" + map_type = MapType(StringType(), IntegerType()) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + # int key should be coerced to string "1" + result = map_type.coerce({1: 2}, policy) + self.assertEqual(result, {"1": 2}) + + def test_map_nested_value_coercion(self): + """MapType with nested value type should coerce recursively.""" + map_type = MapType(StringType(), ArrayType(IntegerType())) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = map_type.coerce({"a": [1, 2.5, 3]}, policy) + self.assertEqual(result, {"a": [1, None, 3]}) + + def test_struct_field_coercion_from_row(self): + """StructType should recursively coerce fields from Row.""" + struct_type = StructType([StructField("x", IntegerType())]) + row = Row(x=1.5) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = struct_type.coerce(row, policy) + self.assertIsNone(result.x) + + def test_struct_field_coercion_from_tuple(self): + """StructType should recursively coerce fields from tuple.""" + struct_type = StructType([StructField("x", IntegerType())]) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = struct_type.coerce((1.5,), policy) + self.assertIsNone(result.x) + + def test_struct_field_coercion_from_dict(self): + """StructType should recursively coerce fields from dict.""" + struct_type = StructType([StructField("x", IntegerType())]) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = struct_type.coerce({"x": 1.5}, policy) + self.assertIsNone(result.x) + + def test_struct_field_coercion_from_list(self): + """StructType should recursively coerce fields from list.""" + struct_type = StructType([StructField("x", IntegerType())]) + result = struct_type.coerce([1.5], CoercionPolicy.PERMISSIVE) + self.assertIsNone(result.x) + + def test_struct_multiple_fields_coercion(self): + """StructType with multiple fields should coerce each field.""" + struct_type = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = struct_type.coerce((1.5, 123), policy) + self.assertIsNone(result.a) # float -> int: None + self.assertEqual(result.b, "123") # int -> string: "123" + + def test_nested_struct_coercion(self): + """Nested StructType should coerce recursively.""" + inner_struct = StructType([StructField("x", IntegerType())]) + outer_struct = StructType([StructField("inner", inner_struct)]) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = outer_struct.coerce((Row(x=1.5),), policy) + self.assertIsNone(result.inner.x) + + def test_struct_with_array_field_coercion(self): + """StructType with ArrayType field should coerce recursively.""" + struct_type = StructType([StructField("arr", ArrayType(IntegerType()))]) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = struct_type.coerce(([1, 2.5, 3],), policy) + self.assertEqual(result.arr, [1, None, 3]) + + def test_array_of_struct_coercion(self): + """ArrayType of StructType should coerce recursively.""" + struct_type = StructType([StructField("x", IntegerType())]) + array_struct_type = ArrayType(struct_type) + for policy in [CoercionPolicy.PERMISSIVE, CoercionPolicy.WARN]: + result = array_struct_type.coerce([Row(x=1), Row(x=2.5)], policy) + self.assertEqual(result[0].x, 1) + self.assertIsNone(result[1].x) + + +class DefaultPolicyTests(unittest.TestCase): + """Tests that coerce() defaults to PERMISSIVE policy.""" + + def setUp(self): + _reset_coercion_warnings() + + def test_default_policy_is_permissive(self): + """coerce() without policy should behave like PERMISSIVE.""" + boolean_type = BooleanType() + # int -> boolean: PERMISSIVE returns None (pickle behavior) + self.assertIsNone(boolean_type.coerce(1)) + + int_type = IntegerType() + # float -> int: PERMISSIVE returns None (pickle behavior) + self.assertIsNone(int_type.coerce(1.0)) + + date_type = DateType() + # int -> date: PERMISSIVE raises PySparkTypeError (pickle behavior) + with self.assertRaises(PySparkTypeError): + date_type.coerce(1) + + +if __name__ == "__main__": + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/udf_type_tests/test_udf_return_types.py b/python/pyspark/sql/tests/udf_type_tests/test_udf_return_types.py index ccec277d4c78e..e5cdf1a40b29a 100644 --- a/python/pyspark/sql/tests/udf_type_tests/test_udf_return_types.py +++ b/python/pyspark/sql/tests/udf_type_tests/test_udf_return_types.py @@ -69,7 +69,7 @@ pandas_requirement_message or pyarrow_requirement_message or numpy_requirement_message - or "float128 not supported on macos", + or "float128 not supported on macOS", ) class UDFReturnTypeTests(ReusedSQLTestCase): @classmethod @@ -177,14 +177,17 @@ def test_udf_return_type_coercion_arrow_enabled(self): ) def _run_udf_return_type_coercion_test(self, use_arrow, legacy_pandas, golden_file, test_name): - with self.sql_conf( - { - "spark.sql.execution.pythonUDF.arrow.enabled": str(use_arrow).lower(), - "spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled": str( - legacy_pandas - ).lower(), - } - ): + conf = { + "spark.sql.execution.pythonUDF.arrow.enabled": str(use_arrow).lower(), + "spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled": str( + legacy_pandas + ).lower(), + } + # Use STRICT policy for Arrow tests to preserve original Arrow behavior + # (PERMISSIVE coercion would make Arrow behave like pickle) + if use_arrow: + conf["spark.sql.execution.pythonUDF.coercion.policy"] = "strict" + with self.sql_conf(conf): results = self._generate_udf_return_type_coercion_results(use_arrow) header = ["SQL Type \\ Python Value(Type)"] + [ f"{str(v)}({type(v).__name__})" for v in self.test_data @@ -244,6 +247,8 @@ def test_pandas_udf_return_type_coercion(self): test_name = "Pandas UDF type coercion" + # Note: coercion.policy config only affects ArrowBatchUDFSerializer (regular Arrow UDFs), + # not ArrowStreamPandasUDFSerializer (Pandas UDFs), so we don't set it here. results = self._generate_pandas_udf_type_coercion_results() header = ["SQL Type \\ Pandas Value(Type)"] + [ f"{str(v).replace(chr(10), ' ')}({type(v).__name__})" for v in self.pandas_test_data diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 943e0943ca489..8dc1736b600e3 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -28,7 +28,9 @@ from array import array import ctypes from collections.abc import Iterable +from enum import Enum from functools import reduce +import warnings from typing import ( cast, overload, @@ -77,6 +79,8 @@ U = TypeVar("U") __all__ = [ + "CoercionPolicy", + "CoercionWarning", "DataType", "NullType", "CharType", @@ -112,6 +116,94 @@ ] +class CoercionPolicy(str, Enum): + """ + Policy for type coercion in Python UDFs. + + This enum controls how values are coerced when the Python UDF returns + a value that doesn't exactly match the declared return type. + + .. versionadded:: 4.2.0 + + Notes + ----- + The three policies represent different trade-offs between compatibility + and strictness: + + - PERMISSIVE: Matches legacy pickle-based UDF behavior. Invalid coercions + silently return None. This is the default for backward compatibility. + - WARN: Same behavior as PERMISSIVE but logs warnings when Arrow would + behave differently. Useful for migration testing. + - STRICT: Matches Arrow-optimized UDF behavior. Invalid coercions either + raise exceptions or apply aggressive type conversion. + + Examples + -------- + >>> from pyspark.sql.types import CoercionPolicy, IntegerType + >>> int_type = IntegerType() + + PERMISSIVE policy returns None for float -> int (pickle behavior): + + >>> int_type.coerce(1.5, CoercionPolicy.PERMISSIVE) is None + True + + Note: In STRICT mode, the serializer skips calling coerce() entirely, + letting Arrow handle type conversion natively. + """ + + PERMISSIVE = "permissive" + """ + Matches legacy pickle-based UDF behavior. + Invalid type coercions silently return None. + This is the default policy for backward compatibility. + """ + + WARN = "warn" + """ + Same coercion behavior as PERMISSIVE, but logs warnings when + the Arrow path would produce different results. + Useful for testing migration to STRICT mode. + """ + + STRICT = "strict" + """ + Skips coercion entirely, letting Arrow handle type conversion natively. + This preserves Arrow's aggressive type conversion behavior + (e.g., int -> bool, float -> int truncation). + """ + + +class CoercionWarning(UserWarning): + """Warning issued when type coercion behavior differs between pickle and Arrow modes.""" + + pass + + +# Configure default filter to show each unique CoercionWarning only once +warnings.filterwarnings("once", category=CoercionWarning) + + +def _warn_coercion_once(message: str) -> None: + """Issue a coercion warning only once per unique message. + + Uses Python's warnings module with 'once' filter. + The stacklevel=3 accounts for: _warn_coercion_once -> coerce() -> caller. + """ + warnings.warn(message, CoercionWarning, stacklevel=3) + + +def _reset_coercion_warnings() -> None: + """Reset coercion warnings so they can be shown again. Used for testing.""" + warnings.filterwarnings("once", category=CoercionWarning) + + +def _is_row_object(value: Any) -> bool: + """Check if value is a Row object. Row causes pickle to error on most type conversions.""" + # Check if it's a tuple subclass with Row-specific attributes + # We can't use isinstance(value, Row) here because Row is defined later in this file + return isinstance(value, tuple) and hasattr(value, "__fields__") and hasattr(value, "asDict") + + class DataType: """Base class for data types.""" @@ -160,6 +252,20 @@ def fromInternal(self, obj: Any) -> Any: """ return obj + @property + def needsCoercion(self) -> bool: + """ + Whether this type needs coercion logic applied. + + Returns False for types where coerce() is a no-op (just returns value unchanged). + This allows container types to skip recursion when element types don't need coercion. + """ + return False + + def coerce(self, value: Any, policy: "CoercionPolicy" = CoercionPolicy.PERMISSIVE) -> Any: + """Coerce a Python value to this data type. Base implementation is a no-op.""" + return value + def _as_nullable(self) -> "DataType": return self @@ -269,12 +375,64 @@ class NumericType(AtomicType): class IntegralType(NumericType, metaclass=DataTypeSingleton): """Integral data types.""" - pass + @property + def needsCoercion(self) -> bool: + return True + + def coerce(self, value: Any, policy: "CoercionPolicy" = CoercionPolicy.PERMISSIVE) -> Any: + # Fast path: check most common case first (int that's not bool) + if type(value) is int: + return value + if value is None: + return None + # int subclass (non-bool) -> int: exact match + if isinstance(value, int) and not isinstance(value, bool): + return value + # Row -> int: pickle raises + if _is_row_object(value): + raise PySparkTypeError( + errorClass="CANNOT_CONVERT_TYPE", + messageParameters={"from_type": "Row", "to_type": "IntegerType"}, + ) + # Other types: pickle returns None + if policy == CoercionPolicy.WARN and isinstance(value, (bool, float, decimal.Decimal)): + _warn_coercion_once( + f"Coercing {type(value).__name__} to integer returns None in pickle mode " + "but would convert or raise in Arrow mode" + ) + return None class FractionalType(NumericType): """Fractional data types.""" + @property + def needsCoercion(self) -> bool: + return True + + def coerce(self, value: Any, policy: "CoercionPolicy" = CoercionPolicy.PERMISSIVE) -> Any: + # Fast path: check most common case first + if type(value) is float: + return value + if value is None: + return None + # float subclass -> float: exact match + if isinstance(value, float): + return value + # Row -> float: pickle raises + if _is_row_object(value): + raise PySparkTypeError( + errorClass="CANNOT_CONVERT_TYPE", + messageParameters={"from_type": "Row", "to_type": "FloatType"}, + ) + # Other types: pickle returns None + if policy == CoercionPolicy.WARN and isinstance(value, (bool, int, decimal.Decimal)): + _warn_coercion_once( + f"Coercing {type(value).__name__} to float returns None in pickle mode " + "but would convert in Arrow mode" + ) + return None + class StringType(AtomicType): """String data type. @@ -316,6 +474,31 @@ def __repr__(self) -> str: def isUTF8BinaryCollation(self) -> bool: return self.collation == "UTF8_BINARY" + @property + def needsCoercion(self) -> bool: + return True + + def coerce(self, value: Any, policy: "CoercionPolicy" = CoercionPolicy.PERMISSIVE) -> Any: + # Fast path: check most common case first + if type(value) is str: + return value + if value is None: + return None + # str subclass -> str + if isinstance(value, str): + return value + # Row -> string: pickle raises + if _is_row_object(value): + raise PySparkTypeError( + errorClass="CANNOT_CONVERT_TYPE", + messageParameters={"from_type": "Row", "to_type": "StringType"}, + ) + # bool -> string: pickle gives 'true'/'false' (Java toString) + if isinstance(value, bool): + return "true" if value else "false" + # Most types can be converted via str() + return str(value) + class CharType(AtomicType): """Char data type @@ -364,13 +547,65 @@ def __repr__(self) -> str: class BinaryType(AtomicType, metaclass=DataTypeSingleton): """Binary (byte array) data type.""" - pass + @property + def needsCoercion(self) -> bool: + return True + + def coerce(self, value: Any, policy: "CoercionPolicy" = CoercionPolicy.PERMISSIVE) -> Any: + if value is None: + return None + # bytes -> binary: exact match + if isinstance(value, bytes): + return value + # bytearray -> binary: both paths convert + if isinstance(value, bytearray): + return bytes(value) + # str -> binary: pickle encodes to bytes + if isinstance(value, str): + if policy == CoercionPolicy.WARN: + _warn_coercion_once( + "Coercing str to binary encodes in pickle mode but raises in Arrow mode" + ) + return value.encode("utf-8") + # Row -> binary: pickle raises + if _is_row_object(value): + raise PySparkTypeError( + errorClass="CANNOT_CONVERT_TYPE", + messageParameters={"from_type": "Row", "to_type": "BinaryType"}, + ) + # Other types: return None + return None class BooleanType(AtomicType, metaclass=DataTypeSingleton): """Boolean data type.""" - pass + @property + def needsCoercion(self) -> bool: + return True + + def coerce(self, value: Any, policy: "CoercionPolicy" = CoercionPolicy.PERMISSIVE) -> Any: + # Fast path: check most common case first + if type(value) is bool: + return value + if value is None: + return None + # bool subclass -> boolean: exact match + if isinstance(value, bool): + return value + # Row -> boolean: pickle raises + if _is_row_object(value): + raise PySparkTypeError( + errorClass="CANNOT_CONVERT_TYPE", + messageParameters={"from_type": "Row", "to_type": "BooleanType"}, + ) + # Other types: pickle returns None + if policy == CoercionPolicy.WARN and isinstance(value, (int, float)): + _warn_coercion_once( + f"Coercing {type(value).__name__} to boolean returns None in pickle mode " + "but would convert to bool in Arrow mode" + ) + return None class DatetimeType(AtomicType): @@ -393,6 +628,40 @@ def fromInternal(self, v: int) -> datetime.date: if v is not None: return datetime.date.fromordinal(v + self.EPOCH_ORDINAL) + def _days_to_date(self, days: int) -> datetime.date: + """Convert days since epoch to date.""" + return datetime.date.fromordinal(days + self.EPOCH_ORDINAL) + + @property + def needsCoercion(self) -> bool: + return True + + def coerce(self, value: Any, policy: "CoercionPolicy" = CoercionPolicy.PERMISSIVE) -> Any: + if value is None: + return None + # date -> date: exact match + if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime): + return value + # datetime -> date: both paths extract date + if isinstance(value, datetime.datetime): + return value.date() + # int/float/Decimal -> date: pickle raises + if isinstance(value, (int, float, decimal.Decimal)) and not isinstance(value, bool): + if policy == CoercionPolicy.WARN: + _warn_coercion_once( + f"Coercing {type(value).__name__} to date raises in pickle mode " + "but converts (days since epoch) in Arrow mode" + ) + raise PySparkTypeError( + errorClass="CANNOT_CONVERT_TYPE", + messageParameters={"from_type": type(value).__name__, "to_type": "DateType"}, + ) + # Other types: raise + raise PySparkTypeError( + errorClass="CANNOT_CONVERT_TYPE", + messageParameters={"from_type": type(value).__name__, "to_type": "DateType"}, + ) + class AnyTimeType(DatetimeType): """A TIME type of any valid precision.""" @@ -469,6 +738,20 @@ def fromInternal(self, ts: int) -> datetime.datetime: microsecond=ts % 1000000, tzinfo=None ) + @property + def needsCoercion(self) -> bool: + return True + + def coerce(self, value: Any, policy: "CoercionPolicy" = CoercionPolicy.PERMISSIVE) -> Any: + # datetime -> timestamp: exact match + if value is None or isinstance(value, datetime.datetime): + return value + # All other types raise in both pickle and Arrow + raise PySparkTypeError( + errorClass="CANNOT_CONVERT_TYPE", + messageParameters={"from_type": type(value).__name__, "to_type": "TimestampType"}, + ) + class TimestampNTZType(DatetimeType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information.""" @@ -527,6 +810,34 @@ def jsonValue(self) -> str: def __repr__(self) -> str: return "DecimalType(%d,%d)" % (self.precision, self.scale) + @property + def needsCoercion(self) -> bool: + return True + + def coerce(self, value: Any, policy: "CoercionPolicy" = CoercionPolicy.PERMISSIVE) -> Any: + if value is None: + return None + # Decimal -> decimal: exact match + if isinstance(value, decimal.Decimal): + return value + # Row -> decimal: pickle raises + if _is_row_object(value): + raise PySparkTypeError( + errorClass="CANNOT_CONVERT_TYPE", + messageParameters={"from_type": "Row", "to_type": "DecimalType"}, + ) + # Other types: pickle returns None + if ( + policy == CoercionPolicy.WARN + and isinstance(value, (int, float)) + and not isinstance(value, bool) + ): + _warn_coercion_once( + f"Coercing {type(value).__name__} to decimal returns None in pickle mode " + "but raises in Arrow mode" + ) + return None + class DoubleType(FractionalType, metaclass=DataTypeSingleton): """Double data type, representing double precision floats.""" @@ -1052,6 +1363,43 @@ def fromInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: return obj return obj and [self.elementType.fromInternal(v) for v in obj] + @property + def needsCoercion(self) -> bool: + # ArrayType needs coercion if element type needs coercion, + # or if we need to handle tuple/array/bytearray -> list conversion + return True + + def coerce(self, value: Any, policy: "CoercionPolicy" = CoercionPolicy.PERMISSIVE) -> Any: + # Fast path: check most common case first (list) + if type(value) is list: + if self.elementType.needsCoercion: + # Cache attribute access for inner loop + elem_coerce = self.elementType.coerce + return [elem_coerce(v, policy) for v in value] + return value + if value is None: + return None + # Row -> array: pickle raises (check BEFORE tuple since Row is a tuple subclass) + if _is_row_object(value): + raise PySparkTypeError( + errorClass="CANNOT_CONVERT_TYPE", + messageParameters={"from_type": "Row", "to_type": "ArrayType"}, + ) + # tuple -> array: both paths convert, recursively coerce elements only if needed + if isinstance(value, (tuple, array, bytearray)): + if self.elementType.needsCoercion: + elem_coerce = self.elementType.coerce + return [elem_coerce(v, policy) for v in value] + return list(value) + # list subclass + if isinstance(value, list): + if self.elementType.needsCoercion: + elem_coerce = self.elementType.coerce + return [elem_coerce(v, policy) for v in value] + return value + # Other types: return None + return None + def _build_formatted_string( self, prefix: str, @@ -1203,6 +1551,39 @@ def fromInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: (self.keyType.fromInternal(k), self.valueType.fromInternal(v)) for k, v in obj.items() ) + @property + def needsCoercion(self) -> bool: + return True + + def coerce(self, value: Any, policy: "CoercionPolicy" = CoercionPolicy.PERMISSIVE) -> Any: + if value is None: + return None + # dict or dict subclass -> map + if isinstance(value, dict): + key_needs = self.keyType.needsCoercion + val_needs = self.valueType.needsCoercion + if not key_needs and not val_needs: + return value + elif key_needs and val_needs: + # Cache method lookups for inner loop + key_coerce = self.keyType.coerce + val_coerce = self.valueType.coerce + return {key_coerce(k, policy): val_coerce(v, policy) for k, v in value.items()} + elif key_needs: + key_coerce = self.keyType.coerce + return {key_coerce(k, policy): v for k, v in value.items()} + else: + val_coerce = self.valueType.coerce + return {k: val_coerce(v, policy) for k, v in value.items()} + # Row -> map: pickle raises + if _is_row_object(value): + raise PySparkTypeError( + errorClass="CANNOT_CONVERT_TYPE", + messageParameters={"from_type": "Row", "to_type": "MapType"}, + ) + # Other types: pickle returns None + return None + def _build_formatted_string( self, prefix: str, @@ -1835,6 +2216,57 @@ def fromInternal(self, obj: Tuple) -> "Row": values = obj return _create_row(self.names, values) + @property + def needsCoercion(self) -> bool: + return True + + def coerce(self, value: Any, policy: "CoercionPolicy" = CoercionPolicy.PERMISSIVE) -> Any: + if value is None: + return None + + fields = self.fields + names = self.names + + # Row -> struct: exact match, recursively coerce fields only if needed + if isinstance(value, Row): + coerced_values = [ + f.dataType.coerce(value[i], policy) if f.dataType.needsCoercion else value[i] + for i, f in enumerate(fields) + ] + return _create_row(names, coerced_values) + # tuple -> struct: both paths convert, recursively coerce fields only if needed + if isinstance(value, tuple): + coerced_values = [ + f.dataType.coerce(value[i], policy) if f.dataType.needsCoercion else value[i] + for i, f in enumerate(fields) + ] + return _create_row(names, coerced_values) + # dict -> struct: both paths convert (field matching), recursively coerce fields only if needed + if isinstance(value, dict): + coerced_values = [ + f.dataType.coerce(value.get(names[i]), policy) + if f.dataType.needsCoercion + else value.get(names[i]) + for i, f in enumerate(fields) + ] + return _create_row(names, coerced_values) + # list -> struct: pickle converts, recursively coerce fields only if needed + if isinstance(value, list): + if policy == CoercionPolicy.WARN: + _warn_coercion_once( + "Coercing list to struct works in pickle mode but raises in Arrow mode" + ) + coerced_values = [ + f.dataType.coerce(value[i], policy) if f.dataType.needsCoercion else value[i] + for i, f in enumerate(fields) + ] + return _create_row(names, coerced_values) + # Other types: raise + raise PySparkTypeError( + errorClass="CANNOT_CONVERT_TYPE", + messageParameters={"from_type": type(value).__name__, "to_type": "StructType"}, + ) + def _build_formatted_string( self, prefix: str, diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b65636857cfe3..ea08bfd0d7f42 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -201,6 +201,10 @@ def arrow_concurrency_level(self) -> int: def profiler(self) -> Optional[str]: return self.get("spark.sql.pyspark.udf.profiler", None) + @property + def coercion_policy(self) -> str: + return self.get("spark.sql.execution.pythonUDF.coercion.policy", "permissive") + def report_times(outfile, boot, init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) @@ -327,6 +331,10 @@ def wrap_arrow_batch_udf_arrow(f, args_offsets, kwargs_offsets, return_type, run return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) + # Coercion is handled in the serializer (ArrowBatchUDFSerializer). + # In PERMISSIVE/WARN mode: coerce -> convert -> pa.array + # In STRICT mode: convert -> pa.array (no coercion) + if zero_arg_exec: def get_args(*args: list): @@ -2865,15 +2873,19 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF and not runner_conf.use_legacy_pandas_udf_conversion ): + from pyspark.sql.types import CoercionPolicy + input_types = [ f.dataType for f in _parse_datatype_json_string(utf8_deserializer.loads(infile)) ] + coercion_policy = CoercionPolicy(runner_conf.coercion_policy.lower()) ser = ArrowBatchUDFSerializer( runner_conf.timezone, runner_conf.safecheck, input_types, runner_conf.int_to_decimal_coercion_enabled, runner_conf.binary_as_bytes, + coercion_policy, ) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d6f595c653e08..771e726c914f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4382,6 +4382,17 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PYTHON_UDF_COERCION_POLICY = + buildConf("spark.sql.execution.pythonUDF.coercion.policy") + .doc("Controls how Python UDF return values are coerced to match the declared return type. " + + "Valid values are: 'permissive' (default) - matches legacy pandas behavior with " + + "implicit type conversions; 'warn' - same as permissive but logs warnings; " + + "'strict' - raises errors for type mismatches, matching Arrow semantics.") + .version("4.2.0") + .stringConf + .checkValues(Set("permissive", "warn", "strict")) + .createWithDefault("permissive") + val PYTHON_PLANNER_EXEC_MEMORY = buildConf("spark.sql.planner.pythonExecution.memory") .doc("Specifies the memory allocation for executing Python code in Spark driver, in MiB. " + @@ -7716,6 +7727,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def legacyPandasConversionUDF: Boolean = getConf(PYTHON_UDF_LEGACY_PANDAS_CONVERSION_ENABLED) + def pythonUDFCoercionPolicy: String = getConf(PYTHON_UDF_COERCION_POLICY) + def pythonPlannerExecMemory: Option[Long] = getConf(PYTHON_PLANNER_EXEC_MEMORY) def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 39d82b4b037b0..103fc068aff83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -175,9 +175,12 @@ object ArrowPythonRunner { val profiler = conf.pythonUDFProfiler.map(p => Seq(SQLConf.PYTHON_UDF_PROFILER.key -> p) ).getOrElse(Seq.empty) + val coercionPolicy = Seq( + SQLConf.PYTHON_UDF_COERCION_POLICY.key -> + conf.pythonUDFCoercionPolicy) Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++ arrowAyncParallelism ++ useLargeVarTypes ++ intToDecimalCoercion ++ binaryAsBytes ++ - legacyPandasConversion ++ legacyPandasConversionUDF ++ profiler: _*) + legacyPandasConversion ++ legacyPandasConversionUDF ++ profiler ++ coercionPolicy: _*) } }