Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,8 @@ required-features = ["math_expressions"]
harness = false
name = "floor_ceil"
required-features = ["math_expressions"]

[[bench]]
harness = false
name = "round"
required-features = ["math_expressions"]
154 changes: 154 additions & 0 deletions datafusion/functions/benches/round.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

extern crate criterion;

use arrow::datatypes::{DataType, Field, Float32Type, Float64Type};
use arrow::util::bench_util::create_primitive_array;
use criterion::{Criterion, SamplingMode, criterion_group, criterion_main};
use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use datafusion_functions::math::round;
use std::hint::black_box;
use std::sync::Arc;
use std::time::Duration;

fn criterion_benchmark(c: &mut Criterion) {
let round_fn = round();
let config_options = Arc::new(ConfigOptions::default());

for size in [1024, 4096, 8192] {
let mut group = c.benchmark_group(format!("round size={size}"));
group.sampling_mode(SamplingMode::Flat);
group.sample_size(10);
group.measurement_time(Duration::from_secs(10));

// Float64 array benchmark
let f64_array = Arc::new(create_primitive_array::<Float64Type>(size, 0.1));
let batch_len = f64_array.len();
let f64_args = vec![
ColumnarValue::Array(f64_array),
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
];

group.bench_function("round_f64_array", |b| {
b.iter(|| {
let args_cloned = f64_args.clone();
black_box(
round_fn
.invoke_with_args(ScalarFunctionArgs {
args: args_cloned,
arg_fields: vec![
Field::new("a", DataType::Float64, true).into(),
Field::new("b", DataType::Int32, false).into(),
],
number_rows: batch_len,
return_field: Field::new("f", DataType::Float64, true).into(),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});

// Float32 array benchmark
let f32_array = Arc::new(create_primitive_array::<Float32Type>(size, 0.1));
let f32_args = vec![
ColumnarValue::Array(f32_array),
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
];

group.bench_function("round_f32_array", |b| {
b.iter(|| {
let args_cloned = f32_args.clone();
black_box(
round_fn
.invoke_with_args(ScalarFunctionArgs {
args: args_cloned,
arg_fields: vec![
Field::new("a", DataType::Float32, true).into(),
Field::new("b", DataType::Int32, false).into(),
],
number_rows: batch_len,
return_field: Field::new("f", DataType::Float32, true).into(),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});

// Scalar benchmark (the optimization we added)
let scalar_f64_args = vec![
ColumnarValue::Scalar(ScalarValue::Float64(Some(std::f64::consts::PI))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
];

group.bench_function("round_f64_scalar", |b| {
b.iter(|| {
let args_cloned = scalar_f64_args.clone();
black_box(
round_fn
.invoke_with_args(ScalarFunctionArgs {
args: args_cloned,
arg_fields: vec![
Field::new("a", DataType::Float64, false).into(),
Field::new("b", DataType::Int32, false).into(),
],
number_rows: 1,
return_field: Field::new("f", DataType::Float64, false)
.into(),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});

let scalar_f32_args = vec![
ColumnarValue::Scalar(ScalarValue::Float32(Some(std::f32::consts::PI))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
];

group.bench_function("round_f32_scalar", |b| {
b.iter(|| {
let args_cloned = scalar_f32_args.clone();
black_box(
round_fn
.invoke_with_args(ScalarFunctionArgs {
args: args_cloned,
arg_fields: vec![
Field::new("a", DataType::Float32, false).into(),
Field::new("b", DataType::Int32, false).into(),
],
number_rows: 1,
return_field: Field::new("f", DataType::Float32, false)
.into(),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});

group.finish();
}
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
131 changes: 130 additions & 1 deletion datafusion/functions/src/math/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use arrow::error::ArrowError;
use datafusion_common::types::{
NativeType, logical_float32, logical_float64, logical_int32,
};
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Expand Down Expand Up @@ -141,6 +141,135 @@ impl ScalarUDFImpl for RoundFunc {
&default_decimal_places
};

// Scalar fast path for float and decimal types - avoid array conversion overhead
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) =
            (&args.args[0], decimal_places)
        {
            if value_scalar.is_null() || dp_scalar.is_null() {
                return ColumnarValue::Scalar(ScalarValue::Null)
                    .cast_to(args.return_type(), None);
            }

            let dp = if let ScalarValue::Int32(Some(dp)) = dp_scalar {
                *dp
            } else {
                return internal_err!(
                    "Unexpected datatype for decimal_places: {}",
                    dp_scalar.data_type()
                );
            };

            match value_scalar {
                ScalarValue::Float32(Some(v)) => {
                    let rounded = round_float(*v, dp)?;
                    Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
                }
                ScalarValue::Float64(Some(v)) => {
                    let rounded = round_float(*v, dp)?;
                    Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
                }
                ScalarValue::Decimal128(Some(v), precision, scale) => {
                    let rounded = round_decimal(*v, *scale, dp)?;
                    let scalar =
                        ScalarValue::Decimal128(Some(rounded), *precision, *scale);
                    Ok(ColumnarValue::Scalar(scalar))
                }
                ScalarValue::Decimal256(Some(v), precision, scale) => {
                    let rounded = round_decimal(*v, *scale, dp)?;
                    let scalar =
                        ScalarValue::Decimal256(Some(rounded), *precision, *scale);
                    Ok(ColumnarValue::Scalar(scalar))
                }
                ScalarValue::Decimal64(Some(v), precision, scale) => {
                    let rounded = round_decimal(*v, *scale, dp)?;
                    let scalar =
                        ScalarValue::Decimal64(Some(rounded), *precision, *scale);
                    Ok(ColumnarValue::Scalar(scalar))
                }
                ScalarValue::Decimal32(Some(v), precision, scale) => {
                    let rounded = round_decimal(*v, *scale, dp)?;
                    let scalar =
                        ScalarValue::Decimal32(Some(rounded), *precision, *scale);
                    Ok(ColumnarValue::Scalar(scalar))
                }
                _ => {
                    internal_err!(
                        "Unexpected datatype for value: {}",
                        value_scalar.data_type()
                    )
                }
            }
        } else {
            round_columnar(&args.args[0], decimal_places, args.number_rows)
        }

Cleaner way of doing this

  • Using internal_err which are more appropriate here than exec_err
  • Collapse null handling using ScalarValue::is_null and ColumnarValue::cast_to
  • Don't need to map the error of round_float and round_decimal because using ? does that for us

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this look much better.

if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) =
(&args.args[0], decimal_places)
{
// Extract decimal places as i32
// Note: decimal_places is coerced to Int32 by the signature, so the non-Int32
// arm should be unreachable in normal execution.
let dp = match dp_scalar {
ScalarValue::Int32(Some(dp)) => *dp,
ScalarValue::Int32(None) => {
// Return null with correct type for null decimal_places
return match value_scalar {
ScalarValue::Float32(_) => {
Ok(ColumnarValue::Scalar(ScalarValue::Float32(None)))
}
ScalarValue::Decimal128(_, p, s) => Ok(ColumnarValue::Scalar(
ScalarValue::Decimal128(None, *p, *s),
)),
ScalarValue::Decimal256(_, p, s) => Ok(ColumnarValue::Scalar(
ScalarValue::Decimal256(None, *p, *s),
)),
ScalarValue::Decimal64(_, p, s) => Ok(ColumnarValue::Scalar(
ScalarValue::Decimal64(None, *p, *s),
)),
ScalarValue::Decimal32(_, p, s) => Ok(ColumnarValue::Scalar(
ScalarValue::Decimal32(None, *p, *s),
)),
_ => Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))),
};
}
_ => {
return exec_err!(
"Internal error: round decimal_places should be Int32, got {:?}",
dp_scalar
);
}
};

match value_scalar {
ScalarValue::Float64(Some(v)) => {
return round_float(*v, dp)
.map(|r| ColumnarValue::Scalar(ScalarValue::Float64(Some(r))))
.map_err(DataFusionError::from);
}
ScalarValue::Float64(None) => {
return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)));
}
ScalarValue::Float32(Some(v)) => {
return round_float(*v, dp)
.map(|r| ColumnarValue::Scalar(ScalarValue::Float32(Some(r))))
.map_err(DataFusionError::from);
}
ScalarValue::Float32(None) => {
return Ok(ColumnarValue::Scalar(ScalarValue::Float32(None)));
}
ScalarValue::Decimal128(Some(v), precision, scale) => {
return round_decimal(*v, *scale, dp)
.map(|r| {
ColumnarValue::Scalar(ScalarValue::Decimal128(
Some(r),
*precision,
*scale,
))
})
.map_err(DataFusionError::from);
}
ScalarValue::Decimal128(None, precision, scale) => {
return Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
None, *precision, *scale,
)));
}
ScalarValue::Decimal256(Some(v), precision, scale) => {
return round_decimal(*v, *scale, dp)
.map(|r| {
ColumnarValue::Scalar(ScalarValue::Decimal256(
Some(r),
*precision,
*scale,
))
})
.map_err(DataFusionError::from);
}
ScalarValue::Decimal256(None, precision, scale) => {
return Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(
None, *precision, *scale,
)));
}
ScalarValue::Decimal64(Some(v), precision, scale) => {
return round_decimal(*v, *scale, dp)
.map(|r| {
ColumnarValue::Scalar(ScalarValue::Decimal64(
Some(r),
*precision,
*scale,
))
})
.map_err(DataFusionError::from);
}
ScalarValue::Decimal64(None, precision, scale) => {
return Ok(ColumnarValue::Scalar(ScalarValue::Decimal64(
None, *precision, *scale,
)));
}
ScalarValue::Decimal32(Some(v), precision, scale) => {
return round_decimal(*v, *scale, dp)
.map(|r| {
ColumnarValue::Scalar(ScalarValue::Decimal32(
Some(r),
*precision,
*scale,
))
})
.map_err(DataFusionError::from);
}
ScalarValue::Decimal32(None, precision, scale) => {
return Ok(ColumnarValue::Scalar(ScalarValue::Decimal32(
None, *precision, *scale,
)));
}
// All supported scalar types are handled above
_ => {
return exec_err!(
"Internal error: unexpected scalar type for round: {:?}",
value_scalar
);
}
}
}

round_columnar(&args.args[0], decimal_places, args.number_rows)
}

Expand Down