Skip to content

Commit

Permalink
[SPARK-48395][PYTHON] Fix StructType.treeString for parameterized t…
Browse files Browse the repository at this point in the history
…ypes

### What changes were proposed in this pull request?
this PR is a follow up of #46685.

### Why are the changes needed?
`StructType.treeString` uses `DataType.typeName` to generate the tree string, however, the `typeName` in python is a class method and can not return the same string for parameterized types.

```
In [2]: schema = StructType().add("c", CharType(10), True).add("v", VarcharType(10), True).add("d", DecimalType(10, 2), True).add("ym00", YearM
   ...: onthIntervalType(0, 0)).add("ym01", YearMonthIntervalType(0, 1)).add("ym11", YearMonthIntervalType(1, 1))

In [3]: print(schema.treeString())
root
 |-- c: char (nullable = true)
 |-- v: varchar (nullable = true)
 |-- d: decimal (nullable = true)
 |-- ym00: yearmonthinterval (nullable = true)
 |-- ym01: yearmonthinterval (nullable = true)
 |-- ym11: yearmonthinterval (nullable = true)
```

it should be
```
In [4]: print(schema.treeString())
root
 |-- c: char(10) (nullable = true)
 |-- v: varchar(10) (nullable = true)
 |-- d: decimal(10,2) (nullable = true)
 |-- ym00: interval year (nullable = true)
 |-- ym01: interval year to month (nullable = true)
 |-- ym11: interval month (nullable = true)
```

### Does this PR introduce _any_ user-facing change?
no, this feature was just added and not release out yet.

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #46711 from zhengruifeng/tree_string_fix.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed May 23, 2024
1 parent e8f58a9 commit 14d3f44
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 4 deletions.
67 changes: 67 additions & 0 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
FloatType,
DateType,
TimestampType,
TimestampNTZType,
DayTimeIntervalType,
YearMonthIntervalType,
CalendarIntervalType,
Expand Down Expand Up @@ -1411,6 +1412,72 @@ def test_tree_string(self):
],
)

def test_tree_string_for_builtin_types(self):
schema = (
StructType()
.add("n", NullType())
.add("str", StringType())
.add("c", CharType(10))
.add("v", VarcharType(10))
.add("bin", BinaryType())
.add("bool", BooleanType())
.add("date", DateType())
.add("ts", TimestampType())
.add("ts_ntz", TimestampNTZType())
.add("dec", DecimalType(10, 2))
.add("double", DoubleType())
.add("float", FloatType())
.add("long", LongType())
.add("int", IntegerType())
.add("short", ShortType())
.add("byte", ByteType())
.add("ym_interval_1", YearMonthIntervalType())
.add("ym_interval_2", YearMonthIntervalType(YearMonthIntervalType.YEAR))
.add(
"ym_interval_3",
YearMonthIntervalType(YearMonthIntervalType.YEAR, YearMonthIntervalType.MONTH),
)
.add("dt_interval_1", DayTimeIntervalType())
.add("dt_interval_2", DayTimeIntervalType(DayTimeIntervalType.DAY))
.add(
"dt_interval_3",
DayTimeIntervalType(DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND),
)
.add("cal_interval", CalendarIntervalType())
.add("var", VariantType())
)
self.assertEqual(
schema.treeString().split("\n"),
[
"root",
" |-- n: void (nullable = true)",
" |-- str: string (nullable = true)",
" |-- c: char(10) (nullable = true)",
" |-- v: varchar(10) (nullable = true)",
" |-- bin: binary (nullable = true)",
" |-- bool: boolean (nullable = true)",
" |-- date: date (nullable = true)",
" |-- ts: timestamp (nullable = true)",
" |-- ts_ntz: timestamp_ntz (nullable = true)",
" |-- dec: decimal(10,2) (nullable = true)",
" |-- double: double (nullable = true)",
" |-- float: float (nullable = true)",
" |-- long: long (nullable = true)",
" |-- int: integer (nullable = true)",
" |-- short: short (nullable = true)",
" |-- byte: byte (nullable = true)",
" |-- ym_interval_1: interval year to month (nullable = true)",
" |-- ym_interval_2: interval year (nullable = true)",
" |-- ym_interval_3: interval year to month (nullable = true)",
" |-- dt_interval_1: interval day to second (nullable = true)",
" |-- dt_interval_2: interval day (nullable = true)",
" |-- dt_interval_3: interval hour to second (nullable = true)",
" |-- cal_interval: interval (nullable = true)",
" |-- var: variant (nullable = true)",
"",
],
)

def test_metadata_null(self):
schema = StructType(
[
Expand Down
27 changes: 23 additions & 4 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,24 @@ def _data_type_build_formatted_string(
if isinstance(dataType, (ArrayType, StructType, MapType)):
dataType._build_formatted_string(prefix, stringConcat, maxDepth - 1)

# The method typeName() is not always the same as the Scala side.
# Add this helper method to make TreeString() compatible with Scala side.
@classmethod
def _get_jvm_type_name(cls, dataType: "DataType") -> str:
if isinstance(
dataType,
(
DecimalType,
CharType,
VarcharType,
DayTimeIntervalType,
YearMonthIntervalType,
),
):
return dataType.simpleString()
else:
return dataType.typeName()


# This singleton pattern does not work with pickle, you will get
# another object after pickle and unpickle
Expand Down Expand Up @@ -758,7 +776,7 @@ def _build_formatted_string(
) -> None:
if maxDepth > 0:
stringConcat.append(
f"{prefix}-- element: {self.elementType.typeName()} "
f"{prefix}-- element: {DataType._get_jvm_type_name(self.elementType)} "
+ f"(containsNull = {str(self.containsNull).lower()})\n"
)
DataType._data_type_build_formatted_string(
Expand Down Expand Up @@ -906,12 +924,12 @@ def _build_formatted_string(
maxDepth: int = JVM_INT_MAX,
) -> None:
if maxDepth > 0:
stringConcat.append(f"{prefix}-- key: {self.keyType.typeName()}\n")
stringConcat.append(f"{prefix}-- key: {DataType._get_jvm_type_name(self.keyType)}\n")
DataType._data_type_build_formatted_string(
self.keyType, f"{prefix} |", stringConcat, maxDepth
)
stringConcat.append(
f"{prefix}-- value: {self.valueType.typeName()} "
f"{prefix}-- value: {DataType._get_jvm_type_name(self.valueType)} "
+ f"(valueContainsNull = {str(self.valueContainsNull).lower()})\n"
)
DataType._data_type_build_formatted_string(
Expand Down Expand Up @@ -1074,7 +1092,8 @@ def _build_formatted_string(
) -> None:
if maxDepth > 0:
stringConcat.append(
f"{prefix}-- {escape_meta_characters(self.name)}: {self.dataType.typeName()} "
f"{prefix}-- {escape_meta_characters(self.name)}: "
+ f"{DataType._get_jvm_type_name(self.dataType)} "
+ f"(nullable = {str(self.nullable).lower()})\n"
)
DataType._data_type_build_formatted_string(
Expand Down

0 comments on commit 14d3f44

Please sign in to comment.