From 425e456c6efff1a5819eebb02741a144caec3683 Mon Sep 17 00:00:00 2001 From: Anton Okolnychyi Date: Thu, 6 Mar 2025 08:32:35 -0800 Subject: [PATCH] [SPARK-51395][SQL] Refine handling of default values in procedures --- .../sql/connector/catalog/DefaultValue.java | 74 +++++ .../procedures/ProcedureParameter.java | 30 +- .../plans/logical/FunctionBuilderBase.scala | 7 +- .../catalyst/util/ExpressionConverter.scala | 292 ++++++++++++++++++ .../util/ResolveDefaultColumnsUtil.scala | 13 +- .../connector/ProcedureParameterImpl.scala | 3 +- .../spark/sql/connector/ProcedureSuite.scala | 45 ++- .../v2/DataSourceV2StrategySuite.scala | 20 +- 8 files changed, 469 insertions(+), 15 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DefaultValue.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionConverter.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DefaultValue.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DefaultValue.java new file mode 100644 index 0000000000000..ce380179e547f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DefaultValue.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +import java.util.Objects; +import javax.annotation.Nullable; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; + +@Evolving +public class DefaultValue { + private final String sql; + private final Expression expr; + + public DefaultValue(String sql) { + this(sql, null /* no expression */); + } + + public DefaultValue(Expression expr) { + this(null /* no sql */, expr); + } + + public DefaultValue(String sql, Expression expr) { + if (sql == null && expr == null) { + throw new IllegalArgumentException("SQL and expression can't be both null"); + } + this.sql = sql; + this.expr = expr; + } + + @Nullable + public String getSql() { + return sql; + } + + @Nullable + public Expression getExpression() { + return expr; + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + DefaultValue that = (DefaultValue) other; + return Objects.equals(sql, that.sql) && Objects.equals(expr, that.expr); + } + + @Override + public int hashCode() { + return Objects.hash(sql, expr); + } + + @Override + public String toString() { + return String.format("DefaultValue{sql=%s, expression=%s}", sql, expr); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java index 18c76833c5879..3d837be366f7f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java @@ -20,6 +20,8 @@ import javax.annotation.Nullable; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.DefaultValue; +import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.internal.connector.ProcedureParameterImpl; import org.apache.spark.sql.types.DataType; @@ -68,7 +70,7 @@ static Builder in(String name, DataType dataType) { * null if not provided. */ @Nullable - String defaultValueExpression(); + DefaultValue defaultValue(); /** * Returns the comment of this parameter or null if not provided. @@ -89,7 +91,7 @@ class Builder { private final Mode mode; private final String name; private final DataType dataType; - private String defaultValueExpression; + private DefaultValue defaultValue; private String comment; private Builder(Mode mode, String name, DataType dataType) { @@ -99,10 +101,26 @@ private Builder(Mode mode, String name, DataType dataType) { } /** - * Sets the default value expression of the parameter. + * Sets the default value of the parameter using SQL. */ - public Builder defaultValue(String defaultValueExpression) { - this.defaultValueExpression = defaultValueExpression; + public Builder defaultValue(String sql) { + this.defaultValue = new DefaultValue(sql); + return this; + } + + /** + * Sets the default value of the parameter using an expression. + */ + public Builder defaultValue(Expression expression) { + this.defaultValue = new DefaultValue(expression); + return this; + } + + /** + * Sets the default value of the parameter. + */ + public Builder defaultValue(DefaultValue defaultValue) { + this.defaultValue = defaultValue; return this; } @@ -118,7 +136,7 @@ public Builder comment(String comment) { * Builds the stored procedure parameter. */ public ProcedureParameter build() { - return new ProcedureParameterImpl(mode, name, dataType, defaultValueExpression, comment); + return new ProcedureParameterImpl(mode, name, dataType, defaultValue, comment); } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala index 75b2fcd3a5f34..e0020e0f848c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala @@ -135,10 +135,11 @@ object NamedParametersSupport { } private def toInputParameter(param: ProcedureParameter): InputParameter = { - val defaultValue = Option(param.defaultValueExpression).map { expr => - ResolveDefaultColumns.analyze(param.name, param.dataType, expr, "CALL") + val defaultValueExpr = Option(param.defaultValue).map { defaultValue => + val sql = ResolveDefaultColumns.generateSQL(defaultValue) + ResolveDefaultColumns.analyze(param.name, param.dataType, sql, "CALL") } - InputParameter(param.name, defaultValue) + InputParameter(param.name, defaultValueExpr) } private def defaultRearrange( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionConverter.scala new file mode 100644 index 0000000000000..a26a3ab3e44e0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionConverter.scala @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkException +import org.apache.spark.internal.{Logging, MDC} +import org.apache.spark.internal.LogKeys.EXPR +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{StartsWith, _} +import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, Literal => V2Literal, LiteralValue, NamedReference} +import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, Predicate => V2Predicate} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.util.ArrayImplicits._ + +object ExpressionConverter extends SQLConfHelper with Logging { + + private val converters = Seq( + LiteralConverter, + ReferenceConverter, + MathFunctionConverter, + PredicateConverter, + ConditionalFunctionConverter, + CastConverter) + + def toV1(expr: V2Expression): Option[Expression] = { + converters.find(_.toV1Func.isDefinedAt(expr)).flatMap(_.toV1Func(expr)) + } + + def toV2(expr: Expression): Option[V2Expression] = { + converters.find(_.toV2Func.isDefinedAt(expr)).flatMap(_.toV2Func(expr)) + } + + def toV2Predicate(expr: Expression): Option[V2Predicate] = { + toV2(expr) match { + case Some(p: V2Predicate) => + Some(p) + + case Some(e) if conf.getConf(SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE) => + // if a predicate is expected but the translation yields something else, + // log a warning and proceed as if the translation was not possible + logWarning(log"Predicate expected but got class: ${MDC(EXPR, e.describe())}") + None + + case Some(e) => + throw SparkException.internalError(s"Expected Predicate but got ${e.describe()}") + + case None => None + } + } + + private object LiteralConverter extends Converter { + override def toV1Func: PartialFunction[V2Expression, Option[Literal]] = { + case _: AlwaysTrue => Some(Literal.TrueLiteral) + case _: AlwaysFalse => Some(Literal.FalseLiteral) + case l: V2Literal[_] => Some(Literal(l.value, l.dataType)) + } + + override def toV2Func: PartialFunction[Expression, Option[V2Literal[_]]] = { + case Literal(true, BooleanType) => Some(new AlwaysTrue()) + case Literal(false, BooleanType) => Some(new AlwaysFalse()) + case Literal(value, dataType) => Some(LiteralValue(value, dataType)) + } + } + + private object ReferenceConverter extends Converter { + override def toV1Func: PartialFunction[V2Expression, Option[Attribute]] = { + case r: NamedReference => Some(UnresolvedAttribute(r.fieldNames.toIndexedSeq)) + } + + override def toV2Func: PartialFunction[Expression, Option[V2Expression]] = { + case c @ ColumnOrField(nameParts) if c.dataType.isInstanceOf[BooleanType] => + val v2Ref = FieldReference(nameParts) + if (c.dataType.isInstanceOf[BooleanType]) { + Some(new V2Predicate("=", Array(v2Ref, new AlwaysTrue()))) + } else { + Some(v2Ref) + } + } + } + + private object MathFunctionConverter extends Converter { + private val leafToV1Conversions: Map[String, () => Expression] = Map( + "RAND" -> (() => new Rand())) + + private val unaryToV1Conversions: Map[String, Expression => Expression] = Map( + "-" -> (child => UnaryMinus(child, failOnError = true)), + "ABS" -> (child => Abs(child, failOnError = true)), + "RAND" -> (child => new Rand(child)), + "LOG10" -> (child => Log10(child)), + "LOG2" -> (child => Log2(child)), + "LN" -> (child => Log(child)), + "EXP" -> (child => Exp(child)), + "SQRT" -> (child => Sqrt(child)), + "FLOOR" -> (child => Floor(child)), + "CEIL" -> (child => Ceil(child)), + "SIN" -> (child => Sin(child)), + "SINH" -> (child => Sinh(child)), + "COS" -> (child => Cos(child)), + "COSH" -> (child => Cosh(child)), + "TAN" -> (child => Tan(child)), + "TANH" -> (child => Tanh(child)), + "COT" -> (child => Cot(child)), + "ASIN" -> (child => Asin(child)), + "ASINH" -> (child => Asinh(child)), + "ACOS" -> (child => Acos(child)), + "ACOSH" -> (child => Acosh(child)), + "ATAN" -> (child => Atan(child)), + "ATANH" -> (child => Atanh(child)), + "CBRT" -> (child => Cbrt(child)), + "DEGREES" -> (child => ToDegrees(child)), + "RADIANS" -> (child => ToRadians(child)), + "SIGN" -> (child => Signum(child))) + + private val binaryToV1Conversions: Map[String, (Expression, Expression) => Expression] = Map( + "+" -> ((left, right) => Add(left, right, evalMode = EvalMode.ANSI)), + "-" -> ((left, right) => Subtract(left, right, evalMode = EvalMode.ANSI)), + "*" -> ((left, right) => Multiply(left, right, evalMode = EvalMode.ANSI)), + "/" -> ((left, right) => Divide(left, right, evalMode = EvalMode.ANSI)), + "%" -> ((left, right) => Remainder(left, right, evalMode = EvalMode.ANSI)), + "LOG" -> ((left, right) => Logarithm(left, right)), + "POWER" -> ((left, right) => Pow(left, right)), + "ROUND" -> ((left, right) => Round(left, right, ansiEnabled = true)), + "ATAN2" -> ((left, right) => Atan2(left, right))) + + private val toV1Conversions: Map[String, Seq[Expression] => Expression] = Map( + "COALESCE" -> (children => Coalesce(children)), + "GREATEST" -> (children => Greatest(children)), + "LEAST" -> (children => Least(children)), + "WIDTH_BUCKET" -> (children => WidthBucket( + children(0), + children(1), + children(2), + children(3)))) + + override def toV1Func: PartialFunction[V2Expression, Option[Expression]] = { + case e: GeneralScalarExpression + if e.children.isEmpty && leafToV1Conversions.contains(e.name) => + Some(leafToV1Conversions(e.name)()) + + case UnaryScalarExpr(name, child) if unaryToV1Conversions.contains(name) => + toV1(child).map(v1Child => unaryToV1Conversions(name)(v1Child)) + + case BinaryScalarExpr(name, left, right) if binaryToV1Conversions.contains(name) => + for { + v1Left <- toV1(left) + v1Right <- toV1(right) + } yield binaryToV1Conversions(name)(v1Left, v1Right) + + case e: GeneralScalarExpression if toV1Conversions.contains(e.name) => + val v1Children = e.children.flatMap(toV1) + if (e.children.length == v1Children.length) { + Some(toV1Conversions(e.name)(v1Children.toImmutableArraySeq)) + } else { + None + } + } + } + + private object PredicateConverter extends Converter { + + private val unaryToV1Conversions: Map[String, Expression => Predicate] = Map( + "IS_NULL" -> (child => IsNull(child)), + "IS_NOT_NULL" -> (child => IsNotNull(child)), + "NOT" -> (child => Not(child))) + + private val binaryToV1Conversions: Map[String, (Expression, Expression) => Predicate] = Map( + "=" -> ((left, right) => EqualTo(left, right)), + "<=>" -> ((left, right) => EqualNullSafe(left, right)), + ">" -> ((left, right) => GreaterThan(left, right)), + ">=" -> ((left, right) => GreaterThanOrEqual(left, right)), + "<" -> ((left, right) => LessThan(left, right)), + "<=" -> ((left, right) => LessThanOrEqual(left, right)), + "<>" -> ((left, right) => Not(EqualTo(left, right))), + "AND" -> ((left, right) => And(left, right)), + "OR" -> ((left, right) => Or(left, right)), + "STARTS_WITH" -> ((left, right) => StartsWith(left, right)), + "ENDS_WITH" -> ((left, right) => EndsWith(left, right)), + "CONTAINS" -> ((left, right) => Contains(left, right))) + + private val toV1Conversions: Map[String, Seq[Expression] => Predicate] = Map( + "IN" -> (children => In(children.head, children.tail))) + + override def toV1Func: PartialFunction[V2Expression, Option[Predicate]] = { + case UnaryPredicate(name, child) if unaryToV1Conversions.contains(name) => + toV1(child).map(v1Child => unaryToV1Conversions(name)(v1Child)) + + case BinaryPredicate(name, left, right) if binaryToV1Conversions.contains(name) => + for { + v1Left <- toV1(left) + v1Right <- toV1(right) + } yield binaryToV1Conversions(name)(v1Left, v1Right) + + case p: V2Predicate if toV1Conversions.contains(p.name) => + val v1Children = p.children.flatMap(toV1) + if (p.children.length == v1Children.length) { + Some(toV1Conversions(p.name)(v1Children.toImmutableArraySeq)) + } else { + None + } + } + } + + private object ConditionalFunctionConverter extends Converter { + override def toV1Func: PartialFunction[V2Expression, Option[Expression]] = { + case e: GeneralScalarExpression if e.name == "CASE_WHEN" => + val v1Children = e.children.flatMap(toV1) + if (e.children.length == v1Children.length) { + if (v1Children.length % 2 == 0) { + val branches = v1Children.grouped(2).map { case Array(a, b) => (a, b) }.toSeq + Some(CaseWhen(branches, None)) + } else { + val (pairs, last) = v1Children.splitAt(v1Children.length - 1) + val branches = pairs.grouped(2).map { case Array(a, b) => (a, b) }.toSeq + Some(CaseWhen(branches, Some(last.head))) + } + } else { + None + } + } + } + + private object CastConverter extends Converter { + override def toV1Func: PartialFunction[V2Expression, Option[Cast]] = { + case c: V2Cast => + toV1(c.expression).map(v1Child => Cast(v1Child, c.dataType, ansiEnabled = true)) + } + + override def toV2Func: PartialFunction[Expression, Option[V2Cast]] = { + case Cast(child, dataType, _, evalMode) + if evalMode == EvalMode.ANSI || Cast.canUpCast(child.dataType, dataType) => + toV2(child).map(v2Child => new V2Cast(v2Child, child.dataType, dataType)) + } + } + + private trait Converter { + def toV1Func: PartialFunction[V2Expression, Option[Expression]] = PartialFunction.empty + def toV2Func: PartialFunction[Expression, Option[V2Expression]] = PartialFunction.empty + } + + private object UnaryScalarExpr { + def unapply(expr: V2Expression): Option[(String, V2Expression)] = expr match { + case e: GeneralScalarExpression if e.children.length == 1 => + Some(e.name, e.children.head) + case _ => + None + } + } + + private object UnaryPredicate { + def unapply(expr: V2Expression): Option[(String, V2Expression)] = expr match { + case p: V2Predicate if p.children.length == 1 => + Some(p.name, p.children.head) + case _ => + None + } + } + + private object BinaryScalarExpr { + def unapply(expr: V2Expression): Option[(String, V2Expression, V2Expression)] = expr match { + case e: GeneralScalarExpression if e.children.length == 2 => + Some(e.name, e.children.apply(0), e.children.apply(1)) + case _ => + None + } + } + + private object BinaryPredicate { + def unapply(expr: V2Expression): Option[(String, V2Expression, V2Expression)] = expr match { + case p: V2Predicate if p.children.length == 2 => + Some(p.name, p.children.apply(0), p.children.apply(1)) + case _ => + None + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index 58b6314e27ade..3cc89d1bca757 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, Optimizer} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION -import org.apache.spark.sql.connector.catalog.{CatalogManager, FunctionCatalog, Identifier, TableCatalog, TableCatalogCapability} +import org.apache.spark.sql.connector.catalog.{CatalogManager, DefaultValue, FunctionCatalog, Identifier, TableCatalog, TableCatalogCapability} import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf @@ -243,6 +243,17 @@ object ResolveDefaultColumns extends QueryErrorsBase getDefaultValueExprOrNullLit(field, useNullAsDefault) } + def generateSQL(defaultValue: DefaultValue): String = { + if (defaultValue.getSql != null) { + defaultValue.getSql + } else { + ExpressionConverter.toV1(defaultValue.getExpression) match { + case Some(e) if !e.isInstanceOf[NonSQLExpression] => e.sql + case _ => throw SparkException.internalError(s"Can't generate SQL for $defaultValue") + } + } + } + /** * Parses and analyzes the DEFAULT column text in `field`, returning an error upon failure. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ProcedureParameterImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ProcedureParameterImpl.scala index 01ea48af1537c..ede9c0915ef24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ProcedureParameterImpl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ProcedureParameterImpl.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.internal.connector +import org.apache.spark.sql.connector.catalog.DefaultValue import org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter import org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter.Mode import org.apache.spark.sql.types.DataType @@ -25,5 +26,5 @@ case class ProcedureParameterImpl( mode: Mode, name: String, dataType: DataType, - defaultValueExpression: String, + defaultValue: DefaultValue, comment: String) extends ProcedureParameter diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala index f5a750a26a741..aaf88259c1d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala @@ -25,15 +25,16 @@ import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkNumberFormatExcept import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId -import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog} +import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, DefaultValue, Identifier, InMemoryCatalog} import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure} import org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter.Mode import org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter.Mode.{IN, INOUT, OUT} +import org.apache.spark.sql.connector.expressions.{Expression, GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.read.{LocalScan, Scan} import org.apache.spark.sql.errors.DataTypeErrors.{toSQLType, toSQLValue} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} +import org.apache.spark.sql.types.{DataType, DataTypes, IntegerType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String class ProcedureSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { @@ -348,6 +349,11 @@ class ProcedureSuite extends QueryTest with SharedSparkSession with BeforeAndAft } } + test("default values with expressions") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSumWithDefaultExpr) + checkAnswer(sql("CALL cat.ns.sum(5)"), Row(9) :: Nil) + } + object UnboundVoidProcedure extends UnboundProcedure { override def name: String = "void" override def description: String = "void procedure" @@ -647,13 +653,46 @@ class ProcedureSuite extends QueryTest with SharedSparkSession with BeforeAndAft } } + object UnboundSumWithDefaultExpr extends UnboundProcedure { + override def name: String = "sum" + override def description: String = "sum longs" + override def bind(inputType: StructType): BoundProcedure = SumWithDefaultExpr + } + + object SumWithDefaultExpr extends BoundProcedure { + override def name: String = "sum" + + override def description: String = "sum longs" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.LongType).build(), + ProcedureParameter.in("in2", DataTypes.LongType) + .defaultValue( + new GeneralScalarExpression( + "+", + Array[Expression](LiteralValue(1, IntegerType), LiteralValue(3, IntegerType)))) + .build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.LongType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val in1 = input.getLong(0) + val in2 = input.getLong(1) + val result = Result(outputType, Array(InternalRow(in1 + in2))) + Collections.singleton[Scan](result).iterator() + } + } + case class Result(readSchema: StructType, rows: Array[InternalRow]) extends LocalScan case class CustomParameterImpl( mode: Mode, name: String, dataType: DataType) extends ProcedureParameter { - override def defaultValueExpression: String = null + override def defaultValue: DefaultValue = null override def comment: String = null } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala index 3c4f5814375dc..727162d7df0a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} +import org.apache.spark.sql.catalyst.util.{ExpressionConverter, V2ExpressionBuilder} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructField, StructType} @@ -317,6 +318,12 @@ class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession { Some(new Predicate("=", Array(FieldReference("col"), LiteralValue(true, BooleanType))))) } + test("round trip conversion of math functions") { + checkRoundTripConversion( + v1Expr = Log10(Literal(100)), + v2Expr = new GeneralScalarExpression("LOG10", Array(LiteralValue(100, IntegerType)))) + } + /** * Translate the given Catalyst [[Expression]] into data source V2 [[Predicate]] * then verify against the given [[Predicate]]. @@ -326,4 +333,15 @@ class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession { DataSourceV2Strategy.translateFilterV2(catalystFilter) } } + + private def checkRoundTripConversion(v1Expr: Expression, v2Expr: V2Expression): Unit = { + val v2ExprActual = new V2ExpressionBuilder(v1Expr).build().getOrElse { + fail(s"can't convert to V2 expression: $v1Expr") + } + assert(v2ExprActual == v2Expr, "V1 expressions must match") + val v1ExprActual = ExpressionConverter.toV1(v2ExprActual).getOrElse { + fail(s"can't convert back to V1 expression: $v2ExprActual") + } + assert(v1ExprActual == v1Expr, "V1 expressions must match") + } }