Skip to content

Commit 234ae67

Browse files
committedAug 16, 2024
chore: add initial rewrite and exec tests, addresses #9
1 parent 13cd0fd commit 234ae67

File tree

1 file changed

+510
-26
lines changed

1 file changed

+510
-26
lines changed
 

‎datafusion-uwheel/src/lib.rs

+510-26
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@ use std::{
33
sync::{Arc, Mutex},
44
};
55

6-
use chrono::{DateTime, NaiveDate, TimeZone, Utc};
7-
use datafusion::error::Result;
6+
use chrono::{DateTime, Utc};
87
use datafusion::{
98
arrow::{
10-
array::{ArrayRef, Float64Array, Int64Array, RecordBatch, TimestampMicrosecondArray},
9+
array::{Float64Array, Int64Array, RecordBatch, TimestampMicrosecondArray},
1110
datatypes::{DataType, SchemaRef},
1211
},
1312
common::{tree_node::Transformed, DFSchema, DFSchemaRef},
@@ -21,6 +20,7 @@ use datafusion::{
2120
scalar::ScalarValue,
2221
sql::TableReference,
2322
};
23+
use datafusion::{error::Result, logical_expr::expr::AggregateFunction};
2424
use expr::{
2525
extract_filter_expr, extract_uwheel_expr, extract_wheel_range, MinMaxFilter, UWheelExpr,
2626
};
@@ -341,11 +341,11 @@ impl UWheelOptimizer {
341341
match agg_expr {
342342
// COUNT(*)
343343
Expr::Alias(alias) if alias.name == COUNT_STAR_ALIAS => {
344-
// extract possible time range filter
345-
let range = extract_wheel_range(&filter.predicate, &self.time_column)?;
346-
let count = self.count(range)?; // early return if range is not queryable
347-
let schema = Arc::new(plan.schema().clone().as_arrow().clone());
348-
count_scan(count, schema).ok()
344+
self.try_count_rewrite(filter, plan)
345+
}
346+
// Also check
347+
Expr::AggregateFunction(agg) if is_count_star_aggregate(agg) => {
348+
self.try_count_rewrite(filter, plan)
349349
}
350350
// Single Aggregate Function (e.g., SUM(col))
351351
Expr::AggregateFunction(agg) if agg.args.len() == 1 => {
@@ -398,6 +398,13 @@ impl UWheelOptimizer {
398398
}
399399
}
400400

401+
fn try_count_rewrite(&self, filter: &Filter, plan: &LogicalPlan) -> Option<LogicalPlan> {
402+
let range = extract_wheel_range(&filter.predicate, &self.time_column)?;
403+
let count = self.count(range)?; // early return if range is not queryable
404+
let schema = Arc::new(plan.schema().clone().as_arrow().clone());
405+
count_scan(count, schema).ok()
406+
}
407+
401408
// Queries the range using the count wheel, returning a empty table scan if the count is 0
402409
// avoiding the need to generate a regular execution plan..
403410
fn maybe_count_filter(&self, range: WheelRange, plan: &LogicalPlan) -> Option<LogicalPlan> {
@@ -481,9 +488,8 @@ impl UWheelOptimizer {
481488
}
482489

483490
fn count_scan(count: u32, schema: SchemaRef) -> Result<LogicalPlan> {
484-
let name = COUNT_STAR_ALIAS.to_string();
485491
let data = Int64Array::from(vec![count as i64]);
486-
let record_batch = RecordBatch::try_from_iter(vec![(&name, Arc::new(data) as ArrayRef)])?;
492+
let record_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(data)])?;
487493
let df_schema = Arc::new(DFSchema::try_from(schema.clone())?);
488494
let mem_table = MemTable::try_new(schema, vec![vec![record_batch]])?;
489495

@@ -551,6 +557,7 @@ fn func_def_to_aggregate_type(func_def: &AggregateFunctionDefinition) -> Option<
551557
AggregateFunctionDefinition::BuiltIn(datafusion::logical_expr::AggregateFunction::Min) => {
552558
Some(AggregateType::Min)
553559
}
560+
AggregateFunctionDefinition::UDF(udf) if udf.name() == "avg" => Some(AggregateType::Avg),
554561
AggregateFunctionDefinition::UDF(udf) if udf.name() == "sum" => Some(AggregateType::Sum),
555562
AggregateFunctionDefinition::UDF(udf) if udf.name() == "count" => {
556563
Some(AggregateType::Count)
@@ -599,6 +606,30 @@ fn mem_table_as_table_scan(table: MemTable, original_schema: DFSchemaRef) -> Res
599606
Ok(LogicalPlan::TableScan(table_scan))
600607
}
601608

609+
fn is_wildcard(expr: &Expr) -> bool {
610+
matches!(expr, Expr::Wildcard { qualifier: None })
611+
}
612+
613+
/// Determines if the given aggregate function is a COUNT(*) aggregate.
614+
///
615+
/// An aggregate function is a COUNT(*) aggregate if its function name is "COUNT" and it either has a single argument that is a wildcard (`*`), or it has no arguments.
616+
///
617+
/// # Arguments
618+
///
619+
/// * `aggregate_function` - The aggregate function to check.
620+
///
621+
/// # Returns
622+
///
623+
/// `true` if the aggregate function is a COUNT(*) aggregate, `false` otherwise.
624+
fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
625+
matches!(aggregate_function,
626+
AggregateFunction {
627+
func_def,
628+
args,
629+
..
630+
} if func_def.name() == "COUNT" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
631+
}
632+
602633
// Helper methods to build the UWheelOptimizer
603634

604635
// Uses the provided TableProvider to build the UWheelOptimizer
@@ -666,18 +697,13 @@ async fn build(
666697
async fn build_min_max_wheel(
667698
schema: SchemaRef,
668699
batches: &[RecordBatch],
669-
_min_timestamp_ms: u64,
700+
min_timestamp_ms: u64,
670701
max_timestamp_ms: u64,
671702
time_col: &str,
672703
min_max_col: &str,
673704
haw_conf: &HawConf,
674705
) -> Result<ReaderWheel<F64MinMaxAggregator>> {
675-
// TODO: remove hardcoded time
676-
let start = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
677-
let date = Utc.from_utc_datetime(&start.and_hms_opt(0, 0, 0).unwrap());
678-
let start_ms = date.timestamp_millis() as u64;
679-
680-
let conf = haw_conf.with_watermark(start_ms);
706+
let conf = haw_conf.with_watermark(min_timestamp_ms);
681707

682708
let mut wheel: RwWheel<F64MinMaxAggregator> = RwWheel::with_conf(
683709
Conf::default()
@@ -794,13 +820,15 @@ where
794820
wheel.insert(entry);
795821
}
796822
}
797-
// Once all data is inserted, advance the wheel to the max timestamp
798-
wheel.advance_to(end_ms);
823+
// Once all data is inserted, advance the wheel to the max timestamp + 1 second
824+
wheel.advance_to(end_ms + 1000);
799825

800826
// convert wheel to index
801827
wheel.read().to_simd_wheels();
802828
// TODO: make this configurable
803-
wheel.read().to_prefix_wheels();
829+
if A::invertible() {
830+
wheel.read().to_prefix_wheels();
831+
}
804832
} else {
805833
// TODO: return Datafusion Error?
806834
panic!("Min/Max column must be a numeric type");
@@ -827,11 +855,7 @@ fn build_count_wheel(
827855
.unwrap()
828856
.timestamp_millis() as u64;
829857

830-
let start = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
831-
let date = Utc.from_utc_datetime(&start.and_hms_opt(0, 0, 0).unwrap());
832-
let start_ms = date.timestamp_millis() as u64;
833-
834-
let conf = haw_conf.with_watermark(start_ms);
858+
let conf = haw_conf.with_watermark(min_ms);
835859

836860
let mut count_wheel: RwWheel<U32SumAggregator> = RwWheel::with_conf(
837861
Conf::default()
@@ -843,12 +867,13 @@ fn build_count_wheel(
843867
let timestamp_ms = DateTime::from_timestamp_micros(timestamp)
844868
.unwrap()
845869
.timestamp_millis() as u64;
870+
846871
// Record a count
847872
let entry = Entry::new(1, timestamp_ms);
848873
count_wheel.insert(entry);
849874
}
850875

851-
count_wheel.advance_to(max_ms);
876+
count_wheel.advance_to(max_ms + 1000); // + 1 second
852877

853878
// convert wheel to index
854879
count_wheel.read().to_simd_wheels();
@@ -919,3 +944,462 @@ fn scalar_to_timestamp(scalar: &ScalarValue) -> Option<i64> {
919944
_ => None,
920945
}
921946
}
947+
948+
#[cfg(test)]
949+
mod tests {
950+
use chrono::Duration;
951+
use chrono::TimeZone;
952+
use datafusion::arrow::datatypes::{Field, Schema, TimeUnit};
953+
use datafusion::functions_aggregate::expr_fn::avg;
954+
use datafusion::logical_expr::test::function_stub::{count, sum};
955+
956+
use super::*;
957+
use builder::Builder;
958+
959+
fn create_test_memtable() -> Result<MemTable> {
960+
let schema = Arc::new(Schema::new(vec![
961+
Field::new(
962+
"timestamp",
963+
DataType::Timestamp(TimeUnit::Microsecond, None),
964+
false,
965+
),
966+
Field::new("agg_col", DataType::Float64, false),
967+
]));
968+
969+
// Define the start time as 2024-05-10 00:00:00 UTC
970+
let base_time: DateTime<Utc> = Utc.with_ymd_and_hms(2024, 5, 10, 0, 0, 0).unwrap();
971+
let timestamps: Vec<i64> = (0..10)
972+
.map(|i| base_time + Duration::seconds(i))
973+
.map(|dt| dt.timestamp_micros())
974+
.collect();
975+
976+
let agg_values: Vec<f64> = (0..10).map(|i| (i + 1) as f64).collect();
977+
978+
let batch = RecordBatch::try_new(
979+
schema.clone(),
980+
vec![
981+
Arc::new(TimestampMicrosecondArray::from(timestamps)),
982+
Arc::new(Float64Array::from(agg_values)),
983+
],
984+
)?;
985+
986+
MemTable::try_new(schema, vec![vec![batch]])
987+
}
988+
989+
#[tokio::test]
990+
async fn create_optimizer_with_memtable() {
991+
let provider = Arc::new(create_test_memtable().unwrap());
992+
assert!(Builder::new("timestamp")
993+
.with_name("test")
994+
.build_with_provider(provider)
995+
.await
996+
.is_ok());
997+
}
998+
999+
async fn test_optimizer() -> Result<Arc<UWheelOptimizer>> {
1000+
let provider = Arc::new(create_test_memtable()?);
1001+
Ok(Arc::new(
1002+
Builder::new("timestamp")
1003+
.with_name("test")
1004+
.build_with_provider(provider)
1005+
.await?,
1006+
))
1007+
}
1008+
1009+
#[tokio::test]
1010+
async fn count_star_aggregation_rewrite() -> Result<()> {
1011+
let optimizer = test_optimizer().await?;
1012+
let temporal_filter = col("timestamp")
1013+
.gt_eq(lit("2024-05-10T00:00:00Z"))
1014+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
1015+
1016+
let plan =
1017+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1018+
.filter(temporal_filter)?
1019+
.aggregate(Vec::<Expr>::new(), vec![count(wildcard())])?
1020+
.project(vec![count(wildcard())])?
1021+
.build()?;
1022+
1023+
// Assert that the original plan is a Projection
1024+
assert!(matches!(plan, LogicalPlan::Projection(_)));
1025+
1026+
let rewritten = optimizer.try_rewrite(&plan).unwrap();
1027+
// assert it was rewritten to a TableScan
1028+
assert!(matches!(rewritten, LogicalPlan::TableScan(_)));
1029+
1030+
Ok(())
1031+
}
1032+
1033+
#[tokio::test]
1034+
async fn sum_aggregation_rewrite() -> Result<()> {
1035+
let optimizer = test_optimizer().await?;
1036+
1037+
// Build a sum index
1038+
optimizer
1039+
.build_index(IndexBuilder::with_col_and_aggregate(
1040+
"agg_col",
1041+
AggregateType::Sum,
1042+
))
1043+
.await?;
1044+
1045+
let temporal_filter = col("timestamp")
1046+
.gt_eq(lit("2024-05-10T00:00:00Z"))
1047+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
1048+
1049+
let plan =
1050+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1051+
.filter(temporal_filter)?
1052+
.aggregate(Vec::<Expr>::new(), vec![sum(col("agg_col"))])?
1053+
.project(vec![sum(col("agg_col"))])?
1054+
.build()?;
1055+
1056+
// Assert that the original plan is a Projection
1057+
assert!(matches!(plan, LogicalPlan::Projection(_)));
1058+
1059+
let rewritten = optimizer.try_rewrite(&plan).unwrap();
1060+
// assert it was rewritten to a TableScan
1061+
assert!(matches!(rewritten, LogicalPlan::TableScan(_)));
1062+
1063+
Ok(())
1064+
}
1065+
1066+
#[tokio::test]
1067+
async fn min_aggregation_rewrite() -> Result<()> {
1068+
let optimizer = test_optimizer().await?;
1069+
1070+
// Build a min index
1071+
optimizer
1072+
.build_index(IndexBuilder::with_col_and_aggregate(
1073+
"agg_col",
1074+
AggregateType::Min,
1075+
))
1076+
.await?;
1077+
1078+
let temporal_filter = col("timestamp")
1079+
.gt_eq(lit("2024-05-10T00:00:00Z"))
1080+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
1081+
1082+
let plan =
1083+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1084+
.filter(temporal_filter)?
1085+
.aggregate(Vec::<Expr>::new(), vec![min(col("agg_col"))])?
1086+
.project(vec![min(col("agg_col"))])?
1087+
.build()?;
1088+
1089+
// Assert that the original plan is a Projection
1090+
assert!(matches!(plan, LogicalPlan::Projection(_)));
1091+
1092+
let rewritten = optimizer.try_rewrite(&plan).unwrap();
1093+
// assert it was rewritten to a TableScan
1094+
assert!(matches!(rewritten, LogicalPlan::TableScan(_)));
1095+
1096+
Ok(())
1097+
}
1098+
1099+
#[tokio::test]
1100+
async fn max_aggregation_rewrite() -> Result<()> {
1101+
let optimizer = test_optimizer().await?;
1102+
1103+
optimizer
1104+
.build_index(IndexBuilder::with_col_and_aggregate(
1105+
"agg_col",
1106+
AggregateType::Max,
1107+
))
1108+
.await?;
1109+
1110+
let temporal_filter = col("timestamp")
1111+
.gt_eq(lit("2024-05-10T00:00:00Z"))
1112+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
1113+
1114+
let plan =
1115+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1116+
.filter(temporal_filter)?
1117+
.aggregate(Vec::<Expr>::new(), vec![max(col("agg_col"))])?
1118+
.project(vec![max(col("agg_col"))])?
1119+
.build()?;
1120+
1121+
// Assert that the original plan is a Projection
1122+
assert!(matches!(plan, LogicalPlan::Projection(_)));
1123+
1124+
let rewritten = optimizer.try_rewrite(&plan).unwrap();
1125+
// assert it was rewritten to a TableScan
1126+
assert!(matches!(rewritten, LogicalPlan::TableScan(_)));
1127+
1128+
Ok(())
1129+
}
1130+
1131+
#[tokio::test]
1132+
async fn avg_aggregation_rewrite() -> Result<()> {
1133+
let optimizer = test_optimizer().await?;
1134+
1135+
optimizer
1136+
.build_index(IndexBuilder::with_col_and_aggregate(
1137+
"agg_col",
1138+
AggregateType::Avg,
1139+
))
1140+
.await?;
1141+
1142+
let temporal_filter = col("timestamp")
1143+
.gt_eq(lit("2024-05-10T00:00:00Z"))
1144+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
1145+
1146+
let plan =
1147+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1148+
.filter(temporal_filter)?
1149+
.aggregate(Vec::<Expr>::new(), vec![avg(col("agg_col"))])?
1150+
.project(vec![avg(col("agg_col"))])?
1151+
.build()?;
1152+
1153+
// Assert that the original plan is a Projection
1154+
assert!(matches!(plan, LogicalPlan::Projection(_)));
1155+
1156+
let rewritten = optimizer.try_rewrite(&plan).unwrap();
1157+
// assert it was rewritten to a TableScan
1158+
assert!(matches!(rewritten, LogicalPlan::TableScan(_)));
1159+
1160+
Ok(())
1161+
}
1162+
1163+
#[tokio::test]
1164+
async fn count_star_aggregation_invalid_rewrite() -> Result<()> {
1165+
let optimizer = test_optimizer().await?;
1166+
// invalid temporal filter
1167+
let temporal_filter = col("timestamp")
1168+
.gt_eq(lit("2024-05-11T00:00:00Z"))
1169+
.and(col("timestamp").lt(lit("2024-05-11T00:00:10Z")));
1170+
1171+
let plan =
1172+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1173+
.filter(temporal_filter)?
1174+
.aggregate(Vec::<Expr>::new(), vec![count(wildcard())])?
1175+
.project(vec![count(wildcard())])?
1176+
.build()?;
1177+
1178+
assert!(optimizer.try_rewrite(&plan).is_none());
1179+
1180+
Ok(())
1181+
}
1182+
1183+
#[tokio::test]
1184+
async fn count_star_aggregation_exec() -> Result<()> {
1185+
let optimizer = test_optimizer().await?;
1186+
let temporal_filter = col("timestamp")
1187+
.gt_eq(lit("2024-05-10T00:00:00Z"))
1188+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
1189+
1190+
let plan =
1191+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1192+
.filter(temporal_filter)?
1193+
.aggregate(Vec::<Expr>::new(), vec![count(wildcard())])?
1194+
.project(vec![count(wildcard())])?
1195+
.build()?;
1196+
1197+
let ctx = SessionContext::new();
1198+
ctx.register_table("test", optimizer.provider().clone())?;
1199+
1200+
// Set UWheelOptimizer as optimizer rule
1201+
let session_state = ctx.state().with_optimizer_rules(vec![optimizer.clone()]);
1202+
let uwheel_ctx = SessionContext::new_with_state(session_state);
1203+
1204+
// Run the query through the ctx that has our OptimizerRule
1205+
let df = uwheel_ctx.execute_logical_plan(plan).await?;
1206+
let results = df.collect().await?;
1207+
1208+
assert_eq!(results.len(), 1);
1209+
assert_eq!(
1210+
results[0]
1211+
.column(0)
1212+
.as_any()
1213+
.downcast_ref::<Int64Array>()
1214+
.unwrap()
1215+
.value(0),
1216+
10,
1217+
);
1218+
1219+
Ok(())
1220+
}
1221+
1222+
#[tokio::test]
1223+
async fn sum_aggregation_exec() -> Result<()> {
1224+
let optimizer = test_optimizer().await?;
1225+
optimizer
1226+
.build_index(IndexBuilder::with_col_and_aggregate(
1227+
"agg_col",
1228+
AggregateType::Sum,
1229+
))
1230+
.await?;
1231+
1232+
let temporal_filter = col("timestamp")
1233+
.gt_eq(lit("2024-05-10T00:00:00Z"))
1234+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
1235+
1236+
let plan =
1237+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1238+
.filter(temporal_filter)?
1239+
.aggregate(Vec::<Expr>::new(), vec![sum(col("agg_col"))])?
1240+
.project(vec![sum(col("agg_col"))])?
1241+
.build()?;
1242+
1243+
let ctx = SessionContext::new();
1244+
ctx.register_table("test", optimizer.provider().clone())?;
1245+
1246+
// Set UWheelOptimizer as optimizer rule
1247+
let session_state = ctx.state().with_optimizer_rules(vec![optimizer.clone()]);
1248+
let uwheel_ctx = SessionContext::new_with_state(session_state);
1249+
1250+
// Run the query through the ctx that has our OptimizerRule
1251+
let df = uwheel_ctx.execute_logical_plan(plan).await?;
1252+
let results = df.collect().await?;
1253+
1254+
assert_eq!(results.len(), 1);
1255+
assert_eq!(
1256+
results[0]
1257+
.column(0)
1258+
.as_any()
1259+
.downcast_ref::<Float64Array>()
1260+
.unwrap()
1261+
.value(0),
1262+
55.0 // 1 + 2 +3 + ... + 9 + 10 = 55
1263+
);
1264+
1265+
Ok(())
1266+
}
1267+
1268+
#[tokio::test]
1269+
async fn min_aggregation_exec() -> Result<()> {
1270+
let optimizer = test_optimizer().await?;
1271+
optimizer
1272+
.build_index(IndexBuilder::with_col_and_aggregate(
1273+
"agg_col",
1274+
AggregateType::Min,
1275+
))
1276+
.await?;
1277+
1278+
let temporal_filter = col("timestamp")
1279+
.gt_eq(lit("2024-05-10T00:00:00Z"))
1280+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
1281+
1282+
let plan =
1283+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1284+
.filter(temporal_filter)?
1285+
.aggregate(Vec::<Expr>::new(), vec![min(col("agg_col"))])?
1286+
.project(vec![min(col("agg_col"))])?
1287+
.build()?;
1288+
1289+
let ctx = SessionContext::new();
1290+
ctx.register_table("test", optimizer.provider().clone())?;
1291+
1292+
// Set UWheelOptimizer as optimizer rule
1293+
let session_state = ctx.state().with_optimizer_rules(vec![optimizer.clone()]);
1294+
let uwheel_ctx = SessionContext::new_with_state(session_state);
1295+
1296+
// Run the query through the ctx that has our OptimizerRule
1297+
let df = uwheel_ctx.execute_logical_plan(plan).await?;
1298+
let results = df.collect().await?;
1299+
1300+
assert_eq!(results.len(), 1);
1301+
assert_eq!(
1302+
results[0]
1303+
.column(0)
1304+
.as_any()
1305+
.downcast_ref::<Float64Array>()
1306+
.unwrap()
1307+
.value(0),
1308+
1.0
1309+
);
1310+
1311+
Ok(())
1312+
}
1313+
1314+
#[tokio::test]
1315+
async fn max_aggregation_exec() -> Result<()> {
1316+
let optimizer = test_optimizer().await?;
1317+
optimizer
1318+
.build_index(IndexBuilder::with_col_and_aggregate(
1319+
"agg_col",
1320+
AggregateType::Max,
1321+
))
1322+
.await?;
1323+
1324+
let temporal_filter = col("timestamp")
1325+
.gt_eq(lit("2024-05-10T00:00:00Z"))
1326+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
1327+
1328+
let plan =
1329+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1330+
.filter(temporal_filter)?
1331+
.aggregate(Vec::<Expr>::new(), vec![max(col("agg_col"))])?
1332+
.project(vec![max(col("agg_col"))])?
1333+
.build()?;
1334+
1335+
let ctx = SessionContext::new();
1336+
ctx.register_table("test", optimizer.provider().clone())?;
1337+
1338+
// Set UWheelOptimizer as optimizer rule
1339+
let session_state = ctx.state().with_optimizer_rules(vec![optimizer.clone()]);
1340+
let uwheel_ctx = SessionContext::new_with_state(session_state);
1341+
1342+
// Run the query through the ctx that has our OptimizerRule
1343+
let df = uwheel_ctx.execute_logical_plan(plan).await?;
1344+
let results = df.collect().await?;
1345+
1346+
assert_eq!(results.len(), 1);
1347+
assert_eq!(
1348+
results[0]
1349+
.column(0)
1350+
.as_any()
1351+
.downcast_ref::<Float64Array>()
1352+
.unwrap()
1353+
.value(0),
1354+
10.0
1355+
);
1356+
1357+
Ok(())
1358+
}
1359+
1360+
#[tokio::test]
1361+
async fn avg_aggregation_exec() -> Result<()> {
1362+
let optimizer = test_optimizer().await?;
1363+
optimizer
1364+
.build_index(IndexBuilder::with_col_and_aggregate(
1365+
"agg_col",
1366+
AggregateType::Avg,
1367+
))
1368+
.await?;
1369+
1370+
let temporal_filter = col("timestamp")
1371+
.gt_eq(lit("2024-05-10T00:00:00Z"))
1372+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
1373+
1374+
let plan =
1375+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
1376+
.filter(temporal_filter)?
1377+
.aggregate(Vec::<Expr>::new(), vec![avg(col("agg_col"))])?
1378+
.project(vec![avg(col("agg_col"))])?
1379+
.build()?;
1380+
1381+
let ctx = SessionContext::new();
1382+
ctx.register_table("test", optimizer.provider().clone())?;
1383+
1384+
// Set UWheelOptimizer as optimizer rule
1385+
let session_state = ctx.state().with_optimizer_rules(vec![optimizer.clone()]);
1386+
let uwheel_ctx = SessionContext::new_with_state(session_state);
1387+
1388+
// Run the query through the ctx that has our OptimizerRule
1389+
let df = uwheel_ctx.execute_logical_plan(plan).await?;
1390+
let results = df.collect().await?;
1391+
1392+
assert_eq!(results.len(), 1);
1393+
assert_eq!(
1394+
results[0]
1395+
.column(0)
1396+
.as_any()
1397+
.downcast_ref::<Float64Array>()
1398+
.unwrap()
1399+
.value(0),
1400+
5.5,
1401+
);
1402+
1403+
Ok(())
1404+
}
1405+
}

0 commit comments

Comments
 (0)
Please sign in to comment.