Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49653][SQL] Single join for correlated scalar subqueries #48145

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2653,7 +2653,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
*/
private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = {
plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) {
case s @ ScalarSubquery(sub, _, exprId, _, _, _) if !sub.resolved =>
case s @ ScalarSubquery(sub, _, exprId, _, _, _, _) if !sub.resolved =>
resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId, _, _) if !sub.resolved =>
resolveSubQuery(e, outer)(Exists(_, _, exprId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -941,19 +941,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
messageParameters = Map.empty)
}

// SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
// are not part of the correlated columns.

// Collect the inner query expressions that are guaranteed to have a single value for each
// outer row. See comment on getCorrelatedEquivalentInnerExpressions.
val correlatedEquivalentExprs = getCorrelatedEquivalentInnerExpressions(query)
// Grouping expressions, except outer refs and constant expressions - grouping by an
// outer ref or a constant is always ok
val groupByExprs =
ExpressionSet(agg.groupingExpressions.filter(x => !x.isInstanceOf[OuterReference] &&
x.references.nonEmpty))
val nonEquivalentGroupByExprs = groupByExprs -- correlatedEquivalentExprs

val nonEquivalentGroupByExprs = nonEquivalentGroupbyCols(query, agg)
val invalidCols = if (!SQLConf.get.getConf(
SQLConf.LEGACY_SCALAR_SUBQUERY_ALLOW_GROUP_BY_NON_EQUALITY_CORRELATED_PREDICATE)) {
nonEquivalentGroupByExprs
Expand Down Expand Up @@ -1033,23 +1021,25 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
checkOuterReference(plan, expr)

expr match {
case ScalarSubquery(query, outerAttrs, _, _, _, _) =>
case ScalarSubquery(query, outerAttrs, _, _, _, _, _) =>
// Scalar subquery must return one column as output.
if (query.output.size != 1) {
throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size,
expr.origin)
}

if (outerAttrs.nonEmpty) {
cleanQueryInScalarSubquery(query) match {
case a: Aggregate => checkAggregateInScalarSubquery(outerAttrs, query, a)
case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(outerAttrs, query, a)
case p: LogicalPlan if p.maxRows.exists(_ <= 1) => // Ok
case other =>
expr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY",
messageParameters = Map.empty)
if (!SQLConf.get.getConf(SQLConf.SCALAR_SUBQUERY_USE_SINGLE_JOIN)) {
cleanQueryInScalarSubquery(query) match {
case a: Aggregate => checkAggregateInScalarSubquery(outerAttrs, query, a)
case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(outerAttrs, query, a)
case p: LogicalPlan if p.maxRows.exists(_ <= 1) => // Ok
case other =>
expr.failAnalysis(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for my curiosity, why there are two places checking scalar subqueries?

Copy link
Contributor Author

@agubichev agubichev Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second check (in lines 1060-1080 of the current file) checks the places where the scalar subquery is allowed (e.g., it can occur in the project or filter but not in the join predicate).
The first check ensures the property of the subquery itself (1 column and at most 1 row)

errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY",
messageParameters = Map.empty)
}
}

// Only certain operators are allowed to host subquery expression containing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,20 @@ object SubExprUtils extends PredicateHelper {
case _ => ExpressionSet().empty
}
}

// Returns grouping expressions of 'aggNode' of a scalar subquery that do not have equivalent
// columns in the outer query (bound by equality predicates like 'col = outer(c)').
// We use it to analyze whether a scalar subquery is guaranteed to return at most 1 row.
def nonEquivalentGroupbyCols(query: LogicalPlan, aggNode: Aggregate): ExpressionSet = {
val correlatedEquivalentExprs = getCorrelatedEquivalentInnerExpressions(query)
// Grouping expressions, except outer refs and constant expressions - grouping by an
// outer ref or a constant is always ok
val groupByExprs =
ExpressionSet(aggNode.groupingExpressions.filter(x => !x.isInstanceOf[OuterReference] &&
x.references.nonEmpty))
val nonEquivalentGroupByExprs = groupByExprs -- correlatedEquivalentExprs
nonEquivalentGroupByExprs
}
}

/**
Expand All @@ -371,14 +385,20 @@ object SubExprUtils extends PredicateHelper {
* case the subquery yields no row at all on empty input to the GROUP BY, which evaluates to NULL.
* It is set in PullupCorrelatedPredicates to true/false, before it is set its value is None.
* See constructLeftJoins in RewriteCorrelatedScalarSubquery for more details.
*
* 'needSingleJoin' is set to true if we can't guarantee that the correlated scalar subquery
* returns at most 1 row. For such subqueries we use a modification of an outer join called
* LeftSingle join. This value is set in PullupCorrelatedPredicates and used in
* RewriteCorrelatedScalarSubquery.
*/
case class ScalarSubquery(
plan: LogicalPlan,
outerAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
joinCond: Seq[Expression] = Seq.empty,
hint: Option[HintInfo] = None,
mayHaveCountBug: Option[Boolean] = None)
mayHaveCountBug: Option[Boolean] = None,
needSingleJoin: Option[Boolean] = None)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable {
override def dataType: DataType = {
if (!plan.schema.fields.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
case d: DynamicPruningSubquery => d
case s @ ScalarSubquery(
PhysicalOperation(projections, predicates, a @ Aggregate(group, _, child)),
_, _, _, _, mayHaveCountBug)
_, _, _, _, mayHaveCountBug, _)
if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) &&
mayHaveCountBug.nonEmpty && mayHaveCountBug.get =>
// This is a subquery with an aggregate that may suffer from a COUNT bug.
Expand Down Expand Up @@ -1985,7 +1985,8 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
}

private def canPushThrough(joinType: JoinType): Boolean = joinType match {
case _: InnerLike | LeftSemi | RightOuter | LeftOuter | LeftAnti | ExistenceJoin(_) => true
case _: InnerLike | LeftSemi | RightOuter | LeftOuter | LeftSingle |
LeftAnti | ExistenceJoin(_) => true
case _ => false
}

Expand Down Expand Up @@ -2025,7 +2026,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {

(leftFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
case LeftOuter | LeftExistence(_) =>
case LeftOuter | LeftSingle | LeftExistence(_) =>
// push down the left side only `where` condition
val newLeft = leftFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
Expand Down Expand Up @@ -2071,6 +2072,8 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And)

Join(newLeft, newRight, joinType, newJoinCond, hint)
// Do not move join predicates of a single join.
case LeftSingle => j

case other =>
throw SparkException.internalError(s"Unexpected join type: $other")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
}

// Don't replace ScalarSubquery if its plan is an aggregate that may suffer from a COUNT bug.
case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug)
case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug, _)
if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) &&
mayHaveCountBug.nonEmpty && mayHaveCountBug.get =>
s
Expand Down Expand Up @@ -1007,7 +1007,7 @@ object FoldablePropagation extends Rule[LogicalPlan] {
replaceFoldable(j.withNewChildren(newChildren).asInstanceOf[Join], foldableMap)
val missDerivedAttrsSet: AttributeSet = AttributeSet(newJoin.joinType match {
case _: InnerLike | LeftExistence(_) => Nil
case LeftOuter => newJoin.right.output
case LeftOuter | LeftSingle => newJoin.right.output
case RightOuter => newJoin.left.output
case FullOuter => newJoin.left.output ++ newJoin.right.output
case _ => Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ trait JoinSelectionHelper extends Logging {
)
}

def getBroadcastNestedLoopJoinBuildSide(hint: JoinHint): Option[BuildSide] = {
if (hintToNotBroadcastAndReplicateLeft(hint)) {
def getBroadcastNestedLoopJoinBuildSide(hint: JoinHint, joinType: JoinType): Option[BuildSide] = {
if (hintToNotBroadcastAndReplicateLeft(hint) || joinType == LeftSingle) {
Some(BuildRight)
} else if (hintToNotBroadcastAndReplicateRight(hint)) {
Some(BuildLeft)
Expand Down Expand Up @@ -375,7 +375,7 @@ trait JoinSelectionHelper extends Logging {

def canBuildBroadcastRight(joinType: JoinType): Boolean = {
joinType match {
case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true
case _: InnerLike | LeftOuter | LeftSingle | LeftSemi | LeftAnti | _: ExistenceJoin => true
case _ => false
}
}
Expand All @@ -389,7 +389,7 @@ trait JoinSelectionHelper extends Logging {

def canBuildShuffledHashJoinRight(joinType: JoinType): Boolean = {
joinType match {
case _: InnerLike | LeftOuter | FullOuter | RightOuter |
case _: InnerLike | LeftOuter | LeftSingle | FullOuter | RightOuter |
LeftSemi | LeftAnti | _: ExistenceJoin => true
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,31 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
(newPlan, newCond)
}

// Returns true if 'query' is guaranteed to return at most 1 row.
private def guaranteedToReturnOneRow(query: LogicalPlan): Boolean = {
Copy link
Contributor

@cloud-fan cloud-fan Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do some refactor to avoid duplicating code between this and CheckAnalysis?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, but the amount of savings is not spectacular, because in CheckAnalysis there is also a "legacy" path (that results in incorrect results) where the check is significantly weaker.

Once the single join is rolled out, we will just remove all extra checks from CheckAnalysis.

if (query.maxRows.exists(_ <= 1)) {
return true
}
val aggNode = query match {
case havingPart@Filter(_, aggPart: Aggregate) => Some(aggPart)
case aggPart: Aggregate => Some(aggPart)
// LIMIT 1 is handled above, this is for all other types of LIMITs
case Limit(_, aggPart: Aggregate) => Some(aggPart)
case Project(_, aggPart: Aggregate) => Some(aggPart)
case _: LogicalPlan => None
}
if (!aggNode.isDefined) {
return false
}
val aggregates = aggNode.get.expressions.flatMap(_.collect {
case a: AggregateExpression => a
})
if (aggregates.isEmpty) {
return false
}
nonEquivalentGroupbyCols(query, aggNode.get).isEmpty
}

private def rewriteSubQueries(plan: LogicalPlan): LogicalPlan = {
/**
* This function is used as a aid to enforce idempotency of pullUpCorrelatedPredicate rule.
Expand All @@ -481,7 +506,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
}

plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
case ScalarSubquery(sub, children, exprId, conditions, hint, mayHaveCountBugOld)
case ScalarSubquery(sub, children, exprId, conditions, hint,
mayHaveCountBugOld, needSingleJoinOld)
if children.nonEmpty =>

def mayHaveCountBugAgg(a: Aggregate): Boolean = {
Expand Down Expand Up @@ -527,8 +553,13 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
val (topPart, havingNode, aggNode) = splitSubquery(sub)
(aggNode.isDefined && aggNode.get.groupingExpressions.isEmpty)
}
val needSingleJoin = if (needSingleJoinOld.isDefined) {
needSingleJoinOld.get
} else {
conf.getConf(SQLConf.SCALAR_SUBQUERY_USE_SINGLE_JOIN) && !guaranteedToReturnOneRow(sub)
}
ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions),
hint, Some(mayHaveCountBug))
hint, Some(mayHaveCountBug), Some(needSingleJoin))
case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty =>
val (newPlan, newCond) = if (SQLConf.get.decorrelateInnerQueryEnabledForExistsIn) {
decorrelate(sub, plan, handleCountBug = true)
Expand Down Expand Up @@ -786,17 +817,22 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = {
val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]()
val newChild = subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug)) =>
case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug,
needSingleJoin)) =>
val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions)
val origOutput = query.output.head
// The subquery appears on the right side of the join, hence add its hint to the right
// of a join hint
val joinHint = JoinHint(None, subHint)

val resultWithZeroTups = evalSubqueryOnZeroTups(query)
val joinType = needSingleJoin match {
case Some(true) => LeftSingle
case _ => LeftOuter
}
lazy val planWithoutCountBug = Project(
currentChild.output :+ origOutput,
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), joinHint))
Join(currentChild, query, joinType, conditions.reduceOption(And), joinHint))

if (Utils.isTesting) {
assert(mayHaveCountBug.isDefined)
Expand Down Expand Up @@ -845,7 +881,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
currentChild.output :+ subqueryResultExpr,
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
LeftOuter, conditions.reduceOption(And), joinHint))
joinType, conditions.reduceOption(And), joinHint))

} else {
// CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
Expand Down Expand Up @@ -877,7 +913,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
currentChild.output :+ caseExpr,
Join(currentChild,
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
LeftOuter, conditions.reduceOption(And), joinHint))
joinType, conditions.reduceOption(And), joinHint))
}
}
}
Expand Down Expand Up @@ -1028,7 +1064,7 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] {

case p: LogicalPlan => p.transformExpressionsUpWithPruning(
_.containsPattern(SCALAR_SUBQUERY)) {
case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _)
case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _)
if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty =>
assert(p.projectList.size == 1)
stripOuterReferences(p.projectList).head
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ case object LeftAnti extends JoinType {
override def sql: String = "LEFT ANTI"
}

case object LeftSingle extends JoinType {
override def sql: String = "LEFT SINGLE"
}

case class ExistenceJoin(exists: Attribute) extends JoinType {
override def sql: String = {
// This join type is only used in the end of optimizer and physical plans, we will not
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,12 +559,12 @@ case class Join(

override def maxRows: Option[Long] = {
joinType match {
case Inner | Cross | FullOuter | LeftOuter | RightOuter
case Inner | Cross | FullOuter | LeftOuter | RightOuter | LeftSingle
if left.maxRows.isDefined && right.maxRows.isDefined =>
val leftMaxRows = BigInt(left.maxRows.get)
val rightMaxRows = BigInt(right.maxRows.get)
val minRows = joinType match {
case LeftOuter => leftMaxRows
case LeftOuter | LeftSingle => leftMaxRows
case RightOuter => rightMaxRows
case FullOuter => leftMaxRows + rightMaxRows
case _ => BigInt(0)
Expand All @@ -590,7 +590,7 @@ case class Join(
left.output :+ j.exists
case LeftExistence(_) =>
left.output
case LeftOuter =>
case LeftOuter | LeftSingle =>
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
Expand Down Expand Up @@ -627,7 +627,7 @@ case class Join(
left.constraints.union(right.constraints)
case LeftExistence(_) =>
left.constraints
case LeftOuter =>
case LeftOuter | LeftSingle =>
left.constraints
case RightOuter =>
right.constraints
Expand Down Expand Up @@ -659,7 +659,7 @@ case class Join(
var patterns = Seq(JOIN)
joinType match {
case _: InnerLike => patterns = patterns :+ INNER_LIKE_JOIN
case LeftOuter | FullOuter | RightOuter => patterns = patterns :+ OUTER_JOIN
case LeftOuter | FullOuter | RightOuter | LeftSingle => patterns = patterns :+ OUTER_JOIN
case LeftSemiOrAnti(_) => patterns = patterns :+ LEFT_SEMI_OR_ANTI_JOIN
case NaturalJoin(_) | UsingJoin(_, _) => patterns = patterns :+ NATURAL_LIKE_JOIN
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2488,6 +2488,12 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
summary = getSummary(context))
}

def scalarSubqueryReturnsMultipleRows(): SparkRuntimeException = {
new SparkRuntimeException(
errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS",
messageParameters = Map.empty)
}

def comparatorReturnsNull(firstValue: String, secondValue: String): Throwable = {
new SparkException(
errorClass = "COMPARATOR_RETURNS_NULL",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5058,6 +5058,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val SCALAR_SUBQUERY_USE_SINGLE_JOIN =
buildConf("spark.sql.optimizer.scalarSubqueryUseSingleJoin")
.internal()
.doc("When set to true, use LEFT_SINGLE join for correlated scalar subqueries where " +
"optimizer can't prove that only 1 row will be returned")
.version("4.0.0")
.booleanConf
.createWithDefault(true)

val ALLOW_SUBQUERY_EXPRESSIONS_IN_LAMBDAS_AND_HIGHER_ORDER_FUNCTIONS =
buildConf("spark.sql.analyzer.allowSubqueryExpressionsInLambdasOrHigherOrderFunctions")
.internal()
Expand Down
Loading