Skip to content

Commit

Permalink
[SPARK-48782][SQL] Add support for executing procedures in catalogs
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR adds support for executing procedures in catalogs.

### Why are the changes needed?

These changes are needed per [discussed and voted](https://lists.apache.org/thread/w586jr53fxwk4pt9m94b413xyjr1v25m) SPIP tracked in [SPARK-44167](https://issues.apache.org/jira/browse/SPARK-44167).

### Does this PR introduce _any_ user-facing change?

Yes. This PR adds CALL commands.

### How was this patch tested?

This PR comes with tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #47943 from aokolnychyi/spark-48782.

Authored-by: Anton Okolnychyi <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
aokolnychyi authored and cloud-fan committed Sep 19, 2024
1 parent 3bdf146 commit 492d1b1
Show file tree
Hide file tree
Showing 34 changed files with 1,162 additions and 22 deletions.
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1456,6 +1456,12 @@
],
"sqlState" : "2203G"
},
"FAILED_TO_LOAD_ROUTINE" : {
"message" : [
"Failed to load routine <routineName>."
],
"sqlState" : "38000"
},
"FAILED_TO_PARSE_TOO_COMPLEX" : {
"message" : [
"The statement, including potential SQL functions and referenced views, was too complex to parse.",
Expand Down
1 change: 1 addition & 0 deletions docs/sql-ref-ansi-compliance.md
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ Below is a list of all the keywords in Spark SQL.
|BY|non-reserved|non-reserved|reserved|
|BYTE|non-reserved|non-reserved|non-reserved|
|CACHE|non-reserved|non-reserved|non-reserved|
|CALL|reserved|non-reserved|reserved|
|CALLED|non-reserved|non-reserved|non-reserved|
|CASCADE|non-reserved|non-reserved|non-reserved|
|CASE|reserved|non-reserved|reserved|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ BUCKETS: 'BUCKETS';
BY: 'BY';
BYTE: 'BYTE';
CACHE: 'CACHE';
CALL: 'CALL';
CALLED: 'CALLED';
CASCADE: 'CASCADE';
CASE: 'CASE';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ statement
LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN
(OPTIONS options=propertyList)? #createIndex
| DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex
| CALL identifierReference
LEFT_PAREN
(functionArgument (COMMA functionArgument)*)?
RIGHT_PAREN #call
| unsupportedHiveNativeCommands .*? #failNativeCommand
;

Expand Down Expand Up @@ -1851,6 +1855,7 @@ nonReserved
| BY
| BYTE
| CACHE
| CALL
| CALLED
| CASCADE
| CASE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
*/
@Evolving
public interface ProcedureParameter {
/**
* A field metadata key that indicates whether an argument is passed by name.
*/
String BY_NAME_METADATA_KEY = "BY_NAME";

/**
* Creates a builder for an IN procedure parameter.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ public interface UnboundProcedure extends Procedure {
* validate if the input types are compatible while binding or delegate that to Spark. Regardless,
* Spark will always perform the final validation of the arguments and rearrange them as needed
* based on {@link BoundProcedure#parameters() reported parameters}.
* <p>
* The provided {@code inputType} is based on the procedure arguments. If an argument is passed
* by name, its metadata will indicate this with {@link ProcedureParameter#BY_NAME_METADATA_KEY}
* set to {@code true}. In such cases, the field name will match the name of the target procedure
* parameter. If the argument is not named, {@link ProcedureParameter#BY_NAME_METADATA_KEY} will
* not be set and the name will be assigned randomly.
*
* @param inputType the input types to bind to
* @return the bound procedure that is most suitable for the given input types
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Random, Success, Try}

import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
Expand All @@ -50,6 +50,7 @@ import org.apache.spark.sql.connector.catalog.{View => _, _}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
Expand Down Expand Up @@ -310,6 +311,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ExtractGenerator ::
ResolveGenerate ::
ResolveFunctions ::
ResolveProcedures ::
BindProcedures ::
ResolveTableSpec ::
ResolveAliases ::
ResolveSubquery ::
Expand Down Expand Up @@ -2611,6 +2614,66 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}

/**
* A rule that resolves procedures.
*/
object ResolveProcedures extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
_.containsPattern(UNRESOLVED_PROCEDURE), ruleId) {
case Call(UnresolvedProcedure(CatalogAndIdentifier(catalog, ident)), args, execute) =>
val procedureCatalog = catalog.asProcedureCatalog
val procedure = load(procedureCatalog, ident)
Call(ResolvedProcedure(procedureCatalog, ident, procedure), args, execute)
}

private def load(catalog: ProcedureCatalog, ident: Identifier): UnboundProcedure = {
try {
catalog.loadProcedure(ident)
} catch {
case e: Exception if !e.isInstanceOf[SparkThrowable] =>
val nameParts = catalog.name +: ident.asMultipartIdentifier
throw QueryCompilationErrors.failedToLoadRoutineError(nameParts, e)
}
}
}

/**
* A rule that binds procedures to the input types and rearranges arguments as needed.
*/
object BindProcedures extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case Call(ResolvedProcedure(catalog, ident, unbound: UnboundProcedure), args, execute)
if args.forall(_.resolved) =>
val inputType = extractInputType(args)
val bound = unbound.bind(inputType)
validateParameterModes(bound)
val rearrangedArgs = NamedParametersSupport.defaultRearrange(bound, args)
Call(ResolvedProcedure(catalog, ident, bound), rearrangedArgs, execute)
}

private def extractInputType(args: Seq[Expression]): StructType = {
val fields = args.zipWithIndex.map {
case (NamedArgumentExpression(name, value), _) =>
StructField(name, value.dataType, value.nullable, byNameMetadata)
case (arg, index) =>
StructField(s"param$index", arg.dataType, arg.nullable)
}
StructType(fields)
}

private def byNameMetadata: Metadata = {
new MetadataBuilder()
.putBoolean(ProcedureParameter.BY_NAME_METADATA_KEY, value = true)
.build()
}

private def validateParameterModes(procedure: BoundProcedure): Unit = {
procedure.parameters.find(_.mode != ProcedureParameter.Mode.IN).foreach { param =>
throw SparkException.internalError(s"Unsupported parameter mode: ${param.mode}")
}
}
}

/**
* This rule resolves and rewrites subqueries inside expressions.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
override def typeCoercionRules: List[Rule[LogicalPlan]] =
UnpivotCoercion ::
WidenSetOperationTypes ::
ProcedureArgumentCoercion ::
new AnsiCombinedTypeCoercionRule(
CollationTypeCasts ::
InConversion ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
varName,
c.defaultExpr.originalSQL)

case c: Call if c.resolved && c.bound && c.checkArgTypes().isFailure =>
c.checkArgTypes() match {
case mismatch: TypeCheckResult.DataTypeMismatch =>
c.dataTypeMismatch("CALL", mismatch)
case _ =>
throw SparkException.internalError("Invalid input for procedure")
}

case _ => // Falls back to the following checks
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeAnyCollation}
Expand Down Expand Up @@ -202,6 +203,20 @@ abstract class TypeCoercionBase {
}
}

/**
* A type coercion rule that implicitly casts procedure arguments to expected types.
*/
object ProcedureArgumentCoercion extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case c @ Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args, _) if c.resolved =>
val expectedDataTypes = procedure.parameters.map(_.dataType)
val coercedArgs = args.zip(expectedDataTypes).map {
case (arg, expectedType) => implicitCast(arg, expectedType).getOrElse(arg)
}
c.copy(args = coercedArgs)
}
}

/**
* Widens the data types of the [[Unpivot]] values.
*/
Expand Down Expand Up @@ -838,6 +853,7 @@ object TypeCoercion extends TypeCoercionBase {
override def typeCoercionRules: List[Rule[LogicalPlan]] =
UnpivotCoercion ::
WidenSetOperationTypes ::
ProcedureArgumentCoercion ::
new CombinedTypeCoercionRule(
CollationTypeCasts ::
InConversion ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,13 @@ package object analysis {
}

def dataTypeMismatch(expr: Expression, mismatch: DataTypeMismatch): Nothing = {
dataTypeMismatch(toSQLExpr(expr), mismatch)
}

def dataTypeMismatch(sqlExpr: String, mismatch: DataTypeMismatch): Nothing = {
throw new AnalysisException(
errorClass = s"DATATYPE_MISMATCH.${mismatch.errorSubClass}",
messageParameters = mismatch.messageParameters + ("sqlExpr" -> toSQLExpr(expr)),
messageParameters = mismatch.messageParameters + ("sqlExpr" -> sqlExpr),
origin = t.origin)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, LeafExpression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC}
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC, UNRESOLVED_PROCEDURE}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, Table, TableCatalog}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, ProcedureCatalog, Table, TableCatalog}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.catalog.procedures.Procedure
import org.apache.spark.sql.types.{DataType, StructField}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -135,6 +136,12 @@ case class UnresolvedFunctionName(
case class UnresolvedIdentifier(nameParts: Seq[String], allowTemp: Boolean = false)
extends UnresolvedLeafNode

/**
* A procedure identifier that should be resolved into [[ResolvedProcedure]].
*/
case class UnresolvedProcedure(nameParts: Seq[String]) extends UnresolvedLeafNode {
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_PROCEDURE)
}

/**
* A resolved leaf node whose statistics has no meaning.
Expand Down Expand Up @@ -192,6 +199,12 @@ case class ResolvedFieldName(path: Seq[String], field: StructField) extends Fiel

case class ResolvedFieldPosition(position: ColumnPosition) extends FieldPosition

case class ResolvedProcedure(
catalog: ProcedureCatalog,
ident: Identifier,
procedure: Procedure) extends LeafNodeWithoutStats {
override def output: Seq[Attribute] = Nil
}

/**
* A plan containing resolved persistent views.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5697,6 +5697,28 @@ class AstBuilder extends DataTypeAstBuilder
ctx.EXISTS != null)
}

/**
* Creates a plan for invoking a procedure.
*
* For example:
* {{{
* CALL multi_part_name(v1, v2, ...);
* CALL multi_part_name(v1, param2 => v2, ...);
* CALL multi_part_name(param1 => v1, param2 => v2, ...);
* }}}
*/
override def visitCall(ctx: CallContext): LogicalPlan = withOrigin(ctx) {
val procedure = withIdentClause(ctx.identifierReference, UnresolvedProcedure)
val args = ctx.functionArgument.asScala.map {
case expr if expr.namedArgumentExpression != null =>
val namedExpr = expr.namedArgumentExpression
NamedArgumentExpression(namedExpr.key.getText, expression(namedExpr.value))
case expr =>
expression(expr)
}.toSeq
Call(procedure, args)
}

/**
* Create a TimestampAdd expression.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.plans.logical

/**
* A logical plan node that requires execution during analysis.
*/
trait ExecutableDuringAnalysis extends LogicalPlan {
/**
* Returns the logical plan node that should be used for EXPLAIN.
*/
def stageForExplain(): LogicalPlan
}
Loading

0 comments on commit 492d1b1

Please sign in to comment.