@@ -4,9 +4,7 @@ use datafusion_common::alias::AliasGenerator;
44use datafusion_common:: config:: ConfigOptions ;
55use datafusion_common:: tree_node:: { Transformed , TreeNode , TreeNodeRecursion } ;
66use datafusion_common:: { JoinSide , JoinType , Result } ;
7- use datafusion_expr_common:: signature:: Volatility ;
87use datafusion_physical_expr:: expressions:: Column ;
9- use datafusion_physical_expr:: ScalarFunctionExpr ;
108use datafusion_physical_expr_common:: physical_expr:: PhysicalExpr ;
119use datafusion_physical_plan:: coalesce_batches:: CoalesceBatchesExec ;
1210use datafusion_physical_plan:: joins:: utils:: { ColumnIndex , JoinFilter } ;
@@ -18,9 +16,9 @@ use std::sync::Arc;
1816
1917/// Tries to push down projections from join filters that only depend on one side of the join.
2018///
21- /// This can be a crucial optimization for nested loop joins. By pushing these projections
22- /// down, even functions that only depend on one side of the join must be done for all row
23- /// combinations .
19+ /// This optimization is currently only applied to nested loop joins. By pushing these projections
20+ /// down, functions that only depend on one side of the join must be done for the cartesian product
21+ /// of the two sides .
2422#[ derive( Debug ) ]
2523pub struct NestedLoopJoinProjectionPushDown ;
2624
@@ -234,6 +232,10 @@ fn ensure_batch_size(
234232}
235233
236234/// Creates a new [JoinFilter] and tries to minimize the internal schema.
235+ ///
236+ /// This could eliminate some columns that were only part of a computation that has been pushed
237+ /// down. As this computation is now materialized on one side of the join, the original input
238+ /// columns are not needed anymore.
237239fn minimize_join_filter (
238240 expr : Arc < dyn PhysicalExpr > ,
239241 old_column_indices : Vec < ColumnIndex > ,
@@ -347,7 +349,7 @@ impl<'a> JoinFilterRewriter<'a> {
347349 // Recurse if there is a dependency to both sides or if the entire expression is volatile.
348350 let depends_on_other_side =
349351 self . depends_on_join_side ( & expr, self . join_side . negate ( ) ) ?;
350- let is_volatile = is_volatile ( expr. as_ref ( ) ) ;
352+ let is_volatile = is_volatile_expression_tree ( expr. as_ref ( ) ) ;
351353 if depends_on_other_side || is_volatile {
352354 return expr. map_children ( |expr| self . rewrite ( expr) ) ;
353355 }
@@ -429,31 +431,31 @@ impl<'a> JoinFilterRewriter<'a> {
429431 }
430432}
431433
432- fn is_volatile ( expr : & dyn PhysicalExpr ) -> bool {
433- match expr. as_any ( ) . downcast_ref :: < ScalarFunctionExpr > ( ) {
434- None => expr
435- . children ( )
436- . iter ( )
437- . map ( |expr| is_volatile ( expr. as_ref ( ) ) )
438- . reduce ( |lhs, rhs| lhs || rhs)
439- . unwrap_or ( false ) ,
440- Some ( expr) => expr. fun ( ) . signature ( ) . volatility == Volatility :: Volatile ,
434+ fn is_volatile_expression_tree ( expr : & dyn PhysicalExpr ) -> bool {
435+ if expr. is_volatile_node ( ) {
436+ return true ;
441437 }
438+
439+ expr. children ( )
440+ . iter ( )
441+ . map ( |expr| is_volatile_expression_tree ( expr. as_ref ( ) ) )
442+ . reduce ( |lhs, rhs| lhs || rhs)
443+ . unwrap_or ( false )
442444}
443445
444446#[ cfg( test) ]
445447mod test {
446448 use super :: * ;
447449 use arrow:: datatypes:: { DataType , Field , FieldRef , Schema } ;
448- use datafusion_expr:: { ScalarUDF , ScalarUDFImpl } ;
449450 use datafusion_expr_common:: operator:: Operator ;
451+ use datafusion_functions:: math:: random;
450452 use datafusion_physical_expr:: expressions:: { binary, lit} ;
453+ use datafusion_physical_expr:: ScalarFunctionExpr ;
451454 use datafusion_physical_expr_common:: physical_expr:: PhysicalExpr ;
452455 use datafusion_physical_plan:: displayable;
453456 use datafusion_physical_plan:: empty:: EmptyExec ;
454457 use insta:: assert_snapshot;
455458 use std:: sync:: Arc ;
456- use datafusion_functions:: math:: random;
457459
458460 #[ tokio:: test]
459461 async fn no_computation_does_not_project ( ) -> Result < ( ) > {
0 commit comments