Skip to content

Commit 492d1b1

Browse files
aokolnychyicloud-fan
authored andcommitted
[SPARK-48782][SQL] Add support for executing procedures in catalogs
### What changes were proposed in this pull request? This PR adds support for executing procedures in catalogs. ### Why are the changes needed? These changes are needed per [discussed and voted](https://lists.apache.org/thread/w586jr53fxwk4pt9m94b413xyjr1v25m) SPIP tracked in [SPARK-44167](https://issues.apache.org/jira/browse/SPARK-44167). ### Does this PR introduce _any_ user-facing change? Yes. This PR adds CALL commands. ### How was this patch tested? This PR comes with tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47943 from aokolnychyi/spark-48782. Authored-by: Anton Okolnychyi <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 3bdf146 commit 492d1b1

File tree

34 files changed

+1162
-22
lines changed

34 files changed

+1162
-22
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,6 +1456,12 @@
14561456
],
14571457
"sqlState" : "2203G"
14581458
},
1459+
"FAILED_TO_LOAD_ROUTINE" : {
1460+
"message" : [
1461+
"Failed to load routine <routineName>."
1462+
],
1463+
"sqlState" : "38000"
1464+
},
14591465
"FAILED_TO_PARSE_TOO_COMPLEX" : {
14601466
"message" : [
14611467
"The statement, including potential SQL functions and referenced views, was too complex to parse.",

docs/sql-ref-ansi-compliance.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ Below is a list of all the keywords in Spark SQL.
426426
|BY|non-reserved|non-reserved|reserved|
427427
|BYTE|non-reserved|non-reserved|non-reserved|
428428
|CACHE|non-reserved|non-reserved|non-reserved|
429+
|CALL|reserved|non-reserved|reserved|
429430
|CALLED|non-reserved|non-reserved|non-reserved|
430431
|CASCADE|non-reserved|non-reserved|non-reserved|
431432
|CASE|reserved|non-reserved|reserved|

sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ BUCKETS: 'BUCKETS';
146146
BY: 'BY';
147147
BYTE: 'BYTE';
148148
CACHE: 'CACHE';
149+
CALL: 'CALL';
149150
CALLED: 'CALLED';
150151
CASCADE: 'CASCADE';
151152
CASE: 'CASE';

sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,10 @@ statement
298298
LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN
299299
(OPTIONS options=propertyList)? #createIndex
300300
| DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex
301+
| CALL identifierReference
302+
LEFT_PAREN
303+
(functionArgument (COMMA functionArgument)*)?
304+
RIGHT_PAREN #call
301305
| unsupportedHiveNativeCommands .*? #failNativeCommand
302306
;
303307

@@ -1851,6 +1855,7 @@ nonReserved
18511855
| BY
18521856
| BYTE
18531857
| CACHE
1858+
| CALL
18541859
| CALLED
18551860
| CASCADE
18561861
| CASE

sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@
3232
*/
3333
@Evolving
3434
public interface ProcedureParameter {
35+
/**
36+
* A field metadata key that indicates whether an argument is passed by name.
37+
*/
38+
String BY_NAME_METADATA_KEY = "BY_NAME";
39+
3540
/**
3641
* Creates a builder for an IN procedure parameter.
3742
*

sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ public interface UnboundProcedure extends Procedure {
3535
* validate if the input types are compatible while binding or delegate that to Spark. Regardless,
3636
* Spark will always perform the final validation of the arguments and rearrange them as needed
3737
* based on {@link BoundProcedure#parameters() reported parameters}.
38+
* <p>
39+
* The provided {@code inputType} is based on the procedure arguments. If an argument is passed
40+
* by name, its metadata will indicate this with {@link ProcedureParameter#BY_NAME_METADATA_KEY}
41+
* set to {@code true}. In such cases, the field name will match the name of the target procedure
42+
* parameter. If the argument is not named, {@link ProcedureParameter#BY_NAME_METADATA_KEY} will
43+
* not be set and the name will be assigned randomly.
3844
*
3945
* @param inputType the input types to bind to
4046
* @return the bound procedure that is most suitable for the given input types

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer
2626
import scala.jdk.CollectionConverters._
2727
import scala.util.{Failure, Random, Success, Try}
2828

29-
import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
29+
import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException}
3030
import org.apache.spark.sql.AnalysisException
3131
import org.apache.spark.sql.catalyst._
3232
import org.apache.spark.sql.catalyst.catalog._
@@ -50,6 +50,7 @@ import org.apache.spark.sql.connector.catalog.{View => _, _}
5050
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
5151
import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition}
5252
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, ScalarFunction, UnboundFunction}
53+
import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure}
5354
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform}
5455
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
5556
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@@ -310,6 +311,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
310311
ExtractGenerator ::
311312
ResolveGenerate ::
312313
ResolveFunctions ::
314+
ResolveProcedures ::
315+
BindProcedures ::
313316
ResolveTableSpec ::
314317
ResolveAliases ::
315318
ResolveSubquery ::
@@ -2611,6 +2614,66 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
26112614
}
26122615
}
26132616

2617+
/**
2618+
* A rule that resolves procedures.
2619+
*/
2620+
object ResolveProcedures extends Rule[LogicalPlan] {
2621+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
2622+
_.containsPattern(UNRESOLVED_PROCEDURE), ruleId) {
2623+
case Call(UnresolvedProcedure(CatalogAndIdentifier(catalog, ident)), args, execute) =>
2624+
val procedureCatalog = catalog.asProcedureCatalog
2625+
val procedure = load(procedureCatalog, ident)
2626+
Call(ResolvedProcedure(procedureCatalog, ident, procedure), args, execute)
2627+
}
2628+
2629+
private def load(catalog: ProcedureCatalog, ident: Identifier): UnboundProcedure = {
2630+
try {
2631+
catalog.loadProcedure(ident)
2632+
} catch {
2633+
case e: Exception if !e.isInstanceOf[SparkThrowable] =>
2634+
val nameParts = catalog.name +: ident.asMultipartIdentifier
2635+
throw QueryCompilationErrors.failedToLoadRoutineError(nameParts, e)
2636+
}
2637+
}
2638+
}
2639+
2640+
/**
2641+
* A rule that binds procedures to the input types and rearranges arguments as needed.
2642+
*/
2643+
object BindProcedures extends Rule[LogicalPlan] {
2644+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
2645+
case Call(ResolvedProcedure(catalog, ident, unbound: UnboundProcedure), args, execute)
2646+
if args.forall(_.resolved) =>
2647+
val inputType = extractInputType(args)
2648+
val bound = unbound.bind(inputType)
2649+
validateParameterModes(bound)
2650+
val rearrangedArgs = NamedParametersSupport.defaultRearrange(bound, args)
2651+
Call(ResolvedProcedure(catalog, ident, bound), rearrangedArgs, execute)
2652+
}
2653+
2654+
private def extractInputType(args: Seq[Expression]): StructType = {
2655+
val fields = args.zipWithIndex.map {
2656+
case (NamedArgumentExpression(name, value), _) =>
2657+
StructField(name, value.dataType, value.nullable, byNameMetadata)
2658+
case (arg, index) =>
2659+
StructField(s"param$index", arg.dataType, arg.nullable)
2660+
}
2661+
StructType(fields)
2662+
}
2663+
2664+
private def byNameMetadata: Metadata = {
2665+
new MetadataBuilder()
2666+
.putBoolean(ProcedureParameter.BY_NAME_METADATA_KEY, value = true)
2667+
.build()
2668+
}
2669+
2670+
private def validateParameterModes(procedure: BoundProcedure): Unit = {
2671+
procedure.parameters.find(_.mode != ProcedureParameter.Mode.IN).foreach { param =>
2672+
throw SparkException.internalError(s"Unsupported parameter mode: ${param.mode}")
2673+
}
2674+
}
2675+
}
2676+
26142677
/**
26152678
* This rule resolves and rewrites subqueries inside expressions.
26162679
*

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
7777
override def typeCoercionRules: List[Rule[LogicalPlan]] =
7878
UnpivotCoercion ::
7979
WidenSetOperationTypes ::
80+
ProcedureArgumentCoercion ::
8081
new AnsiCombinedTypeCoercionRule(
8182
CollationTypeCasts ::
8283
InConversion ::

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
676676
varName,
677677
c.defaultExpr.originalSQL)
678678

679+
case c: Call if c.resolved && c.bound && c.checkArgTypes().isFailure =>
680+
c.checkArgTypes() match {
681+
case mismatch: TypeCheckResult.DataTypeMismatch =>
682+
c.dataTypeMismatch("CALL", mismatch)
683+
case _ =>
684+
throw SparkException.internalError("Invalid input for procedure")
685+
}
686+
679687
case _ => // Falls back to the following checks
680688
}
681689

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
2929
import org.apache.spark.sql.catalyst.rules.Rule
3030
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
3131
import org.apache.spark.sql.catalyst.types.DataTypeUtils
32+
import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
3233
import org.apache.spark.sql.errors.QueryCompilationErrors
3334
import org.apache.spark.sql.internal.SQLConf
3435
import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeAnyCollation}
@@ -202,6 +203,20 @@ abstract class TypeCoercionBase {
202203
}
203204
}
204205

206+
/**
207+
* A type coercion rule that implicitly casts procedure arguments to expected types.
208+
*/
209+
object ProcedureArgumentCoercion extends Rule[LogicalPlan] {
210+
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
211+
case c @ Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args, _) if c.resolved =>
212+
val expectedDataTypes = procedure.parameters.map(_.dataType)
213+
val coercedArgs = args.zip(expectedDataTypes).map {
214+
case (arg, expectedType) => implicitCast(arg, expectedType).getOrElse(arg)
215+
}
216+
c.copy(args = coercedArgs)
217+
}
218+
}
219+
205220
/**
206221
* Widens the data types of the [[Unpivot]] values.
207222
*/
@@ -838,6 +853,7 @@ object TypeCoercion extends TypeCoercionBase {
838853
override def typeCoercionRules: List[Rule[LogicalPlan]] =
839854
UnpivotCoercion ::
840855
WidenSetOperationTypes ::
856+
ProcedureArgumentCoercion ::
841857
new CombinedTypeCoercionRule(
842858
CollationTypeCasts ::
843859
InConversion ::

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,13 @@ package object analysis {
6767
}
6868

6969
def dataTypeMismatch(expr: Expression, mismatch: DataTypeMismatch): Nothing = {
70+
dataTypeMismatch(toSQLExpr(expr), mismatch)
71+
}
72+
73+
def dataTypeMismatch(sqlExpr: String, mismatch: DataTypeMismatch): Nothing = {
7074
throw new AnalysisException(
7175
errorClass = s"DATATYPE_MISMATCH.${mismatch.errorSubClass}",
72-
messageParameters = mismatch.messageParameters + ("sqlExpr" -> toSQLExpr(expr)),
76+
messageParameters = mismatch.messageParameters + ("sqlExpr" -> sqlExpr),
7377
origin = t.origin)
7478
}
7579

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable
2323
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
2424
import org.apache.spark.sql.catalyst.expressions.{Attribute, LeafExpression, Unevaluable}
2525
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
26-
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC}
26+
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC, UNRESOLVED_PROCEDURE}
2727
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
2828
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
29-
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, Table, TableCatalog}
29+
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, ProcedureCatalog, Table, TableCatalog}
3030
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
3131
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
3232
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
33+
import org.apache.spark.sql.connector.catalog.procedures.Procedure
3334
import org.apache.spark.sql.types.{DataType, StructField}
3435
import org.apache.spark.sql.util.CaseInsensitiveStringMap
3536
import org.apache.spark.util.ArrayImplicits._
@@ -135,6 +136,12 @@ case class UnresolvedFunctionName(
135136
case class UnresolvedIdentifier(nameParts: Seq[String], allowTemp: Boolean = false)
136137
extends UnresolvedLeafNode
137138

139+
/**
140+
* A procedure identifier that should be resolved into [[ResolvedProcedure]].
141+
*/
142+
case class UnresolvedProcedure(nameParts: Seq[String]) extends UnresolvedLeafNode {
143+
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_PROCEDURE)
144+
}
138145

139146
/**
140147
* A resolved leaf node whose statistics has no meaning.
@@ -192,6 +199,12 @@ case class ResolvedFieldName(path: Seq[String], field: StructField) extends Fiel
192199

193200
case class ResolvedFieldPosition(position: ColumnPosition) extends FieldPosition
194201

202+
case class ResolvedProcedure(
203+
catalog: ProcedureCatalog,
204+
ident: Identifier,
205+
procedure: Procedure) extends LeafNodeWithoutStats {
206+
override def output: Seq[Attribute] = Nil
207+
}
195208

196209
/**
197210
* A plan containing resolved persistent views.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5697,6 +5697,28 @@ class AstBuilder extends DataTypeAstBuilder
56975697
ctx.EXISTS != null)
56985698
}
56995699

5700+
/**
5701+
* Creates a plan for invoking a procedure.
5702+
*
5703+
* For example:
5704+
* {{{
5705+
* CALL multi_part_name(v1, v2, ...);
5706+
* CALL multi_part_name(v1, param2 => v2, ...);
5707+
* CALL multi_part_name(param1 => v1, param2 => v2, ...);
5708+
* }}}
5709+
*/
5710+
override def visitCall(ctx: CallContext): LogicalPlan = withOrigin(ctx) {
5711+
val procedure = withIdentClause(ctx.identifierReference, UnresolvedProcedure)
5712+
val args = ctx.functionArgument.asScala.map {
5713+
case expr if expr.namedArgumentExpression != null =>
5714+
val namedExpr = expr.namedArgumentExpression
5715+
NamedArgumentExpression(namedExpr.key.getText, expression(namedExpr.value))
5716+
case expr =>
5717+
expression(expr)
5718+
}.toSeq
5719+
Call(procedure, args)
5720+
}
5721+
57005722
/**
57015723
* Create a TimestampAdd expression.
57025724
*/
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.plans.logical
19+
20+
/**
21+
* A logical plan node that requires execution during analysis.
22+
*/
23+
trait ExecutableDuringAnalysis extends LogicalPlan {
24+
/**
25+
* Returns the logical plan node that should be used for EXPLAIN.
26+
*/
27+
def stageForExplain(): LogicalPlan
28+
}

0 commit comments

Comments
 (0)