Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49552][PYTHON] Add DataFrame API support for new 'randstr' and 'uniform' SQL functions #48143

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/docs/source/reference/pyspark.sql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Mathematical Functions
try_multiply
try_subtract
unhex
uniform
width_bucket


Expand Down Expand Up @@ -189,6 +190,7 @@ String Functions
overlay
position
printf
randstr
regexp_count
regexp_extract
regexp_extract_all
Expand Down
22 changes: 22 additions & 0 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,18 @@ def unhex(col: "ColumnOrName") -> Column:
unhex.__doc__ = pysparkfuncs.unhex.__doc__


def uniform(
min: "ColumnOrName", max: "ColumnOrName", seed: Optional["ColumnOrName"] = None
) -> Column:
if seed is None:
return _invoke_function_over_columns("uniform", min, max)
else:
return _invoke_function_over_columns("uniform", min, max, seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return _invoke_function_over_columns("uniform", min, max, seed)
return _invoke_function_over_columns("uniform", lit(min), lit(max), lit(seed))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is done.



uniform.__doc__ = pysparkfuncs.uniform.__doc__


def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column:
warnings.warn("Deprecated in 3.4, use approx_count_distinct instead.", FutureWarning)
return approx_count_distinct(col, rsd)
Expand Down Expand Up @@ -2578,6 +2590,16 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
regexp_like.__doc__ = pysparkfuncs.regexp_like.__doc__


def randstr(length: "ColumnOrName", seed: Optional["ColumnOrName"] = None) -> Column:
dtenedor marked this conversation as resolved.
Show resolved Hide resolved
if seed is None:
return _invoke_function_over_columns("randstr", length)
dtenedor marked this conversation as resolved.
Show resolved Hide resolved
else:
return _invoke_function_over_columns("randstr", length, seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return _invoke_function_over_columns("randstr", length, seed)
return _invoke_function_over_columns("randstr", lit(length), lit(seed))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is done.



randstr.__doc__ = pysparkfuncs.randstr.__doc__


def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
return _invoke_function_over_columns("regexp_count", str, regexp)

Expand Down
68 changes: 68 additions & 0 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11861,6 +11861,37 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
return _invoke_function_over_columns("regexp_like", str, regexp)


@_try_remote_functions
def randstr(length: "ColumnOrName", seed: Optional["ColumnOrName"] = None) -> Column:
dtenedor marked this conversation as resolved.
Show resolved Hide resolved
"""Returns a string of the specified length whose characters are chosen uniformly at random from
the following pool of characters: 0-9, a-z, A-Z. The random seed is optional. The string length
must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively).

.. versionadded:: 4.0.0

Parameters
----------
length : :class:`~pyspark.sql.Column` or int
Number of characters in the string to generate.
seed : :class:`~pyspark.sql.Column` or int
Optional random number seed to use.

Returns
-------
:class:`~pyspark.sql.Column`
The generated random string with the specified length.

Examples
--------
>>> spark.createDataFrame([('3',)], ['a']).select(randstr(5).alias('x')).select(isnull('x')).collect()
dtenedor marked this conversation as resolved.
Show resolved Hide resolved
[Row(false)]
"""
if seed is None:
return _invoke_function_over_columns("randstr", length)
else:
return _invoke_function_over_columns("randstr", length, seed)


@_try_remote_functions
def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
r"""Returns a count of the number of times that the Java regex pattern `regexp` is matched
Expand Down Expand Up @@ -12227,6 +12258,43 @@ def unhex(col: "ColumnOrName") -> Column:
return _invoke_function_over_columns("unhex", col)


@_try_remote_functions
def uniform(
dtenedor marked this conversation as resolved.
Show resolved Hide resolved
min: "ColumnOrName", max: "ColumnOrName", seed: Optional["ColumnOrName"] = None
) -> Column:
"""Returns a random value with independent and identically distributed (i.i.d.) values with the
specified range of numbers. The random seed is optional. The provided numbers specifying the
minimum and maximum values of the range must be constant. If both of these numbers are integers,
then the result will also be an integer. Otherwise if one or both of these are floating-point
numbers, then the result will also be a floating-point number.

.. versionadded:: 4.0.0

Parameters
----------
min : :class:`~pyspark.sql.Column`, int, or float
Minimum value in the range.
max : :class:`~pyspark.sql.Column`, int, or float
Maximum value in the range.
seed : :class:`~pyspark.sql.Column` or int
Optional random number seed to use.

Returns
-------
:class:`~pyspark.sql.Column`
The generated random number within the specified range.

Examples
--------
>>> spark.createDataFrame([('3',)], ['a']).select(uniform(0, 10).alias('x')).select(isnull('x')).collect()
dtenedor marked this conversation as resolved.
Show resolved Hide resolved
[Row(false)]
"""
if seed is None:
return _invoke_function_over_columns("uniform", min, max)
else:
return _invoke_function_over_columns("uniform", min, max, seed)


@_try_remote_functions
def length(col: "ColumnOrName") -> Column:
"""Computes the character length of string data or number of bytes of binary data.
Expand Down
21 changes: 20 additions & 1 deletion python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pyspark.sql import Row, Window, functions as F, types
from pyspark.sql.avro.functions import from_avro, to_avro
from pyspark.sql.column import Column
from pyspark.sql.functions.builtin import nullifzero, zeroifnull
from pyspark.sql.functions.builtin import isnull, length, nullifzero, randstr, uniform, zeroifnull
from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
from pyspark.testing.utils import have_numpy

Expand Down Expand Up @@ -1610,6 +1610,25 @@ def test_nullifzero_zeroifnull(self):
result = df.select(zeroifnull(df.a).alias("r")).collect()
self.assertEqual([Row(r=0), Row(r=1)], result)

def test_randstr_uniform(self):
df = self.spark.createDataFrame([(0,)], ["a"])
result = df.select(randstr(F.lit(5), F.lit(0)).alias("x")).selectExpr("length(x)").collect()
self.assertEqual([Row(5)], result)
# The random seed is optional.
result = df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)").collect()
self.assertEqual([Row(5)], result)

df = self.spark.createDataFrame([(0,)], ["a"])
result = (
df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x"))
.selectExpr("x > 5")
.collect()
)
self.assertEqual([Row(True)], result)
# The random seed is optional.
result = df.select(uniform(F.lit(10), F.lit(20)).alias("x")).selectExpr("x > 5").collect()
self.assertEqual([Row(True)], result)


class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):
pass
Expand Down
45 changes: 45 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1884,6 +1884,26 @@ object functions {
*/
def randn(): Column = randn(SparkClassUtils.random.nextLong)

/**
* Returns a string of the specified length whose characters are chosen uniformly at random from
* the following pool of characters: 0-9, a-z, A-Z. The string length must be a constant
* two-byte or four-byte integer (SMALLINT or INT, respectively).
*
* @group string_funcs
* @since 4.0.0
*/
def randstr(length: Column): Column = Column.fn("randstr", length)

/**
* Returns a string of the specified length whose characters are chosen uniformly at random from
* the following pool of characters: 0-9, a-z, A-Z, with the chosen random seed. The string
* length must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively).
*
* @group string_funcs
* @since 4.0.0
*/
def randstr(length: Column, seed: Column): Column = Column.fn("randstr", length, seed)

/**
* Partition ID.
*
Expand Down Expand Up @@ -3728,6 +3748,31 @@ object functions {
*/
def stack(cols: Column*): Column = Column.fn("stack", cols: _*)

/**
* Returns a random value with independent and identically distributed (i.i.d.) values with the
* specified range of numbers. The provided numbers specifying the minimum and maximum values of
* the range must be constant. If both of these numbers are integers, then the result will also
* be an integer. Otherwise if one or both of these are floating-point numbers, then the result
* will also be a floating-point number.
*
* @group math_funcs
* @since 4.0.0
*/
def uniform(min: Column, max: Column): Column = Column.fn("uniform", min, max)

/**
* Returns a random value with independent and identically distributed (i.i.d.) values with the
* specified range of numbers, with the chosen random seed. The provided numbers specifying the
* minimum and maximum values of the range must be constant. If both of these numbers are
* integers, then the result will also be an integer. Otherwise if one or both of these are
* floating-point numbers, then the result will also be a floating-point number.
*
* @group math_funcs
* @since 4.0.0
*/
def uniform(min: Column, max: Column, seed: Column): Column =
Column.fn("uniform", min, max, seed)

/**
* Returns a random value with independent and identically distributed (i.i.d.) uniformly
* distributed values in [0, 1).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,110 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(df.select(nvl2(col("b"), col("a"), col("c"))), Seq(Row(null)))
}

test("randstr function") {
withTable("t") {
sql("create table t(col int not null) using csv")
sql("insert into t values (0)")
val df = sql("select col from t")
checkAnswer(
df.select(randstr(lit(5), lit(0)).alias("x")).select(length(col("x"))),
Seq(Row(5)))
// The random seed is optional.
checkAnswer(
df.select(randstr(lit(5)).alias("x")).select(length(col("x"))),
Seq(Row(5)))
}
// Here we exercise some error cases.
val df = Seq((0)).toDF("a")
var expr = randstr(lit(10), lit("a"))
checkError(
intercept[AnalysisException](df.select(expr)),
condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
parameters = Map(
"sqlExpr" -> "\"randstr(10, a)\"",
"paramIndex" -> "second",
"inputSql" -> "\"a\"",
"inputType" -> "\"STRING\"",
"requiredType" -> "INT or SMALLINT"),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "randstr",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
expr = randstr(col("a"), lit(10))
checkError(
intercept[AnalysisException](df.select(expr)),
condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
parameters = Map(
"inputName" -> "length",
"inputType" -> "INT or SMALLINT",
"inputExpr" -> "\"a\"",
"sqlExpr" -> "\"randstr(a, 10)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "randstr",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
}

test("uniform function") {
withTable("t") {
sql("create table t(col int not null) using csv")
sql("insert into t values (0)")
val df = sql("select col from t")
checkAnswer(
df.select(uniform(lit(10), lit(20), lit(0)).alias("x")).selectExpr("x > 5"),
Seq(Row(true)))
// The random seed is optional.
checkAnswer(
df.select(uniform(lit(10), lit(20)).alias("x")).selectExpr("x > 5"),
Seq(Row(true)))
}
// Here we exercise some error cases.
val df = Seq((0)).toDF("a")
var expr = uniform(lit(10), lit("a"))
checkError(
intercept[AnalysisException](df.select(expr)),
condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
parameters = Map(
"sqlExpr" -> "\"uniform(10, a)\"",
"paramIndex" -> "second",
"inputSql" -> "\"a\"",
"inputType" -> "\"STRING\"",
"requiredType" -> "INT or SMALLINT"),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "uniform",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
expr = uniform(col("a"), lit(10))
checkError(
intercept[AnalysisException](df.select(expr)),
condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
parameters = Map(
"inputName" -> "length",
"inputType" -> "INT or SMALLINT",
"inputExpr" -> "\"a\"",
"sqlExpr" -> "\"uniform(a, 10)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "uniform",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
}

test("zeroifnull function") {
withTable("t") {
// Here we exercise a non-nullable, non-foldable column.
Expand Down