From 492d1b14c0d19fa89b9ce9c0e48fc0e4c120b70c Mon Sep 17 00:00:00 2001 From: Anton Okolnychyi Date: Thu, 19 Sep 2024 11:09:40 +0200 Subject: [PATCH] [SPARK-48782][SQL] Add support for executing procedures in catalogs ### What changes were proposed in this pull request? This PR adds support for executing procedures in catalogs. ### Why are the changes needed? These changes are needed per [discussed and voted](https://lists.apache.org/thread/w586jr53fxwk4pt9m94b413xyjr1v25m) SPIP tracked in [SPARK-44167](https://issues.apache.org/jira/browse/SPARK-44167). ### Does this PR introduce _any_ user-facing change? Yes. This PR adds CALL commands. ### How was this patch tested? This PR comes with tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47943 from aokolnychyi/spark-48782. Authored-by: Anton Okolnychyi Signed-off-by: Wenchen Fan --- .../resources/error/error-conditions.json | 6 + docs/sql-ref-ansi-compliance.md | 1 + .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 1 + .../sql/catalyst/parser/SqlBaseParser.g4 | 5 + .../procedures/ProcedureParameter.java | 5 + .../catalog/procedures/UnboundProcedure.java | 6 + .../sql/catalyst/analysis/Analyzer.scala | 65 +- .../catalyst/analysis/AnsiTypeCoercion.scala | 1 + .../sql/catalyst/analysis/CheckAnalysis.scala | 8 + .../sql/catalyst/analysis/TypeCoercion.scala | 16 + .../spark/sql/catalyst/analysis/package.scala | 6 +- .../catalyst/analysis/v2ResolutionPlans.scala | 17 +- .../sql/catalyst/parser/AstBuilder.scala | 22 + .../logical/ExecutableDuringAnalysis.scala | 28 + .../plans/logical/FunctionBuilderBase.scala | 36 +- .../catalyst/plans/logical/MultiResult.scala | 30 + .../catalyst/plans/logical/v2Commands.scala | 67 +- .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../sql/catalyst/trees/TreePatterns.scala | 1 + .../catalog/CatalogV2Implicits.scala | 7 + .../sql/errors/QueryCompilationErrors.scala | 7 + .../connector/catalog/InMemoryCatalog.scala | 19 +- .../catalyst/analysis/InvokeProcedures.scala | 71 ++ .../spark/sql/execution/MultiResultExec.scala | 36 + .../spark/sql/execution/SparkStrategies.scala | 2 + .../sql/execution/command/commands.scala | 11 +- .../datasources/v2/DataSourceV2Strategy.scala | 6 +- .../datasources/v2/ExplainOnlySparkPlan.scala | 38 + .../internal/BaseSessionStateBuilder.scala | 3 +- .../sql-tests/results/ansi/keywords.sql.out | 2 + .../sql-tests/results/keywords.sql.out | 1 + .../spark/sql/connector/ProcedureSuite.scala | 654 ++++++++++++++++++ .../ThriftServerWithSparkContextSuite.scala | 2 +- .../sql/hive/HiveSessionStateBuilder.scala | 3 +- 34 files changed, 1162 insertions(+), 22 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ExecutableDuringAnalysis.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MultiResult.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 6463cc2c12da7..72985de6631f0 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1456,6 +1456,12 @@ ], "sqlState" : "2203G" }, + "FAILED_TO_LOAD_ROUTINE" : { + "message" : [ + "Failed to load routine ." + ], + "sqlState" : "38000" + }, "FAILED_TO_PARSE_TOO_COMPLEX" : { "message" : [ "The statement, including potential SQL functions and referenced views, was too complex to parse.", diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index fff6906457f7d..12dff1e325c49 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -426,6 +426,7 @@ Below is a list of all the keywords in Spark SQL. |BY|non-reserved|non-reserved|reserved| |BYTE|non-reserved|non-reserved|non-reserved| |CACHE|non-reserved|non-reserved|non-reserved| +|CALL|reserved|non-reserved|reserved| |CALLED|non-reserved|non-reserved|non-reserved| |CASCADE|non-reserved|non-reserved|non-reserved| |CASE|reserved|non-reserved|reserved| diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index e704f9f58b964..de28041acd41f 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -146,6 +146,7 @@ BUCKETS: 'BUCKETS'; BY: 'BY'; BYTE: 'BYTE'; CACHE: 'CACHE'; +CALL: 'CALL'; CALLED: 'CALLED'; CASCADE: 'CASCADE'; CASE: 'CASE'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index f13dde773496a..e591a43b84d1a 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -298,6 +298,10 @@ statement LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN (OPTIONS options=propertyList)? #createIndex | DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex + | CALL identifierReference + LEFT_PAREN + (functionArgument (COMMA functionArgument)*)? + RIGHT_PAREN #call | unsupportedHiveNativeCommands .*? #failNativeCommand ; @@ -1851,6 +1855,7 @@ nonReserved | BY | BYTE | CACHE + | CALL | CALLED | CASCADE | CASE diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java index 90d531ae21892..18c76833c5879 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java @@ -32,6 +32,11 @@ */ @Evolving public interface ProcedureParameter { + /** + * A field metadata key that indicates whether an argument is passed by name. + */ + String BY_NAME_METADATA_KEY = "BY_NAME"; + /** * Creates a builder for an IN procedure parameter. * diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java index ee9a09055243b..1a91fd21bf07e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java @@ -35,6 +35,12 @@ public interface UnboundProcedure extends Procedure { * validate if the input types are compatible while binding or delegate that to Spark. Regardless, * Spark will always perform the final validation of the arguments and rearrange them as needed * based on {@link BoundProcedure#parameters() reported parameters}. + *

+ * The provided {@code inputType} is based on the procedure arguments. If an argument is passed + * by name, its metadata will indicate this with {@link ProcedureParameter#BY_NAME_METADATA_KEY} + * set to {@code true}. In such cases, the field name will match the name of the target procedure + * parameter. If the argument is not named, {@link ProcedureParameter#BY_NAME_METADATA_KEY} will + * not be set and the name will be assigned randomly. * * @param inputType the input types to bind to * @return the bound procedure that is most suitable for the given input types diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0164af945ca28..9e5b1d1254c87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.{Failure, Random, Success, Try} -import org.apache.spark.{SparkException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ @@ -50,6 +50,7 @@ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition} import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, ScalarFunction, UnboundFunction} +import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -310,6 +311,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ExtractGenerator :: ResolveGenerate :: ResolveFunctions :: + ResolveProcedures :: + BindProcedures :: ResolveTableSpec :: ResolveAliases :: ResolveSubquery :: @@ -2611,6 +2614,66 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } + /** + * A rule that resolves procedures. + */ + object ResolveProcedures extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(UNRESOLVED_PROCEDURE), ruleId) { + case Call(UnresolvedProcedure(CatalogAndIdentifier(catalog, ident)), args, execute) => + val procedureCatalog = catalog.asProcedureCatalog + val procedure = load(procedureCatalog, ident) + Call(ResolvedProcedure(procedureCatalog, ident, procedure), args, execute) + } + + private def load(catalog: ProcedureCatalog, ident: Identifier): UnboundProcedure = { + try { + catalog.loadProcedure(ident) + } catch { + case e: Exception if !e.isInstanceOf[SparkThrowable] => + val nameParts = catalog.name +: ident.asMultipartIdentifier + throw QueryCompilationErrors.failedToLoadRoutineError(nameParts, e) + } + } + } + + /** + * A rule that binds procedures to the input types and rearranges arguments as needed. + */ + object BindProcedures extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case Call(ResolvedProcedure(catalog, ident, unbound: UnboundProcedure), args, execute) + if args.forall(_.resolved) => + val inputType = extractInputType(args) + val bound = unbound.bind(inputType) + validateParameterModes(bound) + val rearrangedArgs = NamedParametersSupport.defaultRearrange(bound, args) + Call(ResolvedProcedure(catalog, ident, bound), rearrangedArgs, execute) + } + + private def extractInputType(args: Seq[Expression]): StructType = { + val fields = args.zipWithIndex.map { + case (NamedArgumentExpression(name, value), _) => + StructField(name, value.dataType, value.nullable, byNameMetadata) + case (arg, index) => + StructField(s"param$index", arg.dataType, arg.nullable) + } + StructType(fields) + } + + private def byNameMetadata: Metadata = { + new MetadataBuilder() + .putBoolean(ProcedureParameter.BY_NAME_METADATA_KEY, value = true) + .build() + } + + private def validateParameterModes(procedure: BoundProcedure): Unit = { + procedure.parameters.find(_.mode != ProcedureParameter.Mode.IN).foreach { param => + throw SparkException.internalError(s"Unsupported parameter mode: ${param.mode}") + } + } + } + /** * This rule resolves and rewrites subqueries inside expressions. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 17b1c4e249f57..3afe0ec8e9a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -77,6 +77,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { override def typeCoercionRules: List[Rule[LogicalPlan]] = UnpivotCoercion :: WidenSetOperationTypes :: + ProcedureArgumentCoercion :: new AnsiCombinedTypeCoercionRule( CollationTypeCasts :: InConversion :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 752ff49e1f90d..5a9d5cd87ecc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -676,6 +676,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB varName, c.defaultExpr.originalSQL) + case c: Call if c.resolved && c.bound && c.checkArgTypes().isFailure => + c.checkArgTypes() match { + case mismatch: TypeCheckResult.DataTypeMismatch => + c.dataTypeMismatch("CALL", mismatch) + case _ => + throw SparkException.internalError("Invalid input for procedure") + } + case _ => // Falls back to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 08c5b3531b4c8..5983346ff1e27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeAnyCollation} @@ -202,6 +203,20 @@ abstract class TypeCoercionBase { } } + /** + * A type coercion rule that implicitly casts procedure arguments to expected types. + */ + object ProcedureArgumentCoercion extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case c @ Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args, _) if c.resolved => + val expectedDataTypes = procedure.parameters.map(_.dataType) + val coercedArgs = args.zip(expectedDataTypes).map { + case (arg, expectedType) => implicitCast(arg, expectedType).getOrElse(arg) + } + c.copy(args = coercedArgs) + } + } + /** * Widens the data types of the [[Unpivot]] values. */ @@ -838,6 +853,7 @@ object TypeCoercion extends TypeCoercionBase { override def typeCoercionRules: List[Rule[LogicalPlan]] = UnpivotCoercion :: WidenSetOperationTypes :: + ProcedureArgumentCoercion :: new CombinedTypeCoercionRule( CollationTypeCasts :: InConversion :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index c0689eb121679..daab9e4d78bf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -67,9 +67,13 @@ package object analysis { } def dataTypeMismatch(expr: Expression, mismatch: DataTypeMismatch): Nothing = { + dataTypeMismatch(toSQLExpr(expr), mismatch) + } + + def dataTypeMismatch(sqlExpr: String, mismatch: DataTypeMismatch): Nothing = { throw new AnalysisException( errorClass = s"DATATYPE_MISMATCH.${mismatch.errorSubClass}", - messageParameters = mismatch.messageParameters + ("sqlExpr" -> toSQLExpr(expr)), + messageParameters = mismatch.messageParameters + ("sqlExpr" -> sqlExpr), origin = t.origin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala index ecdf40e87a894..dee78b8f03af4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala @@ -23,13 +23,14 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, LeafExpression, Unevaluable} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC} +import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC, UNRESOLVED_PROCEDURE} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, Table, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, ProcedureCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.catalog.functions.UnboundFunction +import org.apache.spark.sql.connector.catalog.procedures.Procedure import org.apache.spark.sql.types.{DataType, StructField} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ @@ -135,6 +136,12 @@ case class UnresolvedFunctionName( case class UnresolvedIdentifier(nameParts: Seq[String], allowTemp: Boolean = false) extends UnresolvedLeafNode +/** + * A procedure identifier that should be resolved into [[ResolvedProcedure]]. + */ +case class UnresolvedProcedure(nameParts: Seq[String]) extends UnresolvedLeafNode { + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_PROCEDURE) +} /** * A resolved leaf node whose statistics has no meaning. @@ -192,6 +199,12 @@ case class ResolvedFieldName(path: Seq[String], field: StructField) extends Fiel case class ResolvedFieldPosition(position: ColumnPosition) extends FieldPosition +case class ResolvedProcedure( + catalog: ProcedureCatalog, + ident: Identifier, + procedure: Procedure) extends LeafNodeWithoutStats { + override def output: Seq[Attribute] = Nil +} /** * A plan containing resolved persistent views. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cb0e0e35c3704..52529bb4b789b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5697,6 +5697,28 @@ class AstBuilder extends DataTypeAstBuilder ctx.EXISTS != null) } + /** + * Creates a plan for invoking a procedure. + * + * For example: + * {{{ + * CALL multi_part_name(v1, v2, ...); + * CALL multi_part_name(v1, param2 => v2, ...); + * CALL multi_part_name(param1 => v1, param2 => v2, ...); + * }}} + */ + override def visitCall(ctx: CallContext): LogicalPlan = withOrigin(ctx) { + val procedure = withIdentClause(ctx.identifierReference, UnresolvedProcedure) + val args = ctx.functionArgument.asScala.map { + case expr if expr.namedArgumentExpression != null => + val namedExpr = expr.namedArgumentExpression + NamedArgumentExpression(namedExpr.key.getText, expression(namedExpr.value)) + case expr => + expression(expr) + }.toSeq + Call(procedure, args) + } + /** * Create a TimestampAdd expression. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ExecutableDuringAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ExecutableDuringAnalysis.scala new file mode 100644 index 0000000000000..dc8dbf701f6a9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ExecutableDuringAnalysis.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +/** + * A logical plan node that requires execution during analysis. + */ +trait ExecutableDuringAnalysis extends LogicalPlan { + /** + * Returns the logical plan node that should be used for EXPLAIN. + */ + def stageForExplain(): LogicalPlan +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala index 4701f4ea1e172..75b2fcd3a5f34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Expression, NamedArgumentExpression} +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns +import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.util.ArrayImplicits._ @@ -122,12 +124,32 @@ object NamedParametersSupport { functionSignature: FunctionSignature, args: Seq[Expression], functionName: String): Seq[Expression] = { - val parameters: Seq[InputParameter] = functionSignature.parameters + defaultRearrange(functionName, functionSignature.parameters, args) + } + + final def defaultRearrange(procedure: BoundProcedure, args: Seq[Expression]): Seq[Expression] = { + defaultRearrange( + procedure.name, + procedure.parameters.map(toInputParameter).toSeq, + args) + } + + private def toInputParameter(param: ProcedureParameter): InputParameter = { + val defaultValue = Option(param.defaultValueExpression).map { expr => + ResolveDefaultColumns.analyze(param.name, param.dataType, expr, "CALL") + } + InputParameter(param.name, defaultValue) + } + + private def defaultRearrange( + routineName: String, + parameters: Seq[InputParameter], + args: Seq[Expression]): Seq[Expression] = { if (parameters.dropWhile(_.default.isEmpty).exists(_.default.isEmpty)) { - throw QueryCompilationErrors.unexpectedRequiredParameter(functionName, parameters) + throw QueryCompilationErrors.unexpectedRequiredParameter(routineName, parameters) } - val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args, functionName) + val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args, routineName) val namedParameters: Seq[InputParameter] = parameters.drop(positionalArgs.size) // The following loop checks for the following: @@ -140,12 +162,12 @@ object NamedParametersSupport { namedArgs.foreach { namedArg => val parameterName = namedArg.key if (!parameterNamesSet.contains(parameterName)) { - throw QueryCompilationErrors.unrecognizedParameterName(functionName, namedArg.key, + throw QueryCompilationErrors.unrecognizedParameterName(routineName, namedArg.key, parameterNamesSet.toSeq) } if (positionalParametersSet.contains(parameterName)) { throw QueryCompilationErrors.positionalAndNamedArgumentDoubleReference( - functionName, namedArg.key) + routineName, namedArg.key) } } @@ -154,7 +176,7 @@ object NamedParametersSupport { val validParameterSizes = Array.range(parameters.count(_.default.isEmpty), parameters.size + 1).toImmutableArraySeq throw QueryCompilationErrors.wrongNumArgsError( - functionName, validParameterSizes, args.length) + routineName, validParameterSizes, args.length) } // This constructs a map from argument name to value for argument rearrangement. @@ -168,7 +190,7 @@ object NamedParametersSupport { namedArgMap.getOrElse( param.name, if (param.default.isEmpty) { - throw QueryCompilationErrors.requiredParameterNotFound(functionName, param.name, index) + throw QueryCompilationErrors.requiredParameterNotFound(routineName, param.name, index) } else { param.default.get } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MultiResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MultiResult.scala new file mode 100644 index 0000000000000..f249e5c87eba2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MultiResult.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class MultiResult(children: Seq[LogicalPlan]) extends LogicalPlan { + + override def output: Seq[Attribute] = children.lastOption.map(_.output).getOrElse(Nil) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): MultiResult = { + copy(children = newChildren) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index fdd43404e1d98..b465e0e11612f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -19,17 +19,22 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.{SparkIllegalArgumentException, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, UnresolvedException, ViewSchemaMode} +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedException, UnresolvedProcedure, ViewSchemaMode} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.FunctionResource import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, MetadataAttribute, NamedExpression, UnaryExpression, Unevaluable, V2ExpressionUtils} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.trees.BinaryLike -import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, RowDeltaUtils, WriteDeltaProjections} +import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils, RowDeltaUtils, WriteDeltaProjections} +import org.apache.spark.sql.catalyst.util.TypeUtils.{ordinalNumber, toSQLExpr} import org.apache.spark.sql.connector.catalog._ +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper} +import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.write.{DeltaWrite, RowLevelOperation, RowLevelOperationTable, SupportsDelta, Write} +import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructField, StructType} import org.apache.spark.util.ArrayImplicits._ @@ -1571,3 +1576,61 @@ case class SetVariable( override protected def withNewChildInternal(newChild: LogicalPlan): SetVariable = copy(sourceQuery = newChild) } + +/** + * The logical plan of the CALL statement. + */ +case class Call( + procedure: LogicalPlan, + args: Seq[Expression], + execute: Boolean = true) + extends UnaryNode with ExecutableDuringAnalysis { + + override def output: Seq[Attribute] = Nil + + override def child: LogicalPlan = procedure + + def bound: Boolean = procedure match { + case ResolvedProcedure(_, _, _: BoundProcedure) => true + case _ => false + } + + def checkArgTypes(): TypeCheckResult = { + require(resolved && bound, "can check arg types only after resolution and binding") + + val params = procedure match { + case ResolvedProcedure(_, _, bound: BoundProcedure) => bound.parameters + } + require(args.length == params.length, "number of args and params must match after binding") + + args.zip(params).zipWithIndex.collectFirst { + case ((arg, param), idx) + if !DataType.equalsIgnoreCompatibleNullability(arg.dataType, param.dataType) => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(idx), + "requiredType" -> toSQLType(param.dataType), + "inputSql" -> toSQLExpr(arg), + "inputType" -> toSQLType(arg.dataType))) + }.getOrElse(TypeCheckSuccess) + } + + override def simpleString(maxFields: Int): String = { + val name = procedure match { + case ResolvedProcedure(catalog, ident, _) => + s"${quoteIfNeeded(catalog.name)}.${ident.quoted}" + case UnresolvedProcedure(nameParts) => + nameParts.quoted + } + val argsString = truncatedString(args, ", ", maxFields) + s"Call $name($argsString)" + } + + override def stageForExplain(): Call = { + copy(execute = false) + } + + override protected def withNewChildInternal(newChild: LogicalPlan): Call = + copy(procedure = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index c70b43f0db173..b5556cbae7cd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -54,6 +54,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveDeserializer" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions" :: + "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveProcedures" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGenerate" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics" :: "org.apache.spark.sql.catalyst.analysis.ResolveHigherOrderFunctions" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 826ac52c2b817..0f1c98b53e0b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -157,6 +157,7 @@ object TreePattern extends Enumeration { // Unresolved Plan patterns (Alphabetically ordered) val UNRESOLVED_FUNC: Value = Value + val UNRESOLVED_PROCEDURE: Value = Value val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value val UNRESOLVED_TRANSPOSE: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 65bdae85be12a..282350dda67d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -126,6 +126,13 @@ private[sql] object CatalogV2Implicits { case _ => throw QueryCompilationErrors.missingCatalogAbilityError(plugin, "functions") } + + def asProcedureCatalog: ProcedureCatalog = plugin match { + case procedureCatalog: ProcedureCatalog => + procedureCatalog + case _ => + throw QueryCompilationErrors.missingCatalogAbilityError(plugin, "procedures") + } } implicit class NamespaceHelper(namespace: Array[String]) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index ad0e1d07bf93d..0b5255e95f073 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -853,6 +853,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat origin = origin) } + def failedToLoadRoutineError(nameParts: Seq[String], e: Exception): Throwable = { + new AnalysisException( + errorClass = "FAILED_TO_LOAD_ROUTINE", + messageParameters = Map("routineName" -> toSQLId(nameParts)), + cause = Some(e)) + } + def unresolvedRoutineError(name: FunctionIdentifier, searchPath: Seq[String]): Throwable = { new AnalysisException( errorClass = "UNRESOLVED_ROUTINE", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala index 8d8d2317f0986..411a88b8765f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala @@ -24,10 +24,13 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException, NoSuchNamespaceException} import org.apache.spark.sql.connector.catalog.functions.UnboundFunction +import org.apache.spark.sql.connector.catalog.procedures.UnboundProcedure -class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog { +class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog with ProcedureCatalog { protected val functions: util.Map[Identifier, UnboundFunction] = new ConcurrentHashMap[Identifier, UnboundFunction]() + protected val procedures: util.Map[Identifier, UnboundProcedure] = + new ConcurrentHashMap[Identifier, UnboundProcedure]() override protected def allNamespaces: Seq[Seq[String]] = { (tables.keySet.asScala.map(_.namespace.toSeq) ++ @@ -63,4 +66,18 @@ class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog { def clearFunctions(): Unit = { functions.clear() } + + override def loadProcedure(ident: Identifier): UnboundProcedure = { + val procedure = procedures.get(ident) + if (procedure == null) throw new RuntimeException("Procedure not found: " + ident) + procedure + } + + def createProcedure(ident: Identifier, procedure: UnboundProcedure): UnboundProcedure = { + procedures.put(ident, procedure) + } + + def clearProcedures(): Unit = { + procedures.clear() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala new file mode 100644 index 0000000000000..c7320d350a7ff --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.jdk.CollectionConverters.IteratorHasAsScala + +import org.apache.spark.SparkException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow} +import org.apache.spark.sql.catalyst.plans.logical.{Call, LocalRelation, LogicalPlan, MultiResult} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure +import org.apache.spark.sql.connector.read.{LocalScan, Scan} +import org.apache.spark.util.ArrayImplicits._ + +class InvokeProcedures(session: SparkSession) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case c: Call if c.resolved && c.bound && c.execute && c.checkArgTypes().isSuccess => + session.sessionState.optimizer.execute(c) match { + case Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args, _) => + invoke(procedure, args) + case _ => + throw SparkException.internalError("Unexpected plan for optimized CALL statement") + } + } + + private def invoke(procedure: BoundProcedure, args: Seq[Expression]): LogicalPlan = { + val input = toInternalRow(args) + val scanIterator = procedure.call(input) + val relations = scanIterator.asScala.map(toRelation).toSeq + relations match { + case Nil => LocalRelation(Nil) + case Seq(relation) => relation + case _ => MultiResult(relations) + } + } + + private def toRelation(scan: Scan): LogicalPlan = scan match { + case s: LocalScan => + val attrs = DataTypeUtils.toAttributes(s.readSchema) + val data = s.rows.toImmutableArraySeq + LocalRelation(attrs, data) + case _ => + throw SparkException.internalError( + s"Only local scans are temporarily supported as procedure output: ${scan.getClass.getName}") + } + + private def toInternalRow(args: Seq[Expression]): InternalRow = { + require(args.forall(_.foldable), "args must be foldable") + val values = args.map(_.eval()).toArray + new GenericInternalRow(values) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala new file mode 100644 index 0000000000000..c2b12b053c927 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class MultiResultExec(children: Seq[SparkPlan]) extends SparkPlan { + + override def output: Seq[Attribute] = children.lastOption.map(_.output).getOrElse(Nil) + + override protected def doExecute(): RDD[InternalRow] = { + children.lastOption.map(_.execute()).getOrElse(sparkContext.emptyRDD) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[SparkPlan]): MultiResultExec = { + copy(children = newChildren) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6d940a30619fb..aee735e48fc5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -1041,6 +1041,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case WriteFiles(child, fileFormat, partitionColumns, bucket, options, staticPartitions) => WriteFilesExec(planLater(child), fileFormat, partitionColumns, bucket, options, staticPartitions) :: Nil + case MultiResult(children) => + MultiResultExec(children.map(planLater)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index ea2736b2c1266..ea9d53190546e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, SupervisingCommand} +import org.apache.spark.sql.catalyst.plans.logical.{Command, ExecutableDuringAnalysis, LogicalPlan, SupervisingCommand} import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} import org.apache.spark.sql.connector.ExternalCommandRunner import org.apache.spark.sql.execution.{CommandExecutionMode, ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode} @@ -165,14 +165,19 @@ case class ExplainCommand( // Run through the optimizer to generate the physical plan. override def run(sparkSession: SparkSession): Seq[Row] = try { - val outputString = sparkSession.sessionState.executePlan(logicalPlan, CommandExecutionMode.SKIP) - .explainString(mode) + val stagedLogicalPlan = stageForAnalysis(logicalPlan) + val qe = sparkSession.sessionState.executePlan(stagedLogicalPlan, CommandExecutionMode.SKIP) + val outputString = qe.explainString(mode) Seq(Row(outputString)) } catch { case NonFatal(cause) => ("Error occurred during query planning: \n" + cause.getMessage).split("\n") .map(Row(_)).toImmutableArraySeq } + private def stageForAnalysis(plan: LogicalPlan): LogicalPlan = plan transform { + case p: ExecutableDuringAnalysis => p.stageForExplain() + } + def withTransformedSupervisedPlan(transformer: LogicalPlan => LogicalPlan): LogicalPlan = copy(logicalPlan = transformer(logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index d7f46c32f99a0..76cd33b815edd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -32,8 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, - IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder} import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDeleteV2, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TruncatableTable} import org.apache.spark.sql.connector.catalog.index.SupportsIndex import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} @@ -554,6 +553,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat systemScope, pattern) :: Nil + case c: Call => + ExplainOnlySparkPlan(c) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala new file mode 100644 index 0000000000000..bbf56eaa71184 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.LeafLike +import org.apache.spark.sql.execution.SparkPlan + +case class ExplainOnlySparkPlan(toExplain: LogicalPlan) extends SparkPlan with LeafLike[SparkPlan] { + + override def output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + toExplain.simpleString(maxFields) + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index a2539828733fc..0d0258f11efb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal import org.apache.spark.annotation.Unstable import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.artifact.ArtifactManager -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -206,6 +206,7 @@ abstract class BaseSessionStateBuilder( ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: new ResolveTranspose(session) +: + new InvokeProcedures(session) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out index 6497a46c68ccd..7c694503056ab 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out @@ -32,6 +32,7 @@ BUCKETS false BY false BYTE false CACHE false +CALL true CALLED false CASCADE false CASE true @@ -378,6 +379,7 @@ ANY AS AUTHORIZATION BOTH +CALL CASE CAST CHECK diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index 0dfd62599afa6..2c16d961b1313 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -32,6 +32,7 @@ BUCKETS false BY false BYTE false CACHE false +CALL false CALLED false CASCADE false CASE false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala new file mode 100644 index 0000000000000..e39a1b7ea340a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala @@ -0,0 +1,654 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util.Collections + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkNumberFormatException} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId +import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog} +import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure} +import org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter.Mode +import org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter.Mode.{IN, INOUT, OUT} +import org.apache.spark.sql.connector.read.{LocalScan, Scan} +import org.apache.spark.sql.errors.DataTypeErrors.{toSQLType, toSQLValue} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + +class ProcedureSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { + + before { + spark.conf.set(s"spark.sql.catalog.cat", classOf[InMemoryCatalog].getName) + } + + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.unsetConf(s"spark.sql.catalog.cat") + } + + private def catalog: InMemoryCatalog = { + val catalog = spark.sessionState.catalogManager.catalog("cat") + catalog.asInstanceOf[InMemoryCatalog] + } + + test("position arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(5, 5)"), Row(10) :: Nil) + } + + test("named arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(in2 => 3, in1 => 5)"), Row(8) :: Nil) + } + + test("position and named arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(3, in2 => 1)"), Row(4) :: Nil) + } + + test("foldable expressions") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(1 + 1, in2 => 2)"), Row(4) :: Nil) + checkAnswer(sql("CALL cat.ns.sum(in2 => 1, in1 => 2 + 1)"), Row(4) :: Nil) + checkAnswer(sql("CALL cat.ns.sum((1 + 1) * 2, in2 => (2 + 1) / 3)"), Row(5) :: Nil) + } + + test("type coercion") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundLongSum) + checkAnswer(sql("CALL cat.ns.sum(1, 2)"), Row(3) :: Nil) + checkAnswer(sql("CALL cat.ns.sum(1L, 2)"), Row(3) :: Nil) + checkAnswer(sql("CALL cat.ns.sum(1, 2L)"), Row(3) :: Nil) + } + + test("multiple output rows") { + catalog.createProcedure(Identifier.of(Array("ns"), "complex"), UnboundComplexProcedure) + checkAnswer( + sql("CALL cat.ns.complex('X', 'Y', 3)"), + Row(1, "X1", "Y1") :: Row(2, "X2", "Y2") :: Row(3, "X3", "Y3") :: Nil) + } + + test("parameters with default values") { + catalog.createProcedure(Identifier.of(Array("ns"), "complex"), UnboundComplexProcedure) + checkAnswer(sql("CALL cat.ns.complex()"), Row(1, "A1", "B1") :: Nil) + checkAnswer(sql("CALL cat.ns.complex('X', 'Y')"), Row(1, "X1", "Y1") :: Nil) + } + + test("parameters with invalid default values") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundInvalidDefaultProcedure) + checkError( + exception = intercept[AnalysisException]( + sql("CALL cat.ns.sum()") + ), + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", + parameters = Map( + "statement" -> "CALL", + "colName" -> toSQLId("in2"), + "defaultValue" -> toSQLValue("B"), + "expectedType" -> toSQLType("INT"), + "actualType" -> toSQLType("STRING"))) + } + + test("IDENTIFIER") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer( + spark.sql("CALL IDENTIFIER(:p1)(1, 2)", Map("p1" -> "cat.ns.sum")), + Row(3) :: Nil) + } + + test("parameterized statements") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer( + spark.sql("CALL cat.ns.sum(?, ?)", Array(2, 3)), + Row(5) :: Nil) + } + + test("undefined procedure") { + checkError( + exception = intercept[AnalysisException]( + sql("CALL cat.non_exist(1, 2)") + ), + sqlState = Some("38000"), + condition = "FAILED_TO_LOAD_ROUTINE", + parameters = Map("routineName" -> "`cat`.`non_exist`") + ) + } + + test("non-procedure catalog") { + withSQLConf("spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) { + checkError( + exception = intercept[AnalysisException]( + sql("CALL testcat.procedure(1, 2)") + ), + condition = "_LEGACY_ERROR_TEMP_1184", + parameters = Map("plugin" -> "testcat", "ability" -> "procedures") + ) + } + } + + test("too many arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException]( + sql("CALL cat.ns.sum(1, 2, 3)") + ), + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + parameters = Map( + "functionName" -> toSQLId("sum"), + "expectedNum" -> "2", + "actualNum" -> "3", + "docroot" -> SPARK_DOC_ROOT)) + } + + test("custom default catalog") { + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "cat") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val df = sql("CALL ns.sum(1, 2)") + checkAnswer(df, Row(3) :: Nil) + } + } + + test("custom default catalog and namespace") { + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "cat") { + catalog.createNamespace(Array("ns"), Collections.emptyMap) + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + sql("USE ns") + val df = sql("CALL sum(1, 2)") + checkAnswer(df, Row(3) :: Nil) + } + } + + test("required parameter not found") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum()") + }, + condition = "REQUIRED_PARAMETER_NOT_FOUND", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"), + "index" -> "0")) + } + + test("conflicting position and named parameter assignments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(1, in1 => 2)") + }, + condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.BOTH_POSITIONAL_AND_NAMED", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"))) + } + + test("duplicate named parameter assignments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(in1 => 1, in1 => 2)") + }, + condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"))) + } + + test("unknown parameter name") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(in1 => 1, in5 => 2)") + }, + condition = "UNRECOGNIZED_PARAMETER_NAME", + parameters = Map( + "routineName" -> toSQLId("sum"), + "argumentName" -> toSQLId("in5"), + "proposal" -> (toSQLId("in1") + " " + toSQLId("in2")))) + } + + test("position parameter after named parameter") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(in1 => 1, 2)") + }, + condition = "UNEXPECTED_POSITIONAL_ARGUMENT", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"))) + } + + test("invalid argument type") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val call = "CALL cat.ns.sum(1, TIMESTAMP '2016-11-15 20:54:00.000')" + checkError( + exception = intercept[AnalysisException] { + sql(call) + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "CALL", + "paramIndex" -> "second", + "inputSql" -> "\"TIMESTAMP '2016-11-15 20:54:00'\"", + "inputType" -> toSQLType("TIMESTAMP"), + "requiredType" -> toSQLType("INT")), + context = ExpectedContext(fragment = call, start = 0, stop = call.length - 1)) + } + + test("malformed input to implicit cast") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val call = "CALL cat.ns.sum('A', 2)" + checkError( + exception = intercept[SparkNumberFormatException]( + sql(call) + ), + condition = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> toSQLValue("A"), + "sourceType" -> toSQLType("STRING"), + "targetType" -> toSQLType("INT")), + context = ExpectedContext(fragment = call, start = 0, stop = call.length - 1)) + } + + test("required parameters after optional") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundInvalidSum) + val e = intercept[SparkException] { + sql("CALL cat.ns.sum(in2 => 1)") + } + assert(e.getMessage.contains("required arguments should come before optional arguments")) + } + + test("INOUT parameters are not supported") { + catalog.createProcedure(Identifier.of(Array("ns"), "procedure"), UnboundInoutProcedure) + val e = intercept[SparkException] { + sql("CALL cat.ns.procedure(1)") + } + assert(e.getMessage.contains(" Unsupported parameter mode: INOUT")) + } + + test("OUT parameters are not supported") { + catalog.createProcedure(Identifier.of(Array("ns"), "procedure"), UnboundOutProcedure) + val e = intercept[SparkException] { + sql("CALL cat.ns.procedure(1)") + } + assert(e.getMessage.contains("Unsupported parameter mode: OUT")) + } + + test("EXPLAIN") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundNonExecutableSum) + val explain1 = sql("EXPLAIN CALL cat.ns.sum(5, 5)").head().get(0) + assert(explain1.toString.contains("cat.ns.sum(5, 5)")) + val explain2 = sql("EXPLAIN EXTENDED CALL cat.ns.sum(10, 10)").head().get(0) + assert(explain2.toString.contains("cat.ns.sum(10, 10)")) + } + + test("void procedure") { + catalog.createProcedure(Identifier.of(Array("ns"), "proc"), UnboundVoidProcedure) + checkAnswer(sql("CALL cat.ns.proc('A', 'B')"), Nil) + } + + test("multi-result procedure") { + catalog.createProcedure(Identifier.of(Array("ns"), "proc"), UnboundMultiResultProcedure) + checkAnswer(sql("CALL cat.ns.proc()"), Row("last") :: Nil) + } + + test("invalid input to struct procedure") { + catalog.createProcedure(Identifier.of(Array("ns"), "proc"), UnboundStructProcedure) + val actualType = + StructType(Seq( + StructField("X", DataTypes.DateType, nullable = false), + StructField("Y", DataTypes.IntegerType, nullable = false))) + val expectedType = StructProcedure.parameters.head.dataType + val call = "CALL cat.ns.proc(named_struct('X', DATE '2011-11-11', 'Y', 2), 'VALUE')" + checkError( + exception = intercept[AnalysisException](sql(call)), + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "CALL", + "paramIndex" -> "first", + "inputSql" -> "\"named_struct(X, DATE '2011-11-11', Y, 2)\"", + "inputType" -> toSQLType(actualType), + "requiredType" -> toSQLType(expectedType)), + context = ExpectedContext(fragment = call, start = 0, stop = call.length - 1)) + } + + test("save execution summary") { + withTable("summary") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val result = sql("CALL cat.ns.sum(1, 2)") + result.write.saveAsTable("summary") + checkAnswer(spark.table("summary"), Row(3) :: Nil) + } + } + + object UnboundVoidProcedure extends UnboundProcedure { + override def name: String = "void" + override def description: String = "void procedure" + override def bind(inputType: StructType): BoundProcedure = VoidProcedure + } + + object VoidProcedure extends BoundProcedure { + override def name: String = "void" + + override def description: String = "void procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.StringType).build(), + ProcedureParameter.in("in2", DataTypes.StringType).build() + ) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + Collections.emptyIterator + } + } + + object UnboundMultiResultProcedure extends UnboundProcedure { + override def name: String = "multi" + override def description: String = "multi-result procedure" + override def bind(inputType: StructType): BoundProcedure = MultiResultProcedure + } + + object MultiResultProcedure extends BoundProcedure { + override def name: String = "multi" + + override def description: String = "multi-result procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array() + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val scans = java.util.Arrays.asList[Scan]( + Result( + new StructType().add("out", DataTypes.IntegerType), + Array(InternalRow(1))), + Result( + new StructType().add("out", DataTypes.StringType), + Array(InternalRow(UTF8String.fromString("last")))) + ) + scans.iterator() + } + } + + object UnboundNonExecutableSum extends UnboundProcedure { + override def name: String = "sum" + override def description: String = "sum integers" + override def bind(inputType: StructType): BoundProcedure = Sum + } + + object NonExecutableSum extends BoundProcedure { + override def name: String = "sum" + + override def description: String = "sum integers" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).build() + ) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundSum extends UnboundProcedure { + override def name: String = "sum" + override def description: String = "sum integers" + override def bind(inputType: StructType): BoundProcedure = Sum + } + + object Sum extends BoundProcedure { + override def name: String = "sum" + + override def description: String = "sum integers" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val in1 = input.getInt(0) + val in2 = input.getInt(1) + val result = Result(outputType, Array(InternalRow(in1 + in2))) + Collections.singleton[Scan](result).iterator() + } + } + + object UnboundLongSum extends UnboundProcedure { + override def name: String = "long_sum" + override def description: String = "sum longs" + override def bind(inputType: StructType): BoundProcedure = LongSum + } + + object LongSum extends BoundProcedure { + override def name: String = "long_sum" + + override def description: String = "sum longs" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.LongType).build(), + ProcedureParameter.in("in2", DataTypes.LongType).build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.LongType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val in1 = input.getLong(0) + val in2 = input.getLong(1) + val result = Result(outputType, Array(InternalRow(in1 + in2))) + Collections.singleton[Scan](result).iterator() + } + } + + object UnboundInvalidSum extends UnboundProcedure { + override def name: String = "invalid" + override def description: String = "sum integers" + override def bind(inputType: StructType): BoundProcedure = InvalidSum + } + + object InvalidSum extends BoundProcedure { + override def name: String = "invalid" + + override def description: String = "sum integers" + + override def isDeterministic: Boolean = false + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).defaultValue("1").build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundInvalidDefaultProcedure extends UnboundProcedure { + override def name: String = "sum" + override def description: String = "invalid default value procedure" + override def bind(inputType: StructType): BoundProcedure = InvalidDefaultProcedure + } + + object InvalidDefaultProcedure extends BoundProcedure { + override def name: String = "sum" + + override def description: String = "invalid default value procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).defaultValue("10").build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).defaultValue("'B'").build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundComplexProcedure extends UnboundProcedure { + override def name: String = "complex" + override def description: String = "complex procedure" + override def bind(inputType: StructType): BoundProcedure = ComplexProcedure + } + + object ComplexProcedure extends BoundProcedure { + override def name: String = "complex" + + override def description: String = "complex procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.StringType).defaultValue("'A'").build(), + ProcedureParameter.in("in2", DataTypes.StringType).defaultValue("'B'").build(), + ProcedureParameter.in("in3", DataTypes.IntegerType).defaultValue("1 + 1 - 1").build() + ) + + def outputType: StructType = new StructType() + .add("out1", DataTypes.IntegerType) + .add("out2", DataTypes.StringType) + .add("out3", DataTypes.StringType) + + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val in1 = input.getString(0) + val in2 = input.getString(1) + val in3 = input.getInt(2) + + val rows = (1 to in3).map { index => + val v1 = UTF8String.fromString(s"$in1$index") + val v2 = UTF8String.fromString(s"$in2$index") + InternalRow(index, v1, v2) + }.toArray + + val result = Result(outputType, rows) + Collections.singleton[Scan](result).iterator() + } + } + + object UnboundStructProcedure extends UnboundProcedure { + override def name: String = "struct_input" + override def description: String = "struct procedure" + override def bind(inputType: StructType): BoundProcedure = StructProcedure + } + + object StructProcedure extends BoundProcedure { + override def name: String = "struct_input" + + override def description: String = "struct procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter + .in( + "in1", + StructType(Seq( + StructField("nested1", DataTypes.IntegerType), + StructField("nested2", DataTypes.StringType)))) + .build(), + ProcedureParameter.in("in2", DataTypes.StringType).build() + ) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + Collections.emptyIterator + } + } + + object UnboundInoutProcedure extends UnboundProcedure { + override def name: String = "procedure" + override def description: String = "inout procedure" + override def bind(inputType: StructType): BoundProcedure = InoutProcedure + } + + object InoutProcedure extends BoundProcedure { + override def name: String = "procedure" + + override def description: String = "inout procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + CustomParameterImpl(INOUT, "in1", DataTypes.IntegerType) + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundOutProcedure extends UnboundProcedure { + override def name: String = "procedure" + override def description: String = "out procedure" + override def bind(inputType: StructType): BoundProcedure = OutProcedure + } + + object OutProcedure extends BoundProcedure { + override def name: String = "procedure" + + override def description: String = "out procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + CustomParameterImpl(IN, "in1", DataTypes.IntegerType), + CustomParameterImpl(OUT, "out1", DataTypes.IntegerType) + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + case class Result(readSchema: StructType, rows: Array[InternalRow]) extends LocalScan + + case class CustomParameterImpl( + mode: Mode, + name: String, + dataType: DataType) extends ProcedureParameter { + override def defaultValueExpression: String = null + override def comment: String = null + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 4bc4116a23da7..dcf3bd8c71731 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 44c1ecd6902ce..dbeb8607facc2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose} import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, InvalidUDFClassException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -95,6 +95,7 @@ class HiveSessionStateBuilder( new EvalSubqueriesForTimeTravel +: new DetermineTableStats(session) +: new ResolveTranspose(session) +: + new InvokeProcedures(session) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =