| 
 | 1 | +// Licensed to the Apache Software Foundation (ASF) under one  | 
 | 2 | +// or more contributor license agreements.  See the NOTICE file  | 
 | 3 | +// distributed with this work for additional information  | 
 | 4 | +// regarding copyright ownership.  The ASF licenses this file  | 
 | 5 | +// to you under the Apache License, Version 2.0 (the  | 
 | 6 | +// "License"); you may not use this file except in compliance  | 
 | 7 | +// with the License.  You may obtain a copy of the License at  | 
 | 8 | +//  | 
 | 9 | +//   http://www.apache.org/licenses/LICENSE-2.0  | 
 | 10 | +//  | 
 | 11 | +// Unless required by applicable law or agreed to in writing,  | 
 | 12 | +// software distributed under the License is distributed on an  | 
 | 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY  | 
 | 14 | +// KIND, either express or implied.  See the License for the  | 
 | 15 | +// specific language governing permissions and limitations  | 
 | 16 | +// under the License.  | 
 | 17 | + | 
 | 18 | +use arrow::array::ArrowNativeTypeOp;  | 
 | 19 | +use arrow::array::{  | 
 | 20 | +    builder::PrimitiveBuilder,  | 
 | 21 | +    cast::AsArray,  | 
 | 22 | +    types::{Float64Type, Int64Type},  | 
 | 23 | +    Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray,  | 
 | 24 | +};  | 
 | 25 | +use arrow::compute::sum;  | 
 | 26 | +use arrow::datatypes::{DataType, Field, FieldRef};  | 
 | 27 | +use datafusion_common::utils::take_function_args;  | 
 | 28 | +use datafusion_common::{not_impl_err, Result, ScalarValue};  | 
 | 29 | +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};  | 
 | 30 | +use datafusion_expr::type_coercion::aggregates::coerce_avg_type;  | 
 | 31 | +use datafusion_expr::utils::format_state_name;  | 
 | 32 | +use datafusion_expr::Volatility::Immutable;  | 
 | 33 | +use datafusion_expr::{  | 
 | 34 | +    type_coercion::aggregates::avg_return_type, Accumulator, AggregateUDFImpl, EmitTo,  | 
 | 35 | +    GroupsAccumulator, ReversedUDAF, Signature,  | 
 | 36 | +};  | 
 | 37 | +use std::{any::Any, sync::Arc};  | 
 | 38 | +use DataType::*;  | 
 | 39 | + | 
 | 40 | +/// AVG aggregate expression  | 
 | 41 | +/// Spark average aggregate expression. Differs from standard DataFusion average aggregate  | 
 | 42 | +/// in that it uses an `i64` for the count (DataFusion version uses `u64`); also there is ANSI mode  | 
 | 43 | +/// support planned in the future for Spark version.  | 
 | 44 | +
  | 
 | 45 | +#[derive(Debug, Clone, PartialEq, Eq, Hash)]  | 
 | 46 | +pub struct SparkAvg {  | 
 | 47 | +    name: String,  | 
 | 48 | +    signature: Signature,  | 
 | 49 | +    input_data_type: DataType,  | 
 | 50 | +    result_data_type: DataType,  | 
 | 51 | +}  | 
 | 52 | + | 
 | 53 | +impl SparkAvg {  | 
 | 54 | +    /// Implement AVG aggregate function  | 
 | 55 | +    pub fn new(name: impl Into<String>, data_type: DataType) -> Self {  | 
 | 56 | +        let result_data_type = avg_return_type("avg", &data_type).unwrap();  | 
 | 57 | + | 
 | 58 | +        Self {  | 
 | 59 | +            name: name.into(),  | 
 | 60 | +            signature: Signature::user_defined(Immutable),  | 
 | 61 | +            input_data_type: data_type,  | 
 | 62 | +            result_data_type,  | 
 | 63 | +        }  | 
 | 64 | +    }  | 
 | 65 | +}  | 
 | 66 | + | 
 | 67 | +impl AggregateUDFImpl for SparkAvg {  | 
 | 68 | +    fn as_any(&self) -> &dyn Any {  | 
 | 69 | +        self  | 
 | 70 | +    }  | 
 | 71 | + | 
 | 72 | +    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {  | 
 | 73 | +        // instantiate specialized accumulator based for the type  | 
 | 74 | +        match (&self.input_data_type, &self.result_data_type) {  | 
 | 75 | +            (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),  | 
 | 76 | +            _ => not_impl_err!(  | 
 | 77 | +                "AvgAccumulator for ({} --> {})",  | 
 | 78 | +                self.input_data_type,  | 
 | 79 | +                self.result_data_type  | 
 | 80 | +            ),  | 
 | 81 | +        }  | 
 | 82 | +    }  | 
 | 83 | + | 
 | 84 | +    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {  | 
 | 85 | +        Ok(vec![  | 
 | 86 | +            Arc::new(Field::new(  | 
 | 87 | +                format_state_name(&self.name, "sum"),  | 
 | 88 | +                self.input_data_type.clone(),  | 
 | 89 | +                true,  | 
 | 90 | +            )),  | 
 | 91 | +            Arc::new(Field::new(  | 
 | 92 | +                format_state_name(&self.name, "count"),  | 
 | 93 | +                Int64,  | 
 | 94 | +                true,  | 
 | 95 | +            )),  | 
 | 96 | +        ])  | 
 | 97 | +    }  | 
 | 98 | + | 
 | 99 | +    fn name(&self) -> &str {  | 
 | 100 | +        &self.name  | 
 | 101 | +    }  | 
 | 102 | + | 
 | 103 | +    fn reverse_expr(&self) -> ReversedUDAF {  | 
 | 104 | +        ReversedUDAF::Identical  | 
 | 105 | +    }  | 
 | 106 | + | 
 | 107 | +    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {  | 
 | 108 | +        true  | 
 | 109 | +    }  | 
 | 110 | + | 
 | 111 | +    fn create_groups_accumulator(  | 
 | 112 | +        &self,  | 
 | 113 | +        _args: AccumulatorArgs,  | 
 | 114 | +    ) -> Result<Box<dyn GroupsAccumulator>> {  | 
 | 115 | +        // instantiate specialized accumulator based for the type  | 
 | 116 | +        match (&self.input_data_type, &self.result_data_type) {  | 
 | 117 | +            (Float64, Float64) => {  | 
 | 118 | +                Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(  | 
 | 119 | +                    &self.input_data_type,  | 
 | 120 | +                    |sum: f64, count: i64| Ok(sum / count as f64),  | 
 | 121 | +                )))  | 
 | 122 | +            }  | 
 | 123 | + | 
 | 124 | +            _ => not_impl_err!(  | 
 | 125 | +                "AvgGroupsAccumulator for ({} --> {})",  | 
 | 126 | +                self.input_data_type,  | 
 | 127 | +                self.result_data_type  | 
 | 128 | +            ),  | 
 | 129 | +        }  | 
 | 130 | +    }  | 
 | 131 | + | 
 | 132 | +    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {  | 
 | 133 | +        Ok(ScalarValue::Float64(None))  | 
 | 134 | +    }  | 
 | 135 | + | 
 | 136 | +    fn signature(&self) -> &Signature {  | 
 | 137 | +        &self.signature  | 
 | 138 | +    }  | 
 | 139 | + | 
 | 140 | +    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {  | 
 | 141 | +        avg_return_type(self.name(), &arg_types[0])  | 
 | 142 | +    }  | 
 | 143 | + | 
 | 144 | +    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {  | 
 | 145 | +        let [arg] = take_function_args(self.name(), arg_types)?;  | 
 | 146 | +        coerce_avg_type(self.name(), std::slice::from_ref(arg))  | 
 | 147 | +    }  | 
 | 148 | +}  | 
 | 149 | + | 
 | 150 | +/// An accumulator to compute the average  | 
 | 151 | +#[derive(Debug, Default)]  | 
 | 152 | +pub struct AvgAccumulator {  | 
 | 153 | +    sum: Option<f64>,  | 
 | 154 | +    count: i64,  | 
 | 155 | +}  | 
 | 156 | + | 
 | 157 | +impl Accumulator for AvgAccumulator {  | 
 | 158 | +    fn state(&mut self) -> Result<Vec<ScalarValue>> {  | 
 | 159 | +        Ok(vec![  | 
 | 160 | +            ScalarValue::Float64(self.sum),  | 
 | 161 | +            ScalarValue::from(self.count),  | 
 | 162 | +        ])  | 
 | 163 | +    }  | 
 | 164 | + | 
 | 165 | +    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {  | 
 | 166 | +        let values = values[0].as_primitive::<Float64Type>();  | 
 | 167 | +        self.count += (values.len() - values.null_count()) as i64;  | 
 | 168 | +        let v = self.sum.get_or_insert(0.);  | 
 | 169 | +        if let Some(x) = sum(values) {  | 
 | 170 | +            *v += x;  | 
 | 171 | +        }  | 
 | 172 | +        Ok(())  | 
 | 173 | +    }  | 
 | 174 | + | 
 | 175 | +    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {  | 
 | 176 | +        // counts are summed  | 
 | 177 | +        self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();  | 
 | 178 | + | 
 | 179 | +        // sums are summed  | 
 | 180 | +        if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {  | 
 | 181 | +            let v = self.sum.get_or_insert(0.);  | 
 | 182 | +            *v += x;  | 
 | 183 | +        }  | 
 | 184 | +        Ok(())  | 
 | 185 | +    }  | 
 | 186 | + | 
 | 187 | +    fn evaluate(&mut self) -> Result<ScalarValue> {  | 
 | 188 | +        if self.count == 0 {  | 
 | 189 | +            // If all input are nulls, count will be 0 and we will get null after the division.  | 
 | 190 | +            // This is consistent with Spark Average implementation.  | 
 | 191 | +            Ok(ScalarValue::Float64(None))  | 
 | 192 | +        } else {  | 
 | 193 | +            Ok(ScalarValue::Float64(  | 
 | 194 | +                self.sum.map(|f| f / self.count as f64),  | 
 | 195 | +            ))  | 
 | 196 | +        }  | 
 | 197 | +    }  | 
 | 198 | + | 
 | 199 | +    fn size(&self) -> usize {  | 
 | 200 | +        size_of_val(self)  | 
 | 201 | +    }  | 
 | 202 | +}  | 
 | 203 | + | 
 | 204 | +/// An accumulator to compute the average of `[PrimitiveArray<T>]`.  | 
 | 205 | +/// Stores values as native types, and does overflow checking  | 
 | 206 | +///  | 
 | 207 | +/// F: Function that calculates the average value from a sum of  | 
 | 208 | +/// T::Native and a total count  | 
 | 209 | +#[derive(Debug)]  | 
 | 210 | +struct AvgGroupsAccumulator<T, F>  | 
 | 211 | +where  | 
 | 212 | +    T: ArrowNumericType + Send,  | 
 | 213 | +    F: Fn(T::Native, i64) -> Result<T::Native> + Send,  | 
 | 214 | +{  | 
 | 215 | +    /// The type of the returned average  | 
 | 216 | +    return_data_type: DataType,  | 
 | 217 | + | 
 | 218 | +    /// Count per group (use i64 to make Int64Array)  | 
 | 219 | +    counts: Vec<i64>,  | 
 | 220 | + | 
 | 221 | +    /// Sums per group, stored as the native type  | 
 | 222 | +    sums: Vec<T::Native>,  | 
 | 223 | + | 
 | 224 | +    /// Function that computes the final average (value / count)  | 
 | 225 | +    avg_fn: F,  | 
 | 226 | +}  | 
 | 227 | + | 
 | 228 | +impl<T, F> AvgGroupsAccumulator<T, F>  | 
 | 229 | +where  | 
 | 230 | +    T: ArrowNumericType + Send,  | 
 | 231 | +    F: Fn(T::Native, i64) -> Result<T::Native> + Send,  | 
 | 232 | +{  | 
 | 233 | +    pub fn new(return_data_type: &DataType, avg_fn: F) -> Self {  | 
 | 234 | +        Self {  | 
 | 235 | +            return_data_type: return_data_type.clone(),  | 
 | 236 | +            counts: vec![],  | 
 | 237 | +            sums: vec![],  | 
 | 238 | +            avg_fn,  | 
 | 239 | +        }  | 
 | 240 | +    }  | 
 | 241 | +}  | 
 | 242 | + | 
 | 243 | +impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>  | 
 | 244 | +where  | 
 | 245 | +    T: ArrowNumericType + Send,  | 
 | 246 | +    F: Fn(T::Native, i64) -> Result<T::Native> + Send,  | 
 | 247 | +{  | 
 | 248 | +    fn update_batch(  | 
 | 249 | +        &mut self,  | 
 | 250 | +        values: &[ArrayRef],  | 
 | 251 | +        group_indices: &[usize],  | 
 | 252 | +        _opt_filter: Option<&arrow::array::BooleanArray>,  | 
 | 253 | +        total_num_groups: usize,  | 
 | 254 | +    ) -> Result<()> {  | 
 | 255 | +        assert_eq!(values.len(), 1, "single argument to update_batch");  | 
 | 256 | +        let values = values[0].as_primitive::<T>();  | 
 | 257 | +        let data = values.values();  | 
 | 258 | + | 
 | 259 | +        // increment counts, update sums  | 
 | 260 | +        self.counts.resize(total_num_groups, 0);  | 
 | 261 | +        self.sums.resize(total_num_groups, T::default_value());  | 
 | 262 | + | 
 | 263 | +        let iter = group_indices.iter().zip(data.iter());  | 
 | 264 | +        if values.null_count() == 0 {  | 
 | 265 | +            for (&group_index, &value) in iter {  | 
 | 266 | +                let sum = &mut self.sums[group_index];  | 
 | 267 | +                *sum = (*sum).add_wrapping(value);  | 
 | 268 | +                self.counts[group_index] += 1;  | 
 | 269 | +            }  | 
 | 270 | +        } else {  | 
 | 271 | +            for (idx, (&group_index, &value)) in iter.enumerate() {  | 
 | 272 | +                if values.is_null(idx) {  | 
 | 273 | +                    continue;  | 
 | 274 | +                }  | 
 | 275 | +                let sum = &mut self.sums[group_index];  | 
 | 276 | +                *sum = (*sum).add_wrapping(value);  | 
 | 277 | + | 
 | 278 | +                self.counts[group_index] += 1;  | 
 | 279 | +            }  | 
 | 280 | +        }  | 
 | 281 | + | 
 | 282 | +        Ok(())  | 
 | 283 | +    }  | 
 | 284 | + | 
 | 285 | +    fn merge_batch(  | 
 | 286 | +        &mut self,  | 
 | 287 | +        values: &[ArrayRef],  | 
 | 288 | +        group_indices: &[usize],  | 
 | 289 | +        _opt_filter: Option<&arrow::array::BooleanArray>,  | 
 | 290 | +        total_num_groups: usize,  | 
 | 291 | +    ) -> Result<()> {  | 
 | 292 | +        assert_eq!(values.len(), 2, "two arguments to merge_batch");  | 
 | 293 | +        // first batch is partial sums, second is counts  | 
 | 294 | +        let partial_sums = values[0].as_primitive::<T>();  | 
 | 295 | +        let partial_counts = values[1].as_primitive::<Int64Type>();  | 
 | 296 | +        // update counts with partial counts  | 
 | 297 | +        self.counts.resize(total_num_groups, 0);  | 
 | 298 | +        let iter1 = group_indices.iter().zip(partial_counts.values().iter());  | 
 | 299 | +        for (&group_index, &partial_count) in iter1 {  | 
 | 300 | +            self.counts[group_index] += partial_count;  | 
 | 301 | +        }  | 
 | 302 | + | 
 | 303 | +        // update sums  | 
 | 304 | +        self.sums.resize(total_num_groups, T::default_value());  | 
 | 305 | +        let iter2 = group_indices.iter().zip(partial_sums.values().iter());  | 
 | 306 | +        for (&group_index, &new_value) in iter2 {  | 
 | 307 | +            let sum = &mut self.sums[group_index];  | 
 | 308 | +            *sum = sum.add_wrapping(new_value);  | 
 | 309 | +        }  | 
 | 310 | + | 
 | 311 | +        Ok(())  | 
 | 312 | +    }  | 
 | 313 | + | 
 | 314 | +    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {  | 
 | 315 | +        let counts = emit_to.take_needed(&mut self.counts);  | 
 | 316 | +        let sums = emit_to.take_needed(&mut self.sums);  | 
 | 317 | +        let mut builder = PrimitiveBuilder::<T>::with_capacity(sums.len());  | 
 | 318 | +        let iter = sums.into_iter().zip(counts);  | 
 | 319 | + | 
 | 320 | +        for (sum, count) in iter {  | 
 | 321 | +            if count != 0 {  | 
 | 322 | +                builder.append_value((self.avg_fn)(sum, count)?)  | 
 | 323 | +            } else {  | 
 | 324 | +                builder.append_null();  | 
 | 325 | +            }  | 
 | 326 | +        }  | 
 | 327 | +        let array: PrimitiveArray<T> = builder.finish();  | 
 | 328 | + | 
 | 329 | +        Ok(Arc::new(array))  | 
 | 330 | +    }  | 
 | 331 | + | 
 | 332 | +    // return arrays for sums and counts  | 
 | 333 | +    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {  | 
 | 334 | +        let counts = emit_to.take_needed(&mut self.counts);  | 
 | 335 | +        let counts = Int64Array::new(counts.into(), None);  | 
 | 336 | + | 
 | 337 | +        let sums = emit_to.take_needed(&mut self.sums);  | 
 | 338 | +        let sums = PrimitiveArray::<T>::new(sums.into(), None)  | 
 | 339 | +            .with_data_type(self.return_data_type.clone());  | 
 | 340 | + | 
 | 341 | +        Ok(vec![  | 
 | 342 | +            Arc::new(sums) as ArrayRef,  | 
 | 343 | +            Arc::new(counts) as ArrayRef,  | 
 | 344 | +        ])  | 
 | 345 | +    }  | 
 | 346 | + | 
 | 347 | +    fn size(&self) -> usize {  | 
 | 348 | +        self.counts.capacity() * size_of::<i64>() + self.sums.capacity() * size_of::<T>()  | 
 | 349 | +    }  | 
 | 350 | +}  | 
0 commit comments