Skip to content

Commit

Permalink
[SPARK-49383][SQL][PYTHON][CONNECT] Support Transpose DataFrame API
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
The PR is proposed to support Transpose as Scala/Python DataFrame API in both Spark Connect and Classic Spark.

Please see https://docs.google.com/document/d/1QSmG81qQ-muab0UOeqgDAELqF7fJTH8GnxCJF4Ir-kA/edit for a detailed design.

### Why are the changes needed?
Transposing data is a crucial operation in data analysis, enabling the transformation of rows into columns. This operation is widely used in tools like pandas and numpy, allowing for more flexible data manipulation and visualization.

While Apache Spark supports unpivot and pivot operations, it currently lacks a built-in transpose function. Implementing a transpose operation in Spark would enhance its data processing capabilities, aligning it with the functionalities available in pandas and numpy, and further empowering users in their data analysis workflows.

### Does this PR introduce _any_ user-facing change?
Yes Transpose is supported.

**Scala**
```scala
scala> df.show()
+---+---+---+
|  a|  b|  c|
+---+---+---+
|  x|  y|  z|
+---+---+---+

scala> df.transpose().show()
+---+---+
|key|  x|
+---+---+
|  b|  y|
|  c|  z|
+---+---+

scala> df.transpose($"b").show()
+---+---+
|key|  y|
+---+---+
|  a|  x|
|  c|  z|
+---+---+

```

**Python**
```py
>>> df.show()
+---+---+---+
|  a|  b|  c|
+---+---+---+
|  x|  y|  z|
+---+---+---+

>>> df.transpose().show()
+---+---+
|key|  x|
+---+---+
|  b|  y|
|  c|  z|
+---+---+
>>> df.transpose(df.b).show()
+---+---+
|key|  y|
+---+---+
|  a|  x|
|  c|  z|
+---+---+
```

**Spark Plan**
```scala
scala> df.show()
+---+---+---+
|  a|  b|  c|
+---+---+---+
|  x|  y|  z|
+---+---+---+
scala> df.transpose().explain(true)
== Parsed Logical Plan ==
'UnresolvedTranspose a#48: string
+- LocalRelation [a#48, b#49, c#50]

== Analyzed Logical Plan ==
key: string, x: string
Transpose [key#83, x#84], [[b,y], [c,z]], true

== Optimized Logical Plan ==
LocalRelation [key#83, x#84]

== Physical Plan ==
LocalTableScan [key#83, x#84]

```
```python
# empty frame with no column headers

>>> empty_df.show()
++
||
++
++

>>> empty_df.transpose().explain(True)
== Parsed Logical Plan ==
'UnresolvedTranspose
+- LogicalRDD false

== Analyzed Logical Plan ==
Transpose false

== Optimized Logical Plan ==
LocalRelation <empty>

== Physical Plan ==
LocalTableScan <empty>

# empty frame with column headers

>>> empty_df.show()
+-------+-------+-------+
|column1|column2|column3|
+-------+-------+-------+
+-------+-------+-------+

>>> empty_df.transpose().explain(True)
== Parsed Logical Plan ==
'UnresolvedTranspose column1#0: string
+- LogicalRDD [column1#0, column2#1, column3#2], false

== Analyzed Logical Plan ==
key: string
Transpose [key#32], [[column2], [column3]], true

== Optimized Logical Plan ==
LocalRelation [key#32]

== Physical Plan ==
LocalTableScan [key#32]
```

### How was this patch tested?
**Spark Connect**

- Python
	- doctest
	- module: python.pyspark.sql.tests.connect.test_parity_dataframe
	- case: test_transpose
- Proto
	- suite: org.apache.spark.sql.PlanGenerationTestSuite
	- case: transpose index_column transpose no_index_colum,

**Spark Classic**

- Python
	- doctest
	- module: python.pyspark.sql.tests.test_dataframe
	- case: test_transpose
- Scala
	- suite: org.apache.spark.sql.DataFrameTransposeSuite
	- case: all

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #47884 from xinrong-meng/transpose_dataframe_api.

Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
xinrong-meng authored and cloud-fan committed Sep 6, 2024
1 parent 26e59f2 commit 23bea28
Show file tree
Hide file tree
Showing 30 changed files with 1,010 additions and 155 deletions.
18 changes: 18 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -4356,6 +4356,24 @@
],
"sqlState" : "428EK"
},
"TRANSPOSE_EXCEED_ROW_LIMIT" : {
"message" : [
"Number of rows exceeds the allowed limit of <maxValues> for TRANSPOSE. If this was intended, set <config> to at least the current row count."
],
"sqlState" : "54006"
},
"TRANSPOSE_INVALID_INDEX_COLUMN" : {
"message" : [
"Invalid index column for TRANSPOSE because: <reason>"
],
"sqlState" : "42804"
},
"TRANSPOSE_NO_LEAST_COMMON_TYPE" : {
"message" : [
"Transpose requires non-index columns to share a least common type, but <dt1> and <dt2> do not."
],
"sqlState" : "42K09"
},
"UDTF_ALIAS_NUMBER_MISMATCH" : {
"message" : [
"The number of aliases supplied in the AS clause does not match the number of columns output by the UDTF.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,14 @@ class Dataset[T] private[sql] (
}
}

private def buildTranspose(indices: Seq[Column]): DataFrame =
sparkSession.newDataFrame { builder =>
val transpose = builder.getTransposeBuilder.setInput(plan.getRoot)
indices.foreach { indexColumn =>
transpose.addIndexColumns(indexColumn.expr)
}
}

/** @inheritdoc */
@scala.annotation.varargs
def groupBy(cols: Column*): RelationalGroupedDataset = {
Expand Down Expand Up @@ -582,6 +590,14 @@ class Dataset[T] private[sql] (
buildUnpivot(ids, None, variableColumnName, valueColumnName)
}

/** @inheritdoc */
def transpose(indexColumn: Column): DataFrame =
buildTranspose(Seq(indexColumn))

/** @inheritdoc */
def transpose(): DataFrame =
buildTranspose(Seq.empty)

/** @inheritdoc */
def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getLimitBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,14 @@ class PlanGenerationTestSuite
valueColumnName = "value")
}

test("transpose index_column") {
simple.transpose(indexColumn = fn.col("id"))
}

test("transpose no_index_column") {
simple.transpose()
}

test("offset") {
simple.offset(1000)
}
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1849,6 +1849,12 @@ def toArrow(self) -> "pa.Table":
def toPandas(self) -> "PandasDataFrameLike":
return PandasConversionMixin.toPandas(self)

def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> ParentDataFrame:
if indexColumn is not None:
return DataFrame(self._jdf.transpose(_to_java_column(indexColumn)), self.sparkSession)
else:
return DataFrame(self._jdf.transpose(), self.sparkSession)

@property
def executionInfo(self) -> Optional["ExecutionInfo"]:
raise PySparkValueError(
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,12 @@ def toPandas(self) -> "PandasDataFrameLike":
self._execution_info = ei
return pdf

def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> ParentDataFrame:
return DataFrame(
plan.Transpose(self._plan, [F._to_col(indexColumn)] if indexColumn is not None else []),
self._session,
)

@property
def schema(self) -> StructType:
# Schema caching is correct in most cases. Connect is lazy by nature. This means that
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,27 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan


class Transpose(LogicalPlan):
"""Logical plan object for a transpose operation."""

def __init__(
self,
child: Optional["LogicalPlan"],
index_columns: Sequence[Column],
) -> None:
super().__init__(child)
self.index_columns = index_columns

def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.transpose.input.CopyFrom(self._child.plan(session))
if self.index_columns is not None and len(self.index_columns) > 0:
for index_column in self.index_columns:
plan.transpose.index_columns.append(index_column.to_plan(session))
return plan


class CollectMetrics(LogicalPlan):
"""Logical plan object for a CollectMetrics operation."""

Expand Down
302 changes: 152 additions & 150 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class Relation(google.protobuf.message.Message):
AS_OF_JOIN_FIELD_NUMBER: builtins.int
COMMON_INLINE_USER_DEFINED_DATA_SOURCE_FIELD_NUMBER: builtins.int
WITH_RELATIONS_FIELD_NUMBER: builtins.int
TRANSPOSE_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -205,6 +206,8 @@ class Relation(google.protobuf.message.Message):
@property
def with_relations(self) -> global___WithRelations: ...
@property
def transpose(self) -> global___Transpose: ...
@property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
Expand Down Expand Up @@ -284,6 +287,7 @@ class Relation(google.protobuf.message.Message):
common_inline_user_defined_data_source: global___CommonInlineUserDefinedDataSource
| None = ...,
with_relations: global___WithRelations | None = ...,
transpose: global___Transpose | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
Expand Down Expand Up @@ -402,6 +406,8 @@ class Relation(google.protobuf.message.Message):
b"to_df",
"to_schema",
b"to_schema",
"transpose",
b"transpose",
"unknown",
b"unknown",
"unpivot",
Expand Down Expand Up @@ -519,6 +525,8 @@ class Relation(google.protobuf.message.Message):
b"to_df",
"to_schema",
b"to_schema",
"transpose",
b"transpose",
"unknown",
b"unknown",
"unpivot",
Expand Down Expand Up @@ -577,6 +585,7 @@ class Relation(google.protobuf.message.Message):
"as_of_join",
"common_inline_user_defined_data_source",
"with_relations",
"transpose",
"fill_na",
"drop_na",
"replace",
Expand Down Expand Up @@ -3141,6 +3150,47 @@ class Unpivot(google.protobuf.message.Message):

global___Unpivot = Unpivot

class Transpose(google.protobuf.message.Message):
"""Transpose a DataFrame, switching rows to columns.
Transforms the DataFrame such that the values in the specified index column
become the new columns of the DataFrame.
"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

INPUT_FIELD_NUMBER: builtins.int
INDEX_COLUMNS_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) The input relation."""
@property
def index_columns(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
pyspark.sql.connect.proto.expressions_pb2.Expression
]:
"""(Optional) A list of columns that will be treated as the indices.
Only single column is supported now.
"""
def __init__(
self,
*,
input: global___Relation | None = ...,
index_columns: collections.abc.Iterable[
pyspark.sql.connect.proto.expressions_pb2.Expression
]
| None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["input", b"input"]
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal["index_columns", b"index_columns", "input", b"input"],
) -> None: ...

global___Transpose = Transpose

class ToSchema(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

Expand Down
66 changes: 66 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6311,6 +6311,72 @@ def toPandas(self) -> "PandasDataFrameLike":
"""
...

@dispatch_df_method
def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> "DataFrame":
"""
Transposes a DataFrame such that the values in the specified index column become the new
columns of the DataFrame. If no index column is provided, the first column is used as
the default.
Please note:
- All columns except the index column must share a least common data type. Unless they
are the same data type, all columns are cast to the nearest common data type.
- The name of the column into which the original column names are transposed defaults
to "key".
- null values in the index column are excluded from the column names for the
transposed table, which are ordered in ascending order.
.. versionadded:: 4.0.0
Parameters
----------
indexColumn : str or :class:`Column`, optional
The single column that will be treated as the index for the transpose operation. This
column will be used to transform the DataFrame such that the values of the indexColumn
become the new columns in the transposed DataFrame. If not provided, the first column of
the DataFrame will be used as the default.
Returns
-------
:class:`DataFrame`
Transposed DataFrame.
Notes
-----
Supports Spark Connect.
Examples
--------
>>> df = spark.createDataFrame(
... [("A", 1, 2), ("B", 3, 4)],
... ["id", "val1", "val2"],
... )
>>> df.show()
+---+----+----+
| id|val1|val2|
+---+----+----+
| A| 1| 2|
| B| 3| 4|
+---+----+----+
>>> df.transpose().show()
+----+---+---+
| key| A| B|
+----+---+---+
|val1| 1| 3|
|val2| 2| 4|
+----+---+---+
>>> df.transpose(df.id).show()
+----+---+---+
| key| A| B|
+----+---+---+
|val1| 1| 3|
|val2| 2| 4|
+----+---+---+
"""
...

@property
def executionInfo(self) -> Optional["ExecutionInfo"]:
"""
Expand Down
69 changes: 69 additions & 0 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
PySparkTypeError,
PySparkValueError,
)
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pyarrow,
Expand Down Expand Up @@ -955,6 +956,74 @@ def test_checkpoint_dataframe(self):
self.spark.range(1).localCheckpoint().explain()
self.assertIn("ExistingRDD", buf.getvalue())

def test_transpose(self):
df = self.spark.createDataFrame([{"a": "x", "b": "y", "c": "z"}])

# default index column
transposed_df = df.transpose()
expected_schema = StructType(
[StructField("key", StringType(), False), StructField("x", StringType(), True)]
)
expected_data = [Row(key="b", x="y"), Row(key="c", x="z")]
expected_df = self.spark.createDataFrame(expected_data, schema=expected_schema)
assertDataFrameEqual(transposed_df, expected_df, checkRowOrder=True)

# specified index column
transposed_df = df.transpose("c")
expected_schema = StructType(
[StructField("key", StringType(), False), StructField("z", StringType(), True)]
)
expected_data = [Row(key="a", z="x"), Row(key="b", z="y")]
expected_df = self.spark.createDataFrame(expected_data, schema=expected_schema)
assertDataFrameEqual(transposed_df, expected_df, checkRowOrder=True)

# enforce transpose max values
with self.sql_conf({"spark.sql.transposeMaxValues": 0}):
with self.assertRaises(AnalysisException) as pe:
df.transpose().collect()
self.check_error(
exception=pe.exception,
errorClass="TRANSPOSE_EXCEED_ROW_LIMIT",
messageParameters={"maxValues": "0", "config": "spark.sql.transposeMaxValues"},
)

# enforce ascending order based on index column values for transposed columns
df = self.spark.createDataFrame([{"a": "z"}, {"a": "y"}, {"a": "x"}])
transposed_df = df.transpose()
expected_schema = StructType(
[
StructField("key", StringType(), False),
StructField("x", StringType(), True),
StructField("y", StringType(), True),
StructField("z", StringType(), True),
]
) # z, y, x -> x, y, z
expected_df = self.spark.createDataFrame([], schema=expected_schema)
assertDataFrameEqual(transposed_df, expected_df, checkRowOrder=True)

# enforce AtomicType Attribute for index column values
df = self.spark.createDataFrame([{"a": ["x", "x"], "b": "y", "c": "z"}])
with self.assertRaises(AnalysisException) as pe:
df.transpose().collect()
self.check_error(
exception=pe.exception,
errorClass="TRANSPOSE_INVALID_INDEX_COLUMN",
messageParameters={
"reason": "Index column must be of atomic type, "
"but found: ArrayType(StringType,true)"
},
)

# enforce least common type for non-index columns
df = self.spark.createDataFrame([{"a": "x", "b": "y", "c": 1}])
with self.assertRaises(AnalysisException) as pe:
df.transpose().collect()
self.check_error(
exception=pe.exception,
errorClass="TRANSPOSE_NO_LEAST_COMMON_TYPE",
messageParameters={"dt1": "STRING", "dt2": "BIGINT"},
)


class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase):
def test_query_execution_unsupported_in_classic(self):
Expand Down
Loading

0 comments on commit 23bea28

Please sign in to comment.