@@ -337,53 +337,165 @@ impl UWheelOptimizer {
337
337
return None ;
338
338
} ;
339
339
340
- let ( wheel_range, _) = extract_filter_expr ( & filter. predicate , & self . time_column ) ?;
340
+ let ( wheel_range, expr_key) =
341
+ match extract_filter_expr ( & filter. predicate , & self . time_column ) ? {
342
+ ( range, Some ( expr) ) => ( range, maybe_replace_table_name ( & expr, & self . name ) ) ,
343
+ ( range, None ) => ( range, STAR_AGGREGATION_ALIAS . to_string ( ) ) ,
344
+ } ;
341
345
342
346
match group_expr {
343
347
Expr :: ScalarFunction ( func) if func. name ( ) == "date_trunc" => {
344
348
let interval = func. args . first ( ) ?;
345
349
if let Expr :: Literal ( ScalarValue :: Utf8 ( duration) ) = interval {
346
- match duration. as_ref ( ) ?. as_str ( ) {
347
- "second" => {
348
- unimplemented ! ( "date_trunc('second') group by is not supported" )
349
- }
350
- "minute" => {
351
- unimplemented ! ( "date_trunc('minute') group by is not supported" )
352
- }
353
- "hour" => {
354
- unimplemented ! ( "date_trunc('hour') group by is not supported" )
355
- }
356
- "day" => {
357
- let res = self
358
- . wheels
359
- . count
360
- . group_by ( wheel_range, Duration :: DAY )
361
- . unwrap_or_default ( )
362
- . iter ( )
363
- . map ( |( k, v) | ( ( * k * 1_000 ) as i64 , * v as i64 ) ) // transform milliseconds to microseconds by multiplying by 1_000
364
- . collect ( ) ;
365
-
366
- let schema = Arc :: new ( plan. schema ( ) . clone ( ) . as_arrow ( ) . clone ( ) ) ;
367
-
368
- return uwheel_group_by_to_table_scan ( res, schema) . ok ( ) ;
369
- }
370
- "week" => {
371
- unimplemented ! ( "date_trunc('week') group by is not supported" )
372
- }
373
- "month" => {
374
- unimplemented ! ( "date_trunc('month') group by is not supported" )
375
- }
376
- "year" => {
377
- unimplemented ! ( "date_trunc('year') group by is not supported" )
350
+ let group_by_interval = match duration. as_ref ( ) ?. as_str ( ) {
351
+ "second" => Duration :: SECOND ,
352
+ "minute" => Duration :: MINUTE ,
353
+ "hour" => Duration :: HOUR ,
354
+ "day" => Duration :: DAY ,
355
+ "week" => Duration :: WEEK ,
356
+ _ => return None ,
357
+ } ;
358
+
359
+ let mut group_agg_result = Vec :: new ( ) ;
360
+ let mut group_col = None ;
361
+
362
+ let mut count_idx = None ;
363
+
364
+ for ( idx, agg) in agg. aggr_expr . iter ( ) . enumerate ( ) {
365
+ let ( agg_type, col) = match agg {
366
+ // COUNT(*)
367
+ Expr :: AggregateFunction ( agg)
368
+ if is_count_star_aggregate ( agg) =>
369
+ {
370
+ ( UWheelAggregate :: Count , None )
371
+ }
372
+
373
+ // COUNT(*)
374
+ Expr :: Alias ( alias) if alias. name == COUNT_STAR_ALIAS => {
375
+ ( UWheelAggregate :: Count , None )
376
+ }
377
+
378
+ Expr :: AggregateFunction ( agg) => {
379
+ if agg. args . len ( ) > 1 {
380
+ return None ;
381
+ }
382
+ let col = match & agg. args [ 0 ] {
383
+ Expr :: Column ( col) => col,
384
+ _ => return None ,
385
+ } ;
386
+ ( func_def_to_aggregate_type ( & agg. func ) ?, Some ( col) )
387
+ }
388
+
389
+ _ => return None ,
390
+ } ;
391
+
392
+ let res = match agg_type {
393
+ UWheelAggregate :: Count => {
394
+ count_idx = Some ( idx) ;
395
+ self . wheels
396
+ . count
397
+ . group_by ( wheel_range, group_by_interval)
398
+ . unwrap_or_default ( )
399
+ . iter ( )
400
+ . map ( |( k, v) | ( * k, * v as f64 ) )
401
+ . collect ( )
402
+ }
403
+
404
+ UWheelAggregate :: Avg => {
405
+ let wheel_key = format ! (
406
+ "{}.{}.{}" ,
407
+ self . name,
408
+ col. unwrap( ) . name,
409
+ expr_key
410
+ ) ;
411
+ self . wheels
412
+ . avg
413
+ . lock ( )
414
+ . unwrap ( )
415
+ . get ( & wheel_key) ?
416
+ . group_by ( wheel_range, group_by_interval)
417
+ . unwrap_or_default ( )
418
+ }
419
+
420
+ UWheelAggregate :: Min => {
421
+ let wheel_key = format ! (
422
+ "{}.{}.{}" ,
423
+ self . name,
424
+ col. unwrap( ) . name,
425
+ expr_key
426
+ ) ;
427
+ self . wheels
428
+ . min
429
+ . lock ( )
430
+ . unwrap ( )
431
+ . get ( & wheel_key) ?
432
+ . group_by ( wheel_range, group_by_interval)
433
+ . unwrap_or_default ( )
434
+ }
435
+
436
+ UWheelAggregate :: Max => {
437
+ let wheel_key = format ! (
438
+ "{}.{}.{}" ,
439
+ self . name,
440
+ col. unwrap( ) . name,
441
+ expr_key
442
+ ) ;
443
+ self . wheels
444
+ . max
445
+ . lock ( )
446
+ . unwrap ( )
447
+ . get ( & wheel_key) ?
448
+ . group_by ( wheel_range, group_by_interval)
449
+ . unwrap_or_default ( )
450
+ }
451
+
452
+ UWheelAggregate :: Sum => {
453
+ let wheel_key = format ! (
454
+ "{}.{}.{}" ,
455
+ self . name,
456
+ col. unwrap( ) . name,
457
+ expr_key
458
+ ) ;
459
+ self . wheels
460
+ . sum
461
+ . lock ( )
462
+ . unwrap ( )
463
+ . get ( & wheel_key) ?
464
+ . group_by ( wheel_range, group_by_interval)
465
+ . unwrap_or_default ( )
466
+ }
467
+
468
+ _ => return None ,
469
+ } ;
470
+
471
+ if group_col. is_none ( ) {
472
+ group_col = Some (
473
+ res. iter ( )
474
+ . map ( |( k, _) | ( * k * 1_000 ) as i64 )
475
+ . collect :: < Vec < _ > > ( ) ,
476
+ ) ;
378
477
}
379
- _ => { }
478
+
479
+ group_agg_result
480
+ . push ( res. iter ( ) . map ( |( _, v) | * v) . collect :: < Vec < _ > > ( ) ) ;
380
481
}
482
+
483
+ group_col. as_ref ( ) ?;
484
+
485
+ let schema = Arc :: new ( plan. schema ( ) . clone ( ) . as_arrow ( ) . clone ( ) ) ;
486
+
487
+ return uwheel_group_by_to_table_scan (
488
+ group_col. unwrap ( ) ,
489
+ group_agg_result,
490
+ count_idx,
491
+ schema,
492
+ )
493
+ . ok ( ) ;
381
494
}
382
495
}
383
- _ => {
384
- unimplemented ! ( "We only support scalar function date_trunc for group by expression now" )
385
- }
496
+ _ => return None ,
386
497
}
498
+
387
499
None
388
500
}
389
501
@@ -622,16 +734,26 @@ fn uwheel_agg_to_table_scan(result: f64, schema: SchemaRef) -> Result<LogicalPla
622
734
// Converts a uwheel group by result to a TableScan with a MemTable as source
623
735
// currently only supports timestamp group by
624
736
fn uwheel_group_by_to_table_scan (
625
- result : Vec < ( i64 , i64 ) > ,
737
+ group_col : Vec < i64 > ,
738
+ agg_result : Vec < Vec < f64 > > ,
739
+ count_idx : Option < usize > ,
626
740
schema : SchemaRef ,
627
741
) -> Result < LogicalPlan > {
628
- let group_by =
629
- TimestampMicrosecondArray :: from ( result. iter ( ) . map ( |( k, _) | * k) . collect :: < Vec < _ > > ( ) ) ;
742
+ let group_by = TimestampMicrosecondArray :: from ( group_col) ;
630
743
631
- let agg = Int64Array :: from ( result. iter ( ) . map ( |( _, v) | * v) . collect :: < Vec < _ > > ( ) ) ;
744
+ let mut columns = vec ! [ Arc :: new( group_by) as Arc <dyn Array >] ;
745
+
746
+ for ( idx, result) in agg_result. into_iter ( ) . enumerate ( ) {
747
+ if count_idx. is_some ( ) && idx == count_idx. unwrap ( ) {
748
+ let data = Int64Array :: from ( result. iter ( ) . map ( |v| * v as i64 ) . collect :: < Vec < _ > > ( ) ) ;
749
+ columns. push ( Arc :: new ( data) as Arc < dyn Array > ) ;
750
+ } else {
751
+ let data = Float64Array :: from ( result) ;
752
+ columns. push ( Arc :: new ( data) as Arc < dyn Array > ) ;
753
+ }
754
+ }
632
755
633
- let record_batch =
634
- RecordBatch :: try_new ( schema. clone ( ) , vec ! [ Arc :: new( group_by) , Arc :: new( agg) ] ) ?;
756
+ let record_batch = RecordBatch :: try_new ( schema. clone ( ) , columns) ?;
635
757
636
758
let df_schema = Arc :: new ( DFSchema :: try_from ( schema. clone ( ) ) ?) ;
637
759
let mem_table = MemTable :: try_new ( schema, vec ! [ vec![ record_batch] ] ) ?;
@@ -1898,4 +2020,136 @@ mod tests {
1898
2020
1899
2021
Ok ( ( ) )
1900
2022
}
2023
+
2024
+ #[ tokio:: test]
2025
+ async fn group_by_multiple_aggregation_rewrite ( ) -> Result < ( ) > {
2026
+ let optimizer = test_optimizer ( ) . await ?;
2027
+
2028
+ optimizer
2029
+ . build_index ( IndexBuilder :: with_col_and_aggregate (
2030
+ "agg_col" ,
2031
+ UWheelAggregate :: Avg ,
2032
+ ) )
2033
+ . await ?;
2034
+
2035
+ optimizer
2036
+ . build_index ( IndexBuilder :: with_col_and_aggregate (
2037
+ "agg_col" ,
2038
+ UWheelAggregate :: Sum ,
2039
+ ) )
2040
+ . await ?;
2041
+
2042
+ let temporal_filter = col ( "timestamp" )
2043
+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
2044
+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
2045
+
2046
+ let plan =
2047
+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
2048
+ . filter ( temporal_filter) ?
2049
+ . aggregate (
2050
+ vec ! [ date_trunc( lit( "day" ) , col( "timestamp" ) ) ] , // GROUP BY date_trunc('day', timestamp)
2051
+ vec ! [ sum( col( "agg_col" ) ) , avg( col( "agg_col" ) ) , count( wildcard( ) ) ] ,
2052
+ ) ?
2053
+ . project ( vec ! [
2054
+ date_trunc( lit( "day" ) , col( "timestamp" ) ) ,
2055
+ sum( col( "agg_col" ) ) ,
2056
+ avg( col( "agg_col" ) ) ,
2057
+ count( wildcard( ) ) ,
2058
+ ] ) ?
2059
+ . build ( ) ?;
2060
+
2061
+ // Assert that the original plan is a Projection
2062
+ assert ! ( matches!( plan, LogicalPlan :: Projection ( _) ) ) ;
2063
+
2064
+ let rewritten = optimizer. try_rewrite ( & plan) . unwrap ( ) ;
2065
+ // assert it was rewritten to a TableScan
2066
+ assert ! ( matches!( rewritten, LogicalPlan :: TableScan ( _) ) ) ;
2067
+
2068
+ Ok ( ( ) )
2069
+ }
2070
+
2071
+ #[ tokio:: test]
2072
+ async fn group_by_multiple_aggregation_exec ( ) -> Result < ( ) > {
2073
+ let optimizer = test_optimizer ( ) . await ?;
2074
+
2075
+ optimizer
2076
+ . build_index ( IndexBuilder :: with_col_and_aggregate (
2077
+ "agg_col" ,
2078
+ UWheelAggregate :: Avg ,
2079
+ ) )
2080
+ . await ?;
2081
+
2082
+ optimizer
2083
+ . build_index ( IndexBuilder :: with_col_and_aggregate (
2084
+ "agg_col" ,
2085
+ UWheelAggregate :: Sum ,
2086
+ ) )
2087
+ . await ?;
2088
+
2089
+ let temporal_filter = col ( "timestamp" )
2090
+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
2091
+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
2092
+
2093
+ let plan =
2094
+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
2095
+ . filter ( temporal_filter) ?
2096
+ . aggregate (
2097
+ vec ! [ date_trunc( lit( "day" ) , col( "timestamp" ) ) ] , // GROUP BY date_trunc('day', timestamp)
2098
+ vec ! [ sum( col( "agg_col" ) ) , avg( col( "agg_col" ) ) , count( wildcard( ) ) ] ,
2099
+ ) ?
2100
+ . project ( vec ! [
2101
+ date_trunc( lit( "day" ) , col( "timestamp" ) ) ,
2102
+ sum( col( "agg_col" ) ) ,
2103
+ avg( col( "agg_col" ) ) ,
2104
+ count( wildcard( ) ) ,
2105
+ ] ) ?
2106
+ . build ( ) ?;
2107
+
2108
+ let ctx = SessionContext :: new ( ) ;
2109
+ ctx. register_table ( "test" , optimizer. provider ( ) . clone ( ) ) ?;
2110
+
2111
+ // Set UWheelOptimizer as optimizer rule
2112
+ let session_state = SessionStateBuilder :: new ( )
2113
+ . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] )
2114
+ . build ( ) ;
2115
+ let uwheel_ctx = SessionContext :: new_with_state ( session_state) ;
2116
+
2117
+ // Run the query through the ctx that has our OptimizerRule
2118
+ let df = uwheel_ctx. execute_logical_plan ( plan) . await ?;
2119
+ let results = df. collect ( ) . await ?;
2120
+
2121
+ assert_eq ! ( results. len( ) , 1 ) ;
2122
+
2123
+ assert_eq ! (
2124
+ results[ 0 ]
2125
+ . column( 1 )
2126
+ . as_any( )
2127
+ . downcast_ref:: <Float64Array >( )
2128
+ . unwrap( )
2129
+ . value( 0 ) ,
2130
+ 55.0
2131
+ ) ;
2132
+
2133
+ assert_eq ! (
2134
+ results[ 0 ]
2135
+ . column( 2 )
2136
+ . as_any( )
2137
+ . downcast_ref:: <Float64Array >( )
2138
+ . unwrap( )
2139
+ . value( 0 ) ,
2140
+ 5.5
2141
+ ) ;
2142
+
2143
+ assert_eq ! (
2144
+ results[ 0 ]
2145
+ . column( 3 )
2146
+ . as_any( )
2147
+ . downcast_ref:: <Int64Array >( )
2148
+ . unwrap( )
2149
+ . value( 0 ) ,
2150
+ 10
2151
+ ) ;
2152
+
2153
+ Ok ( ( ) )
2154
+ }
1901
2155
}
0 commit comments