Skip to content

Commit 5d61f52

Browse files
[SPARK-54776][SQL] Improve error message for SQL UDF in higher-order functions
1 parent 20af8bd commit 5d61f52

File tree

3 files changed

+68
-1
lines changed

3 files changed

+68
-1
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6747,6 +6747,11 @@
67476747
"Lambda function with Python UDF <funcName> in a higher order function."
67486748
]
67496749
},
6750+
"LAMBDA_FUNCTION_WITH_SQL_UDF" : {
6751+
"message" : [
6752+
"Lambda function with SQL UDF <funcName> in a higher order function."
6753+
]
6754+
},
67506755
"LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC" : {
67516756
"message" : [
67526757
"Referencing a lateral column alias <lca> in the aggregate function <aggFunc>."

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
3939
import org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder
4040
import org.apache.spark.sql.catalyst.catalog.SQLFunction.parseDefault
4141
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Cast, Expression, ExpressionInfo, LateralSubquery, NamedArgumentExpression, NamedExpression, OuterReference, ScalarSubquery, UpCast}
42+
import org.apache.spark.sql.catalyst.expressions.NamedLambdaVariable
43+
import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable
4244
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
4345
import org.apache.spark.sql.catalyst.plans.Inner
4446
import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter, LateralJoin, LogicalPlan, NamedParametersSupport, OneRowRelation, Project, SubqueryAlias, View}
@@ -1633,6 +1635,25 @@ class SessionCatalog(
16331635
throw UserDefinedFunctionErrors.notAScalarFunction(function.name.nameParts)
16341636
}
16351637
(input: Seq[Expression]) => {
1638+
// Check if any input contains a lambda variable
1639+
val hasLambdaVar = input.exists { expr =>
1640+
expr.find {
1641+
case _: NamedLambdaVariable => true
1642+
case _: UnresolvedNamedLambdaVariable => true
1643+
case _ => false
1644+
}.isDefined
1645+
}
1646+
if (hasLambdaVar) {
1647+
val formattedInputs = input.map {
1648+
case v: NamedLambdaVariable => "lambda " + v.name
1649+
case v: UnresolvedNamedLambdaVariable => "lambda " + v.name
1650+
case e => e.sql
1651+
}.mkString(", ")
1652+
throw new AnalysisException(
1653+
errorClass = "UNSUPPORTED_FEATURE.LAMBDA_FUNCTION_WITH_SQL_UDF",
1654+
messageParameters = Map(
1655+
"funcName" -> ("\"" + function.name.unquotedString + "(" + formattedInputs + ")\"")))
1656+
}
16361657
val args = rearrangeArguments(function.inputParam, input, function.name.toString)
16371658
val returnType = function.getScalarFuncReturnType
16381659
SQLFunctionExpression(
@@ -1712,6 +1733,26 @@ class SessionCatalog(
17121733
name, paramSize.toString, input.size)
17131734
}
17141735

1736+
// Check if any input contains a lambda variable
1737+
val hasLambdaVar = input.exists { expr =>
1738+
expr.find {
1739+
case _: NamedLambdaVariable => true
1740+
case _: UnresolvedNamedLambdaVariable => true
1741+
case _ => false
1742+
}.isDefined
1743+
}
1744+
if (hasLambdaVar) {
1745+
val formattedInputs = input.map {
1746+
case v: NamedLambdaVariable => "lambda " + v.name
1747+
case v: UnresolvedNamedLambdaVariable => "lambda " + v.name
1748+
case e => e.sql
1749+
}.mkString(", ")
1750+
throw new AnalysisException(
1751+
errorClass = "UNSUPPORTED_FEATURE.LAMBDA_FUNCTION_WITH_SQL_UDF",
1752+
messageParameters = Map(
1753+
"funcName" -> ("\"" + function.name.unquotedString + "(" + formattedInputs + ")\"")))
1754+
}
1755+
17151756
val inputs = inputParam.map { param =>
17161757
// Attributes referencing the input parameters inside the function can use the
17171758
// function name as a qualifier. E.G.:

sql/core/src/test/scala/org/apache/spark/sql/execution/SQLFunctionSuite.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import org.apache.spark.sql.{QueryTest, Row}
20+
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
2121
import org.apache.spark.sql.test.SharedSparkSession
2222

2323
/**
@@ -87,4 +87,25 @@ class SQLFunctionSuite extends QueryTest with SharedSparkSession {
8787
checkAnswer(sql("SELECT bar(1)"), Row(2))
8888
}
8989
}
90+
91+
test("SQL UDF in higher-order function should fail with clear error message") {
92+
withUserDefinedFunction("test_lower_udf" -> false) {
93+
sql(
94+
"""
95+
|CREATE FUNCTION test_lower_udf(s STRING)
96+
|RETURNS STRING
97+
|RETURN lower(s)
98+
|""".stripMargin)
99+
val exception = intercept[AnalysisException] {
100+
sql("SELECT transform(array('A', 'B', 'C'), x -> test_lower_udf(x))").collect()
101+
}
102+
assert(exception.getCondition == "UNSUPPORTED_FEATURE.LAMBDA_FUNCTION_WITH_SQL_UDF",
103+
s"Expected UNSUPPORTED_FEATURE.LAMBDA_FUNCTION_WITH_SQL_UDF " +
104+
s"but got ${exception.getCondition}: ${exception.getMessage}")
105+
assert(exception.getMessage.contains("test_lower_udf"),
106+
s"Error message should contain function name: ${exception.getMessage}")
107+
assert(exception.getMessage.contains("lambda x"),
108+
s"Error message should contain 'lambda x': ${exception.getMessage}")
109+
}
110+
}
90111
}

0 commit comments

Comments
 (0)