Skip to content

Commit f8e988f

Browse files
codetyri0nsriram
andauthored
Feat: [datafusion-spark] Migrate avg from comet to datafusion-spark and add tests. (#17871)
* Chore: Migrate avg from comet to datafusion-spark and add a few tests. * CI Fix: Apply cargo format. * CI Fix: Add coerce types function. * Add group by tests to the suite. * Add doc highlighting differences with Spark. * CI assertion error fixes and improved docs. --------- Co-authored-by: sriram <[email protected]>
1 parent 43dafd6 commit f8e988f

File tree

3 files changed

+421
-2
lines changed
  • datafusion

3 files changed

+421
-2
lines changed
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::ArrowNativeTypeOp;
19+
use arrow::array::{
20+
builder::PrimitiveBuilder,
21+
cast::AsArray,
22+
types::{Float64Type, Int64Type},
23+
Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray,
24+
};
25+
use arrow::compute::sum;
26+
use arrow::datatypes::{DataType, Field, FieldRef};
27+
use datafusion_common::utils::take_function_args;
28+
use datafusion_common::{not_impl_err, Result, ScalarValue};
29+
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
30+
use datafusion_expr::type_coercion::aggregates::coerce_avg_type;
31+
use datafusion_expr::utils::format_state_name;
32+
use datafusion_expr::Volatility::Immutable;
33+
use datafusion_expr::{
34+
type_coercion::aggregates::avg_return_type, Accumulator, AggregateUDFImpl, EmitTo,
35+
GroupsAccumulator, ReversedUDAF, Signature,
36+
};
37+
use std::{any::Any, sync::Arc};
38+
use DataType::*;
39+
40+
/// AVG aggregate expression
41+
/// Spark average aggregate expression. Differs from standard DataFusion average aggregate
42+
/// in that it uses an `i64` for the count (DataFusion version uses `u64`); also there is ANSI mode
43+
/// support planned in the future for Spark version.
44+
45+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46+
pub struct SparkAvg {
47+
name: String,
48+
signature: Signature,
49+
input_data_type: DataType,
50+
result_data_type: DataType,
51+
}
52+
53+
impl SparkAvg {
54+
/// Implement AVG aggregate function
55+
pub fn new(name: impl Into<String>, data_type: DataType) -> Self {
56+
let result_data_type = avg_return_type("avg", &data_type).unwrap();
57+
58+
Self {
59+
name: name.into(),
60+
signature: Signature::user_defined(Immutable),
61+
input_data_type: data_type,
62+
result_data_type,
63+
}
64+
}
65+
}
66+
67+
impl AggregateUDFImpl for SparkAvg {
68+
fn as_any(&self) -> &dyn Any {
69+
self
70+
}
71+
72+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
73+
// instantiate specialized accumulator based for the type
74+
match (&self.input_data_type, &self.result_data_type) {
75+
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
76+
_ => not_impl_err!(
77+
"AvgAccumulator for ({} --> {})",
78+
self.input_data_type,
79+
self.result_data_type
80+
),
81+
}
82+
}
83+
84+
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
85+
Ok(vec![
86+
Arc::new(Field::new(
87+
format_state_name(&self.name, "sum"),
88+
self.input_data_type.clone(),
89+
true,
90+
)),
91+
Arc::new(Field::new(
92+
format_state_name(&self.name, "count"),
93+
Int64,
94+
true,
95+
)),
96+
])
97+
}
98+
99+
fn name(&self) -> &str {
100+
&self.name
101+
}
102+
103+
fn reverse_expr(&self) -> ReversedUDAF {
104+
ReversedUDAF::Identical
105+
}
106+
107+
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
108+
true
109+
}
110+
111+
fn create_groups_accumulator(
112+
&self,
113+
_args: AccumulatorArgs,
114+
) -> Result<Box<dyn GroupsAccumulator>> {
115+
// instantiate specialized accumulator based for the type
116+
match (&self.input_data_type, &self.result_data_type) {
117+
(Float64, Float64) => {
118+
Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
119+
&self.input_data_type,
120+
|sum: f64, count: i64| Ok(sum / count as f64),
121+
)))
122+
}
123+
124+
_ => not_impl_err!(
125+
"AvgGroupsAccumulator for ({} --> {})",
126+
self.input_data_type,
127+
self.result_data_type
128+
),
129+
}
130+
}
131+
132+
fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
133+
Ok(ScalarValue::Float64(None))
134+
}
135+
136+
fn signature(&self) -> &Signature {
137+
&self.signature
138+
}
139+
140+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
141+
avg_return_type(self.name(), &arg_types[0])
142+
}
143+
144+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
145+
let [arg] = take_function_args(self.name(), arg_types)?;
146+
coerce_avg_type(self.name(), std::slice::from_ref(arg))
147+
}
148+
}
149+
150+
/// An accumulator to compute the average
151+
#[derive(Debug, Default)]
152+
pub struct AvgAccumulator {
153+
sum: Option<f64>,
154+
count: i64,
155+
}
156+
157+
impl Accumulator for AvgAccumulator {
158+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
159+
Ok(vec![
160+
ScalarValue::Float64(self.sum),
161+
ScalarValue::from(self.count),
162+
])
163+
}
164+
165+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
166+
let values = values[0].as_primitive::<Float64Type>();
167+
self.count += (values.len() - values.null_count()) as i64;
168+
let v = self.sum.get_or_insert(0.);
169+
if let Some(x) = sum(values) {
170+
*v += x;
171+
}
172+
Ok(())
173+
}
174+
175+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
176+
// counts are summed
177+
self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
178+
179+
// sums are summed
180+
if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {
181+
let v = self.sum.get_or_insert(0.);
182+
*v += x;
183+
}
184+
Ok(())
185+
}
186+
187+
fn evaluate(&mut self) -> Result<ScalarValue> {
188+
if self.count == 0 {
189+
// If all input are nulls, count will be 0 and we will get null after the division.
190+
// This is consistent with Spark Average implementation.
191+
Ok(ScalarValue::Float64(None))
192+
} else {
193+
Ok(ScalarValue::Float64(
194+
self.sum.map(|f| f / self.count as f64),
195+
))
196+
}
197+
}
198+
199+
fn size(&self) -> usize {
200+
size_of_val(self)
201+
}
202+
}
203+
204+
/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
205+
/// Stores values as native types, and does overflow checking
206+
///
207+
/// F: Function that calculates the average value from a sum of
208+
/// T::Native and a total count
209+
#[derive(Debug)]
210+
struct AvgGroupsAccumulator<T, F>
211+
where
212+
T: ArrowNumericType + Send,
213+
F: Fn(T::Native, i64) -> Result<T::Native> + Send,
214+
{
215+
/// The type of the returned average
216+
return_data_type: DataType,
217+
218+
/// Count per group (use i64 to make Int64Array)
219+
counts: Vec<i64>,
220+
221+
/// Sums per group, stored as the native type
222+
sums: Vec<T::Native>,
223+
224+
/// Function that computes the final average (value / count)
225+
avg_fn: F,
226+
}
227+
228+
impl<T, F> AvgGroupsAccumulator<T, F>
229+
where
230+
T: ArrowNumericType + Send,
231+
F: Fn(T::Native, i64) -> Result<T::Native> + Send,
232+
{
233+
pub fn new(return_data_type: &DataType, avg_fn: F) -> Self {
234+
Self {
235+
return_data_type: return_data_type.clone(),
236+
counts: vec![],
237+
sums: vec![],
238+
avg_fn,
239+
}
240+
}
241+
}
242+
243+
impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
244+
where
245+
T: ArrowNumericType + Send,
246+
F: Fn(T::Native, i64) -> Result<T::Native> + Send,
247+
{
248+
fn update_batch(
249+
&mut self,
250+
values: &[ArrayRef],
251+
group_indices: &[usize],
252+
_opt_filter: Option<&arrow::array::BooleanArray>,
253+
total_num_groups: usize,
254+
) -> Result<()> {
255+
assert_eq!(values.len(), 1, "single argument to update_batch");
256+
let values = values[0].as_primitive::<T>();
257+
let data = values.values();
258+
259+
// increment counts, update sums
260+
self.counts.resize(total_num_groups, 0);
261+
self.sums.resize(total_num_groups, T::default_value());
262+
263+
let iter = group_indices.iter().zip(data.iter());
264+
if values.null_count() == 0 {
265+
for (&group_index, &value) in iter {
266+
let sum = &mut self.sums[group_index];
267+
*sum = (*sum).add_wrapping(value);
268+
self.counts[group_index] += 1;
269+
}
270+
} else {
271+
for (idx, (&group_index, &value)) in iter.enumerate() {
272+
if values.is_null(idx) {
273+
continue;
274+
}
275+
let sum = &mut self.sums[group_index];
276+
*sum = (*sum).add_wrapping(value);
277+
278+
self.counts[group_index] += 1;
279+
}
280+
}
281+
282+
Ok(())
283+
}
284+
285+
fn merge_batch(
286+
&mut self,
287+
values: &[ArrayRef],
288+
group_indices: &[usize],
289+
_opt_filter: Option<&arrow::array::BooleanArray>,
290+
total_num_groups: usize,
291+
) -> Result<()> {
292+
assert_eq!(values.len(), 2, "two arguments to merge_batch");
293+
// first batch is partial sums, second is counts
294+
let partial_sums = values[0].as_primitive::<T>();
295+
let partial_counts = values[1].as_primitive::<Int64Type>();
296+
// update counts with partial counts
297+
self.counts.resize(total_num_groups, 0);
298+
let iter1 = group_indices.iter().zip(partial_counts.values().iter());
299+
for (&group_index, &partial_count) in iter1 {
300+
self.counts[group_index] += partial_count;
301+
}
302+
303+
// update sums
304+
self.sums.resize(total_num_groups, T::default_value());
305+
let iter2 = group_indices.iter().zip(partial_sums.values().iter());
306+
for (&group_index, &new_value) in iter2 {
307+
let sum = &mut self.sums[group_index];
308+
*sum = sum.add_wrapping(new_value);
309+
}
310+
311+
Ok(())
312+
}
313+
314+
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
315+
let counts = emit_to.take_needed(&mut self.counts);
316+
let sums = emit_to.take_needed(&mut self.sums);
317+
let mut builder = PrimitiveBuilder::<T>::with_capacity(sums.len());
318+
let iter = sums.into_iter().zip(counts);
319+
320+
for (sum, count) in iter {
321+
if count != 0 {
322+
builder.append_value((self.avg_fn)(sum, count)?)
323+
} else {
324+
builder.append_null();
325+
}
326+
}
327+
let array: PrimitiveArray<T> = builder.finish();
328+
329+
Ok(Arc::new(array))
330+
}
331+
332+
// return arrays for sums and counts
333+
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
334+
let counts = emit_to.take_needed(&mut self.counts);
335+
let counts = Int64Array::new(counts.into(), None);
336+
337+
let sums = emit_to.take_needed(&mut self.sums);
338+
let sums = PrimitiveArray::<T>::new(sums.into(), None)
339+
.with_data_type(self.return_data_type.clone());
340+
341+
Ok(vec![
342+
Arc::new(sums) as ArrayRef,
343+
Arc::new(counts) as ArrayRef,
344+
])
345+
}
346+
347+
fn size(&self) -> usize {
348+
self.counts.capacity() * size_of::<i64>() + self.sums.capacity() * size_of::<T>()
349+
}
350+
}

datafusion/spark/src/function/aggregate/mod.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,24 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use arrow::datatypes::DataType;
1819
use datafusion_expr::AggregateUDF;
1920
use std::sync::Arc;
2021

21-
pub mod expr_fn {}
22+
pub mod avg;
23+
pub mod expr_fn {
24+
use datafusion_functions::export_functions;
25+
26+
export_functions!((avg, "Returns the average value of a given column", arg1));
27+
}
28+
29+
pub fn avg() -> Arc<AggregateUDF> {
30+
Arc::new(AggregateUDF::new_from_impl(avg::SparkAvg::new(
31+
"avg",
32+
DataType::Float64,
33+
)))
34+
}
2235

2336
pub fn functions() -> Vec<Arc<AggregateUDF>> {
24-
vec![]
37+
vec![avg()]
2538
}

0 commit comments

Comments
 (0)