diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0164af945ca28..2ff174d88f750 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2653,7 +2653,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor */ private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) { - case s @ ScalarSubquery(sub, _, exprId, _, _, _) if !sub.resolved => + case s @ ScalarSubquery(sub, _, exprId, _, _, _, _) if !sub.resolved => resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId, _, _) if !sub.resolved => resolveSubQuery(e, outer)(Exists(_, _, exprId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a9fbe548ba39e..90bd80d441dec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -941,19 +941,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB messageParameters = Map.empty) } - // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns - // are not part of the correlated columns. - - // Collect the inner query expressions that are guaranteed to have a single value for each - // outer row. See comment on getCorrelatedEquivalentInnerExpressions. - val correlatedEquivalentExprs = getCorrelatedEquivalentInnerExpressions(query) - // Grouping expressions, except outer refs and constant expressions - grouping by an - // outer ref or a constant is always ok - val groupByExprs = - ExpressionSet(agg.groupingExpressions.filter(x => !x.isInstanceOf[OuterReference] && - x.references.nonEmpty)) - val nonEquivalentGroupByExprs = groupByExprs -- correlatedEquivalentExprs - + val nonEquivalentGroupByExprs = nonEquivalentGroupbyCols(query, agg) val invalidCols = if (!SQLConf.get.getConf( SQLConf.LEGACY_SCALAR_SUBQUERY_ALLOW_GROUP_BY_NON_EQUALITY_CORRELATED_PREDICATE)) { nonEquivalentGroupByExprs @@ -1033,7 +1021,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB checkOuterReference(plan, expr) expr match { - case ScalarSubquery(query, outerAttrs, _, _, _, _) => + case ScalarSubquery(query, outerAttrs, _, _, _, _, _) => // Scalar subquery must return one column as output. if (query.output.size != 1) { throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size, @@ -1041,15 +1029,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB } if (outerAttrs.nonEmpty) { - cleanQueryInScalarSubquery(query) match { - case a: Aggregate => checkAggregateInScalarSubquery(outerAttrs, query, a) - case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(outerAttrs, query, a) - case p: LogicalPlan if p.maxRows.exists(_ <= 1) => // Ok - case other => - expr.failAnalysis( - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY", - messageParameters = Map.empty) + if (!SQLConf.get.getConf(SQLConf.SCALAR_SUBQUERY_USE_SINGLE_JOIN)) { + cleanQueryInScalarSubquery(query) match { + case a: Aggregate => checkAggregateInScalarSubquery(outerAttrs, query, a) + case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(outerAttrs, query, a) + case p: LogicalPlan if p.maxRows.exists(_ <= 1) => // Ok + case other => + expr.failAnalysis( + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + "MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY", + messageParameters = Map.empty) + } } // Only certain operators are allowed to host subquery expression containing diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 174d32c73fc01..0c8253659dd56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -358,6 +358,20 @@ object SubExprUtils extends PredicateHelper { case _ => ExpressionSet().empty } } + + // Returns grouping expressions of 'aggNode' of a scalar subquery that do not have equivalent + // columns in the outer query (bound by equality predicates like 'col = outer(c)'). + // We use it to analyze whether a scalar subquery is guaranteed to return at most 1 row. + def nonEquivalentGroupbyCols(query: LogicalPlan, aggNode: Aggregate): ExpressionSet = { + val correlatedEquivalentExprs = getCorrelatedEquivalentInnerExpressions(query) + // Grouping expressions, except outer refs and constant expressions - grouping by an + // outer ref or a constant is always ok + val groupByExprs = + ExpressionSet(aggNode.groupingExpressions.filter(x => !x.isInstanceOf[OuterReference] && + x.references.nonEmpty)) + val nonEquivalentGroupByExprs = groupByExprs -- correlatedEquivalentExprs + nonEquivalentGroupByExprs + } } /** @@ -371,6 +385,11 @@ object SubExprUtils extends PredicateHelper { * case the subquery yields no row at all on empty input to the GROUP BY, which evaluates to NULL. * It is set in PullupCorrelatedPredicates to true/false, before it is set its value is None. * See constructLeftJoins in RewriteCorrelatedScalarSubquery for more details. + * + * 'needSingleJoin' is set to true if we can't guarantee that the correlated scalar subquery + * returns at most 1 row. For such subqueries we use a modification of an outer join called + * LeftSingle join. This value is set in PullupCorrelatedPredicates and used in + * RewriteCorrelatedScalarSubquery. */ case class ScalarSubquery( plan: LogicalPlan, @@ -378,7 +397,8 @@ case class ScalarSubquery( exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None, - mayHaveCountBug: Option[Boolean] = None) + mayHaveCountBug: Option[Boolean] = None, + needSingleJoin: Option[Boolean] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { override def dataType: DataType = { if (!plan.schema.fields.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6ceeeb9bfdf38..1f1bccec2af73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -338,7 +338,7 @@ abstract class Optimizer(catalogManager: CatalogManager) case d: DynamicPruningSubquery => d case s @ ScalarSubquery( PhysicalOperation(projections, predicates, a @ Aggregate(group, _, child)), - _, _, _, _, mayHaveCountBug) + _, _, _, _, mayHaveCountBug, _) if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) && mayHaveCountBug.nonEmpty && mayHaveCountBug.get => // This is a subquery with an aggregate that may suffer from a COUNT bug. @@ -1985,7 +1985,8 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } private def canPushThrough(joinType: JoinType): Boolean = joinType match { - case _: InnerLike | LeftSemi | RightOuter | LeftOuter | LeftAnti | ExistenceJoin(_) => true + case _: InnerLike | LeftSemi | RightOuter | LeftOuter | LeftSingle | + LeftAnti | ExistenceJoin(_) => true case _ => false } @@ -2025,7 +2026,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { (leftFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) - case LeftOuter | LeftExistence(_) => + case LeftOuter | LeftSingle | LeftExistence(_) => // push down the left side only `where` condition val newLeft = leftFilterConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) @@ -2071,6 +2072,8 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) Join(newLeft, newRight, joinType, newJoinCond, hint) + // Do not move join predicates of a single join. + case LeftSingle => j case other => throw SparkException.internalError(s"Unexpected join type: $other") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 3cdde622d51f7..1601d798283c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -90,7 +90,7 @@ object ConstantFolding extends Rule[LogicalPlan] { } // Don't replace ScalarSubquery if its plan is an aggregate that may suffer from a COUNT bug. - case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug) + case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug, _) if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) && mayHaveCountBug.nonEmpty && mayHaveCountBug.get => s @@ -1007,7 +1007,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { replaceFoldable(j.withNewChildren(newChildren).asInstanceOf[Join], foldableMap) val missDerivedAttrsSet: AttributeSet = AttributeSet(newJoin.joinType match { case _: InnerLike | LeftExistence(_) => Nil - case LeftOuter => newJoin.right.output + case LeftOuter | LeftSingle => newJoin.right.output case RightOuter => newJoin.left.output case FullOuter => newJoin.left.output ++ newJoin.right.output case _ => Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 9fc4873c248b5..6802adaa2ea24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -339,8 +339,8 @@ trait JoinSelectionHelper extends Logging { ) } - def getBroadcastNestedLoopJoinBuildSide(hint: JoinHint): Option[BuildSide] = { - if (hintToNotBroadcastAndReplicateLeft(hint)) { + def getBroadcastNestedLoopJoinBuildSide(hint: JoinHint, joinType: JoinType): Option[BuildSide] = { + if (hintToNotBroadcastAndReplicateLeft(hint) || joinType == LeftSingle) { Some(BuildRight) } else if (hintToNotBroadcastAndReplicateRight(hint)) { Some(BuildLeft) @@ -375,7 +375,7 @@ trait JoinSelectionHelper extends Logging { def canBuildBroadcastRight(joinType: JoinType): Boolean = { joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true + case _: InnerLike | LeftOuter | LeftSingle | LeftSemi | LeftAnti | _: ExistenceJoin => true case _ => false } } @@ -389,7 +389,7 @@ trait JoinSelectionHelper extends Logging { def canBuildShuffledHashJoinRight(joinType: JoinType): Boolean = { joinType match { - case _: InnerLike | LeftOuter | FullOuter | RightOuter | + case _: InnerLike | LeftOuter | LeftSingle | FullOuter | RightOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 1239a5dde1302..d9795cf338279 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -456,6 +456,31 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper (newPlan, newCond) } + // Returns true if 'query' is guaranteed to return at most 1 row. + private def guaranteedToReturnOneRow(query: LogicalPlan): Boolean = { + if (query.maxRows.exists(_ <= 1)) { + return true + } + val aggNode = query match { + case havingPart@Filter(_, aggPart: Aggregate) => Some(aggPart) + case aggPart: Aggregate => Some(aggPart) + // LIMIT 1 is handled above, this is for all other types of LIMITs + case Limit(_, aggPart: Aggregate) => Some(aggPart) + case Project(_, aggPart: Aggregate) => Some(aggPart) + case _: LogicalPlan => None + } + if (!aggNode.isDefined) { + return false + } + val aggregates = aggNode.get.expressions.flatMap(_.collect { + case a: AggregateExpression => a + }) + if (aggregates.isEmpty) { + return false + } + nonEquivalentGroupbyCols(query, aggNode.get).isEmpty + } + private def rewriteSubQueries(plan: LogicalPlan): LogicalPlan = { /** * This function is used as a aid to enforce idempotency of pullUpCorrelatedPredicate rule. @@ -481,7 +506,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper } plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { - case ScalarSubquery(sub, children, exprId, conditions, hint, mayHaveCountBugOld) + case ScalarSubquery(sub, children, exprId, conditions, hint, + mayHaveCountBugOld, needSingleJoinOld) if children.nonEmpty => def mayHaveCountBugAgg(a: Aggregate): Boolean = { @@ -527,8 +553,13 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper val (topPart, havingNode, aggNode) = splitSubquery(sub) (aggNode.isDefined && aggNode.get.groupingExpressions.isEmpty) } + val needSingleJoin = if (needSingleJoinOld.isDefined) { + needSingleJoinOld.get + } else { + conf.getConf(SQLConf.SCALAR_SUBQUERY_USE_SINGLE_JOIN) && !guaranteedToReturnOneRow(sub) + } ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), - hint, Some(mayHaveCountBug)) + hint, Some(mayHaveCountBug), Some(needSingleJoin)) case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty => val (newPlan, newCond) = if (SQLConf.get.decorrelateInnerQueryEnabledForExistsIn) { decorrelate(sub, plan, handleCountBug = true) @@ -786,7 +817,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = { val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]() val newChild = subqueries.foldLeft(child) { - case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug)) => + case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug, + needSingleJoin)) => val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions) val origOutput = query.output.head // The subquery appears on the right side of the join, hence add its hint to the right @@ -794,9 +826,13 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe val joinHint = JoinHint(None, subHint) val resultWithZeroTups = evalSubqueryOnZeroTups(query) + val joinType = needSingleJoin match { + case Some(true) => LeftSingle + case _ => LeftOuter + } lazy val planWithoutCountBug = Project( currentChild.output :+ origOutput, - Join(currentChild, query, LeftOuter, conditions.reduceOption(And), joinHint)) + Join(currentChild, query, joinType, conditions.reduceOption(And), joinHint)) if (Utils.isTesting) { assert(mayHaveCountBug.isDefined) @@ -845,7 +881,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe currentChild.output :+ subqueryResultExpr, Join(currentChild, Project(query.output :+ alwaysTrueExpr, query), - LeftOuter, conditions.reduceOption(And), joinHint)) + joinType, conditions.reduceOption(And), joinHint)) } else { // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. @@ -877,7 +913,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe currentChild.output :+ caseExpr, Join(currentChild, Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), - LeftOuter, conditions.reduceOption(And), joinHint)) + joinType, conditions.reduceOption(And), joinHint)) } } } @@ -1028,7 +1064,7 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { case p: LogicalPlan => p.transformExpressionsUpWithPruning( _.containsPattern(SCALAR_SUBQUERY)) { - case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _) + case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _) if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty => assert(p.projectList.size == 1) stripOuterReferences(p.projectList).head diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index d9da255eccc9d..41bba99673a2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -95,6 +95,10 @@ case object LeftAnti extends JoinType { override def sql: String = "LEFT ANTI" } +case object LeftSingle extends JoinType { + override def sql: String = "LEFT SINGLE" +} + case class ExistenceJoin(exists: Attribute) extends JoinType { override def sql: String = { // This join type is only used in the end of optimizer and physical plans, we will not diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 926027df4c74b..e82a38a3450cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -559,12 +559,12 @@ case class Join( override def maxRows: Option[Long] = { joinType match { - case Inner | Cross | FullOuter | LeftOuter | RightOuter + case Inner | Cross | FullOuter | LeftOuter | RightOuter | LeftSingle if left.maxRows.isDefined && right.maxRows.isDefined => val leftMaxRows = BigInt(left.maxRows.get) val rightMaxRows = BigInt(right.maxRows.get) val minRows = joinType match { - case LeftOuter => leftMaxRows + case LeftOuter | LeftSingle => leftMaxRows case RightOuter => rightMaxRows case FullOuter => leftMaxRows + rightMaxRows case _ => BigInt(0) @@ -590,7 +590,7 @@ case class Join( left.output :+ j.exists case LeftExistence(_) => left.output - case LeftOuter => + case LeftOuter | LeftSingle => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output @@ -627,7 +627,7 @@ case class Join( left.constraints.union(right.constraints) case LeftExistence(_) => left.constraints - case LeftOuter => + case LeftOuter | LeftSingle => left.constraints case RightOuter => right.constraints @@ -659,7 +659,7 @@ case class Join( var patterns = Seq(JOIN) joinType match { case _: InnerLike => patterns = patterns :+ INNER_LIKE_JOIN - case LeftOuter | FullOuter | RightOuter => patterns = patterns :+ OUTER_JOIN + case LeftOuter | FullOuter | RightOuter | LeftSingle => patterns = patterns :+ OUTER_JOIN case LeftSemiOrAnti(_) => patterns = patterns :+ LEFT_SEMI_OR_ANTI_JOIN case NaturalJoin(_) | UsingJoin(_, _) => patterns = patterns :+ NATURAL_LIKE_JOIN case _ => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 2ab86a5c5f03f..39e6efaff0896 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2488,6 +2488,12 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE summary = getSummary(context)) } + def scalarSubqueryReturnsMultipleRows(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS", + messageParameters = Map.empty) + } + def comparatorReturnsNull(firstValue: String, secondValue: String): Throwable = { new SparkException( errorClass = "COMPARATOR_RETURNS_NULL", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 094fb8f050bc8..991e503fe6b11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5058,6 +5058,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SCALAR_SUBQUERY_USE_SINGLE_JOIN = + buildConf("spark.sql.optimizer.scalarSubqueryUseSingleJoin") + .internal() + .doc("When set to true, use LEFT_SINGLE join for correlated scalar subqueries where " + + "optimizer can't prove that only 1 row will be returned") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + val ALLOW_SUBQUERY_EXPRESSIONS_IN_LAMBDAS_AND_HIGHER_ORDER_FUNCTIONS = buildConf("spark.sql.analyzer.allowSubqueryExpressionsInLambdasOrHigherOrderFunctions") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6d940a30619fb..41dc729daae85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -269,8 +269,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + def canMerge(joinType: JoinType): Boolean = joinType match { + case LeftSingle => false + case _ => true + } + def createSortMergeJoin() = { - if (RowOrdering.isOrderable(leftKeys)) { + if (canMerge(joinType) && RowOrdering.isOrderable(leftKeys)) { Some(Seq(joins.SortMergeJoinExec( leftKeys, rightKeys, joinType, nonEquiCond, planLater(left), planLater(right)))) } else { @@ -297,7 +302,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // This join could be very slow or OOM // Build the smaller side unless the join requires a particular build side // (e.g. NO_BROADCAST_AND_REPLICATION hint) - val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint) + val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint, joinType) val buildSide = requiredBuildSide.getOrElse(getSmallerSide(left, right)) Seq(joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, j.condition)) @@ -390,7 +395,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // This join could be very slow or OOM // Build the desired side unless the join requires a particular build side // (e.g. NO_BROADCAST_AND_REPLICATION hint) - val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint) + val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint, joinType) val buildSide = requiredBuildSide.getOrElse(desiredBuildSide) Seq(joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index df4d895867586..5f2638655c37c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -30,7 +30,7 @@ case class PlanAdaptiveSubqueries( def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning( _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { - case expressions.ScalarSubquery(_, _, exprId, _, _, _) => + case expressions.ScalarSubquery(_, _, exprId, _, _, _, _) => val subquery = SubqueryExec.createForScalarSubquery( s"subquery#${exprId.id}", subqueryMap(exprId.id)) execution.ScalarSubquery(subquery, exprId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 6dd41aca3a5e1..a7292ee1f8fa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ArrayImplicits._ @@ -63,13 +64,15 @@ case class BroadcastNestedLoopJoinExec( override def outputPartitioning: Partitioning = (joinType, buildSide) match { case (_: InnerLike, _) | (LeftOuter, BuildRight) | (RightOuter, BuildLeft) | - (LeftSemi, BuildRight) | (LeftAnti, BuildRight) => streamed.outputPartitioning + (LeftSingle, BuildRight) | (LeftSemi, BuildRight) | (LeftAnti, BuildRight) => + streamed.outputPartitioning case _ => super.outputPartitioning } override def outputOrdering: Seq[SortOrder] = (joinType, buildSide) match { case (_: InnerLike, _) | (LeftOuter, BuildRight) | (RightOuter, BuildLeft) | - (LeftSemi, BuildRight) | (LeftAnti, BuildRight) => streamed.outputOrdering + (LeftSingle, BuildRight) | (LeftSemi, BuildRight) | (LeftAnti, BuildRight) => + streamed.outputOrdering case _ => Nil } @@ -87,7 +90,7 @@ case class BroadcastNestedLoopJoinExec( joinType match { case _: InnerLike => left.output ++ right.output - case LeftOuter => + case LeftOuter | LeftSingle => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output @@ -135,8 +138,14 @@ case class BroadcastNestedLoopJoinExec( * * LeftOuter with BuildRight * RightOuter with BuildLeft + * LeftSingle with BuildRight + * + * For the (LeftSingle, BuildRight) case we pass 'singleJoin' flag that + * makes sure there is at most 1 matching build row per every probe tuple. */ - private def outerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + private def outerJoin( + relation: Broadcast[Array[InternalRow]], + singleJoin: Boolean = false): RDD[InternalRow] = { streamed.execute().mapPartitionsInternal { streamedIter => val buildRows = relation.value val joinedRow = new JoinedRow @@ -167,6 +176,9 @@ case class BroadcastNestedLoopJoinExec( resultRow = joinedRow(streamRow, buildRows(nextIndex)) nextIndex += 1 if (boundCondition(resultRow)) { + if (foundMatch && singleJoin) { + throw QueryExecutionErrors.scalarSubqueryReturnsMultipleRows(); + } foundMatch = true return true } @@ -382,12 +394,18 @@ case class BroadcastNestedLoopJoinExec( innerJoin(broadcastedRelation) case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => outerJoin(broadcastedRelation) + case (LeftSingle, BuildRight) => + outerJoin(broadcastedRelation, singleJoin = true) case (LeftSemi, _) => leftExistenceJoin(broadcastedRelation, exists = true) case (LeftAnti, _) => leftExistenceJoin(broadcastedRelation, exists = false) case (_: ExistenceJoin, _) => existenceJoin(broadcastedRelation) + case (LeftSingle, BuildLeft) => + throw new IllegalArgumentException( + s"BroadcastNestedLoopJoin should not use the left side as build when " + + s"executing a LeftSingle join") case _ => /** * LeftOuter with BuildLeft @@ -410,7 +428,7 @@ case class BroadcastNestedLoopJoinExec( override def supportCodegen: Boolean = (joinType, buildSide) match { case (_: InnerLike, _) | (LeftOuter, BuildRight) | (RightOuter, BuildLeft) | - (LeftSemi | LeftAnti, BuildRight) => true + (LeftSemi | LeftAnti, BuildRight) | (LeftSingle, BuildRight) => true case _ => false } @@ -428,6 +446,7 @@ case class BroadcastNestedLoopJoinExec( (joinType, buildSide) match { case (_: InnerLike, _) => codegenInner(ctx, input) case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => codegenOuter(ctx, input) + case (LeftSingle, BuildRight) => codegenOuter(ctx, input) case (LeftSemi, BuildRight) => codegenLeftExistence(ctx, input, exists = true) case (LeftAnti, BuildRight) => codegenLeftExistence(ctx, input, exists = false) case _ => @@ -473,7 +492,9 @@ case class BroadcastNestedLoopJoinExec( """.stripMargin } - private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { + private def codegenOuter( + ctx: CodegenContext, + input: Seq[ExprCode]): String = { val (buildRowArray, buildRowArrayTerm) = prepareBroadcast(ctx) val (buildRow, checkCondition, _) = getJoinCondition(ctx, input, streamed, broadcast) val buildVars = genOneSideJoinVars(ctx, buildRow, broadcast, setDefaultValue = true) @@ -494,12 +515,23 @@ case class BroadcastNestedLoopJoinExec( |${consume(ctx, resultVars)} """.stripMargin } else { + // For LeftSingle joins, generate the check on the number of matches. + val evaluateSingleCheck = if (joinType == LeftSingle) { + s""" + |if ($foundMatch) { + | throw QueryExecutionErrors.scalarSubqueryReturnsMultipleRows(); + |} + |""".stripMargin + } else { + "" + } s""" |boolean $foundMatch = false; |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) { | UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex]; | boolean $shouldOutputRow = false; | $checkCondition { + | $evaluateSingleCheck | $shouldOutputRow = true; | $foundMatch = true; | } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 5d59a48d544a0..ce7d48babc91e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, RowIterator} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.{BooleanType, IntegralType, LongType} @@ -52,7 +53,7 @@ trait HashJoin extends JoinCodegenSupport { joinType match { case _: InnerLike => left.output ++ right.output - case LeftOuter => + case LeftOuter | LeftSingle => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output @@ -75,7 +76,7 @@ trait HashJoin extends JoinCodegenSupport { } case BuildRight => joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => + case _: InnerLike | LeftOuter | LeftSingle | LeftSemi | LeftAnti | _: ExistenceJoin => left.outputPartitioning case x => throw new IllegalArgumentException( @@ -93,7 +94,7 @@ trait HashJoin extends JoinCodegenSupport { } case BuildRight => joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => + case _: InnerLike | LeftOuter | LeftSingle | LeftSemi | LeftAnti | _: ExistenceJoin => left.outputOrdering case x => throw new IllegalArgumentException( @@ -191,7 +192,8 @@ trait HashJoin extends JoinCodegenSupport { private def outerJoin( streamedIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = { + hashedRelation: HashedRelation, + singleJoin: Boolean = false): Iterator[InternalRow] = { val joinedRow = new JoinedRow() val keyGenerator = streamSideKeyGenerator() val nullRow = new GenericInternalRow(buildPlan.output.length) @@ -218,6 +220,9 @@ trait HashJoin extends JoinCodegenSupport { while (buildIter != null && buildIter.hasNext) { val nextBuildRow = buildIter.next() if (boundCondition(joinedRow.withRight(nextBuildRow))) { + if (found && singleJoin) { + throw QueryExecutionErrors.scalarSubqueryReturnsMultipleRows(); + } found = true return true } @@ -329,6 +334,8 @@ trait HashJoin extends JoinCodegenSupport { innerJoin(streamedIter, hashed) case LeftOuter | RightOuter => outerJoin(streamedIter, hashed) + case LeftSingle => + outerJoin(streamedIter, hashed, singleJoin = true) case LeftSemi => semiJoin(streamedIter, hashed) case LeftAnti => @@ -354,7 +361,7 @@ trait HashJoin extends JoinCodegenSupport { override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { joinType match { case _: InnerLike => codegenInner(ctx, input) - case LeftOuter | RightOuter => codegenOuter(ctx, input) + case LeftOuter | RightOuter | LeftSingle => codegenOuter(ctx, input) case LeftSemi => codegenSemi(ctx, input) case LeftAnti => codegenAnti(ctx, input) case _: ExistenceJoin => codegenExistence(ctx, input) @@ -492,6 +499,17 @@ trait HashJoin extends JoinCodegenSupport { val matches = ctx.freshName("matches") val iteratorCls = classOf[Iterator[UnsafeRow]].getName val found = ctx.freshName("found") + // For LeftSingle joins generate the check on the number of build rows that match every + // probe row. Return an error for >1 matches. + val evaluateSingleCheck = if (joinType == LeftSingle) { + s""" + |if ($found) { + | throw QueryExecutionErrors.scalarSubqueryReturnsMultipleRows(); + |} + |""".stripMargin + } else { + "" + } s""" |// generate join key for stream side @@ -505,6 +523,7 @@ trait HashJoin extends JoinCodegenSupport { | (UnsafeRow) $matches.next() : null; | ${checkCondition.trim} | if ($conditionPassed) { + | $evaluateSingleCheck | $found = true; | $numOutput.add(1); | ${consume(ctx, resultVars)} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala index 7c4628c8576c5..60e5a7769a503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, LeftExistence, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, LeftExistence, LeftOuter, LeftSingle, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning, PartitioningCollection, UnknownPartitioning, UnspecifiedDistribution} /** @@ -47,7 +47,7 @@ trait ShuffledJoin extends JoinCodegenSupport { override def outputPartitioning: Partitioning = joinType match { case _: InnerLike => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) - case LeftOuter => left.outputPartitioning + case LeftOuter | LeftSingle => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) case LeftExistence(_) => left.outputPartitioning @@ -60,7 +60,7 @@ trait ShuffledJoin extends JoinCodegenSupport { joinType match { case _: InnerLike => left.output ++ right.output - case LeftOuter => + case LeftOuter | LeftSingle => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out index bea91e09b0053..01de7beda551d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out @@ -142,6 +142,12 @@ Project [x1#x, x2#x, scalar-subquery#x [x1#x && x2#x] AS scalarsubquery(x1, x2)# +- LocalRelation [col1#x, col2#x] +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false +-- !query analysis +SetCommand (spark.sql.optimizer.scalarSubqueryUseSingleJoin,Some(false)) + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query analysis @@ -202,24 +208,83 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = true +-- !query analysis +SetCommand (spark.sql.optimizer.scalarSubqueryUseSingleJoin,Some(true)) + + +-- !query +select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 +-- !query analysis +Project [x1#x, x2#x] ++- Filter (scalar-subquery#x [x1#x] = cast(1 as bigint)) + : +- Aggregate [y1#x], [count(1) AS count(1)#xL] + : +- Filter (y1#x > outer(x1#x)) + : +- SubqueryAlias y + : +- View (`y`, [y1#x, y2#x]) + : +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x +-- !query analysis +Project [x1#x, x2#x, scalar-subquery#x [x1#x] AS scalarsubquery(x1)#xL] +: +- Aggregate [y1#x], [count(1) AS count(1)#xL] +: +- Filter ((y1#x + y2#x) = outer(x1#x)) +: +- SubqueryAlias y +: +- View (`y`, [y1#x, y2#x]) +: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x +-- !query analysis +Project [x1#x, x2#x, scalar-subquery#x [x1#x && x1#x] AS scalarsubquery(x1, x1)#xL] +: +- Aggregate [y2#x], [count(1) AS count(1)#xL] +: +- Filter ((outer(x1#x) = y1#x) AND ((y2#x + 10) = (outer(x1#x) + 1))) +: +- SubqueryAlias y +: +- View (`y`, [y1#x, y2#x]) +: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] + + -- !query select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x -- !query analysis -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", - "sqlState" : "0A000", - "messageParameters" : { - "value" : "y1" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 11, - "stopIndex" : 106, - "fragment" : "(select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1)" - } ] -} +Project [x1#x, x2#x, scalar-subquery#x [x1#x] AS scalarsubquery(x1)#xL] +: +- Aggregate [y1#x], [count(1) AS count(1)#xL] +: +- SubqueryAlias sub +: +- Union false, false +: :- Project [y1#x, y2#x] +: : +- Filter (y1#x = outer(x1#x)) +: : +- SubqueryAlias y +: : +- View (`y`, [y1#x, y2#x]) +: : +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: : +- LocalRelation [col1#x, col2#x] +: +- Project [y1#x, y2#x] +: +- SubqueryAlias y +: +- View (`y`, [y1#x, y2#x]) +: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] -- !query @@ -227,17 +292,17 @@ select *, (select count(*) from y left join (select * from z where z1 = x1) sub -- !query analysis org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED", "sqlState" : "0A000", "messageParameters" : { - "value" : "z1" + "treeNode" : "Filter (z1#x = outer(x1#x))\n+- SubqueryAlias z\n +- View (`z`, [z1#x, z2#x])\n +- Project [cast(col1#x as int) AS z1#x, cast(col2#x as int) AS z2#x]\n +- LocalRelation [col1#x, col2#x]\n" }, "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 11, - "stopIndex" : 103, - "fragment" : "(select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1)" + "startIndex" : 46, + "stopIndex" : 74, + "fragment" : "select * from z where z1 = x1" } ] } @@ -248,6 +313,12 @@ set spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate = SetCommand (spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate,Some(true)) +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false +-- !query analysis +SetCommand (spark.sql.optimizer.scalarSubqueryUseSingleJoin,Some(false)) + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index e3ce85fe5d209..4ff0222d6e965 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -1748,3 +1748,21 @@ Project [t1a#x, t1b#x, t1c#x] +- View (`t1`, [t1a#x, t1b#x, t1c#x]) +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x as int) AS t1b#x, cast(col3#x as int) AS t1c#x] +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +SELECT * FROM t0 WHERE t0a = (SELECT distinct(t1c) FROM t1 WHERE t1a = t0a) +-- !query analysis +Project [t0a#x, t0b#x] ++- Filter (t0a#x = scalar-subquery#x [t0a#x]) + : +- Distinct + : +- Project [t1c#x] + : +- Filter (t1a#x = outer(t0a#x)) + : +- SubqueryAlias t1 + : +- View (`t1`, [t1a#x, t1b#x, t1c#x]) + : +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x as int) AS t1b#x, cast(col3#x as int) AS t1c#x] + : +- LocalRelation [col1#x, col2#x, col3#x] + +- SubqueryAlias t0 + +- View (`t0`, [t0a#x, t0b#x]) + +- Project [cast(col1#x as int) AS t0a#x, cast(col2#x as int) AS t0b#x] + +- LocalRelation [col1#x, col2#x] diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql index db7cdc97614cb..a23083e9e0e4d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql @@ -22,16 +22,25 @@ select *, (select count(*) from y where x1 = y1 and cast(y2 as double) = x1 + 1 select *, (select count(*) from y where y2 + 1 = x1 + x2 group by y2 + 1) from x; --- Illegal queries +-- Illegal queries (single join disabled) +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false; select * from x where (select count(*) from y where y1 > x1 group by y1) = 1; select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x; select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x; +-- Same queries, with LeftSingle join +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = true; +select * from x where (select count(*) from y where y1 > x1 group by y1) = 1; +select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x; +select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x; + + -- Certain other operators like OUTER JOIN or UNION between the correlating filter and the group-by also can cause the scalar subquery to return multiple values and hence make the query illegal. select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x; select *, (select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1) from x; -- The correlation below the join is unsupported in Spark anyway, but when we do support it this query should still be disallowed. -- Test legacy behavior conf set spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate = true; +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false; select * from x where (select count(*) from y where y1 > x1 group by y1) = 1; reset spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql index 2823888e6e438..81e0c5f98d82b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql @@ -529,3 +529,6 @@ FROM t1 WHERE (SELECT max(t2c) FROM t2 WHERE t1b = t2b ) between 1 and 2; + + +SELECT * FROM t0 WHERE t0a = (SELECT distinct(t1c) FROM t1 WHERE t1a = t0a); diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out index 41cba1f43745f..56932edd4e545 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out @@ -112,6 +112,14 @@ struct 2 2 NULL +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false +-- !query schema +struct +-- !query output +spark.sql.optimizer.scalarSubqueryUseSingleJoin false + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query schema @@ -178,25 +186,56 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = true +-- !query schema +struct +-- !query output +spark.sql.optimizer.scalarSubqueryUseSingleJoin true + + +-- !query +select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkRuntimeException +{ + "errorClass" : "SCALAR_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" +} + + +-- !query +select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkRuntimeException +{ + "errorClass" : "SCALAR_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" +} + + +-- !query +select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x +-- !query schema +struct +-- !query output +1 1 NULL +2 2 NULL + + -- !query select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.ExtendedAnalysisException +org.apache.spark.SparkRuntimeException { - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", - "sqlState" : "0A000", - "messageParameters" : { - "value" : "y1" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 11, - "stopIndex" : 106, - "fragment" : "(select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1)" - } ] + "errorClass" : "SCALAR_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" } @@ -207,17 +246,17 @@ struct<> -- !query output org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED", "sqlState" : "0A000", "messageParameters" : { - "value" : "z1" + "treeNode" : "Filter (z1#x = outer(x1#x))\n+- SubqueryAlias z\n +- View (`z`, [z1#x, z2#x])\n +- Project [cast(col1#x as int) AS z1#x, cast(col2#x as int) AS z2#x]\n +- LocalRelation [col1#x, col2#x]\n" }, "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 11, - "stopIndex" : 103, - "fragment" : "(select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1)" + "startIndex" : 46, + "stopIndex" : 74, + "fragment" : "select * from z where z1 = x1" } ] } @@ -230,6 +269,14 @@ struct spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate true +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false +-- !query schema +struct +-- !query output +spark.sql.optimizer.scalarSubqueryUseSingleJoin false + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index a02f0c70be6da..2460c2452ea56 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -906,3 +906,11 @@ WHERE (SELECT max(t2c) struct -- !query output + + +-- !query +SELECT * FROM t0 WHERE t0a = (SELECT distinct(t1c) FROM t1 WHERE t1a = t0a) +-- !query schema +struct +-- !query output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 9afba65183974..a892cd4db02b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.scalactic.source.Position import org.scalatest.Tag +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpressionSet} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.Aggregate @@ -554,7 +555,15 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { | FROM (SELECT dept * 2.0 AS id, id + 1 AS id2 FROM $testTable)) > 5 |ORDER BY id |""".stripMargin - withLCAOff { intercept[AnalysisException] { sql(query4) } } + withLCAOff { + val exception = intercept[SparkRuntimeException] { + sql(query4).collect() + } + checkError( + exception, + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS" + ) + } withLCAOn { val analyzedPlan = sql(query4).queryExecution.analyzed assert(!analyzedPlan.containsPattern(OUTER_REFERENCE)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 23c4d51983bb4..6e160b4407ca8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.collection.mutable.ArrayBuffer +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan, Project, Sort, Union} @@ -527,43 +528,30 @@ class SubquerySuite extends QueryTest test("SPARK-18504 extra GROUP BY column in correlated scalar subquery is not permitted") { withTempView("v") { Seq((1, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("v") - - val exception = intercept[AnalysisException] { - sql("select (select sum(-1) from v t2 where t1.c2 = t2.c1 group by t2.c2) sum from v t1") + val exception = intercept[SparkRuntimeException] { + sql("select (select sum(-1) from v t2 where t1.c2 = t2.c1 group by t2.c2) sum from v t1"). + collect() } checkError( exception, - condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "NON_CORRELATED_COLUMNS_IN_GROUP_BY", - parameters = Map("value" -> "c2"), - sqlState = None, - context = ExpectedContext( - fragment = "(select sum(-1) from v t2 where t1.c2 = t2.c1 group by t2.c2)", - start = 7, stop = 67)) } + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS" + ) + } } test("non-aggregated correlated scalar subquery") { - val exception1 = intercept[AnalysisException] { - sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1") + val exception1 = intercept[SparkRuntimeException] { + sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1").collect() } checkError( exception1, - condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY", - parameters = Map.empty, - context = ExpectedContext( - fragment = "(select b from l l2 where l2.a = l1.a)", start = 10, stop = 47)) - val exception2 = intercept[AnalysisException] { - sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1") - } - checkErrorMatchPVals( - exception2, - condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY", - parameters = Map.empty[String, String], - sqlState = None, - context = ExpectedContext( - fragment = "(select b from l l2 where l2.a = l1.a group by 1)", start = 10, stop = 58)) + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS" + ) + checkAnswer( + sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: + Row(null, null) :: Row(null, null) :: Row(6, null) :: Nil + ) } test("non-equal correlated scalar subquery") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SingleJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SingleJoinSuite.scala new file mode 100644 index 0000000000000..a318769af6871 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SingleJoinSuite.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.SparkRuntimeException +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.BuildRight +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint, Project} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} + +class SingleJoinSuite extends SparkPlanTest with SharedSparkSession { + import testImplicits.toRichColumn + + private val EnsureRequirements = new EnsureRequirements() + + private lazy val left = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + // (a > c && a != 6) + + private lazy val right = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(4, 2.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val singleConditionEQ = EqualTo(left.col("a").expr, right.col("c").expr) + + private lazy val nonEqualityCond = And(GreaterThan(left.col("a").expr, right.col("c").expr), + Not(EqualTo(left.col("a").expr, Literal(6)))) + + + + private def testSingleJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: => Option[Expression], + expectedAnswer: Seq[Row], + expectError: Boolean = false): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, + Inner, condition, JoinHint.NONE) + ExtractEquiJoinKeys.unapply(join) + } + + def checkSingleJoinError(planFunction: (SparkPlan, SparkPlan) => SparkPlan): Unit = { + val outputPlan = planFunction(leftRows.queryExecution.sparkPlan, + rightRows.queryExecution.sparkPlan) + checkError( + exception = intercept[SparkRuntimeException] { + SparkPlanTest.executePlan(outputPlan, spark.sqlContext) + }, + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS", + parameters = Map.empty + ) + } + + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin") { _ => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) => + val planFunction = (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply(BroadcastHashJoinExec( + leftKeys, rightKeys, LeftSingle, BuildRight, boundCondition, left, right)) + if (expectError) { + checkSingleJoinError(planFunction) + } else { + checkAnswer2(leftRows, rightRows, planFunction, + expectedAnswer, + sortAnswers = true) + } + } + } + testWithWholeStageCodegenOnAndOff(s"$testName using ShuffledHashJoin") { _ => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) => + val planFunction = (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply( + ShuffledHashJoinExec( + leftKeys, rightKeys, LeftSingle, BuildRight, boundCondition, left, right)) + if (expectError) { + checkSingleJoinError(planFunction) + } else { + checkAnswer2(leftRows, rightRows, planFunction, + expectedAnswer, + sortAnswers = true) + } + } + } + + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin") { _ => + val planFunction = (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply( + BroadcastNestedLoopJoinExec(left, right, BuildRight, LeftSingle, condition)) + if (expectError) { + checkSingleJoinError(planFunction) + } else { + checkAnswer2(leftRows, rightRows, planFunction, + expectedAnswer, + sortAnswers = true) + } + } + } + + testSingleJoin( + "test single condition (equal) for a left single join", + left, + Project(Seq(right.col("c").expr.asInstanceOf[NamedExpression]), right.logicalPlan), + Some(singleConditionEQ), + Seq(Row(1, 2.0, null), + Row(1, 2.0, null), + Row(2, 1.0, 2), + Row(2, 1.0, 2), + Row(3, 3.0, 3), + Row(6, null, 6), + Row(null, 5.0, null), + Row(null, null, null))) + + testSingleJoin( + "test single condition (equal) for a left single join -- multiple matches", + left, + Project(Seq(right.col("d").expr.asInstanceOf[NamedExpression]), right.logicalPlan), + Some(EqualTo(left.col("b").expr, right.col("d").expr)), + Seq.empty, true) + + testSingleJoin( + "test non-equality for a left single join", + left, + Project(Seq(right.col("c").expr.asInstanceOf[NamedExpression]), right.logicalPlan), + Some(nonEqualityCond), + Seq(Row(1, 2.0, null), + Row(1, 2.0, null), + Row(2, 1.0, null), + Row(2, 1.0, null), + Row(3, 3.0, 2), + Row(6, null, null), + Row(null, 5.0, null), + Row(null, null, null))) + + testSingleJoin( + "test non-equality for a left single join -- multiple matches", + left, + Project(Seq(right.col("c").expr.asInstanceOf[NamedExpression]), right.logicalPlan), + Some(GreaterThan(left.col("a").expr, right.col("c").expr)), + Seq.empty, expectError = true) + + private lazy val emptyFrame = spark.createDataFrame( + spark.sparkContext.emptyRDD[Row], new StructType().add("c", IntegerType).add("d", DoubleType)) + + testSingleJoin( + "empty inner (right) side", + left, + Project(Seq(emptyFrame.col("c").expr.asInstanceOf[NamedExpression]), emptyFrame.logicalPlan), + Some(GreaterThan(left.col("a").expr, emptyFrame.col("c").expr)), + Seq(Row(1, 2.0, null), + Row(1, 2.0, null), + Row(2, 1.0, null), + Row(2, 1.0, null), + Row(3, 3.0, null), + Row(6, null, null), + Row(null, 5.0, null), + Row(null, null, null))) + + testSingleJoin( + "empty outer (left) side", + Project(Seq(emptyFrame.col("c").expr.asInstanceOf[NamedExpression]), emptyFrame.logicalPlan), + right, + Some(EqualTo(emptyFrame.col("c").expr, right.col("c").expr)), + Seq.empty) +}