-
Notifications
You must be signed in to change notification settings - Fork 28.7k
[SPARK-51395][SQL] Refine handling of default values in procedures #50197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.connector.catalog; | ||
|
||
import java.util.Map; | ||
import java.util.Objects; | ||
import javax.annotation.Nullable; | ||
|
||
import org.apache.spark.SparkIllegalArgumentException; | ||
import org.apache.spark.annotation.Evolving; | ||
import org.apache.spark.sql.connector.expressions.Expression; | ||
|
||
/** | ||
* A class that represents default values. | ||
* <p> | ||
* Connectors can define default values using either a SQL string (Spark SQL dialect) or an | ||
* {@link Expression expression} if the default value can be expressed as a supported connector | ||
* expression. If both the SQL string and the expression are provided, Spark first attempts to | ||
* convert the given expression to its internal representation. If the expression cannot be | ||
* converted, and a SQL string is provided, Spark will fall back to parsing the SQL string. | ||
* | ||
* @since 4.1.0 | ||
*/ | ||
@Evolving | ||
public class DefaultValue { | ||
private final String sql; | ||
private final Expression expr; | ||
|
||
public DefaultValue(String sql) { | ||
this(sql, null /* no expression */); | ||
} | ||
|
||
public DefaultValue(Expression expr) { | ||
this(null /* no sql */, expr); | ||
} | ||
|
||
public DefaultValue(String sql, Expression expr) { | ||
if (sql == null && expr == null) { | ||
throw new SparkIllegalArgumentException( | ||
"INTERNAL_ERROR", | ||
Map.of("message", "SQL and expression can't be both null")); | ||
} | ||
this.sql = sql; | ||
this.expr = expr; | ||
} | ||
|
||
/** | ||
* Returns the SQL representation of the default value (Spark SQL dialect), if provided. | ||
*/ | ||
@Nullable | ||
public String getSql() { | ||
return sql; | ||
} | ||
|
||
/** | ||
* Returns the expression representing the default value, if provided. | ||
*/ | ||
@Nullable | ||
public Expression getExpression() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we really need to follow the java getter naming style? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would personally prefer not to have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes sense, let's keep it |
||
return expr; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object other) { | ||
if (this == other) return true; | ||
if (other == null || getClass() != other.getClass()) return false; | ||
DefaultValue that = (DefaultValue) other; | ||
return Objects.equals(sql, that.sql) && Objects.equals(expr, that.expr); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(sql, expr); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return String.format("DefaultValue{sql=%s, expression=%s}", sql, expr); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,14 +23,15 @@ import org.apache.spark.internal.{Logging, MDC} | |
import org.apache.spark.internal.LogKeys.{FUNCTION_NAME, FUNCTION_PARAM} | ||
import org.apache.spark.sql.AnalysisException | ||
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} | ||
import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException | ||
import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException, UnresolvedAttribute} | ||
import org.apache.spark.sql.catalyst.encoders.EncoderUtils | ||
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} | ||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan | ||
import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier} | ||
import org.apache.spark.sql.connector.catalog.functions._ | ||
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME | ||
import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression => V2Expression, FieldReference, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} | ||
import org.apache.spark.sql.connector.expressions.{BucketTransform, Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} | ||
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue} | ||
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId | ||
import org.apache.spark.sql.errors.QueryCompilationErrors | ||
import org.apache.spark.sql.types._ | ||
|
@@ -205,4 +206,171 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { | |
None | ||
} | ||
} | ||
|
||
def toCatalyst(expr: V2Expression): Option[Expression] = expr match { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan, I went for a simpler option and added the conversion to |
||
case _: AlwaysTrue => Some(Literal.TrueLiteral) | ||
case _: AlwaysFalse => Some(Literal.FalseLiteral) | ||
case l: V2Literal[_] => Some(Literal(l.value, l.dataType)) | ||
case r: NamedReference => Some(UnresolvedAttribute(r.fieldNames.toImmutableArraySeq)) | ||
case c: V2Cast => toCatalyst(c.expression).map(Cast(_, c.dataType, ansiEnabled = true)) | ||
case e: GeneralScalarExpression => convertScalarExpr(e) | ||
case _ => None | ||
} | ||
|
||
private def convertScalarExpr(expr: GeneralScalarExpression): Option[Expression] = { | ||
convertPredicate(expr) | ||
.orElse(convertConditionalFunc(expr)) | ||
.orElse(convertMathFunc(expr)) | ||
.orElse(convertBitwiseFunc(expr)) | ||
.orElse(convertTrigonometricFunc(expr)) | ||
} | ||
|
||
private def convertPredicate(expr: GeneralScalarExpression): Option[Expression] = { | ||
expr.name match { | ||
case "IS_NULL" => convertUnaryExpr(expr, IsNull) | ||
case "IS_NOT_NULL" => convertUnaryExpr(expr, IsNotNull) | ||
case "NOT" => convertUnaryExpr(expr, Not) | ||
case "=" => convertBinaryExpr(expr, EqualTo) | ||
case "<=>" => convertBinaryExpr(expr, EqualNullSafe) | ||
case ">" => convertBinaryExpr(expr, GreaterThan) | ||
case ">=" => convertBinaryExpr(expr, GreaterThanOrEqual) | ||
case "<" => convertBinaryExpr(expr, LessThan) | ||
case "<=" => convertBinaryExpr(expr, LessThanOrEqual) | ||
case "<>" => convertBinaryExpr(expr, (left, right) => Not(EqualTo(left, right))) | ||
case "AND" => convertBinaryExpr(expr, And) | ||
case "OR" => convertBinaryExpr(expr, Or) | ||
case "STARTS_WITH" => convertBinaryExpr(expr, StartsWith) | ||
case "ENDS_WITH" => convertBinaryExpr(expr, EndsWith) | ||
case "CONTAINS" => convertBinaryExpr(expr, Contains) | ||
case "IN" => convertExpr(expr, children => In(children.head, children.tail)) | ||
case _ => None | ||
} | ||
} | ||
|
||
private def convertConditionalFunc(expr: GeneralScalarExpression): Option[Expression] = { | ||
expr.name match { | ||
case "CASE_WHEN" => | ||
convertExpr(expr, children => | ||
if (children.length % 2 == 0) { | ||
val branches = children.grouped(2).map { case Seq(c, v) => (c, v) }.toSeq | ||
CaseWhen(branches, None) | ||
} else { | ||
val (pairs, last) = children.splitAt(children.length - 1) | ||
val branches = pairs.grouped(2).map { case Seq(c, v) => (c, v) }.toSeq | ||
CaseWhen(branches, Some(last.head)) | ||
}) | ||
case _ => None | ||
} | ||
} | ||
|
||
private def convertMathFunc(expr: GeneralScalarExpression): Option[Expression] = { | ||
expr.name match { | ||
case "+" => convertBinaryExpr(expr, Add(_, _, evalMode = EvalMode.ANSI)) | ||
case "-" => | ||
if (expr.children.length == 1) { | ||
convertUnaryExpr(expr, UnaryMinus(_, failOnError = true)) | ||
} else if (expr.children.length == 2) { | ||
convertBinaryExpr(expr, Subtract(_, _, evalMode = EvalMode.ANSI)) | ||
} else { | ||
None | ||
} | ||
case "*" => convertBinaryExpr(expr, Multiply(_, _, evalMode = EvalMode.ANSI)) | ||
case "/" => convertBinaryExpr(expr, Divide(_, _, evalMode = EvalMode.ANSI)) | ||
case "%" => convertBinaryExpr(expr, Remainder(_, _, evalMode = EvalMode.ANSI)) | ||
case "ABS" => convertUnaryExpr(expr, Abs(_, failOnError = true)) | ||
case "COALESCE" => convertExpr(expr, Coalesce) | ||
case "GREATEST" => convertExpr(expr, Greatest) | ||
case "LEAST" => convertExpr(expr, Least) | ||
case "RAND" => | ||
if (expr.children.isEmpty) { | ||
Some(new Rand()) | ||
} else if (expr.children.length == 1) { | ||
convertUnaryExpr(expr, new Rand(_)) | ||
} else { | ||
None | ||
} | ||
case "LOG" => convertBinaryExpr(expr, Logarithm) | ||
case "LOG10" => convertUnaryExpr(expr, Log10) | ||
case "LOG2" => convertUnaryExpr(expr, Log2) | ||
case "LN" => convertUnaryExpr(expr, Log) | ||
case "EXP" => convertUnaryExpr(expr, Exp) | ||
case "POWER" => convertBinaryExpr(expr, Pow) | ||
case "SQRT" => convertUnaryExpr(expr, Sqrt) | ||
case "FLOOR" => convertUnaryExpr(expr, Floor) | ||
case "CEIL" => convertUnaryExpr(expr, Ceil) | ||
case "ROUND" => convertBinaryExpr(expr, Round(_, _, ansiEnabled = true)) | ||
case "CBRT" => convertUnaryExpr(expr, Cbrt) | ||
case "DEGREES" => convertUnaryExpr(expr, ToDegrees) | ||
case "RADIANS" => convertUnaryExpr(expr, ToRadians) | ||
case "SIGN" => convertUnaryExpr(expr, Signum) | ||
case "WIDTH_BUCKET" => | ||
convertExpr( | ||
expr, | ||
children => WidthBucket(children(0), children(1), children(2), children(3))) | ||
case _ => None | ||
} | ||
} | ||
|
||
private def convertTrigonometricFunc(expr: GeneralScalarExpression): Option[Expression] = { | ||
expr.name match { | ||
case "SIN" => convertUnaryExpr(expr, Sin) | ||
case "SINH" => convertUnaryExpr(expr, Sinh) | ||
case "COS" => convertUnaryExpr(expr, Cos) | ||
case "COSH" => convertUnaryExpr(expr, Cosh) | ||
case "TAN" => convertUnaryExpr(expr, Tan) | ||
case "TANH" => convertUnaryExpr(expr, Tanh) | ||
case "COT" => convertUnaryExpr(expr, Cot) | ||
case "ASIN" => convertUnaryExpr(expr, Asin) | ||
case "ASINH" => convertUnaryExpr(expr, Asinh) | ||
case "ACOS" => convertUnaryExpr(expr, Acos) | ||
case "ACOSH" => convertUnaryExpr(expr, Acosh) | ||
case "ATAN" => convertUnaryExpr(expr, Atan) | ||
case "ATANH" => convertUnaryExpr(expr, Atanh) | ||
case "ATAN2" => convertBinaryExpr(expr, Atan2) | ||
case _ => None | ||
} | ||
} | ||
|
||
private def convertBitwiseFunc(expr: GeneralScalarExpression): Option[Expression] = { | ||
expr.name match { | ||
case "~" => convertUnaryExpr(expr, BitwiseNot) | ||
case "&" => convertBinaryExpr(expr, BitwiseAnd) | ||
case "|" => convertBinaryExpr(expr, BitwiseOr) | ||
case "^" => convertBinaryExpr(expr, BitwiseXor) | ||
case _ => None | ||
} | ||
} | ||
|
||
private def convertUnaryExpr( | ||
expr: GeneralScalarExpression, | ||
catalystExprBuilder: Expression => Expression): Option[Expression] = { | ||
expr.children match { | ||
case Array(child) => toCatalyst(child).map(catalystExprBuilder) | ||
case _ => None | ||
} | ||
} | ||
|
||
private def convertBinaryExpr( | ||
expr: GeneralScalarExpression, | ||
catalystExprBuilder: (Expression, Expression) => Expression): Option[Expression] = { | ||
expr.children match { | ||
case Array(left, right) => | ||
for { | ||
catalystLeft <- toCatalyst(left) | ||
catalystRight <- toCatalyst(right) | ||
} yield catalystExprBuilder(catalystLeft, catalystRight) | ||
case _ => None | ||
} | ||
} | ||
|
||
private def convertExpr( | ||
expr: GeneralScalarExpression, | ||
catalystExprBuilder: Seq[Expression] => Expression): Option[Expression] = { | ||
val catalystChildren = expr.children.flatMap(toCatalyst).toImmutableArraySeq | ||
if (expr.children.length == catalystChildren.length) { | ||
Some(catalystExprBuilder(catalystChildren)) | ||
} else { | ||
None | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should
ColumnDefaultValue
extend it?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It the future yes. That's the whole idea.