Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50892][SQL]Add UnionLoopExec, physical operator for recursion, to perform execution of recursive queries #49955

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
74aefca
Apply Milan's already existing changes
Feb 14, 2025
fe8f5d0
Delete sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/opti…
Pajaraja Feb 14, 2025
5101b1a
Add space to make uniform code
Feb 14, 2025
22ced7c
Make changes according to part of Wenchen's comments
Feb 24, 2025
ebda64d
Seperate global and local limits in recursive CTEs
Feb 27, 2025
eef24d9
Fix compile error caused by typo
Mar 3, 2025
dda4e54
Add unionloop pruning
Mar 3, 2025
e1d4932
Fix pruning and regenerate golden files for different types of limits
Mar 4, 2025
c968b0d
Remove debug output
Mar 4, 2025
8187004
Separate UnionLoopExec into separate file
Mar 4, 2025
7558695
Add skip shuffle when the recursion is simple
Mar 4, 2025
a73d3c6
Stylistic changes
Mar 4, 2025
1554707
Update common/utils/src/main/resources/error/error-conditions.json
Pajaraja Mar 5, 2025
bddeade
Update sql/core/src/main/scala/org/apache/spark/sql/execution/Recursi…
Pajaraja Mar 5, 2025
13f25d4
Make small changes according to Wenchen's comments
Mar 5, 2025
6a15bb1
Change LocalLimit and GlobalLimit handling
Mar 5, 2025
ada9f8a
Remove debug output
Mar 5, 2025
2875c8f
Revert avoiding caching simple queries and pruning for union loop
Mar 6, 2025
322bbb6
Remove useless import
Mar 6, 2025
27024a2
Merge branch 'master' into UnionLoopExecCont
Mar 6, 2025
37f13f6
Regenerate golden files
Mar 6, 2025
328d328
Make changes according to Peter's and Wenchen's comments
Mar 7, 2025
6b7b558
Optimization thinking
Mar 12, 2025
baf1e00
Revert "Optimization thinking"
Mar 12, 2025
a82885f
Revert changes to limit, introduce RecursionRowLimit and some new tes…
Mar 18, 2025
4aa8065
Remove LocalLimit Node above UnionLoop
Mar 18, 2025
06676e5
Remove unnecessary variable in LimitPushDown and revert coalescing th…
Mar 18, 2025
bc712cb
Update common/utils/src/main/resources/error/error-conditions.json
Pajaraja Mar 18, 2025
36636f5
Update sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLo…
Pajaraja Mar 18, 2025
3480f7b
Update sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLo…
Pajaraja Mar 18, 2025
56027f9
Update sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLo…
Pajaraja Mar 18, 2025
99b71ba
Update common/utils/src/main/resources/error/error-conditions.json
Pajaraja Mar 18, 2025
edeffa7
Update sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLo…
Pajaraja Mar 18, 2025
904f0bd
Make changes according to Wenchen's comments.
Mar 18, 2025
1f1020f
Change recursion row limit and change golden file testcase break new …
Mar 18, 2025
622b66a
Change way we find count to be faster; remove inconsitent test from G…
Mar 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -4447,6 +4447,18 @@
],
"sqlState" : "38000"
},
"RECURSION_LEVEL_LIMIT_EXCEEDED" : {
"message" : [
"Recursion level limit <levelLimit> reached but query has not exhausted, try increasing 'spark.sql.cteRecursionLevelLimit'"
],
"sqlState" : "42836"
},
"RECURSION_ROW_LIMIT_EXCEEDED" : {
"message" : [
"Recursion row limit <rowLimit> reached but query has not exhausted, try increasing 'spark.sql.cteRecursionRowLimit'"
],
"sqlState" : "42836"
},
"RECURSIVE_CTE_IN_LEGACY_MODE" : {
"message" : [
"Recursive definitions cannot be used in legacy CTE precedence mode (spark.sql.legacy.ctePrecedencePolicy=LEGACY)."
Expand Down Expand Up @@ -5206,6 +5218,12 @@
],
"sqlState" : "42846"
},
"UNION_NOT_SUPPORTED_IN_RECURSIVE_CTE" : {
"message" : [
"The UNION operator is not yet supported within recursive common table expressions (WITH clauses that refer to themselves, directly or indirectly). Please use UNION ALL instead."
],
"sqlState" : "42836"
},
"UNKNOWN_PRIMITIVE_TYPE_IN_VARIANT" : {
"message" : [
"Unknown primitive type with id <id> was found in a variant value."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
// and we exclude those rows from the current iteration result.
case alias @ SubqueryAlias(_,
Distinct(Union(Seq(anchor, recursion), false, false))) =>
cteDef.failAnalysis(
errorClass = "UNION_NOT_SUPPORTED_IN_RECURSIVE_CTE",
messageParameters = Map.empty)
if (!anchor.resolved) {
cteDef
} else {
Expand All @@ -126,6 +129,9 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
colNames,
Distinct(Union(Seq(anchor, recursion), false, false))
)) =>
cteDef.failAnalysis(
errorClass = "UNION_NOT_SUPPORTED_IN_RECURSIVE_CTE",
messageParameters = Map.empty)
if (!anchor.resolved) {
cteDef
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,13 @@ object LimitPushDown extends Rule[LogicalPlan] {
case LocalLimit(exp, u: Union) =>
LocalLimit(exp, u.copy(children = u.children.map(maybePushLocalLimit(exp, _))))

// If limit node is present, we should propagate it down to UnionLoop, so that it is later
// propagated to UnionLoopExec.
case LocalLimit(IntegerLiteral(limit), p @ Project(_, ul: UnionLoop)) =>
p.copy(child = ul.copy(limit = Some(limit)))
case LocalLimit(IntegerLiteral(limit), ul: UnionLoop) =>
ul.copy(limit = Some(limit))

// Add extra limits below JOIN:
// 1. For LEFT OUTER and RIGHT OUTER JOIN, we push limits to the left and right sides
// respectively if join condition is not empty.
Expand Down Expand Up @@ -1032,6 +1039,10 @@ object ColumnPruning extends Rule[LogicalPlan] {
p
}

// TODO: Pruning `UnionLoop`s needs to take into account both the outer `Project` and the inner
// `UnionLoopRef` nodes.
case p @ Project(_, _: UnionLoop) => p

// Prune unnecessary window expressions
case p @ Project(_, w: Window) if !w.windowOutputSet.subsetOf(p.references) =>
val windowExprs = w.windowExpressions.filter(p.references.contains)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4537,6 +4537,22 @@ object SQLConf {
.checkValues(LegacyBehaviorPolicy.values.map(_.toString))
.createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString)

val CTE_RECURSION_LEVEL_LIMIT = buildConf("spark.sql.cteRecursionLevelLimit")
.doc("Maximum level of recursion that is allowed while executing a recursive CTE definition." +
"If a query does not get exhausted before reaching this limit it fails. Use -1 for " +
"unlimited.")
.version("4.1.0")
.intConf
.createWithDefault(100)

val CTE_RECURSION_ROW_LIMIT = buildConf("spark.sql.cteRecursionRowLimit")
.doc("Maximum number of rows that can be returned when executing a recursive CTE definition." +
"If a query does not get exhausted before reaching this limit it fails. Use -1 for " +
"unlimited.")
.version("4.1.0")
.intConf
.createWithDefault(1000000)

val LEGACY_INLINE_CTE_IN_COMMANDS = buildConf("spark.sql.legacy.inlineCTEInCommands")
.internal()
.doc("If true, always inline the CTE relations for the queries in commands. This is the " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
GlobalLimitExec(child = planLater(child), offset = offset) :: Nil
case union: logical.Union =>
execution.UnionExec(union.children.map(planLater)) :: Nil
case u @ logical.UnionLoop(id, anchor, recursion, limit) =>
execution.UnionLoopExec(id, anchor, recursion, u.output, limit) :: Nil
case g @ logical.Generate(generator, _, outer, _, _, child) =>
execution.GenerateExec(
generator, g.requiredChildOutput, outer,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import scala.collection.mutable

import org.apache.spark.SparkException
import org.apache.spark.rdd.{EmptyRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, LogicalPlan, Union, UnionLoopRef}
import org.apache.spark.sql.classic.Dataset
import org.apache.spark.sql.execution.LogicalRDD.rewriteStatsAndConstraints
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf


/**
* The physical node for recursion. Currently only UNION ALL case is supported.
* For the details about the execution, look at the comment above doExecute function.
*
* A simple recursive query:
* {{{
* WITH RECURSIVE t(n) AS (
* SELECT 1
* UNION ALL
* SELECT n+1 FROM t WHERE n < 5)
* SELECT * FROM t;
* }}}
* Corresponding logical plan for the recursive query above:
* {{{
* WithCTE
* :- CTERelationDef 0, false
* : +- SubqueryAlias t
* : +- Project [1#0 AS n#3]
* : +- UnionLoop 0
* : :- Project [1 AS 1#0]
* : : +- OneRowRelation
* : +- Project [(n#1 + 1) AS (n + 1)#2]
* : +- Filter (n#1 < 5)
* : +- SubqueryAlias t
* : +- Project [1#0 AS n#1]
* : +- UnionLoopRef 0, [1#0], false
* +- Project [n#3]
* +- SubqueryAlias t
* +- CTERelationRef 0, true, [n#3], false, false
* }}}
*
* @param loopId This is id of the CTERelationDef containing the recursive query. Its value is
* first passed down to UnionLoop when creating it, and then to UnionLoopExec in
* SparkStrategies.
* @param anchor The logical plan of the initial element of the loop.
* @param recursion The logical plan that describes the recursion with an [[UnionLoopRef]] node.
* CTERelationRef, which is marked as recursive, gets substituted with
* [[UnionLoopRef]] in ResolveWithCTE.
* Both anchor and recursion are marked with @transient annotation, so that they
* are not serialized.
* @param output The output attributes of this loop.
* @param limit If defined, the total number of rows output by this operator will be bounded by
* limit.
* Its value is pushed down to UnionLoop in Optimizer in case LocalLimit node is
* present in the logical plan and then transferred to UnionLoopExec in
* SparkStrategies.
* Note here: limit can be applied in the main query calling the recursive CTE, and not
* inside the recursive term of recursive CTE.
*/
case class UnionLoopExec(
loopId: Long,
@transient anchor: LogicalPlan,
@transient recursion: LogicalPlan,
override val output: Seq[Attribute],
limit: Option[Int] = None) extends LeafExecNode {

override def innerChildren: Seq[QueryPlan[_]] = Seq(anchor, recursion)

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"numIterations" -> SQLMetrics.createMetric(sparkContext, "number of recursive iterations"))

/**
* This function executes the plan (optionally with appended limit node) and caches the result,
* with the caching mode specified in config.
*/
private def executeAndCacheAndCount(plan: LogicalPlan, currentLimit: Int) = {
// In case limit is defined, we create a (local) limit node above the plan and execute
// the newly created plan.
val planWithLimit = if (limit.isDefined) {
LocalLimit(Literal(currentLimit), plan)
} else {
plan
}
val df = Dataset.ofRows(session, planWithLimit)
val materializedDF = df.repartition()
val count = materializedDF.queryExecution.toRdd.count()
(materializedDF, count)
}

/**
* In the first iteration, anchor term is executed.
* Then, in each following iteration, the UnionLoopRef node is substituted with the plan from the
* previous iteration, and such plan is executed.
* After every iteration, the dataframe is materialized.
* The recursion stops when the generated dataframe is empty, or either the limit or
* the specified maximum depth from the config is reached.
*/
override protected def doExecute(): RDD[InternalRow] = {
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
val numOutputRows = longMetric("numOutputRows")
val numIterations = longMetric("numIterations")
val levelLimit = conf.getConf(SQLConf.CTE_RECURSION_LEVEL_LIMIT)
val rowLimit = conf.getConf(SQLConf.CTE_RECURSION_ROW_LIMIT)

// currentLimit is initialized from the limit argument, and in each step it is decreased by
// the number of rows generated in that step.
// If limit is not passed down, currentLimit is set to be zero and won't be considered in the
// condition of while loop down (limit.isEmpty will be true).
var currentLimit = limit.getOrElse(-1)

val unionChildren = mutable.ArrayBuffer.empty[LogicalRDD]

var (prevDF, prevCount) = executeAndCacheAndCount(anchor, currentLimit)

var currentLevel = 1

var currentNumRows = 0

var limitReached: Boolean = false

val numPartitions = prevDF.queryExecution.toRdd.partitions.length
// Main loop for obtaining the result of the recursive query.
while (prevCount > 0 && !limitReached) {

if (levelLimit != -1 && currentLevel > levelLimit) {
throw new SparkException(
errorClass = "RECURSION_LEVEL_LIMIT_EXCEEDED",
messageParameters = Map("levelLimit" -> levelLimit.toString),
cause = null)
}

// Inherit stats and constraints from the dataset of the previous iteration.
val prevPlan = LogicalRDD.fromDataset(prevDF.queryExecution.toRdd, prevDF, prevDF.isStreaming)
.newInstance()
unionChildren += prevPlan

currentNumRows += prevCount.toInt

if (limit.isDefined) {
currentLimit -= prevCount.toInt
if (currentLimit <= 0) {
limitReached = true
}
}

if (rowLimit != -1 && currentNumRows > rowLimit) {
throw new SparkException(
errorClass = "RECURSION_ROW_LIMIT_EXCEEDED",
messageParameters = Map("rowLimit" -> rowLimit.toString),
cause = null)
}

// Update metrics
numOutputRows += prevCount
numIterations += 1
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)

if (!limitReached) {
// the current plan is created by substituting UnionLoopRef node with the project node of
// the previous plan.
// This way we support only UNION ALL case. Additional case should be added for UNION case.
// One way of supporting UNION case can be seen at SPARK-24497 PR from Peter Toth.
val newRecursion = recursion.transform {
case r: UnionLoopRef =>
val logicalPlan = prevDF.logicalPlan
val optimizedPlan = prevDF.queryExecution.optimizedPlan
val (stats, constraints) = rewriteStatsAndConstraints(logicalPlan, optimizedPlan)
prevPlan.copy(output = r.output)(prevDF.sparkSession, stats, constraints)
}

val (df, count) = executeAndCacheAndCount(newRecursion, currentLimit)
prevDF = df
prevCount = count

currentLevel += 1
}
}

if (unionChildren.isEmpty) {
new EmptyRDD[InternalRow](sparkContext)
} else {
val df = {
if (unionChildren.length == 1) {
Dataset.ofRows(session, unionChildren.head)
} else {
Dataset.ofRows(session, Union(unionChildren.toSeq))
}
}
val coalescedDF = df.coalesce(numPartitions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we don't need to do coalesce for if (unionChildren.length == 1) branch.

coalescedDF.queryExecution.toRdd
}
}

override def doCanonicalize(): SparkPlan =
super.doCanonicalize().asInstanceOf[UnionLoopExec]
.copy(anchor = anchor.canonicalized, recursion = recursion.canonicalized)

override def verboseStringWithOperatorId(): String = {
s"""
|$formattedNodeName
|Loop id: $loopId
|${QueryPlan.generateFieldString("Output", output)}
|Limit: $limit
|""".stripMargin
}
}
Loading