Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49552][PYTHON] Add DataFrame API support for new 'randstr' and 'uniform' SQL functions #48143

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/docs/source/reference/pyspark.sql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Mathematical Functions
try_multiply
try_subtract
unhex
uniform
width_bucket


Expand Down Expand Up @@ -189,6 +190,7 @@ String Functions
overlay
position
printf
randstr
regexp_count
regexp_extract
regexp_extract_all
Expand Down
26 changes: 26 additions & 0 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,22 @@ def unhex(col: "ColumnOrName") -> Column:
unhex.__doc__ = pysparkfuncs.unhex.__doc__


def uniform(
min: Union[Column, int, float],
max: Union[Column, int, float],
seed: Optional[Union[Column, int]] = None,
) -> Column:
if seed is None:
return _invoke_function_over_columns(
"uniform", min, max, lit(random.randint(0, sys.maxsize))
)
else:
return _invoke_function_over_columns("uniform", min, max, seed)


uniform.__doc__ = pysparkfuncs.uniform.__doc__


def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column:
warnings.warn("Deprecated in 3.4, use approx_count_distinct instead.", FutureWarning)
return approx_count_distinct(col, rsd)
Expand Down Expand Up @@ -2578,6 +2594,16 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
regexp_like.__doc__ = pysparkfuncs.regexp_like.__doc__


def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = None) -> Column:
if seed is None:
return _invoke_function_over_columns("randstr", length, lit(random.randint(0, sys.maxsize)))
else:
return _invoke_function_over_columns("randstr", length, seed)


randstr.__doc__ = pysparkfuncs.randstr.__doc__


def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
return _invoke_function_over_columns("regexp_count", str, regexp)

Expand Down
84 changes: 84 additions & 0 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11861,6 +11861,44 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
return _invoke_function_over_columns("regexp_like", str, regexp)


@_try_remote_functions
def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = None) -> Column:
"""Returns a string of the specified length whose characters are chosen uniformly at random from
the following pool of characters: 0-9, a-z, A-Z. The random seed is optional. The string length
must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively).

.. versionadded:: 4.0.0

Parameters
----------
length : :class:`~pyspark.sql.Column` or int
Number of characters in the string to generate.
seed : :class:`~pyspark.sql.Column` or int
Optional random number seed to use.

Returns
-------
:class:`~pyspark.sql.Column`
The generated random string with the specified length.

Examples
--------
>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([('3',)], ['a'])
>>> df.select(sf.randstr(lit(5), lit(0)).alias('result')).show()
+------+
|result|
+------+
| ceV0P|
+------+

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we normally don't include an empty line at the end of the docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, this is done.

"""
if seed is None:
return _invoke_function_over_columns("randstr", length)
else:
return _invoke_function_over_columns("randstr", length, seed)


@_try_remote_functions
def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
r"""Returns a count of the number of times that the Java regex pattern `regexp` is matched
Expand Down Expand Up @@ -12227,6 +12265,52 @@ def unhex(col: "ColumnOrName") -> Column:
return _invoke_function_over_columns("unhex", col)


@_try_remote_functions
def uniform(
dtenedor marked this conversation as resolved.
Show resolved Hide resolved
min: Union[Column, int, float],
max: Union[Column, int, float],
seed: Optional[Union[Column, int]] = None,
) -> Column:
"""Returns a random value with independent and identically distributed (i.i.d.) values with the
specified range of numbers. The random seed is optional. The provided numbers specifying the
minimum and maximum values of the range must be constant. If both of these numbers are integers,
then the result will also be an integer. Otherwise if one or both of these are floating-point
numbers, then the result will also be a floating-point number.

.. versionadded:: 4.0.0

Parameters
----------
min : :class:`~pyspark.sql.Column`, int, or float
Minimum value in the range.
max : :class:`~pyspark.sql.Column`, int, or float
Maximum value in the range.
seed : :class:`~pyspark.sql.Column` or int
Optional random number seed to use.

Returns
-------
:class:`~pyspark.sql.Column`
The generated random number within the specified range.

Examples
--------
>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([('3',)], ['a'])
>>> df.select(sf.uniform(lit(0), lit(10), lit(0)).alias('result')).show()
+------+
|result|
+------+
| 7|
+------+

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, this is done.

"""
if seed is None:
return _invoke_function_over_columns("uniform", min, max)
else:
return _invoke_function_over_columns("uniform", min, max, seed)


@_try_remote_functions
def length(col: "ColumnOrName") -> Column:
"""Computes the character length of string data or number of bytes of binary data.
Expand Down
21 changes: 20 additions & 1 deletion python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pyspark.sql import Row, Window, functions as F, types
from pyspark.sql.avro.functions import from_avro, to_avro
from pyspark.sql.column import Column
from pyspark.sql.functions.builtin import nullifzero, zeroifnull
from pyspark.sql.functions.builtin import nullifzero, randstr, uniform, zeroifnull
from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
from pyspark.testing.utils import have_numpy

Expand Down Expand Up @@ -1610,6 +1610,25 @@ def test_nullifzero_zeroifnull(self):
result = df.select(zeroifnull(df.a).alias("r")).collect()
self.assertEqual([Row(r=0), Row(r=1)], result)

def test_randstr_uniform(self):
df = self.spark.createDataFrame([(0,)], ["a"])
result = df.select(randstr(F.lit(5), F.lit(0)).alias("x")).selectExpr("length(x)").collect()
self.assertEqual([Row(5)], result)
# The random seed is optional.
result = df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)").collect()
self.assertEqual([Row(5)], result)

df = self.spark.createDataFrame([(0,)], ["a"])
result = (
df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x"))
.selectExpr("x > 5")
.collect()
)
self.assertEqual([Row(True)], result)
# The random seed is optional.
result = df.select(uniform(F.lit(10), F.lit(20)).alias("x")).selectExpr("x > 5").collect()
self.assertEqual([Row(True)], result)


class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):
pass
Expand Down
45 changes: 45 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1884,6 +1884,26 @@ object functions {
*/
def randn(): Column = randn(SparkClassUtils.random.nextLong)

/**
* Returns a string of the specified length whose characters are chosen uniformly at random from
* the following pool of characters: 0-9, a-z, A-Z. The string length must be a constant
* two-byte or four-byte integer (SMALLINT or INT, respectively).
*
* @group string_funcs
* @since 4.0.0
*/
def randstr(length: Column): Column = Column.fn("randstr", length)

/**
* Returns a string of the specified length whose characters are chosen uniformly at random from
* the following pool of characters: 0-9, a-z, A-Z, with the chosen random seed. The string
* length must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively).
*
* @group string_funcs
* @since 4.0.0
*/
def randstr(length: Column, seed: Column): Column = Column.fn("randstr", length, seed)

/**
* Partition ID.
*
Expand Down Expand Up @@ -3728,6 +3748,31 @@ object functions {
*/
def stack(cols: Column*): Column = Column.fn("stack", cols: _*)

/**
* Returns a random value with independent and identically distributed (i.i.d.) values with the
* specified range of numbers. The provided numbers specifying the minimum and maximum values of
* the range must be constant. If both of these numbers are integers, then the result will also
* be an integer. Otherwise if one or both of these are floating-point numbers, then the result
* will also be a floating-point number.
*
* @group math_funcs
* @since 4.0.0
*/
def uniform(min: Column, max: Column): Column = Column.fn("uniform", min, max)

/**
* Returns a random value with independent and identically distributed (i.i.d.) values with the
* specified range of numbers, with the chosen random seed. The provided numbers specifying the
* minimum and maximum values of the range must be constant. If both of these numbers are
* integers, then the result will also be an integer. Otherwise if one or both of these are
* floating-point numbers, then the result will also be a floating-point number.
*
* @group math_funcs
* @since 4.0.0
*/
def uniform(min: Column, max: Column, seed: Column): Column =
Column.fn("uniform", min, max, seed)

/**
* Returns a random value with independent and identically distributed (i.i.d.) uniformly
* distributed values in [0, 1).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,18 @@ object Randn {
""",
since = "4.0.0",
group = "math_funcs")
case class Uniform(min: Expression, max: Expression, seedExpression: Expression)
case class Uniform(min: Expression, max: Expression, seedExpression: Expression, hideSeed: Boolean)
extends RuntimeReplaceable with TernaryLike[Expression] with RDG {
def this(min: Expression, max: Expression) = this(min, max, UnresolvedSeed)
def this(min: Expression, max: Expression) =
this(min, max, UnresolvedSeed, hideSeed = true)
def this(min: Expression, max: Expression, seedExpression: Expression) =
this(min, max, seedExpression, hideSeed = false)

final override lazy val deterministic: Boolean = false
override val nodePatterns: Seq[TreePattern] =
Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED)

override val dataType: DataType = {
override def dataType: DataType = {
val first = min.dataType
val second = max.dataType
(min.dataType, max.dataType) match {
Expand All @@ -239,6 +242,10 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression)
case _ => false
}

override def sql: String = {
s"uniform(${min.sql}, ${max.sql}${if (hideSeed) "" else s", ${seedExpression.sql}"})"
}

override def checkInputDataTypes(): TypeCheckResult = {
var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
def requiredType = "integer or floating-point"
Expand Down Expand Up @@ -276,11 +283,11 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression)
override def third: Expression = seedExpression

override def withNewSeed(newSeed: Long): Expression =
Uniform(min, max, Literal(newSeed, LongType))
Uniform(min, max, Literal(newSeed, LongType), hideSeed)

override def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
Uniform(newFirst, newSecond, newThird)
Uniform(newFirst, newSecond, newThird, hideSeed)

override def replacement: Expression = {
if (Seq(min, max, seedExpression).exists(_.dataType == NullType)) {
Expand All @@ -299,6 +306,13 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression)
}
}

object Uniform {
def apply(min: Expression, max: Expression): Uniform =
Uniform(min, max, UnresolvedSeed, hideSeed = true)
def apply(min: Expression, max: Expression, seedExpression: Expression): Uniform =
Uniform(min, max, seedExpression, hideSeed = false)
}

@ExpressionDescription(
usage = """
_FUNC_(length[, seed]) - Returns a string of the specified length whose characters are chosen
Expand All @@ -314,9 +328,13 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression)
""",
since = "4.0.0",
group = "string_funcs")
case class RandStr(length: Expression, override val seedExpression: Expression)
case class RandStr(
length: Expression, override val seedExpression: Expression, hideSeed: Boolean)
extends ExpressionWithRandomSeed with BinaryLike[Expression] with Nondeterministic {
def this(length: Expression) = this(length, UnresolvedSeed)
def this(length: Expression) =
this(length, UnresolvedSeed, hideSeed = true)
def this(length: Expression, seedExpression: Expression) =
this(length, seedExpression, hideSeed = false)

override def nullable: Boolean = false
override def dataType: DataType = StringType
Expand All @@ -338,9 +356,14 @@ case class RandStr(length: Expression, override val seedExpression: Expression)
rng = new XORShiftRandom(seed + partitionIndex)
}

override def withNewSeed(newSeed: Long): Expression = RandStr(length, Literal(newSeed, LongType))
override def withNewSeed(newSeed: Long): Expression =
RandStr(length, Literal(newSeed, LongType), hideSeed)
override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression =
RandStr(newFirst, newSecond)
RandStr(newFirst, newSecond, hideSeed)

override def sql: String = {
s"randstr(${length.sql}${if (hideSeed) "" else s", ${seedExpression.sql}"})"
}

override def checkInputDataTypes(): TypeCheckResult = {
var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
Expand Down Expand Up @@ -421,3 +444,11 @@ case class RandStr(length: Expression, override val seedExpression: Expression)
isNull = FalseLiteral)
}
}

object RandStr {
def apply(length: Expression): RandStr =
RandStr(length, UnresolvedSeed, hideSeed = true)
def apply(length: Expression, seedExpression: Expression): RandStr =
RandStr(length, seedExpression, hideSeed = false)
}

Loading