From de925b87284b63fb3482a08cbde452f7bd5ed226 Mon Sep 17 00:00:00 2001 From: Avery Qi Date: Wed, 7 May 2025 17:58:02 -0700 Subject: [PATCH] use boolean flag instead of adding a new argument for SubqueryExpression --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../analysis/ValidateSubqueryExpression.scala | 3 +- .../ExpressionResolutionValidator.scala | 6 +- .../resolver/SubqueryExpressionResolver.scala | 6 +- .../catalyst/expressions/DynamicPruning.scala | 8 +- ...ctionTableSubqueryArgumentExpression.scala | 25 ++-- .../sql/catalyst/expressions/subquery.scala | 107 ++++++++++++++---- ...PullOutNestedDataOuterRefExpressions.scala | 2 +- .../RemoveRedundantAliasAndProjectSuite.scala | 6 +- 9 files changed, 119 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7b0f3e37f9649..fb074fe8ff4e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2326,7 +2326,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor private def resolveSubQuery( e: SubqueryExpression, outer: LogicalPlan)( - f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { + f: (LogicalPlan, Seq[(Expression, Boolean)]) => SubqueryExpression): SubqueryExpression = { val newSubqueryPlan = AnalysisContext.withOuterPlan(outer) { executeSameContext(e.plan) } @@ -2335,7 +2335,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // them as children of SubqueryExpression. if (newSubqueryPlan.resolved) { // Record the outer references as children of subquery expression. - f(newSubqueryPlan, SubExprUtils.getOuterReferences(newSubqueryPlan)) + f(newSubqueryPlan, SubExprUtils.getOuterReferences(newSubqueryPlan).map((_, true))) } else { e.withNewPlan(newSubqueryPlan) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ValidateSubqueryExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ValidateSubqueryExpression.scala index d6b7a4dccb907..318a115c6dcf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ValidateSubqueryExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ValidateSubqueryExpression.scala @@ -129,7 +129,8 @@ object ValidateSubqueryExpression checkOuterReference(plan, expr) expr match { - case ScalarSubquery(query, outerAttrs, _, _, _, _, _) => + case ScalarSubquery(query, rawOuterAttrs, _, _, _, _, _) => + val outerAttrs = rawOuterAttrs.map(_._1) // Scalar subquery must return one column as output. if (query.output.size != 1) { throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala index e0508e924678a..e23c44ab8be4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala @@ -140,7 +140,7 @@ class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { resolutionValidator.validate(scalarSubquery.plan) } - for (outerAttribute <- scalarSubquery.outerAttrs) { + for (outerAttribute <- scalarSubquery.getOuterAttrs) { validate(outerAttribute) } @@ -163,7 +163,7 @@ class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { resolutionValidator.validate(listQuery.plan) } - for (outerAttribute <- listQuery.outerAttrs) { + for (outerAttribute <- listQuery.getOuterAttrs) { validate(outerAttribute) } } @@ -173,7 +173,7 @@ class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { resolutionValidator.validate(exists.plan) } - for (outerAttribute <- exists.outerAttrs) { + for (outerAttribute <- exists.getOuterAttrs) { validate(outerAttribute) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SubqueryExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SubqueryExpressionResolver.scala index c36024e7269e0..1296b8f887d5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SubqueryExpressionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SubqueryExpressionResolver.scala @@ -63,7 +63,7 @@ class SubqueryExpressionResolver(expressionResolver: ExpressionResolver, resolve val resolvedScalarSubquery = unresolvedScalarSubquery.copy( plan = resolvedSubqueryExpressionPlan.plan, - outerAttrs = resolvedSubqueryExpressionPlan.outerExpressions + outerAttrs = resolvedSubqueryExpressionPlan.outerExpressions.map((_, true)) ) val coercedScalarSubquery = @@ -108,7 +108,7 @@ class SubqueryExpressionResolver(expressionResolver: ExpressionResolver, resolve unresolvedListQuery.copy( plan = resolvedSubqueryExpressionPlan.plan, - outerAttrs = resolvedSubqueryExpressionPlan.outerExpressions, + outerAttrs = resolvedSubqueryExpressionPlan.outerExpressions.map((_, true)), numCols = resolvedSubqueryExpressionPlan.output.size ) } @@ -125,7 +125,7 @@ class SubqueryExpressionResolver(expressionResolver: ExpressionResolver, resolve val resolvedExists = unresolvedExists.copy( plan = resolvedSubqueryExpressionPlan.plan, - outerAttrs = resolvedSubqueryExpressionPlan.outerExpressions + outerAttrs = resolvedSubqueryExpressionPlan.outerExpressions.map((_, true)) ) val coercedExists = typeCoercionResolver.resolve(resolvedExists).asInstanceOf[Exists] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala index b65576403e9d8..e214502a6c31a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala @@ -47,7 +47,7 @@ case class DynamicPruningSubquery( onlyInBroadcast: Boolean, exprId: ExprId = NamedExpression.newExprId, hint: Option[HintInfo] = None) - extends SubqueryExpression(buildQuery, Seq(pruningKey), exprId, Seq.empty, hint) + extends SubqueryExpression(buildQuery, Seq(pruningKey).map((_, true)), exprId, Seq.empty, hint) with DynamicPruning with Unevaluable with UnaryLike[Expression] { @@ -60,10 +60,12 @@ case class DynamicPruningSubquery( override def withNewPlan(plan: LogicalPlan): DynamicPruningSubquery = copy(buildQuery = plan) - override def withNewOuterAttrs(outerAttrs: Seq[Expression]): DynamicPruningSubquery = { + override def withNewOuterAttrs(outerAttrs: Seq[(Expression, Boolean)]): DynamicPruningSubquery = { // Updating outer attrs of DynamicPruningSubquery is unsupported; assert that they match // pruningKey and return a copy without any changes. - assert(outerAttrs.size == 1 && outerAttrs.head.semanticEquals(pruningKey)) + assert(outerAttrs.size == 1) + val (expr, notOuterScope) = outerAttrs.head + assert(expr.semanticEquals(pruningKey) && notOuterScope) copy() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala index bfd3bc8051dff..fdd255e6c0e7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala @@ -44,8 +44,10 @@ import org.apache.spark.sql.types.DataType * * @param plan the logical plan provided as input for the table argument as either a logical * relation or as a more complex logical plan in the event of a table subquery. - * @param outerAttrs outer references of this subquery plan, generally empty since these table - * arguments do not allow correlated references currently + * @param outerAttrs the outer references in the subquery plan and a boolean flag marking whether + * the outer reference can be resolved in its immediate parent plan or not, + * generally empty since these table arguments do not allow correlated + * references currently. * @param exprId expression ID of this subquery expression, generally generated afresh each time * @param partitionByExpressions if non-empty, the TABLE argument included the PARTITION BY clause * to indicate that the input relation should be repartitioned by the @@ -66,7 +68,7 @@ import org.apache.spark.sql.types.DataType */ case class FunctionTableSubqueryArgumentExpression( plan: LogicalPlan, - outerAttrs: Seq[Expression] = Seq.empty, + outerAttrs: Seq[(Expression, Boolean)] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, partitionByExpressions: Seq[Expression] = Seq.empty, withSinglePartition: Boolean = false, @@ -78,19 +80,26 @@ case class FunctionTableSubqueryArgumentExpression( "WITH SINGLE PARTITION is mutually exclusive with PARTITION BY") override def dataType: DataType = plan.schema + override def nullable: Boolean = false + override def withNewPlan(plan: LogicalPlan): FunctionTableSubqueryArgumentExpression = copy(plan = plan) - override def withNewOuterAttrs(outerAttrs: Seq[Expression]) + + override def withNewOuterAttrs(outerAttrs: Seq[(Expression, Boolean)]) : FunctionTableSubqueryArgumentExpression = copy(outerAttrs = outerAttrs) + override def hint: Option[HintInfo] = None + override def withNewHint(hint: Option[HintInfo]): FunctionTableSubqueryArgumentExpression = copy() + override def toString: String = s"table-argument#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { FunctionTableSubqueryArgumentExpression( plan.canonicalized, - outerAttrs.map(_.canonicalized), + outerAttrs.map {case (expr, b) => (expr.canonicalized, b)}, ExprId(0), partitionByExpressions, withSinglePartition, @@ -98,8 +107,10 @@ case class FunctionTableSubqueryArgumentExpression( } override protected def withNewChildrenInternal( - newChildren: IndexedSeq[Expression]): FunctionTableSubqueryArgumentExpression = - copy(outerAttrs = newChildren) + newChildren: IndexedSeq[Expression]): FunctionTableSubqueryArgumentExpression = { + val newOuterAttrs = newChildren.take(outerAttrs.size).zip(outerAttrs.map(_._2)) + copy(outerAttrs = newOuterAttrs) + } final override def nodePatternsInternal(): Seq[TreePattern] = Seq(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 210b7f8fb5306..403eb932382ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -66,7 +66,8 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { * A base interface for expressions that contain a [[LogicalPlan]]. * * @param plan: the subquery plan - * @param outerAttrs: the outer references in the subquery plan + * @param outerAttrs: the outer references in the subquery plan and a boolean flag marking whether + * the outer reference can be resolved in its immediate parent plan or not. * @param exprId: ID of the expression * @param joinCond: the join conditions with the outer query. It contains both inner and outer * query references. @@ -75,18 +76,35 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { */ abstract class SubqueryExpression( plan: LogicalPlan, - outerAttrs: Seq[Expression], + outerAttrs: Seq[(Expression, Boolean)], exprId: ExprId, joinCond: Seq[Expression], hint: Option[HintInfo]) extends PlanExpression[LogicalPlan] { + override lazy val resolved: Boolean = childrenResolved && plan.resolved + + lazy val outerScopeAttrsWithoutFlags = outerAttrs.filter(!_._2).map(_._1) + + lazy val outerAttrsWithoutFlags = outerAttrs.map(_._1) + override lazy val references: AttributeSet = - AttributeSet.fromAttributeSets(outerAttrs.map(_.references)) - override def children: Seq[Expression] = outerAttrs ++ joinCond + AttributeSet.fromAttributeSets(outerAttrsWithoutFlags.map(_.references)) -- + AttributeSet.fromAttributeSets(outerScopeAttrsWithoutFlags.map(_.references)) + + override def children: Seq[Expression] = outerAttrsWithoutFlags ++ joinCond + override def withNewPlan(plan: LogicalPlan): SubqueryExpression - def withNewOuterAttrs(outerAttrs: Seq[Expression]): SubqueryExpression + + def withNewOuterAttrs(outerAttrs: Seq[(Expression, Boolean)]): SubqueryExpression + + def getOuterAttrs: Seq[Expression] = outerAttrsWithoutFlags + + def getOuterScopeAttrs: Seq[Expression] = outerScopeAttrsWithoutFlags + def isCorrelated: Boolean = outerAttrs.nonEmpty + def hint: Option[HintInfo] + def withNewHint(hint: Option[HintInfo]): SubqueryExpression } @@ -394,13 +412,14 @@ object SubExprUtils extends PredicateHelper { */ case class ScalarSubquery( plan: LogicalPlan, - outerAttrs: Seq[Expression] = Seq.empty, + outerAttrs: Seq[(Expression, Boolean)] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None, mayHaveCountBug: Option[Boolean] = None, needSingleJoin: Option[Boolean] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { + override def dataType: DataType = { if (!plan.schema.fields.nonEmpty) { throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(plan.schema.fields.length, @@ -408,25 +427,33 @@ case class ScalarSubquery( } plan.schema.fields.head.dataType } + override def nullable: Boolean = true + override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan) - override def withNewOuterAttrs(outerAttrs: Seq[Expression]): ScalarSubquery = copy( + + override def withNewOuterAttrs(outerAttrs: Seq[(Expression, Boolean)]): ScalarSubquery = copy( outerAttrs = outerAttrs) + override def withNewHint(hint: Option[HintInfo]): ScalarSubquery = copy(hint = hint) + override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { ScalarSubquery( plan.canonicalized, - outerAttrs.map(_.canonicalized), + outerAttrs.map {case (expr, b) => (expr.canonicalized, b)}, ExprId(0), joinCond.map(_.canonicalized)) } override protected def withNewChildrenInternal( - newChildren: IndexedSeq[Expression]): ScalarSubquery = + newChildren: IndexedSeq[Expression]): ScalarSubquery = { + val newOuterAttrs = newChildren.take(outerAttrs.size).zip(outerAttrs.map(_._2)) copy( - outerAttrs = newChildren.take(outerAttrs.size), + outerAttrs = newOuterAttrs, joinCond = newChildren.drop(outerAttrs.size)) + } final override def nodePatternsInternal(): Seq[TreePattern] = Seq(SCALAR_SUBQUERY) } @@ -473,31 +500,40 @@ case class UnresolvedTableArgPlanId( */ case class LateralSubquery( plan: LogicalPlan, - outerAttrs: Seq[Expression] = Seq.empty, + outerAttrs: Seq[(Expression, Boolean)] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { + override def dataType: DataType = plan.output.toStructType + override def nullable: Boolean = true + override def withNewPlan(plan: LogicalPlan): LateralSubquery = copy(plan = plan) - override def withNewOuterAttrs(outerAttrs: Seq[Expression]): LateralSubquery = copy( + + override def withNewOuterAttrs(outerAttrs: Seq[(Expression, Boolean)]): LateralSubquery = copy( outerAttrs = outerAttrs) + override def withNewHint(hint: Option[HintInfo]): LateralSubquery = copy(hint = hint) + override def toString: String = s"lateral-subquery#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { LateralSubquery( plan.canonicalized, - outerAttrs.map(_.canonicalized), + outerAttrs.map {case (expr, b) => (expr.canonicalized, b)}, ExprId(0), joinCond.map(_.canonicalized)) } override protected def withNewChildrenInternal( - newChildren: IndexedSeq[Expression]): LateralSubquery = + newChildren: IndexedSeq[Expression]): LateralSubquery = { + val newOuterAttrs = newChildren.take(outerAttrs.size).zip(outerAttrs.map(_._2)) copy( - outerAttrs = newChildren.take(outerAttrs.size), + outerAttrs = newOuterAttrs, joinCond = newChildren.drop(outerAttrs.size)) + } final override def nodePatternsInternal(): Seq[TreePattern] = Seq(LATERAL_SUBQUERY) } @@ -516,7 +552,7 @@ case class LateralSubquery( */ case class ListQuery( plan: LogicalPlan, - outerAttrs: Seq[Expression] = Seq.empty, + outerAttrs: Seq[(Expression, Boolean)] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, // The plan of list query may have more columns after de-correlation, and we need to track the // number of the columns of the original plan, to report the data type properly. @@ -524,13 +560,17 @@ case class ListQuery( joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { + def childOutputs: Seq[Attribute] = plan.output.take(numCols) + override def dataType: DataType = if (numCols > 1) { childOutputs.toStructType } else { plan.output.head.dataType } + override lazy val resolved: Boolean = childrenResolved && plan.resolved && numCols != -1 + override def nullable: Boolean = { // ListQuery can't be executed alone so its nullability is not defined. // Consider using ListQuery.childOutputs.exists(_.nullable) @@ -540,24 +580,32 @@ case class ListQuery( } false } + override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) - override def withNewOuterAttrs(outerAttrs: Seq[Expression]): ListQuery = copy( + + override def withNewOuterAttrs(outerAttrs: Seq[(Expression, Boolean)]): ListQuery = copy( outerAttrs = outerAttrs) + override def withNewHint(hint: Option[HintInfo]): ListQuery = copy(hint = hint) + override def toString: String = s"list#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { ListQuery( plan.canonicalized, - outerAttrs.map(_.canonicalized), + outerAttrs.map {case (expr, b) => (expr.canonicalized, b)}, ExprId(0), numCols, joinCond.map(_.canonicalized)) } - override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ListQuery = + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): ListQuery = { + val newOuterAttrs = newChildren.take(outerAttrs.size).zip(outerAttrs.map(_._2)) copy( - outerAttrs = newChildren.take(outerAttrs.size), + outerAttrs = newOuterAttrs, joinCond = newChildren.drop(outerAttrs.size)) + } final override def nodePatternsInternal(): Seq[TreePattern] = Seq(LIST_SUBQUERY) } @@ -590,31 +638,40 @@ case class ListQuery( */ case class Exists( plan: LogicalPlan, - outerAttrs: Seq[Expression] = Seq.empty, + outerAttrs: Seq[(Expression, Boolean)] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Predicate with Unevaluable { + override def nullable: Boolean = false + override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) - override def withNewOuterAttrs(outerAttrs: Seq[Expression]): Exists = copy( + + override def withNewOuterAttrs(outerAttrs: Seq[(Expression, Boolean)]): Exists = copy( outerAttrs = outerAttrs) + override def withNewHint(hint: Option[HintInfo]): Exists = copy(hint = hint) + override def toString: String = s"exists#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { Exists( plan.canonicalized, - outerAttrs.map(_.canonicalized), + outerAttrs.map {case (expr, b) => (expr.canonicalized, b)}, ExprId(0), joinCond.map(_.canonicalized)) } - override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Exists = + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Exists = { + val newOuterAttrs = newChildren.take(outerAttrs.size).zip(outerAttrs.map(_._2)) copy( - outerAttrs = newChildren.take(outerAttrs.size), + outerAttrs = newOuterAttrs, joinCond = newChildren.drop(outerAttrs.size)) + } final override def nodePatternsInternal(): Seq[TreePattern] = Seq(EXISTS_SUBQUERY) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutNestedDataOuterRefExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutNestedDataOuterRefExpressions.scala index c46e75c3b37b5..0fca6294eebd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutNestedDataOuterRefExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutNestedDataOuterRefExpressions.scala @@ -113,7 +113,7 @@ object PullOutNestedDataOuterRefExpressions extends Rule[LogicalPlan] { // them from the project. subqueryExpression .withNewPlan(newInnerPlan) - .withNewOuterAttrs(SubExprUtils.getOuterReferences(newInnerPlan)) + .withNewOuterAttrs(SubExprUtils.getOuterReferences(newInnerPlan).map((_, true))) } if (newExprMap.isEmpty) { // Nothing to change diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala index 552a638f6e614..b2fc3825a1fda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala @@ -146,7 +146,7 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest { val query = Filter( Exists( LocalRelation(b), - outerAttrs = Seq(a_alias_attr), + outerAttrs = Seq(a_alias_attr).map((_, true)), joinCond = Seq(EqualTo(a_alias_attr, b)) ), Project(Seq(a_alias), LocalRelation(a)) @@ -162,7 +162,7 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest { val expectedWhenNotExcluded = Filter( Exists( LocalRelation(b), - outerAttrs = Seq(a), + outerAttrs = Seq(a).map((_, true)), joinCond = Seq(EqualTo(a, b)) ), LocalRelation(a) @@ -201,7 +201,7 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest { CaseWhen(Seq(( Exists( LocalRelation(a), - outerAttrs = Seq(a_alias_attr), + outerAttrs = Seq(a_alias_attr).map((_, true)), joinCond = Seq(EqualTo(a_alias_attr, a)) ), Literal(1))), Some(Literal(2))),