diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index ff1fd0cd4b37..a822710f4275 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use super::power::PowerFunc; -use crate::utils::{calculate_binary_math, decimal128_to_i128}; +use crate::utils::{calculate_binary_math, decimal128_to_i128, decimal256_to_i256}; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{ DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type, @@ -134,15 +134,60 @@ fn log_decimal128(value: i128, scale: i8, base: f64) -> Result } } -/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base -/// Returns error if base is invalid or if value is out of bounds of Decimal128 +/// Binary function to calculate an integer logarithm of Decimal256 `value` using `base` base +/// Returns error if base is invalid fn log_decimal256(value: i256, scale: i8, base: f64) -> Result { - match value.to_i128() { - Some(value) => log_decimal128(value, scale, base), - None => Err(ArrowError::NotYetImplemented(format!( - "Log of Decimal256 larger than Decimal128 is not yet supported: {value}" - ))), + if !base.is_finite() || base.trunc() != base { + return Err(ArrowError::ComputeError(format!( + "Log cannot use non-integer base: {base}" + ))); + } + if (base as u32) < 2 { + return Err(ArrowError::ComputeError(format!( + "Log base must be greater than 1: {base}" + ))); + } + + // Try to convert to i128 for faster calculation if possible + if let Some(value_i128) = value.to_i128() { + let unscaled_value = decimal128_to_i128(value_i128, scale)?; + if unscaled_value > 0 { + let log_value: u32 = unscaled_value.ilog(base as i128); + return Ok(log_value as f64); + } else { + return Ok(f64::NAN); + } + } + + // For values that don't fit in i128, use f64 approximation + let unscaled_value = decimal256_to_i256(value, scale)?; + + // Check if the value is non-positive + if !unscaled_value.is_positive() { + return Ok(f64::NAN); + } + + // Convert i256 to f64 for logarithm calculation + // Note: This may lose precision for very large numbers, but that's acceptable + // for logarithm calculation since the result is relatively small + let value_f64 = i256_to_f64(unscaled_value); + let log_value = value_f64.log(base); + + Ok(log_value.floor()) +} + +/// Converts i256 to f64 for logarithm calculation +/// This may lose precision for very large numbers but is acceptable for log calculation +fn i256_to_f64(value: i256) -> f64 { + // Try to convert directly if it fits in i128 + if let Some(value_i128) = value.to_i128() { + return value_i128 as f64; } + + // For larger values, use string conversion (less efficient but more accurate) + // Parse the string representation + let value_str = value.to_string(); + value_str.parse::().unwrap_or(f64::INFINITY) } impl ScalarUDFImpl for LogFunc { @@ -1088,24 +1133,109 @@ mod tests { } #[test] - fn test_log_decimal256_error() { - let arg_field = Field::new("a", DataType::Decimal256(38, 0), false).into(); + fn test_log_decimal256_large_values() { + let arg_field = Field::new("a", DataType::Decimal256(76, 0), false).into(); let args = ScalarFunctionArgs { args: vec![ - ColumnarValue::Array(Arc::new(Decimal256Array::from(vec![ - // Slightly larger than i128 - Some(i256::from_i128(i128::MAX) + i256::from(1000)), - ]))), // num + ColumnarValue::Array(Arc::new( + Decimal256Array::from(vec![ + // Slightly larger than i128 max + Some(i256::from_i128(i128::MAX) + i256::from(1000)), + // A much larger value + Some(i256::from_i128(i128::MAX) * i256::from_i128(1000)), + // 10^50 (a very large number) + Some({ + let mut val = i256::from_i128(1); + for _ in 0..50 { + val = val * i256::from_i128(10); + } + val + }), + ]) + .with_precision_and_scale(76, 0) + .unwrap(), + )), // num ], arg_fields: vec![arg_field], - number_rows: 1, + number_rows: 3, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; - let result = LogFunc::new().invoke_with_args(args); - assert!(result.is_err()); - assert_eq!(result.unwrap_err().to_string().lines().next().unwrap(), - "Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported: 170141183460469231731687303715884106727" - ); + let result = LogFunc::new() + .invoke_with_args(args) + .expect("failed to calculate log for large Decimal256 values"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 3); + // log10(i128::MAX + 1000) ≈ 38.53 -> floor to 38 + assert!((floats.value(0) - 38.0).abs() < 1.0); + // log10(i128::MAX * 1000) ≈ 41.53 -> floor to 41 + assert!((floats.value(1) - 41.0).abs() < 1.0); + // log10(10^50) = 50 + assert!((floats.value(2) - 50.0).abs() < 1.0); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_decimal256_with_different_bases() { + // Test log with base 2 for large Decimal256 values + let arg_fields = vec![ + Field::new("b", DataType::Float64, false).into(), + Field::new("x", DataType::Decimal256(76, 0), false).into(), + ]; + + // Calculate 2^100 as a test value + let power_of_2_100 = { + let mut val = i256::from_i128(1); + for _ in 0..100 { + val = val * i256::from_i128(2); + } + val + }; + + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // base + ColumnarValue::Array(Arc::new( + Decimal256Array::from(vec![ + Some(power_of_2_100), + Some(i256::from_i128(i128::MAX)), + ]) + .with_precision_and_scale(76, 0) + .unwrap(), + )), // num + ], + arg_fields, + number_rows: 2, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new() + .invoke_with_args(args) + .expect("failed to calculate log for large Decimal256 with base 2"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 2); + // log2(2^100) = 100 + assert!((floats.value(0) - 100.0).abs() < 1.0); + // log2(i128::MAX) ≈ 126 (since i128::MAX = 2^127 - 1) + assert!((floats.value(1) - 126.0).abs() < 1.0); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } } } diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 932d61e8007c..642897e3481e 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -19,6 +19,7 @@ use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray} use arrow::compute::try_binary; use arrow::datatypes::DataType; use arrow::error::ArrowError; +use arrow_buffer::i256; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::function::Hint; use datafusion_expr::ColumnarValue; @@ -192,6 +193,24 @@ pub fn decimal128_to_i128(value: i128, scale: i8) -> Result { } } +/// Converts Decimal256 components (value and scale) to an unscaled i256 +pub fn decimal256_to_i256(value: i256, scale: i8) -> Result { + if scale < 0 { + Err(ArrowError::ComputeError( + "Negative scale is not supported".into(), + )) + } else if scale == 0 { + Ok(value) + } else { + let divisor = i256::from_i128(10) + .checked_pow(scale as u32) + .ok_or_else(|| { + ArrowError::ComputeError(format!("Cannot get a power of {scale}")) + })?; + Ok(value / divisor) + } +} + #[cfg(test)] pub mod test { /// $FUNC ScalarUDFImpl to test diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index 502821fcc304..1c491ca5ee56 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -830,8 +830,10 @@ select log(100000000000000000000000000000000000::decimal(76,0)); 35 # log(10^50) for decimal256 for a value larger than i128 -query error Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported +query R select log(100000000000000000000000000000000000000000000000000::decimal(76,0)); +---- +50 # log(10^35) for decimal128 with explicit base query R @@ -863,6 +865,24 @@ select log(2.0, 100000000000000000000000000000000000::decimal(38,0)); ---- 116.267483321058 +# log(10^50) for decimal256 with explicit base 2 for a value larger than i128 +query R +select log(2, 100000000000000000000000000000000000000000000000000::decimal(76,0)); +---- +166 + +# log(10^50) for decimal256 with explicit float base for a value larger than i128 +query R +select log(10.0, 100000000000000000000000000000000000000000000000000::decimal(76,0)); +---- +50 + +# log for a very large decimal256 value (close to i128::MAX * 1000) +query R +select log(10, 170141183460469231731687303715884106000::decimal(76,0)); +---- +38 + # null cases query R select log(null, 100);