Skip to content

Commit 61632fe

Browse files
committed
handle Project+Aggregate with `mergeExpressions(), add more comment.
1 parent e2db127 commit 61632fe

File tree

1 file changed

+30
-15
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer

1 file changed

+30
-15
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
12061206
newPythonUDFEvalTypesInUpperProjects,
12071207
pythonUDFArrowFallbackOnUDT)) match {
12081208
case p1 @ Project(_, p2: Project) =>
1209-
mergeProjectExpressions(
1209+
mergeExpressions(
12101210
p1.projectList,
12111211
p2.projectList,
12121212
alwaysInline,
@@ -1216,16 +1216,17 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
12161216
case (newUpper, newLower) =>
12171217
p1.copy(projectList = newUpper, child = p2.copy(projectList = newLower))
12181218
}
1219-
case p @ Project(_, agg: Aggregate)
1220-
if canCollapseExpressions(
1221-
p.projectList,
1222-
getAliasMap(agg.aggregateExpressions),
1223-
alwaysInline,
1224-
newPythonUDFEvalTypesInUpperProjects,
1225-
pythonUDFArrowFallbackOnUDT)
1226-
&& canCollapseAggregate(p, agg) =>
1227-
agg.copy(aggregateExpressions = buildCleanedProjectList(
1228-
p.projectList, agg.aggregateExpressions))
1219+
case p @ Project(_, agg: Aggregate) if canCollapseAggregate(p, agg) =>
1220+
mergeExpressions(
1221+
p.projectList,
1222+
agg.aggregateExpressions,
1223+
alwaysInline,
1224+
newPythonUDFEvalTypesInUpperProjects,
1225+
pythonUDFArrowFallbackOnUDT) match {
1226+
case (Seq(), merged) => agg.copy(aggregateExpressions = merged)
1227+
case (newUpper, newLower) =>
1228+
p.copy(projectList = newUpper, child = agg.copy(aggregateExpressions = newLower))
1229+
}
12291230
case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _))))
12301231
if isRenaming(l1, l2) =>
12311232
val newProjectList = buildCleanedProjectList(l1, l2)
@@ -1275,7 +1276,7 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
12751276
case other => isCheap(other)
12761277
}
12771278

1278-
private def mergeProjectExpressions(
1279+
private def mergeExpressions(
12791280
consumers: Seq[NamedExpression],
12801281
producers: Seq[NamedExpression],
12811282
alwaysInline: Boolean,
@@ -1291,9 +1292,19 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
12911292
.groupMap(_._1)(_._2)
12921293
.view.mapValues(v => v.size -> ExpressionSet(v)))
12931294

1295+
// Split the producers from the lower node to 4 categories:
1296+
// - `neverInlines` contains producer expressions that shouldn't be inlined.
1297+
// These include non-deterministic expressions or expensive ones that are referenced multiple
1298+
// times.
1299+
// - `mustInlines` contains expressions with Python UDFs that must be inlined into the upper
1300+
// node to avoid performance issues, or expressions with aggregate nodes.
1301+
// - `maybeInlines` contains expressions that might make sense to inline, such as expressions
1302+
// that are used only once, or are cheap to inline.
1303+
// But we need to take into account the side effect of adding new pass-through attributes to
1304+
// the lover node, which can make the node much wider than it was originally.
1305+
val neverInlines = ListBuffer.empty[NamedExpression]
12941306
val mustInlines = ListBuffer.empty[NamedExpression]
12951307
val maybeInlines = ListBuffer.empty[NamedExpression]
1296-
val neverInlines = ListBuffer.empty[NamedExpression]
12971308
val others = ListBuffer.empty[NamedExpression]
12981309
producers.foreach {
12991310
case a: Alias =>
@@ -1306,7 +1317,7 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
13061317
case _ => false
13071318
}
13081319

1309-
if (!a.deterministic) {
1320+
if (!a.child.deterministic || AggregateExpression.containsAggregate(a.child)) {
13101321
neverInlines += a
13111322
} else if (alwaysInline || containsUDF) {
13121323
mustInlines += a
@@ -1321,11 +1332,15 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
13211332
}
13221333

13231334
if (neverInlines.isEmpty) {
1335+
// If `neverInlines` is empty then we can collapse the nodes into one.
13241336
(Seq.empty, buildCleanedProjectList(consumers, producers))
13251337
} else if (mustInlines.isEmpty) {
1326-
// Let's keep `maybeInlines` in the lower node for now.
1338+
// Otherwise we can't collapse the nodes into one, but if `mustInlines` is empty then we can
1339+
// keep `maybeInlines` in the lower node for now, so there is no change to the nodes.
13271340
(consumers, producers)
13281341
} else {
1342+
// If both `neverInlines` and `mustInlines` are not empty, then inline `mustInlines` and add
1343+
// new pass-through attributes to the lower node.
13291344
val newConsumers = buildCleanedProjectList(consumers, mustInlines)
13301345
val passthroughAttributes = AttributeSet(others.flatMap(_.references))
13311346
val newPassthroughAttributes =

0 commit comments

Comments
 (0)