@@ -3,11 +3,10 @@ use std::{
3
3
sync:: { Arc , Mutex } ,
4
4
} ;
5
5
6
- use chrono:: { DateTime , NaiveDate , TimeZone , Utc } ;
7
- use datafusion:: error:: Result ;
6
+ use chrono:: { DateTime , Utc } ;
8
7
use datafusion:: {
9
8
arrow:: {
10
- array:: { ArrayRef , Float64Array , Int64Array , RecordBatch , TimestampMicrosecondArray } ,
9
+ array:: { Float64Array , Int64Array , RecordBatch , TimestampMicrosecondArray } ,
11
10
datatypes:: { DataType , SchemaRef } ,
12
11
} ,
13
12
common:: { tree_node:: Transformed , DFSchema , DFSchemaRef } ,
@@ -21,6 +20,7 @@ use datafusion::{
21
20
scalar:: ScalarValue ,
22
21
sql:: TableReference ,
23
22
} ;
23
+ use datafusion:: { error:: Result , logical_expr:: expr:: AggregateFunction } ;
24
24
use expr:: {
25
25
extract_filter_expr, extract_uwheel_expr, extract_wheel_range, MinMaxFilter , UWheelExpr ,
26
26
} ;
@@ -341,11 +341,11 @@ impl UWheelOptimizer {
341
341
match agg_expr {
342
342
// COUNT(*)
343
343
Expr :: Alias ( alias) if alias. name == COUNT_STAR_ALIAS => {
344
- // extract possible time range filter
345
- let range = extract_wheel_range ( & filter . predicate , & self . time_column ) ? ;
346
- let count = self . count ( range ) ? ; // early return if range is not queryable
347
- let schema = Arc :: new ( plan . schema ( ) . clone ( ) . as_arrow ( ) . clone ( ) ) ;
348
- count_scan ( count , schema ) . ok ( )
344
+ self . try_count_rewrite ( filter, plan )
345
+ }
346
+ // Also check
347
+ Expr :: AggregateFunction ( agg ) if is_count_star_aggregate ( agg ) => {
348
+ self . try_count_rewrite ( filter , plan )
349
349
}
350
350
// Single Aggregate Function (e.g., SUM(col))
351
351
Expr :: AggregateFunction ( agg) if agg. args . len ( ) == 1 => {
@@ -398,6 +398,13 @@ impl UWheelOptimizer {
398
398
}
399
399
}
400
400
401
+ fn try_count_rewrite ( & self , filter : & Filter , plan : & LogicalPlan ) -> Option < LogicalPlan > {
402
+ let range = extract_wheel_range ( & filter. predicate , & self . time_column ) ?;
403
+ let count = self . count ( range) ?; // early return if range is not queryable
404
+ let schema = Arc :: new ( plan. schema ( ) . clone ( ) . as_arrow ( ) . clone ( ) ) ;
405
+ count_scan ( count, schema) . ok ( )
406
+ }
407
+
401
408
// Queries the range using the count wheel, returning a empty table scan if the count is 0
402
409
// avoiding the need to generate a regular execution plan..
403
410
fn maybe_count_filter ( & self , range : WheelRange , plan : & LogicalPlan ) -> Option < LogicalPlan > {
@@ -481,9 +488,8 @@ impl UWheelOptimizer {
481
488
}
482
489
483
490
fn count_scan ( count : u32 , schema : SchemaRef ) -> Result < LogicalPlan > {
484
- let name = COUNT_STAR_ALIAS . to_string ( ) ;
485
491
let data = Int64Array :: from ( vec ! [ count as i64 ] ) ;
486
- let record_batch = RecordBatch :: try_from_iter ( vec ! [ ( & name , Arc :: new( data) as ArrayRef ) ] ) ?;
492
+ let record_batch = RecordBatch :: try_new ( schema . clone ( ) , vec ! [ Arc :: new( data) ] ) ?;
487
493
let df_schema = Arc :: new ( DFSchema :: try_from ( schema. clone ( ) ) ?) ;
488
494
let mem_table = MemTable :: try_new ( schema, vec ! [ vec![ record_batch] ] ) ?;
489
495
@@ -551,6 +557,7 @@ fn func_def_to_aggregate_type(func_def: &AggregateFunctionDefinition) -> Option<
551
557
AggregateFunctionDefinition :: BuiltIn ( datafusion:: logical_expr:: AggregateFunction :: Min ) => {
552
558
Some ( AggregateType :: Min )
553
559
}
560
+ AggregateFunctionDefinition :: UDF ( udf) if udf. name ( ) == "avg" => Some ( AggregateType :: Avg ) ,
554
561
AggregateFunctionDefinition :: UDF ( udf) if udf. name ( ) == "sum" => Some ( AggregateType :: Sum ) ,
555
562
AggregateFunctionDefinition :: UDF ( udf) if udf. name ( ) == "count" => {
556
563
Some ( AggregateType :: Count )
@@ -599,6 +606,30 @@ fn mem_table_as_table_scan(table: MemTable, original_schema: DFSchemaRef) -> Res
599
606
Ok ( LogicalPlan :: TableScan ( table_scan) )
600
607
}
601
608
609
+ fn is_wildcard ( expr : & Expr ) -> bool {
610
+ matches ! ( expr, Expr :: Wildcard { qualifier: None } )
611
+ }
612
+
613
+ /// Determines if the given aggregate function is a COUNT(*) aggregate.
614
+ ///
615
+ /// An aggregate function is a COUNT(*) aggregate if its function name is "COUNT" and it either has a single argument that is a wildcard (`*`), or it has no arguments.
616
+ ///
617
+ /// # Arguments
618
+ ///
619
+ /// * `aggregate_function` - The aggregate function to check.
620
+ ///
621
+ /// # Returns
622
+ ///
623
+ /// `true` if the aggregate function is a COUNT(*) aggregate, `false` otherwise.
624
+ fn is_count_star_aggregate ( aggregate_function : & AggregateFunction ) -> bool {
625
+ matches ! ( aggregate_function,
626
+ AggregateFunction {
627
+ func_def,
628
+ args,
629
+ ..
630
+ } if func_def. name( ) == "COUNT" && ( args. len( ) == 1 && is_wildcard( & args[ 0 ] ) || args. is_empty( ) ) )
631
+ }
632
+
602
633
// Helper methods to build the UWheelOptimizer
603
634
604
635
// Uses the provided TableProvider to build the UWheelOptimizer
@@ -666,18 +697,13 @@ async fn build(
666
697
async fn build_min_max_wheel (
667
698
schema : SchemaRef ,
668
699
batches : & [ RecordBatch ] ,
669
- _min_timestamp_ms : u64 ,
700
+ min_timestamp_ms : u64 ,
670
701
max_timestamp_ms : u64 ,
671
702
time_col : & str ,
672
703
min_max_col : & str ,
673
704
haw_conf : & HawConf ,
674
705
) -> Result < ReaderWheel < F64MinMaxAggregator > > {
675
- // TODO: remove hardcoded time
676
- let start = NaiveDate :: from_ymd_opt ( 2022 , 1 , 1 ) . unwrap ( ) ;
677
- let date = Utc . from_utc_datetime ( & start. and_hms_opt ( 0 , 0 , 0 ) . unwrap ( ) ) ;
678
- let start_ms = date. timestamp_millis ( ) as u64 ;
679
-
680
- let conf = haw_conf. with_watermark ( start_ms) ;
706
+ let conf = haw_conf. with_watermark ( min_timestamp_ms) ;
681
707
682
708
let mut wheel: RwWheel < F64MinMaxAggregator > = RwWheel :: with_conf (
683
709
Conf :: default ( )
@@ -794,13 +820,15 @@ where
794
820
wheel. insert ( entry) ;
795
821
}
796
822
}
797
- // Once all data is inserted, advance the wheel to the max timestamp
798
- wheel. advance_to ( end_ms) ;
823
+ // Once all data is inserted, advance the wheel to the max timestamp + 1 second
824
+ wheel. advance_to ( end_ms + 1000 ) ;
799
825
800
826
// convert wheel to index
801
827
wheel. read ( ) . to_simd_wheels ( ) ;
802
828
// TODO: make this configurable
803
- wheel. read ( ) . to_prefix_wheels ( ) ;
829
+ if A :: invertible ( ) {
830
+ wheel. read ( ) . to_prefix_wheels ( ) ;
831
+ }
804
832
} else {
805
833
// TODO: return Datafusion Error?
806
834
panic ! ( "Min/Max column must be a numeric type" ) ;
@@ -827,11 +855,7 @@ fn build_count_wheel(
827
855
. unwrap ( )
828
856
. timestamp_millis ( ) as u64 ;
829
857
830
- let start = NaiveDate :: from_ymd_opt ( 2022 , 1 , 1 ) . unwrap ( ) ;
831
- let date = Utc . from_utc_datetime ( & start. and_hms_opt ( 0 , 0 , 0 ) . unwrap ( ) ) ;
832
- let start_ms = date. timestamp_millis ( ) as u64 ;
833
-
834
- let conf = haw_conf. with_watermark ( start_ms) ;
858
+ let conf = haw_conf. with_watermark ( min_ms) ;
835
859
836
860
let mut count_wheel: RwWheel < U32SumAggregator > = RwWheel :: with_conf (
837
861
Conf :: default ( )
@@ -843,12 +867,13 @@ fn build_count_wheel(
843
867
let timestamp_ms = DateTime :: from_timestamp_micros ( timestamp)
844
868
. unwrap ( )
845
869
. timestamp_millis ( ) as u64 ;
870
+
846
871
// Record a count
847
872
let entry = Entry :: new ( 1 , timestamp_ms) ;
848
873
count_wheel. insert ( entry) ;
849
874
}
850
875
851
- count_wheel. advance_to ( max_ms) ;
876
+ count_wheel. advance_to ( max_ms + 1000 ) ; // + 1 second
852
877
853
878
// convert wheel to index
854
879
count_wheel. read ( ) . to_simd_wheels ( ) ;
@@ -919,3 +944,462 @@ fn scalar_to_timestamp(scalar: &ScalarValue) -> Option<i64> {
919
944
_ => None ,
920
945
}
921
946
}
947
+
948
+ #[ cfg( test) ]
949
+ mod tests {
950
+ use chrono:: Duration ;
951
+ use chrono:: TimeZone ;
952
+ use datafusion:: arrow:: datatypes:: { Field , Schema , TimeUnit } ;
953
+ use datafusion:: functions_aggregate:: expr_fn:: avg;
954
+ use datafusion:: logical_expr:: test:: function_stub:: { count, sum} ;
955
+
956
+ use super :: * ;
957
+ use builder:: Builder ;
958
+
959
+ fn create_test_memtable ( ) -> Result < MemTable > {
960
+ let schema = Arc :: new ( Schema :: new ( vec ! [
961
+ Field :: new(
962
+ "timestamp" ,
963
+ DataType :: Timestamp ( TimeUnit :: Microsecond , None ) ,
964
+ false ,
965
+ ) ,
966
+ Field :: new( "agg_col" , DataType :: Float64 , false ) ,
967
+ ] ) ) ;
968
+
969
+ // Define the start time as 2024-05-10 00:00:00 UTC
970
+ let base_time: DateTime < Utc > = Utc . with_ymd_and_hms ( 2024 , 5 , 10 , 0 , 0 , 0 ) . unwrap ( ) ;
971
+ let timestamps: Vec < i64 > = ( 0 ..10 )
972
+ . map ( |i| base_time + Duration :: seconds ( i) )
973
+ . map ( |dt| dt. timestamp_micros ( ) )
974
+ . collect ( ) ;
975
+
976
+ let agg_values: Vec < f64 > = ( 0 ..10 ) . map ( |i| ( i + 1 ) as f64 ) . collect ( ) ;
977
+
978
+ let batch = RecordBatch :: try_new (
979
+ schema. clone ( ) ,
980
+ vec ! [
981
+ Arc :: new( TimestampMicrosecondArray :: from( timestamps) ) ,
982
+ Arc :: new( Float64Array :: from( agg_values) ) ,
983
+ ] ,
984
+ ) ?;
985
+
986
+ MemTable :: try_new ( schema, vec ! [ vec![ batch] ] )
987
+ }
988
+
989
+ #[ tokio:: test]
990
+ async fn create_optimizer_with_memtable ( ) {
991
+ let provider = Arc :: new ( create_test_memtable ( ) . unwrap ( ) ) ;
992
+ assert ! ( Builder :: new( "timestamp" )
993
+ . with_name( "test" )
994
+ . build_with_provider( provider)
995
+ . await
996
+ . is_ok( ) ) ;
997
+ }
998
+
999
+ async fn test_optimizer ( ) -> Result < Arc < UWheelOptimizer > > {
1000
+ let provider = Arc :: new ( create_test_memtable ( ) ?) ;
1001
+ Ok ( Arc :: new (
1002
+ Builder :: new ( "timestamp" )
1003
+ . with_name ( "test" )
1004
+ . build_with_provider ( provider)
1005
+ . await ?,
1006
+ ) )
1007
+ }
1008
+
1009
+ #[ tokio:: test]
1010
+ async fn count_star_aggregation_rewrite ( ) -> Result < ( ) > {
1011
+ let optimizer = test_optimizer ( ) . await ?;
1012
+ let temporal_filter = col ( "timestamp" )
1013
+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
1014
+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
1015
+
1016
+ let plan =
1017
+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1018
+ . filter ( temporal_filter) ?
1019
+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ count( wildcard( ) ) ] ) ?
1020
+ . project ( vec ! [ count( wildcard( ) ) ] ) ?
1021
+ . build ( ) ?;
1022
+
1023
+ // Assert that the original plan is a Projection
1024
+ assert ! ( matches!( plan, LogicalPlan :: Projection ( _) ) ) ;
1025
+
1026
+ let rewritten = optimizer. try_rewrite ( & plan) . unwrap ( ) ;
1027
+ // assert it was rewritten to a TableScan
1028
+ assert ! ( matches!( rewritten, LogicalPlan :: TableScan ( _) ) ) ;
1029
+
1030
+ Ok ( ( ) )
1031
+ }
1032
+
1033
+ #[ tokio:: test]
1034
+ async fn sum_aggregation_rewrite ( ) -> Result < ( ) > {
1035
+ let optimizer = test_optimizer ( ) . await ?;
1036
+
1037
+ // Build a sum index
1038
+ optimizer
1039
+ . build_index ( IndexBuilder :: with_col_and_aggregate (
1040
+ "agg_col" ,
1041
+ AggregateType :: Sum ,
1042
+ ) )
1043
+ . await ?;
1044
+
1045
+ let temporal_filter = col ( "timestamp" )
1046
+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
1047
+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
1048
+
1049
+ let plan =
1050
+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1051
+ . filter ( temporal_filter) ?
1052
+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ sum( col( "agg_col" ) ) ] ) ?
1053
+ . project ( vec ! [ sum( col( "agg_col" ) ) ] ) ?
1054
+ . build ( ) ?;
1055
+
1056
+ // Assert that the original plan is a Projection
1057
+ assert ! ( matches!( plan, LogicalPlan :: Projection ( _) ) ) ;
1058
+
1059
+ let rewritten = optimizer. try_rewrite ( & plan) . unwrap ( ) ;
1060
+ // assert it was rewritten to a TableScan
1061
+ assert ! ( matches!( rewritten, LogicalPlan :: TableScan ( _) ) ) ;
1062
+
1063
+ Ok ( ( ) )
1064
+ }
1065
+
1066
+ #[ tokio:: test]
1067
+ async fn min_aggregation_rewrite ( ) -> Result < ( ) > {
1068
+ let optimizer = test_optimizer ( ) . await ?;
1069
+
1070
+ // Build a min index
1071
+ optimizer
1072
+ . build_index ( IndexBuilder :: with_col_and_aggregate (
1073
+ "agg_col" ,
1074
+ AggregateType :: Min ,
1075
+ ) )
1076
+ . await ?;
1077
+
1078
+ let temporal_filter = col ( "timestamp" )
1079
+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
1080
+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
1081
+
1082
+ let plan =
1083
+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1084
+ . filter ( temporal_filter) ?
1085
+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ min( col( "agg_col" ) ) ] ) ?
1086
+ . project ( vec ! [ min( col( "agg_col" ) ) ] ) ?
1087
+ . build ( ) ?;
1088
+
1089
+ // Assert that the original plan is a Projection
1090
+ assert ! ( matches!( plan, LogicalPlan :: Projection ( _) ) ) ;
1091
+
1092
+ let rewritten = optimizer. try_rewrite ( & plan) . unwrap ( ) ;
1093
+ // assert it was rewritten to a TableScan
1094
+ assert ! ( matches!( rewritten, LogicalPlan :: TableScan ( _) ) ) ;
1095
+
1096
+ Ok ( ( ) )
1097
+ }
1098
+
1099
+ #[ tokio:: test]
1100
+ async fn max_aggregation_rewrite ( ) -> Result < ( ) > {
1101
+ let optimizer = test_optimizer ( ) . await ?;
1102
+
1103
+ optimizer
1104
+ . build_index ( IndexBuilder :: with_col_and_aggregate (
1105
+ "agg_col" ,
1106
+ AggregateType :: Max ,
1107
+ ) )
1108
+ . await ?;
1109
+
1110
+ let temporal_filter = col ( "timestamp" )
1111
+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
1112
+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
1113
+
1114
+ let plan =
1115
+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1116
+ . filter ( temporal_filter) ?
1117
+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ max( col( "agg_col" ) ) ] ) ?
1118
+ . project ( vec ! [ max( col( "agg_col" ) ) ] ) ?
1119
+ . build ( ) ?;
1120
+
1121
+ // Assert that the original plan is a Projection
1122
+ assert ! ( matches!( plan, LogicalPlan :: Projection ( _) ) ) ;
1123
+
1124
+ let rewritten = optimizer. try_rewrite ( & plan) . unwrap ( ) ;
1125
+ // assert it was rewritten to a TableScan
1126
+ assert ! ( matches!( rewritten, LogicalPlan :: TableScan ( _) ) ) ;
1127
+
1128
+ Ok ( ( ) )
1129
+ }
1130
+
1131
+ #[ tokio:: test]
1132
+ async fn avg_aggregation_rewrite ( ) -> Result < ( ) > {
1133
+ let optimizer = test_optimizer ( ) . await ?;
1134
+
1135
+ optimizer
1136
+ . build_index ( IndexBuilder :: with_col_and_aggregate (
1137
+ "agg_col" ,
1138
+ AggregateType :: Avg ,
1139
+ ) )
1140
+ . await ?;
1141
+
1142
+ let temporal_filter = col ( "timestamp" )
1143
+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
1144
+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
1145
+
1146
+ let plan =
1147
+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1148
+ . filter ( temporal_filter) ?
1149
+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ avg( col( "agg_col" ) ) ] ) ?
1150
+ . project ( vec ! [ avg( col( "agg_col" ) ) ] ) ?
1151
+ . build ( ) ?;
1152
+
1153
+ // Assert that the original plan is a Projection
1154
+ assert ! ( matches!( plan, LogicalPlan :: Projection ( _) ) ) ;
1155
+
1156
+ let rewritten = optimizer. try_rewrite ( & plan) . unwrap ( ) ;
1157
+ // assert it was rewritten to a TableScan
1158
+ assert ! ( matches!( rewritten, LogicalPlan :: TableScan ( _) ) ) ;
1159
+
1160
+ Ok ( ( ) )
1161
+ }
1162
+
1163
+ #[ tokio:: test]
1164
+ async fn count_star_aggregation_invalid_rewrite ( ) -> Result < ( ) > {
1165
+ let optimizer = test_optimizer ( ) . await ?;
1166
+ // invalid temporal filter
1167
+ let temporal_filter = col ( "timestamp" )
1168
+ . gt_eq ( lit ( "2024-05-11T00:00:00Z" ) )
1169
+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-11T00:00:10Z" ) ) ) ;
1170
+
1171
+ let plan =
1172
+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1173
+ . filter ( temporal_filter) ?
1174
+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ count( wildcard( ) ) ] ) ?
1175
+ . project ( vec ! [ count( wildcard( ) ) ] ) ?
1176
+ . build ( ) ?;
1177
+
1178
+ assert ! ( optimizer. try_rewrite( & plan) . is_none( ) ) ;
1179
+
1180
+ Ok ( ( ) )
1181
+ }
1182
+
1183
+ #[ tokio:: test]
1184
+ async fn count_star_aggregation_exec ( ) -> Result < ( ) > {
1185
+ let optimizer = test_optimizer ( ) . await ?;
1186
+ let temporal_filter = col ( "timestamp" )
1187
+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
1188
+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
1189
+
1190
+ let plan =
1191
+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1192
+ . filter ( temporal_filter) ?
1193
+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ count( wildcard( ) ) ] ) ?
1194
+ . project ( vec ! [ count( wildcard( ) ) ] ) ?
1195
+ . build ( ) ?;
1196
+
1197
+ let ctx = SessionContext :: new ( ) ;
1198
+ ctx. register_table ( "test" , optimizer. provider ( ) . clone ( ) ) ?;
1199
+
1200
+ // Set UWheelOptimizer as optimizer rule
1201
+ let session_state = ctx. state ( ) . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] ) ;
1202
+ let uwheel_ctx = SessionContext :: new_with_state ( session_state) ;
1203
+
1204
+ // Run the query through the ctx that has our OptimizerRule
1205
+ let df = uwheel_ctx. execute_logical_plan ( plan) . await ?;
1206
+ let results = df. collect ( ) . await ?;
1207
+
1208
+ assert_eq ! ( results. len( ) , 1 ) ;
1209
+ assert_eq ! (
1210
+ results[ 0 ]
1211
+ . column( 0 )
1212
+ . as_any( )
1213
+ . downcast_ref:: <Int64Array >( )
1214
+ . unwrap( )
1215
+ . value( 0 ) ,
1216
+ 10 ,
1217
+ ) ;
1218
+
1219
+ Ok ( ( ) )
1220
+ }
1221
+
1222
+ #[ tokio:: test]
1223
+ async fn sum_aggregation_exec ( ) -> Result < ( ) > {
1224
+ let optimizer = test_optimizer ( ) . await ?;
1225
+ optimizer
1226
+ . build_index ( IndexBuilder :: with_col_and_aggregate (
1227
+ "agg_col" ,
1228
+ AggregateType :: Sum ,
1229
+ ) )
1230
+ . await ?;
1231
+
1232
+ let temporal_filter = col ( "timestamp" )
1233
+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
1234
+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
1235
+
1236
+ let plan =
1237
+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1238
+ . filter ( temporal_filter) ?
1239
+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ sum( col( "agg_col" ) ) ] ) ?
1240
+ . project ( vec ! [ sum( col( "agg_col" ) ) ] ) ?
1241
+ . build ( ) ?;
1242
+
1243
+ let ctx = SessionContext :: new ( ) ;
1244
+ ctx. register_table ( "test" , optimizer. provider ( ) . clone ( ) ) ?;
1245
+
1246
+ // Set UWheelOptimizer as optimizer rule
1247
+ let session_state = ctx. state ( ) . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] ) ;
1248
+ let uwheel_ctx = SessionContext :: new_with_state ( session_state) ;
1249
+
1250
+ // Run the query through the ctx that has our OptimizerRule
1251
+ let df = uwheel_ctx. execute_logical_plan ( plan) . await ?;
1252
+ let results = df. collect ( ) . await ?;
1253
+
1254
+ assert_eq ! ( results. len( ) , 1 ) ;
1255
+ assert_eq ! (
1256
+ results[ 0 ]
1257
+ . column( 0 )
1258
+ . as_any( )
1259
+ . downcast_ref:: <Float64Array >( )
1260
+ . unwrap( )
1261
+ . value( 0 ) ,
1262
+ 55.0 // 1 + 2 +3 + ... + 9 + 10 = 55
1263
+ ) ;
1264
+
1265
+ Ok ( ( ) )
1266
+ }
1267
+
1268
+ #[ tokio:: test]
1269
+ async fn min_aggregation_exec ( ) -> Result < ( ) > {
1270
+ let optimizer = test_optimizer ( ) . await ?;
1271
+ optimizer
1272
+ . build_index ( IndexBuilder :: with_col_and_aggregate (
1273
+ "agg_col" ,
1274
+ AggregateType :: Min ,
1275
+ ) )
1276
+ . await ?;
1277
+
1278
+ let temporal_filter = col ( "timestamp" )
1279
+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
1280
+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
1281
+
1282
+ let plan =
1283
+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1284
+ . filter ( temporal_filter) ?
1285
+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ min( col( "agg_col" ) ) ] ) ?
1286
+ . project ( vec ! [ min( col( "agg_col" ) ) ] ) ?
1287
+ . build ( ) ?;
1288
+
1289
+ let ctx = SessionContext :: new ( ) ;
1290
+ ctx. register_table ( "test" , optimizer. provider ( ) . clone ( ) ) ?;
1291
+
1292
+ // Set UWheelOptimizer as optimizer rule
1293
+ let session_state = ctx. state ( ) . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] ) ;
1294
+ let uwheel_ctx = SessionContext :: new_with_state ( session_state) ;
1295
+
1296
+ // Run the query through the ctx that has our OptimizerRule
1297
+ let df = uwheel_ctx. execute_logical_plan ( plan) . await ?;
1298
+ let results = df. collect ( ) . await ?;
1299
+
1300
+ assert_eq ! ( results. len( ) , 1 ) ;
1301
+ assert_eq ! (
1302
+ results[ 0 ]
1303
+ . column( 0 )
1304
+ . as_any( )
1305
+ . downcast_ref:: <Float64Array >( )
1306
+ . unwrap( )
1307
+ . value( 0 ) ,
1308
+ 1.0
1309
+ ) ;
1310
+
1311
+ Ok ( ( ) )
1312
+ }
1313
+
1314
+ #[ tokio:: test]
1315
+ async fn max_aggregation_exec ( ) -> Result < ( ) > {
1316
+ let optimizer = test_optimizer ( ) . await ?;
1317
+ optimizer
1318
+ . build_index ( IndexBuilder :: with_col_and_aggregate (
1319
+ "agg_col" ,
1320
+ AggregateType :: Max ,
1321
+ ) )
1322
+ . await ?;
1323
+
1324
+ let temporal_filter = col ( "timestamp" )
1325
+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
1326
+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
1327
+
1328
+ let plan =
1329
+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1330
+ . filter ( temporal_filter) ?
1331
+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ max( col( "agg_col" ) ) ] ) ?
1332
+ . project ( vec ! [ max( col( "agg_col" ) ) ] ) ?
1333
+ . build ( ) ?;
1334
+
1335
+ let ctx = SessionContext :: new ( ) ;
1336
+ ctx. register_table ( "test" , optimizer. provider ( ) . clone ( ) ) ?;
1337
+
1338
+ // Set UWheelOptimizer as optimizer rule
1339
+ let session_state = ctx. state ( ) . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] ) ;
1340
+ let uwheel_ctx = SessionContext :: new_with_state ( session_state) ;
1341
+
1342
+ // Run the query through the ctx that has our OptimizerRule
1343
+ let df = uwheel_ctx. execute_logical_plan ( plan) . await ?;
1344
+ let results = df. collect ( ) . await ?;
1345
+
1346
+ assert_eq ! ( results. len( ) , 1 ) ;
1347
+ assert_eq ! (
1348
+ results[ 0 ]
1349
+ . column( 0 )
1350
+ . as_any( )
1351
+ . downcast_ref:: <Float64Array >( )
1352
+ . unwrap( )
1353
+ . value( 0 ) ,
1354
+ 10.0
1355
+ ) ;
1356
+
1357
+ Ok ( ( ) )
1358
+ }
1359
+
1360
+ #[ tokio:: test]
1361
+ async fn avg_aggregation_exec ( ) -> Result < ( ) > {
1362
+ let optimizer = test_optimizer ( ) . await ?;
1363
+ optimizer
1364
+ . build_index ( IndexBuilder :: with_col_and_aggregate (
1365
+ "agg_col" ,
1366
+ AggregateType :: Avg ,
1367
+ ) )
1368
+ . await ?;
1369
+
1370
+ let temporal_filter = col ( "timestamp" )
1371
+ . gt_eq ( lit ( "2024-05-10T00:00:00Z" ) )
1372
+ . and ( col ( "timestamp" ) . lt ( lit ( "2024-05-10T00:00:10Z" ) ) ) ;
1373
+
1374
+ let plan =
1375
+ LogicalPlanBuilder :: scan ( "test" , provider_as_source ( optimizer. provider ( ) ) , None ) ?
1376
+ . filter ( temporal_filter) ?
1377
+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ avg( col( "agg_col" ) ) ] ) ?
1378
+ . project ( vec ! [ avg( col( "agg_col" ) ) ] ) ?
1379
+ . build ( ) ?;
1380
+
1381
+ let ctx = SessionContext :: new ( ) ;
1382
+ ctx. register_table ( "test" , optimizer. provider ( ) . clone ( ) ) ?;
1383
+
1384
+ // Set UWheelOptimizer as optimizer rule
1385
+ let session_state = ctx. state ( ) . with_optimizer_rules ( vec ! [ optimizer. clone( ) ] ) ;
1386
+ let uwheel_ctx = SessionContext :: new_with_state ( session_state) ;
1387
+
1388
+ // Run the query through the ctx that has our OptimizerRule
1389
+ let df = uwheel_ctx. execute_logical_plan ( plan) . await ?;
1390
+ let results = df. collect ( ) . await ?;
1391
+
1392
+ assert_eq ! ( results. len( ) , 1 ) ;
1393
+ assert_eq ! (
1394
+ results[ 0 ]
1395
+ . column( 0 )
1396
+ . as_any( )
1397
+ . downcast_ref:: <Float64Array >( )
1398
+ . unwrap( )
1399
+ . value( 0 ) ,
1400
+ 5.5 ,
1401
+ ) ;
1402
+
1403
+ Ok ( ( ) )
1404
+ }
1405
+ }
0 commit comments