Skip to content

Commit 6f50c79

Browse files
Improved the logs message regarding lambda function with SQL UDF
1 parent 1e6d743 commit 6f50c79

File tree

3 files changed

+65
-2
lines changed

3 files changed

+65
-2
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6747,6 +6747,11 @@
67476747
"Lambda function with Python UDF <funcName> in a higher order function."
67486748
]
67496749
},
6750+
"LAMBDA_FUNCTION_WITH_SQL_UDF" : {
6751+
"message" : [
6752+
"Lambda function with SQL UDF <funcName> in a higher order function."
6753+
]
6754+
},
67506755
"LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC" : {
67516756
"message" : [
67526757
"Referencing a lateral column alias <lca> in the aggregate function <aggFunc>."

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.analysis._
3838
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
3939
import org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder
4040
import org.apache.spark.sql.catalyst.catalog.SQLFunction.parseDefault
41-
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Cast, Expression, ExpressionInfo, LateralSubquery, NamedArgumentExpression, NamedExpression, OuterReference, ScalarSubquery, UpCast}
41+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Cast, Expression, ExpressionInfo, LateralSubquery, NamedArgumentExpression, NamedExpression, NamedLambdaVariable, OuterReference, ScalarSubquery, UnresolvedNamedLambdaVariable, UpCast}
4242
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
4343
import org.apache.spark.sql.catalyst.plans.Inner
4444
import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter, LateralJoin, LogicalPlan, NamedParametersSupport, OneRowRelation, Project, SubqueryAlias, View}
@@ -1633,6 +1633,24 @@ class SessionCatalog(
16331633
throw UserDefinedFunctionErrors.notAScalarFunction(function.name.nameParts)
16341634
}
16351635
(input: Seq[Expression]) => {
1636+
// Check if any input contains a lambda variable - SQL functions cannot be used
1637+
// inside lambda functions in higher-order functions
1638+
val lambdaVar = input.flatMap(_.find {
1639+
case _: NamedLambdaVariable | _: UnresolvedNamedLambdaVariable => true
1640+
case _ => false
1641+
}).headOption
1642+
if (lambdaVar.isDefined) {
1643+
// Format input expressions, replacing lambda variables with "lambda <name>" format
1644+
val formattedInputs = input.map {
1645+
case v: NamedLambdaVariable => s"lambda ${v.name}"
1646+
case v: UnresolvedNamedLambdaVariable => s"lambda ${v.name}"
1647+
case e => e.sql
1648+
}.mkString(", ")
1649+
throw new AnalysisException(
1650+
errorClass = "UNSUPPORTED_FEATURE.LAMBDA_FUNCTION_WITH_SQL_UDF",
1651+
messageParameters = Map(
1652+
"funcName" -> s"\"${function.name.unquotedString}($formattedInputs)\""))
1653+
}
16361654
val args = rearrangeArguments(function.inputParam, input, function.name.toString)
16371655
val returnType = function.getScalarFuncReturnType
16381656
SQLFunctionExpression(
@@ -1712,6 +1730,25 @@ class SessionCatalog(
17121730
name, paramSize.toString, input.size)
17131731
}
17141732

1733+
// Check if any input contains a lambda variable - SQL functions cannot be used
1734+
// inside lambda functions in higher-order functions
1735+
val lambdaVar = input.flatMap(_.find {
1736+
case _: NamedLambdaVariable | _: UnresolvedNamedLambdaVariable => true
1737+
case _ => false
1738+
}).headOption
1739+
if (lambdaVar.isDefined) {
1740+
// Format input expressions, replacing lambda variables with "lambda <name>" format
1741+
val formattedInputs = input.map {
1742+
case v: NamedLambdaVariable => s"lambda ${v.name}"
1743+
case v: UnresolvedNamedLambdaVariable => s"lambda ${v.name}"
1744+
case e => e.sql
1745+
}.mkString(", ")
1746+
throw new AnalysisException(
1747+
errorClass = "UNSUPPORTED_FEATURE.LAMBDA_FUNCTION_WITH_SQL_UDF",
1748+
messageParameters = Map(
1749+
"funcName" -> s"\"${function.name.unquotedString}($formattedInputs)\""))
1750+
}
1751+
17151752
val inputs = inputParam.map { param =>
17161753
// Attributes referencing the input parameters inside the function can use the
17171754
// function name as a qualifier. E.G.:

sql/core/src/test/scala/org/apache/spark/sql/execution/SQLFunctionSuite.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import org.apache.spark.sql.{QueryTest, Row}
20+
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
2121
import org.apache.spark.sql.test.SharedSparkSession
2222

2323
/**
@@ -87,4 +87,25 @@ class SQLFunctionSuite extends QueryTest with SharedSparkSession {
8787
checkAnswer(sql("SELECT bar(1)"), Row(2))
8888
}
8989
}
90+
91+
test("SQL UDF in higher-order function should fail with clear error message") {
92+
withUserDefinedFunction("test_lower_udf" -> false) {
93+
sql(
94+
"""
95+
|CREATE FUNCTION test_lower_udf(s STRING)
96+
|RETURNS STRING
97+
|RETURN lower(s)
98+
|""".stripMargin)
99+
val exception = intercept[AnalysisException] {
100+
sql("SELECT transform(array('A', 'B', 'C'), x -> test_lower_udf(x))").collect()
101+
}
102+
assert(exception.getCondition == "UNSUPPORTED_FEATURE.LAMBDA_FUNCTION_WITH_SQL_UDF",
103+
s"Expected UNSUPPORTED_FEATURE.LAMBDA_FUNCTION_WITH_SQL_UDF " +
104+
s"but got ${exception.getCondition}: ${exception.getMessage}")
105+
assert(exception.getMessage.contains("test_lower_udf"),
106+
s"Error message should contain function name: ${exception.getMessage}")
107+
assert(exception.getMessage.contains("lambda x"),
108+
s"Error message should contain 'lambda x': ${exception.getMessage}")
109+
}
110+
}
90111
}

0 commit comments

Comments
 (0)