Skip to content

Commit 6e408c2

Browse files
committed
[SPARK-51395][SQL] Refine handling of default values in procedures
1 parent eae5ca7 commit 6e408c2

File tree

8 files changed

+809
-21
lines changed

8 files changed

+809
-21
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.catalog;
19+
20+
import java.util.Map;
21+
import java.util.Objects;
22+
import javax.annotation.Nullable;
23+
24+
import org.apache.spark.SparkIllegalArgumentException;
25+
import org.apache.spark.annotation.Evolving;
26+
import org.apache.spark.sql.connector.expressions.Expression;
27+
28+
/**
29+
* A class that represents default values.
30+
* <p>
31+
* Connectors can define default values using either a SQL string (Spark SQL dialect) or an
32+
* {@link Expression expression} if the default value can be expressed as a supported connector
33+
* expression. If both the SQL string and the expression are provided, Spark first attempts to
34+
* convert the given expression to its internal representation. If the expression cannot be
35+
* converted, and a SQL string is provided, Spark will fall back to parsing the SQL string.
36+
*
37+
* @since 4.1.0
38+
*/
39+
@Evolving
40+
public class DefaultValue {
41+
private final String sql;
42+
private final Expression expr;
43+
44+
public DefaultValue(String sql) {
45+
this(sql, null /* no expression */);
46+
}
47+
48+
public DefaultValue(Expression expr) {
49+
this(null /* no sql */, expr);
50+
}
51+
52+
public DefaultValue(String sql, Expression expr) {
53+
if (sql == null && expr == null) {
54+
throw new SparkIllegalArgumentException(
55+
"INTERNAL_ERROR",
56+
Map.of("message", "SQL and expression can't be both null"));
57+
}
58+
this.sql = sql;
59+
this.expr = expr;
60+
}
61+
62+
/**
63+
* Returns the SQL representation of the default value (Spark SQL dialect), if provided.
64+
*/
65+
@Nullable
66+
public String getSql() {
67+
return sql;
68+
}
69+
70+
/**
71+
* Returns the expression representing the default value, if provided.
72+
*/
73+
@Nullable
74+
public Expression getExpression() {
75+
return expr;
76+
}
77+
78+
@Override
79+
public boolean equals(Object other) {
80+
if (this == other) return true;
81+
if (other == null || getClass() != other.getClass()) return false;
82+
DefaultValue that = (DefaultValue) other;
83+
return Objects.equals(sql, that.sql) && Objects.equals(expr, that.expr);
84+
}
85+
86+
@Override
87+
public int hashCode() {
88+
return Objects.hash(sql, expr);
89+
}
90+
91+
@Override
92+
public String toString() {
93+
return String.format("DefaultValue{sql=%s, expression=%s}", sql, expr);
94+
}
95+
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import javax.annotation.Nullable;
2121

2222
import org.apache.spark.annotation.Evolving;
23+
import org.apache.spark.sql.connector.catalog.DefaultValue;
24+
import org.apache.spark.sql.connector.expressions.Expression;
2325
import org.apache.spark.sql.internal.connector.ProcedureParameterImpl;
2426
import org.apache.spark.sql.types.DataType;
2527

@@ -68,7 +70,7 @@ static Builder in(String name, DataType dataType) {
6870
* null if not provided.
6971
*/
7072
@Nullable
71-
String defaultValueExpression();
73+
DefaultValue defaultValue();
7274

7375
/**
7476
* Returns the comment of this parameter or null if not provided.
@@ -89,7 +91,7 @@ class Builder {
8991
private final Mode mode;
9092
private final String name;
9193
private final DataType dataType;
92-
private String defaultValueExpression;
94+
private DefaultValue defaultValue;
9395
private String comment;
9496

9597
private Builder(Mode mode, String name, DataType dataType) {
@@ -99,10 +101,26 @@ private Builder(Mode mode, String name, DataType dataType) {
99101
}
100102

101103
/**
102-
* Sets the default value expression of the parameter.
104+
* Sets the default value of the parameter using SQL.
103105
*/
104-
public Builder defaultValue(String defaultValueExpression) {
105-
this.defaultValueExpression = defaultValueExpression;
106+
public Builder defaultValue(String sql) {
107+
this.defaultValue = new DefaultValue(sql);
108+
return this;
109+
}
110+
111+
/**
112+
* Sets the default value of the parameter using an expression.
113+
*/
114+
public Builder defaultValue(Expression expression) {
115+
this.defaultValue = new DefaultValue(expression);
116+
return this;
117+
}
118+
119+
/**
120+
* Sets the default value of the parameter.
121+
*/
122+
public Builder defaultValue(DefaultValue defaultValue) {
123+
this.defaultValue = defaultValue;
106124
return this;
107125
}
108126

@@ -118,7 +136,7 @@ public Builder comment(String comment) {
118136
* Builds the stored procedure parameter.
119137
*/
120138
public ProcedureParameter build() {
121-
return new ProcedureParameterImpl(mode, name, dataType, defaultValueExpression, comment);
139+
return new ProcedureParameterImpl(mode, name, dataType, defaultValue, comment);
122140
}
123141
}
124142
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala

Lines changed: 170 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@ import org.apache.spark.internal.{Logging, MDC}
2323
import org.apache.spark.internal.LogKeys.{FUNCTION_NAME, FUNCTION_PARAM}
2424
import org.apache.spark.sql.AnalysisException
2525
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
26-
import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException
26+
import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException, UnresolvedAttribute}
2727
import org.apache.spark.sql.catalyst.encoders.EncoderUtils
2828
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
2929
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3030
import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier}
3131
import org.apache.spark.sql.connector.catalog.functions._
3232
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
33-
import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression => V2Expression, FieldReference, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
33+
import org.apache.spark.sql.connector.expressions.{BucketTransform, Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
34+
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue}
3435
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
3536
import org.apache.spark.sql.errors.QueryCompilationErrors
3637
import org.apache.spark.sql.types._
@@ -205,4 +206,171 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
205206
None
206207
}
207208
}
209+
210+
def toCatalyst(expr: V2Expression): Option[Expression] = expr match {
211+
case _: AlwaysTrue => Some(Literal.TrueLiteral)
212+
case _: AlwaysFalse => Some(Literal.FalseLiteral)
213+
case l: V2Literal[_] => Some(Literal(l.value, l.dataType))
214+
case r: NamedReference => Some(UnresolvedAttribute(r.fieldNames.toImmutableArraySeq))
215+
case c: V2Cast => toCatalyst(c.expression).map(Cast(_, c.dataType, ansiEnabled = true))
216+
case e: GeneralScalarExpression => convertScalarExpr(e)
217+
case _ => None
218+
}
219+
220+
private def convertScalarExpr(expr: GeneralScalarExpression): Option[Expression] = {
221+
convertPredicate(expr)
222+
.orElse(convertConditionalFunc(expr))
223+
.orElse(convertMathFunc(expr))
224+
.orElse(convertBitwiseFunc(expr))
225+
.orElse(convertTrigonometricFunc(expr))
226+
}
227+
228+
private def convertPredicate(expr: GeneralScalarExpression): Option[Expression] = {
229+
expr.name match {
230+
case "IS_NULL" => convertUnaryExpr(expr, IsNull)
231+
case "IS_NOT_NULL" => convertUnaryExpr(expr, IsNotNull)
232+
case "NOT" => convertUnaryExpr(expr, Not)
233+
case "=" => convertBinaryExpr(expr, EqualTo)
234+
case "<=>" => convertBinaryExpr(expr, EqualNullSafe)
235+
case ">" => convertBinaryExpr(expr, GreaterThan)
236+
case ">=" => convertBinaryExpr(expr, GreaterThanOrEqual)
237+
case "<" => convertBinaryExpr(expr, LessThan)
238+
case "<=" => convertBinaryExpr(expr, LessThanOrEqual)
239+
case "<>" => convertBinaryExpr(expr, (left, right) => Not(EqualTo(left, right)))
240+
case "AND" => convertBinaryExpr(expr, And)
241+
case "OR" => convertBinaryExpr(expr, Or)
242+
case "STARTS_WITH" => convertBinaryExpr(expr, StartsWith)
243+
case "ENDS_WITH" => convertBinaryExpr(expr, EndsWith)
244+
case "CONTAINS" => convertBinaryExpr(expr, Contains)
245+
case "IN" => convertExpr(expr, children => In(children.head, children.tail))
246+
case _ => None
247+
}
248+
}
249+
250+
private def convertConditionalFunc(expr: GeneralScalarExpression): Option[Expression] = {
251+
expr.name match {
252+
case "CASE_WHEN" =>
253+
convertExpr(expr, children =>
254+
if (children.length % 2 == 0) {
255+
val branches = children.grouped(2).map { case Seq(c, v) => (c, v) }.toSeq
256+
CaseWhen(branches, None)
257+
} else {
258+
val (pairs, last) = children.splitAt(children.length - 1)
259+
val branches = pairs.grouped(2).map { case Seq(c, v) => (c, v) }.toSeq
260+
CaseWhen(branches, Some(last.head))
261+
})
262+
case _ => None
263+
}
264+
}
265+
266+
private def convertMathFunc(expr: GeneralScalarExpression): Option[Expression] = {
267+
expr.name match {
268+
case "+" => convertBinaryExpr(expr, Add(_, _, evalMode = EvalMode.ANSI))
269+
case "-" =>
270+
if (expr.children.length == 1) {
271+
convertUnaryExpr(expr, UnaryMinus(_, failOnError = true))
272+
} else if (expr.children.length == 2) {
273+
convertBinaryExpr(expr, Subtract(_, _, evalMode = EvalMode.ANSI))
274+
} else {
275+
None
276+
}
277+
case "*" => convertBinaryExpr(expr, Multiply(_, _, evalMode = EvalMode.ANSI))
278+
case "/" => convertBinaryExpr(expr, Divide(_, _, evalMode = EvalMode.ANSI))
279+
case "%" => convertBinaryExpr(expr, Remainder(_, _, evalMode = EvalMode.ANSI))
280+
case "ABS" => convertUnaryExpr(expr, Abs(_, failOnError = true))
281+
case "COALESCE" => convertExpr(expr, Coalesce)
282+
case "GREATEST" => convertExpr(expr, Greatest)
283+
case "LEAST" => convertExpr(expr, Least)
284+
case "RAND" =>
285+
if (expr.children.isEmpty) {
286+
Some(new Rand())
287+
} else if (expr.children.length == 1) {
288+
convertUnaryExpr(expr, new Rand(_))
289+
} else {
290+
None
291+
}
292+
case "LOG" => convertBinaryExpr(expr, Logarithm)
293+
case "LOG10" => convertUnaryExpr(expr, Log10)
294+
case "LOG2" => convertUnaryExpr(expr, Log2)
295+
case "LN" => convertUnaryExpr(expr, Log)
296+
case "EXP" => convertUnaryExpr(expr, Exp)
297+
case "POWER" => convertBinaryExpr(expr, Pow)
298+
case "SQRT" => convertUnaryExpr(expr, Sqrt)
299+
case "FLOOR" => convertUnaryExpr(expr, Floor)
300+
case "CEIL" => convertUnaryExpr(expr, Ceil)
301+
case "ROUND" => convertBinaryExpr(expr, Round(_, _, ansiEnabled = true))
302+
case "CBRT" => convertUnaryExpr(expr, Cbrt)
303+
case "DEGREES" => convertUnaryExpr(expr, ToDegrees)
304+
case "RADIANS" => convertUnaryExpr(expr, ToRadians)
305+
case "SIGN" => convertUnaryExpr(expr, Signum)
306+
case "WIDTH_BUCKET" =>
307+
convertExpr(
308+
expr,
309+
children => WidthBucket(children(0), children(1), children(2), children(3)))
310+
case _ => None
311+
}
312+
}
313+
314+
private def convertTrigonometricFunc(expr: GeneralScalarExpression): Option[Expression] = {
315+
expr.name match {
316+
case "SIN" => convertUnaryExpr(expr, Sin)
317+
case "SINH" => convertUnaryExpr(expr, Sinh)
318+
case "COS" => convertUnaryExpr(expr, Cos)
319+
case "COSH" => convertUnaryExpr(expr, Cosh)
320+
case "TAN" => convertUnaryExpr(expr, Tan)
321+
case "TANH" => convertUnaryExpr(expr, Tanh)
322+
case "COT" => convertUnaryExpr(expr, Cot)
323+
case "ASIN" => convertUnaryExpr(expr, Asin)
324+
case "ASINH" => convertUnaryExpr(expr, Asinh)
325+
case "ACOS" => convertUnaryExpr(expr, Acos)
326+
case "ACOSH" => convertUnaryExpr(expr, Acosh)
327+
case "ATAN" => convertUnaryExpr(expr, Atan)
328+
case "ATANH" => convertUnaryExpr(expr, Atanh)
329+
case "ATAN2" => convertBinaryExpr(expr, Atan2)
330+
case _ => None
331+
}
332+
}
333+
334+
private def convertBitwiseFunc(expr: GeneralScalarExpression): Option[Expression] = {
335+
expr.name match {
336+
case "~" => convertUnaryExpr(expr, BitwiseNot)
337+
case "&" => convertBinaryExpr(expr, BitwiseAnd)
338+
case "|" => convertBinaryExpr(expr, BitwiseOr)
339+
case "^" => convertBinaryExpr(expr, BitwiseXor)
340+
case _ => None
341+
}
342+
}
343+
344+
private def convertUnaryExpr(
345+
expr: GeneralScalarExpression,
346+
catalystExprBuilder: Expression => Expression): Option[Expression] = {
347+
expr.children match {
348+
case Array(child) => toCatalyst(child).map(catalystExprBuilder)
349+
case _ => None
350+
}
351+
}
352+
353+
private def convertBinaryExpr(
354+
expr: GeneralScalarExpression,
355+
catalystExprBuilder: (Expression, Expression) => Expression): Option[Expression] = {
356+
expr.children match {
357+
case Array(left, right) =>
358+
for {
359+
catalystLeft <- toCatalyst(left)
360+
catalystRight <- toCatalyst(right)
361+
} yield catalystExprBuilder(catalystLeft, catalystRight)
362+
case _ => None
363+
}
364+
}
365+
366+
private def convertExpr(
367+
expr: GeneralScalarExpression,
368+
catalystExprBuilder: Seq[Expression] => Expression): Option[Expression] = {
369+
val catalystChildren = expr.children.flatMap(toCatalyst).toImmutableArraySeq
370+
if (expr.children.length == catalystChildren.length) {
371+
Some(catalystExprBuilder(catalystChildren))
372+
} else {
373+
None
374+
}
375+
}
208376
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ object NamedParametersSupport {
135135
}
136136

137137
private def toInputParameter(param: ProcedureParameter): InputParameter = {
138-
val defaultValue = Option(param.defaultValueExpression).map { expr =>
139-
ResolveDefaultColumns.analyze(param.name, param.dataType, expr, "CALL")
138+
val defaultValueExpr = Option(param.defaultValue).map { defaultValue =>
139+
ResolveDefaultColumns.analyze(param.name, param.dataType, defaultValue, "CALL")
140140
}
141-
InputParameter(param.name, defaultValue)
141+
InputParameter(param.name, defaultValueExpr)
142142
}
143143

144144
private def defaultRearrange(

0 commit comments

Comments
 (0)