Skip to content

Commit

Permalink
[SPARK-49734][PYTHON] Add seed argument for function shuffle
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
1, Add `seed` argument for function `shuffle`;
2, Rewrite and enable the doctest by specify the seed and control the partitioning;

### Why are the changes needed?
feature parity, seed is support in SQL side

### Does this PR introduce _any_ user-facing change?
yes, new argument

### How was this patch tested?
updated doctest

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #48184 from zhengruifeng/py_func_shuffle.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Sep 23, 2024
1 parent 719b57a commit 0eeb61f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 39 deletions.
10 changes: 3 additions & 7 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
from pyspark.sql.types import (
_from_numpy_type,
DataType,
LongType,
StructType,
ArrayType,
StringType,
Expand Down Expand Up @@ -2206,12 +2205,9 @@ def schema_of_xml(xml: Union[str, Column], options: Optional[Mapping[str, str]]
schema_of_xml.__doc__ = pysparkfuncs.schema_of_xml.__doc__


def shuffle(col: "ColumnOrName") -> Column:
return _invoke_function(
"shuffle",
_to_col(col),
LiteralExpression(random.randint(0, sys.maxsize), LongType()),
)
def shuffle(col: "ColumnOrName", seed: Optional[Union[Column, int]] = None) -> Column:
_seed = lit(random.randint(0, sys.maxsize)) if seed is None else lit(seed)
return _invoke_function("shuffle", _to_col(col), _seed)


shuffle.__doc__ = pysparkfuncs.shuffle.__doc__
Expand Down
69 changes: 38 additions & 31 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17723,7 +17723,7 @@ def array_sort(


@_try_remote_functions
def shuffle(col: "ColumnOrName") -> Column:
def shuffle(col: "ColumnOrName", seed: Optional[Union[Column, int]] = None) -> Column:
"""
Array function: Generates a random permutation of the given array.

Expand All @@ -17736,6 +17736,10 @@ def shuffle(col: "ColumnOrName") -> Column:
----------
col : :class:`~pyspark.sql.Column` or str
The name of the column or expression to be shuffled.
seed : :class:`~pyspark.sql.Column` or int, optional
Seed value for the random generator.

.. versionadded:: 4.0.0

Returns
-------
Expand All @@ -17752,48 +17756,51 @@ def shuffle(col: "ColumnOrName") -> Column:
Example 1: Shuffling a simple array

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([([1, 20, 3, 5],)], ['data'])
>>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP
+-------------+
|shuffle(data)|
+-------------+
|[1, 3, 20, 5]|
+-------------+
>>> df = spark.sql("SELECT ARRAY(1, 20, 3, 5) AS data")
>>> df.select("*", sf.shuffle(df.data, sf.lit(123))).show()
+-------------+-------------+
| data|shuffle(data)|
+-------------+-------------+
|[1, 20, 3, 5]|[5, 1, 20, 3]|
+-------------+-------------+

Example 2: Shuffling an array with null values

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([([1, 20, None, 3],)], ['data'])
>>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP
+----------------+
| shuffle(data)|
+----------------+
|[20, 3, NULL, 1]|
+----------------+
>>> df = spark.sql("SELECT ARRAY(1, 20, NULL, 5) AS data")
>>> df.select("*", sf.shuffle(sf.col("data"), 234)).show()
+----------------+----------------+
| data| shuffle(data)|
+----------------+----------------+
|[1, 20, NULL, 5]|[NULL, 5, 20, 1]|
+----------------+----------------+

Example 3: Shuffling an array with duplicate values

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([([1, 2, 2, 3, 3, 3],)], ['data'])
>>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP
+------------------+
| shuffle(data)|
+------------------+
|[3, 2, 1, 3, 2, 3]|
+------------------+
>>> df = spark.sql("SELECT ARRAY(1, 2, 2, 3, 3, 3) AS data")
>>> df.select("*", sf.shuffle("data", 345)).show()
+------------------+------------------+
| data| shuffle(data)|
+------------------+------------------+
|[1, 2, 2, 3, 3, 3]|[2, 3, 3, 1, 2, 3]|
+------------------+------------------+

Example 4: Shuffling an array with different types of elements
Example 4: Shuffling an array with random seed

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([(['a', 'b', 'c', 1, 2, 3],)], ['data'])
>>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP
+------------------+
| shuffle(data)|
+------------------+
|[1, c, 2, a, b, 3]|
+------------------+
>>> df = spark.sql("SELECT ARRAY(1, 2, 2, 3, 3, 3) AS data")
>>> df.select("*", sf.shuffle("data")).show() # doctest: +SKIP
+------------------+------------------+
| data| shuffle(data)|
+------------------+------------------+
|[1, 2, 2, 3, 3, 3]|[3, 3, 2, 3, 2, 1]|
+------------------+------------------+
"""
return _invoke_function_over_columns("shuffle", col)
if seed is not None:
return _invoke_function_over_columns("shuffle", col, lit(seed))
else:
return _invoke_function_over_columns("shuffle", col)


@_try_remote_functions
Expand Down
13 changes: 12 additions & 1 deletion sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7252,7 +7252,18 @@ object functions {
* @group array_funcs
* @since 2.4.0
*/
def shuffle(e: Column): Column = Column.fn("shuffle", e, lit(SparkClassUtils.random.nextLong))
def shuffle(e: Column): Column = shuffle(e, lit(SparkClassUtils.random.nextLong))

/**
* Returns a random permutation of the given array.
*
* @note
* The function is non-deterministic.
*
* @group array_funcs
* @since 4.0.0
*/
def shuffle(e: Column, seed: Column): Column = Column.fn("shuffle", e, seed)

/**
* Returns a reversed string or an array with reverse order of elements.
Expand Down

0 comments on commit 0eeb61f

Please sign in to comment.