From 14d3f447360b66663c8979a8cdb4c40c480a1e04 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 23 May 2024 16:12:38 +0800 Subject: [PATCH] [SPARK-48395][PYTHON] Fix `StructType.treeString` for parameterized types ### What changes were proposed in this pull request? this PR is a follow up of https://github.com/apache/spark/pull/46685. ### Why are the changes needed? `StructType.treeString` uses `DataType.typeName` to generate the tree string, however, the `typeName` in python is a class method and can not return the same string for parameterized types. ``` In [2]: schema = StructType().add("c", CharType(10), True).add("v", VarcharType(10), True).add("d", DecimalType(10, 2), True).add("ym00", YearM ...: onthIntervalType(0, 0)).add("ym01", YearMonthIntervalType(0, 1)).add("ym11", YearMonthIntervalType(1, 1)) In [3]: print(schema.treeString()) root |-- c: char (nullable = true) |-- v: varchar (nullable = true) |-- d: decimal (nullable = true) |-- ym00: yearmonthinterval (nullable = true) |-- ym01: yearmonthinterval (nullable = true) |-- ym11: yearmonthinterval (nullable = true) ``` it should be ``` In [4]: print(schema.treeString()) root |-- c: char(10) (nullable = true) |-- v: varchar(10) (nullable = true) |-- d: decimal(10,2) (nullable = true) |-- ym00: interval year (nullable = true) |-- ym01: interval year to month (nullable = true) |-- ym11: interval month (nullable = true) ``` ### Does this PR introduce _any_ user-facing change? no, this feature was just added and not release out yet. ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46711 from zhengruifeng/tree_string_fix. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/tests/test_types.py | 67 ++++++++++++++++++++++++++ python/pyspark/sql/types.py | 27 +++++++++-- 2 files changed, 90 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index ec07406b11912..6c64a9471363a 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -41,6 +41,7 @@ FloatType, DateType, TimestampType, + TimestampNTZType, DayTimeIntervalType, YearMonthIntervalType, CalendarIntervalType, @@ -1411,6 +1412,72 @@ def test_tree_string(self): ], ) + def test_tree_string_for_builtin_types(self): + schema = ( + StructType() + .add("n", NullType()) + .add("str", StringType()) + .add("c", CharType(10)) + .add("v", VarcharType(10)) + .add("bin", BinaryType()) + .add("bool", BooleanType()) + .add("date", DateType()) + .add("ts", TimestampType()) + .add("ts_ntz", TimestampNTZType()) + .add("dec", DecimalType(10, 2)) + .add("double", DoubleType()) + .add("float", FloatType()) + .add("long", LongType()) + .add("int", IntegerType()) + .add("short", ShortType()) + .add("byte", ByteType()) + .add("ym_interval_1", YearMonthIntervalType()) + .add("ym_interval_2", YearMonthIntervalType(YearMonthIntervalType.YEAR)) + .add( + "ym_interval_3", + YearMonthIntervalType(YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH), + ) + .add("dt_interval_1", DayTimeIntervalType()) + .add("dt_interval_2", DayTimeIntervalType(DayTimeIntervalType.DAY)) + .add( + "dt_interval_3", + DayTimeIntervalType(DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND), + ) + .add("cal_interval", CalendarIntervalType()) + .add("var", VariantType()) + ) + self.assertEqual( + schema.treeString().split("\n"), + [ + "root", + " |-- n: void (nullable = true)", + " |-- str: string (nullable = true)", + " |-- c: char(10) (nullable = true)", + " |-- v: varchar(10) (nullable = true)", + " |-- bin: binary (nullable = true)", + " |-- bool: boolean (nullable = true)", + " |-- date: date (nullable = true)", + " |-- ts: timestamp (nullable = true)", + " |-- ts_ntz: timestamp_ntz (nullable = true)", + " |-- dec: decimal(10,2) (nullable = true)", + " |-- double: double (nullable = true)", + " |-- float: float (nullable = true)", + " |-- long: long (nullable = true)", + " |-- int: integer (nullable = true)", + " |-- short: short (nullable = true)", + " |-- byte: byte (nullable = true)", + " |-- ym_interval_1: interval year to month (nullable = true)", + " |-- ym_interval_2: interval year (nullable = true)", + " |-- ym_interval_3: interval year to month (nullable = true)", + " |-- dt_interval_1: interval day to second (nullable = true)", + " |-- dt_interval_2: interval day (nullable = true)", + " |-- dt_interval_3: interval hour to second (nullable = true)", + " |-- cal_interval: interval (nullable = true)", + " |-- var: variant (nullable = true)", + "", + ], + ) + def test_metadata_null(self): schema = StructType( [ diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ee0cc9db5c445..17b019240f826 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -215,6 +215,24 @@ def _data_type_build_formatted_string( if isinstance(dataType, (ArrayType, StructType, MapType)): dataType._build_formatted_string(prefix, stringConcat, maxDepth - 1) + # The method typeName() is not always the same as the Scala side. + # Add this helper method to make TreeString() compatible with Scala side. + @classmethod + def _get_jvm_type_name(cls, dataType: "DataType") -> str: + if isinstance( + dataType, + ( + DecimalType, + CharType, + VarcharType, + DayTimeIntervalType, + YearMonthIntervalType, + ), + ): + return dataType.simpleString() + else: + return dataType.typeName() + # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle @@ -758,7 +776,7 @@ def _build_formatted_string( ) -> None: if maxDepth > 0: stringConcat.append( - f"{prefix}-- element: {self.elementType.typeName()} " + f"{prefix}-- element: {DataType._get_jvm_type_name(self.elementType)} " + f"(containsNull = {str(self.containsNull).lower()})\n" ) DataType._data_type_build_formatted_string( @@ -906,12 +924,12 @@ def _build_formatted_string( maxDepth: int = JVM_INT_MAX, ) -> None: if maxDepth > 0: - stringConcat.append(f"{prefix}-- key: {self.keyType.typeName()}\n") + stringConcat.append(f"{prefix}-- key: {DataType._get_jvm_type_name(self.keyType)}\n") DataType._data_type_build_formatted_string( self.keyType, f"{prefix} |", stringConcat, maxDepth ) stringConcat.append( - f"{prefix}-- value: {self.valueType.typeName()} " + f"{prefix}-- value: {DataType._get_jvm_type_name(self.valueType)} " + f"(valueContainsNull = {str(self.valueContainsNull).lower()})\n" ) DataType._data_type_build_formatted_string( @@ -1074,7 +1092,8 @@ def _build_formatted_string( ) -> None: if maxDepth > 0: stringConcat.append( - f"{prefix}-- {escape_meta_characters(self.name)}: {self.dataType.typeName()} " + f"{prefix}-- {escape_meta_characters(self.name)}: " + + f"{DataType._get_jvm_type_name(self.dataType)} " + f"(nullable = {str(self.nullable).lower()})\n" ) DataType._data_type_build_formatted_string(