Skip to content

Commit

Permalink
[SPARK-49306][PYTHON][SQL] Create SQL function aliases for 'zeroifnul…
Browse files Browse the repository at this point in the history
…l' and 'nullifzero'

### What changes were proposed in this pull request?

In #47817 we added new SQL functions `zeroifnull` and `nullifzero`.

In this PR we add Scala and Python DataFrame API endpoints for them.

For example, in Scala:

```
var df = Seq((0)).toDF("a")
df.selectExpr("nullifzero(0)").collect()
> null
df.select(nullifzero(lit(0))).collect()
> null

df.selectExpr("nullifzero(a)").collect()
> null
df.select(nullifzero(lit(5))).collect()
> 5

df = Seq[(Integer)]((null)).toDF("a")
df.selectExpr("zeroifnull(null)").collect()
> 5
df.select(nullifzero(lit(null))).collect()
> 0

df.selectExpr("zeroifnull(a)").collect()
> 0
df.select(zeroifnull(lit(5)))
> 5
```

### Why are the changes needed?

This improves DataFrame parity with the SQL API.

### Does this PR introduce _any_ user-facing change?

Yes, see above.

### How was this patch tested?

This PR adds unit test coverage.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #47851 from dtenedor/dataframe-zeroifnull.

Authored-by: Daniel Tenedorio <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
dtenedor authored and MaxGekk committed Aug 28, 2024
1 parent 620f16e commit a3cb064
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 4 deletions.
6 changes: 4 additions & 2 deletions core/src/test/scala/org/apache/spark/SparkFunSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,10 @@ abstract class SparkFunSuite
} else if (actual.contextType() == QueryContextType.DataFrame) {
assert(actual.fragment() === expected.fragment,
"Invalid code fragment of a query context. Actual:" + actual.toString)
assert(actual.callSite().matches(expected.callSitePattern),
"Invalid callSite of a query context. Actual:" + actual.toString)
if (expected.callSitePattern.nonEmpty) {
assert(actual.callSite().matches(expected.callSitePattern),
"Invalid callSite of a query context. Actual:" + actual.toString)
}
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions python/docs/source/reference/pyspark.sql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ Conditional Functions
ifnull
nanvl
nullif
nullifzero
nvl
nvl2
when
zeroifnull


Predicate Functions
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3921,6 +3921,13 @@ def nullif(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
nullif.__doc__ = pysparkfuncs.nullif.__doc__


def nullifzero(col: "ColumnOrName") -> Column:
return _invoke_function_over_columns("nullifzero", col)


nullifzero.__doc__ = pysparkfuncs.nullifzero.__doc__


def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
return _invoke_function_over_columns("nvl", col1, col2)

Expand All @@ -3935,6 +3942,13 @@ def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Co
nvl2.__doc__ = pysparkfuncs.nvl2.__doc__


def zeroifnull(col: "ColumnOrName") -> Column:
return _invoke_function_over_columns("zeroifnull", col)


zeroifnull.__doc__ = pysparkfuncs.zeroifnull.__doc__


def aes_encrypt(
input: "ColumnOrName",
key: "ColumnOrName",
Expand Down
50 changes: 50 additions & 0 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20681,6 +20681,31 @@ def nullif(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
return _invoke_function_over_columns("nullif", col1, col2)


@_try_remote_functions
def nullifzero(col: "ColumnOrName") -> Column:
"""
Returns null if `col` is equal to zero, or `col` otherwise.

.. versionadded:: 4.0.0

Parameters
----------
col : :class:`~pyspark.sql.Column` or str

Examples
--------
>>> df = spark.createDataFrame([(0,), (1,)], ["a"])
>>> df.select(nullifzero(df.a).alias("result")).show()
+------+
|result|
+------+
| NULL|
| 1|
+------+
"""
return _invoke_function_over_columns("nullifzero", col)


@_try_remote_functions
def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
"""
Expand Down Expand Up @@ -20724,6 +20749,31 @@ def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Co
return _invoke_function_over_columns("nvl2", col1, col2, col3)


@_try_remote_functions
def zeroifnull(col: "ColumnOrName") -> Column:
"""
Returns zero if `col` is null, or `col` otherwise.

.. versionadded:: 4.0.0

Parameters
----------
col : :class:`~pyspark.sql.Column` or str

Examples
--------
>>> df = spark.createDataFrame([(None,), (1,)], ["a"])
>>> df.select(zeroifnull(df.a).alias("result")).show()
+------+
|result|
+------+
| 0|
| 1|
+------+
"""
return _invoke_function_over_columns("zeroifnull", col)


@_try_remote_functions
def aes_encrypt(
input: "ColumnOrName",
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pyspark.sql import Row, Window, functions as F, types
from pyspark.sql.avro.functions import from_avro, to_avro
from pyspark.sql.column import Column
from pyspark.sql.functions.builtin import nullifzero, zeroifnull
from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
from pyspark.testing.utils import have_numpy

Expand Down Expand Up @@ -1593,6 +1594,15 @@ class IntEnum(Enum):
for r, c, e in zip(result, cols, expected):
self.assertEqual(r, e, str(c))

def test_nullifzero_zeroifnull(self):
df = self.spark.createDataFrame([(0,), (1,)], ["a"])
result = df.select(nullifzero(df.a).alias("r")).collect()
self.assertEqual([Row(r=None), Row(r=1)], result)

df = self.spark.createDataFrame([(None,), (1,)], ["a"])
result = df.select(zeroifnull(df.a).alias("r")).collect()
self.assertEqual([Row(r=0), Row(r=1)], result)


class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):
pass
Expand Down
16 changes: 16 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7834,6 +7834,14 @@ object functions {
*/
def nullif(col1: Column, col2: Column): Column = Column.fn("nullif", col1, col2)

/**
* Returns null if `col` is equal to zero, or `col` otherwise.
*
* @group conditional_funcs
* @since 4.0.0
*/
def nullifzero(col: Column): Column = Column.fn("nullifzero", col)

/**
* Returns `col2` if `col1` is null, or `col1` otherwise.
*
Expand All @@ -7850,6 +7858,14 @@ object functions {
*/
def nvl2(col1: Column, col2: Column, col3: Column): Column = Column.fn("nvl2", col1, col2, col3)

/**
* Returns zero if `col` is null, or `col` otherwise.
*
* @group conditional_funcs
* @since 4.0.0
*/
def zeroifnull(col: Column): Column = Column.fn("zeroifnull", col)

// scalastyle:off line.size.limit
// scalastyle:off parameter.number

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import java.sql.{Date, Timestamp}

import scala.util.Random

import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkRuntimeException}
import org.apache.spark.{QueryContextType, SPARK_DOC_ROOT, SparkException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
Expand Down Expand Up @@ -331,6 +331,66 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(df.select(nullif(lit(5), lit(5))), Seq(Row(null)))
}

test("nullifzero function") {
withTable("t") {
// Here we exercise a non-nullable, non-foldable column.
sql("create table t(col int not null) using csv")
sql("insert into t values (0)")
val df = sql("select col from t")
checkAnswer(df.select(nullifzero($"col")), Seq(Row(null)))
}
// Here we exercise invalid cases including types that do not support ordering.
val df = Seq((0)).toDF("a")
var expr = nullifzero(map(lit(1), lit("a")))
checkError(
intercept[AnalysisException](df.select(expr)),
errorClass = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES",
parameters = Map(
"left" -> "\"MAP<INT, STRING>\"",
"right" -> "\"INT\"",
"sqlExpr" -> "\"(map(1, a) = 0)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "nullifzero",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
expr = nullifzero(array(lit(1), lit(2)))
checkError(
intercept[AnalysisException](df.select(expr)),
errorClass = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES",
parameters = Map(
"left" -> "\"ARRAY<INT>\"",
"right" -> "\"INT\"",
"sqlExpr" -> "\"(array(1, 2) = 0)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "nullifzero",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
expr = nullifzero(Literal.create(20201231, DateType))
checkError(
intercept[AnalysisException](df.select(expr)),
errorClass = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES",
parameters = Map(
"left" -> "\"DATE\"",
"right" -> "\"INT\"",
"sqlExpr" -> "\"(DATE '+57279-02-03' = 0)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "nullifzero",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
}

test("nvl") {
val df = Seq[(Integer, Integer)]((null, 8)).toDF("a", "b")
checkAnswer(df.selectExpr("nvl(a, b)"), Seq(Row(8)))
Expand All @@ -349,6 +409,66 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(df.select(nvl2(col("b"), col("a"), col("c"))), Seq(Row(null)))
}

test("zeroifnull function") {
withTable("t") {
// Here we exercise a non-nullable, non-foldable column.
sql("create table t(col int not null) using csv")
sql("insert into t values (0)")
val df = sql("select col from t")
checkAnswer(df.select(zeroifnull($"col")), Seq(Row(0)))
}
// Here we exercise invalid cases including types that do not support ordering.
val df = Seq((0)).toDF("a")
var expr = zeroifnull(map(lit(1), lit("a")))
checkError(
intercept[AnalysisException](df.select(expr)),
errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
parameters = Map(
"functionName" -> "`coalesce`",
"dataType" -> "(\"MAP<INT, STRING>\" or \"INT\")",
"sqlExpr" -> "\"coalesce(map(1, a), 0)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "zeroifnull",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
expr = zeroifnull(array(lit(1), lit(2)))
checkError(
intercept[AnalysisException](df.select(expr)),
errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
parameters = Map(
"functionName" -> "`coalesce`",
"dataType" -> "(\"ARRAY<INT>\" or \"INT\")",
"sqlExpr" -> "\"coalesce(array(1, 2), 0)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "zeroifnull",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
expr = zeroifnull(Literal.create(20201231, DateType))
checkError(
intercept[AnalysisException](df.select(expr)),
errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
parameters = Map(
"functionName" -> "`coalesce`",
"dataType" -> "(\"DATE\" or \"INT\")",
"sqlExpr" -> "\"coalesce(DATE '+57279-02-03', 0)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "zeroifnull",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
}

test("misc md5 function") {
val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b")
checkAnswer(
Expand Down

0 comments on commit a3cb064

Please sign in to comment.