Skip to content

Commit 2d9a5bf

Browse files
authored
Merge pull request #26 from LYZJU2019/group-by-multiple-agg
Support group by & multiple aggregations
2 parents 860a735 + e802b9b commit 2d9a5bf

File tree

1 file changed

+297
-43
lines changed

1 file changed

+297
-43
lines changed

datafusion-uwheel/src/lib.rs

+297-43
Original file line numberDiff line numberDiff line change
@@ -337,53 +337,165 @@ impl UWheelOptimizer {
337337
return None;
338338
};
339339

340-
let (wheel_range, _) = extract_filter_expr(&filter.predicate, &self.time_column)?;
340+
let (wheel_range, expr_key) =
341+
match extract_filter_expr(&filter.predicate, &self.time_column)? {
342+
(range, Some(expr)) => (range, maybe_replace_table_name(&expr, &self.name)),
343+
(range, None) => (range, STAR_AGGREGATION_ALIAS.to_string()),
344+
};
341345

342346
match group_expr {
343347
Expr::ScalarFunction(func) if func.name() == "date_trunc" => {
344348
let interval = func.args.first()?;
345349
if let Expr::Literal(ScalarValue::Utf8(duration)) = interval {
346-
match duration.as_ref()?.as_str() {
347-
"second" => {
348-
unimplemented!("date_trunc('second') group by is not supported")
349-
}
350-
"minute" => {
351-
unimplemented!("date_trunc('minute') group by is not supported")
352-
}
353-
"hour" => {
354-
unimplemented!("date_trunc('hour') group by is not supported")
355-
}
356-
"day" => {
357-
let res = self
358-
.wheels
359-
.count
360-
.group_by(wheel_range, Duration::DAY)
361-
.unwrap_or_default()
362-
.iter()
363-
.map(|(k, v)| ((*k * 1_000) as i64, *v as i64)) // transform milliseconds to microseconds by multiplying by 1_000
364-
.collect();
365-
366-
let schema = Arc::new(plan.schema().clone().as_arrow().clone());
367-
368-
return uwheel_group_by_to_table_scan(res, schema).ok();
369-
}
370-
"week" => {
371-
unimplemented!("date_trunc('week') group by is not supported")
372-
}
373-
"month" => {
374-
unimplemented!("date_trunc('month') group by is not supported")
375-
}
376-
"year" => {
377-
unimplemented!("date_trunc('year') group by is not supported")
350+
let group_by_interval = match duration.as_ref()?.as_str() {
351+
"second" => Duration::SECOND,
352+
"minute" => Duration::MINUTE,
353+
"hour" => Duration::HOUR,
354+
"day" => Duration::DAY,
355+
"week" => Duration::WEEK,
356+
_ => return None,
357+
};
358+
359+
let mut group_agg_result = Vec::new();
360+
let mut group_col = None;
361+
362+
let mut count_idx = None;
363+
364+
for (idx, agg) in agg.aggr_expr.iter().enumerate() {
365+
let (agg_type, col) = match agg {
366+
// COUNT(*)
367+
Expr::AggregateFunction(agg)
368+
if is_count_star_aggregate(agg) =>
369+
{
370+
(UWheelAggregate::Count, None)
371+
}
372+
373+
// COUNT(*)
374+
Expr::Alias(alias) if alias.name == COUNT_STAR_ALIAS => {
375+
(UWheelAggregate::Count, None)
376+
}
377+
378+
Expr::AggregateFunction(agg) => {
379+
if agg.args.len() > 1 {
380+
return None;
381+
}
382+
let col = match &agg.args[0] {
383+
Expr::Column(col) => col,
384+
_ => return None,
385+
};
386+
(func_def_to_aggregate_type(&agg.func)?, Some(col))
387+
}
388+
389+
_ => return None,
390+
};
391+
392+
let res = match agg_type {
393+
UWheelAggregate::Count => {
394+
count_idx = Some(idx);
395+
self.wheels
396+
.count
397+
.group_by(wheel_range, group_by_interval)
398+
.unwrap_or_default()
399+
.iter()
400+
.map(|(k, v)| (*k, *v as f64))
401+
.collect()
402+
}
403+
404+
UWheelAggregate::Avg => {
405+
let wheel_key = format!(
406+
"{}.{}.{}",
407+
self.name,
408+
col.unwrap().name,
409+
expr_key
410+
);
411+
self.wheels
412+
.avg
413+
.lock()
414+
.unwrap()
415+
.get(&wheel_key)?
416+
.group_by(wheel_range, group_by_interval)
417+
.unwrap_or_default()
418+
}
419+
420+
UWheelAggregate::Min => {
421+
let wheel_key = format!(
422+
"{}.{}.{}",
423+
self.name,
424+
col.unwrap().name,
425+
expr_key
426+
);
427+
self.wheels
428+
.min
429+
.lock()
430+
.unwrap()
431+
.get(&wheel_key)?
432+
.group_by(wheel_range, group_by_interval)
433+
.unwrap_or_default()
434+
}
435+
436+
UWheelAggregate::Max => {
437+
let wheel_key = format!(
438+
"{}.{}.{}",
439+
self.name,
440+
col.unwrap().name,
441+
expr_key
442+
);
443+
self.wheels
444+
.max
445+
.lock()
446+
.unwrap()
447+
.get(&wheel_key)?
448+
.group_by(wheel_range, group_by_interval)
449+
.unwrap_or_default()
450+
}
451+
452+
UWheelAggregate::Sum => {
453+
let wheel_key = format!(
454+
"{}.{}.{}",
455+
self.name,
456+
col.unwrap().name,
457+
expr_key
458+
);
459+
self.wheels
460+
.sum
461+
.lock()
462+
.unwrap()
463+
.get(&wheel_key)?
464+
.group_by(wheel_range, group_by_interval)
465+
.unwrap_or_default()
466+
}
467+
468+
_ => return None,
469+
};
470+
471+
if group_col.is_none() {
472+
group_col = Some(
473+
res.iter()
474+
.map(|(k, _)| (*k * 1_000) as i64)
475+
.collect::<Vec<_>>(),
476+
);
378477
}
379-
_ => {}
478+
479+
group_agg_result
480+
.push(res.iter().map(|(_, v)| *v).collect::<Vec<_>>());
380481
}
482+
483+
group_col.as_ref()?;
484+
485+
let schema = Arc::new(plan.schema().clone().as_arrow().clone());
486+
487+
return uwheel_group_by_to_table_scan(
488+
group_col.unwrap(),
489+
group_agg_result,
490+
count_idx,
491+
schema,
492+
)
493+
.ok();
381494
}
382495
}
383-
_ => {
384-
unimplemented!("We only support scalar function date_trunc for group by expression now")
385-
}
496+
_ => return None,
386497
}
498+
387499
None
388500
}
389501

@@ -622,16 +734,26 @@ fn uwheel_agg_to_table_scan(result: f64, schema: SchemaRef) -> Result<LogicalPla
622734
// Converts a uwheel group by result to a TableScan with a MemTable as source
623735
// currently only supports timestamp group by
624736
fn uwheel_group_by_to_table_scan(
625-
result: Vec<(i64, i64)>,
737+
group_col: Vec<i64>,
738+
agg_result: Vec<Vec<f64>>,
739+
count_idx: Option<usize>,
626740
schema: SchemaRef,
627741
) -> Result<LogicalPlan> {
628-
let group_by =
629-
TimestampMicrosecondArray::from(result.iter().map(|(k, _)| *k).collect::<Vec<_>>());
742+
let group_by = TimestampMicrosecondArray::from(group_col);
630743

631-
let agg = Int64Array::from(result.iter().map(|(_, v)| *v).collect::<Vec<_>>());
744+
let mut columns = vec![Arc::new(group_by) as Arc<dyn Array>];
745+
746+
for (idx, result) in agg_result.into_iter().enumerate() {
747+
if count_idx.is_some() && idx == count_idx.unwrap() {
748+
let data = Int64Array::from(result.iter().map(|v| *v as i64).collect::<Vec<_>>());
749+
columns.push(Arc::new(data) as Arc<dyn Array>);
750+
} else {
751+
let data = Float64Array::from(result);
752+
columns.push(Arc::new(data) as Arc<dyn Array>);
753+
}
754+
}
632755

633-
let record_batch =
634-
RecordBatch::try_new(schema.clone(), vec![Arc::new(group_by), Arc::new(agg)])?;
756+
let record_batch = RecordBatch::try_new(schema.clone(), columns)?;
635757

636758
let df_schema = Arc::new(DFSchema::try_from(schema.clone())?);
637759
let mem_table = MemTable::try_new(schema, vec![vec![record_batch]])?;
@@ -1898,4 +2020,136 @@ mod tests {
18982020

18992021
Ok(())
19002022
}
2023+
2024+
#[tokio::test]
2025+
async fn group_by_multiple_aggregation_rewrite() -> Result<()> {
2026+
let optimizer = test_optimizer().await?;
2027+
2028+
optimizer
2029+
.build_index(IndexBuilder::with_col_and_aggregate(
2030+
"agg_col",
2031+
UWheelAggregate::Avg,
2032+
))
2033+
.await?;
2034+
2035+
optimizer
2036+
.build_index(IndexBuilder::with_col_and_aggregate(
2037+
"agg_col",
2038+
UWheelAggregate::Sum,
2039+
))
2040+
.await?;
2041+
2042+
let temporal_filter = col("timestamp")
2043+
.gt_eq(lit("2024-05-10T00:00:00Z"))
2044+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
2045+
2046+
let plan =
2047+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
2048+
.filter(temporal_filter)?
2049+
.aggregate(
2050+
vec![date_trunc(lit("day"), col("timestamp"))], // GROUP BY date_trunc('day', timestamp)
2051+
vec![sum(col("agg_col")), avg(col("agg_col")), count(wildcard())],
2052+
)?
2053+
.project(vec![
2054+
date_trunc(lit("day"), col("timestamp")),
2055+
sum(col("agg_col")),
2056+
avg(col("agg_col")),
2057+
count(wildcard()),
2058+
])?
2059+
.build()?;
2060+
2061+
// Assert that the original plan is a Projection
2062+
assert!(matches!(plan, LogicalPlan::Projection(_)));
2063+
2064+
let rewritten = optimizer.try_rewrite(&plan).unwrap();
2065+
// assert it was rewritten to a TableScan
2066+
assert!(matches!(rewritten, LogicalPlan::TableScan(_)));
2067+
2068+
Ok(())
2069+
}
2070+
2071+
#[tokio::test]
2072+
async fn group_by_multiple_aggregation_exec() -> Result<()> {
2073+
let optimizer = test_optimizer().await?;
2074+
2075+
optimizer
2076+
.build_index(IndexBuilder::with_col_and_aggregate(
2077+
"agg_col",
2078+
UWheelAggregate::Avg,
2079+
))
2080+
.await?;
2081+
2082+
optimizer
2083+
.build_index(IndexBuilder::with_col_and_aggregate(
2084+
"agg_col",
2085+
UWheelAggregate::Sum,
2086+
))
2087+
.await?;
2088+
2089+
let temporal_filter = col("timestamp")
2090+
.gt_eq(lit("2024-05-10T00:00:00Z"))
2091+
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));
2092+
2093+
let plan =
2094+
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
2095+
.filter(temporal_filter)?
2096+
.aggregate(
2097+
vec![date_trunc(lit("day"), col("timestamp"))], // GROUP BY date_trunc('day', timestamp)
2098+
vec![sum(col("agg_col")), avg(col("agg_col")), count(wildcard())],
2099+
)?
2100+
.project(vec![
2101+
date_trunc(lit("day"), col("timestamp")),
2102+
sum(col("agg_col")),
2103+
avg(col("agg_col")),
2104+
count(wildcard()),
2105+
])?
2106+
.build()?;
2107+
2108+
let ctx = SessionContext::new();
2109+
ctx.register_table("test", optimizer.provider().clone())?;
2110+
2111+
// Set UWheelOptimizer as optimizer rule
2112+
let session_state = SessionStateBuilder::new()
2113+
.with_optimizer_rules(vec![optimizer.clone()])
2114+
.build();
2115+
let uwheel_ctx = SessionContext::new_with_state(session_state);
2116+
2117+
// Run the query through the ctx that has our OptimizerRule
2118+
let df = uwheel_ctx.execute_logical_plan(plan).await?;
2119+
let results = df.collect().await?;
2120+
2121+
assert_eq!(results.len(), 1);
2122+
2123+
assert_eq!(
2124+
results[0]
2125+
.column(1)
2126+
.as_any()
2127+
.downcast_ref::<Float64Array>()
2128+
.unwrap()
2129+
.value(0),
2130+
55.0
2131+
);
2132+
2133+
assert_eq!(
2134+
results[0]
2135+
.column(2)
2136+
.as_any()
2137+
.downcast_ref::<Float64Array>()
2138+
.unwrap()
2139+
.value(0),
2140+
5.5
2141+
);
2142+
2143+
assert_eq!(
2144+
results[0]
2145+
.column(3)
2146+
.as_any()
2147+
.downcast_ref::<Int64Array>()
2148+
.unwrap()
2149+
.value(0),
2150+
10
2151+
);
2152+
2153+
Ok(())
2154+
}
19012155
}

0 commit comments

Comments
 (0)