@@ -1206,7 +1206,7 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
1206
1206
newPythonUDFEvalTypesInUpperProjects,
1207
1207
pythonUDFArrowFallbackOnUDT)) match {
1208
1208
case p1 @ Project (_, p2 : Project ) =>
1209
- mergeProjectExpressions (
1209
+ mergeExpressions (
1210
1210
p1.projectList,
1211
1211
p2.projectList,
1212
1212
alwaysInline,
@@ -1216,16 +1216,17 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
1216
1216
case (newUpper, newLower) =>
1217
1217
p1.copy(projectList = newUpper, child = p2.copy(projectList = newLower))
1218
1218
}
1219
- case p @ Project (_, agg : Aggregate )
1220
- if canCollapseExpressions(
1221
- p.projectList,
1222
- getAliasMap(agg.aggregateExpressions),
1223
- alwaysInline,
1224
- newPythonUDFEvalTypesInUpperProjects,
1225
- pythonUDFArrowFallbackOnUDT)
1226
- && canCollapseAggregate(p, agg) =>
1227
- agg.copy(aggregateExpressions = buildCleanedProjectList(
1228
- p.projectList, agg.aggregateExpressions))
1219
+ case p @ Project (_, agg : Aggregate ) if canCollapseAggregate(p, agg) =>
1220
+ mergeExpressions(
1221
+ p.projectList,
1222
+ agg.aggregateExpressions,
1223
+ alwaysInline,
1224
+ newPythonUDFEvalTypesInUpperProjects,
1225
+ pythonUDFArrowFallbackOnUDT) match {
1226
+ case (Seq (), merged) => agg.copy(aggregateExpressions = merged)
1227
+ case (newUpper, newLower) =>
1228
+ p.copy(projectList = newUpper, child = agg.copy(aggregateExpressions = newLower))
1229
+ }
1229
1230
case Project (l1, g @ GlobalLimit (_, limit @ LocalLimit (_, p2 @ Project (l2, _))))
1230
1231
if isRenaming(l1, l2) =>
1231
1232
val newProjectList = buildCleanedProjectList(l1, l2)
@@ -1275,7 +1276,7 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
1275
1276
case other => isCheap(other)
1276
1277
}
1277
1278
1278
- private def mergeProjectExpressions (
1279
+ private def mergeExpressions (
1279
1280
consumers : Seq [NamedExpression ],
1280
1281
producers : Seq [NamedExpression ],
1281
1282
alwaysInline : Boolean ,
@@ -1291,9 +1292,19 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
1291
1292
.groupMap(_._1)(_._2)
1292
1293
.view.mapValues(v => v.size -> ExpressionSet (v)))
1293
1294
1295
+ // Split the producers from the lower node to 4 categories:
1296
+ // - `neverInlines` contains producer expressions that shouldn't be inlined.
1297
+ // These include non-deterministic expressions or expensive ones that are referenced multiple
1298
+ // times.
1299
+ // - `mustInlines` contains expressions with Python UDFs that must be inlined into the upper
1300
+ // node to avoid performance issues, or expressions with aggregate nodes.
1301
+ // - `maybeInlines` contains expressions that might make sense to inline, such as expressions
1302
+ // that are used only once, or are cheap to inline.
1303
+ // But we need to take into account the side effect of adding new pass-through attributes to
1304
+ // the lover node, which can make the node much wider than it was originally.
1305
+ val neverInlines = ListBuffer .empty[NamedExpression ]
1294
1306
val mustInlines = ListBuffer .empty[NamedExpression ]
1295
1307
val maybeInlines = ListBuffer .empty[NamedExpression ]
1296
- val neverInlines = ListBuffer .empty[NamedExpression ]
1297
1308
val others = ListBuffer .empty[NamedExpression ]
1298
1309
producers.foreach {
1299
1310
case a : Alias =>
@@ -1306,7 +1317,7 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
1306
1317
case _ => false
1307
1318
}
1308
1319
1309
- if (! a.deterministic) {
1320
+ if (! a.child. deterministic || AggregateExpression .containsAggregate(a.child) ) {
1310
1321
neverInlines += a
1311
1322
} else if (alwaysInline || containsUDF) {
1312
1323
mustInlines += a
@@ -1321,11 +1332,15 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
1321
1332
}
1322
1333
1323
1334
if (neverInlines.isEmpty) {
1335
+ // If `neverInlines` is empty then we can collapse the nodes into one.
1324
1336
(Seq .empty, buildCleanedProjectList(consumers, producers))
1325
1337
} else if (mustInlines.isEmpty) {
1326
- // Let's keep `maybeInlines` in the lower node for now.
1338
+ // Otherwise we can't collapse the nodes into one, but if `mustInlines` is empty then we can
1339
+ // keep `maybeInlines` in the lower node for now, so there is no change to the nodes.
1327
1340
(consumers, producers)
1328
1341
} else {
1342
+ // If both `neverInlines` and `mustInlines` are not empty, then inline `mustInlines` and add
1343
+ // new pass-through attributes to the lower node.
1329
1344
val newConsumers = buildCleanedProjectList(consumers, mustInlines)
1330
1345
val passthroughAttributes = AttributeSet (others.flatMap(_.references))
1331
1346
val newPassthroughAttributes =
0 commit comments