From 4b35f827725ddd47f739b3731b6a3ff2a10a3c2d Mon Sep 17 00:00:00 2001 From: Sergio Esteves Date: Thu, 22 Jan 2026 16:41:42 +0000 Subject: [PATCH] fix: correct weight handling in approx_percentile_cont_with_weight The approx_percentile_cont_with_weight function was producing incorrect results due to wrong weight handling in the TDigest implementation. Root cause: In TDigest::new_with_centroid(), the count field was hardcoded to 1 regardless of the actual centroid weight, while the weight was correctly used in the sum calculation. This mismatch caused incorrect percentile calculations since estimate_quantile() uses count to compute the rank. Changes: - Changed TDigest::count from u64 to f64 to properly support fractional weights (consistent with ClickHouse's TDigest implementation) - Fixed new_with_centroid() to use centroid.weight for count - Updated state_fields() in approx_percentile_cont and approx_median to use Float64 for the count field - Added early return in merge_digests() when all centroids have zero weight to prevent panic - Updated test expectations to reflect correct weighted percentile behavior --- .../functions-aggregate-common/src/tdigest.rs | 54 +++++++++---------- .../functions-aggregate/src/approx_median.rs | 2 +- .../src/approx_percentile_cont.rs | 8 +-- .../sqllogictest/test_files/aggregate.slt | 21 ++++---- 4 files changed, 40 insertions(+), 45 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/tdigest.rs b/datafusion/functions-aggregate-common/src/tdigest.rs index 225c61b71939e..a7450f0eb52e9 100644 --- a/datafusion/functions-aggregate-common/src/tdigest.rs +++ b/datafusion/functions-aggregate-common/src/tdigest.rs @@ -49,17 +49,6 @@ macro_rules! cast_scalar_f64 { }; } -// Cast a non-null [`ScalarValue::UInt64`] to an [`u64`], or -// panic. -macro_rules! cast_scalar_u64 { - ($value:expr ) => { - match &$value { - ScalarValue::UInt64(Some(v)) => *v, - v => panic!("invalid type {}", v), - } - }; -} - /// Centroid implementation to the cluster mentioned in the paper. #[derive(Debug, PartialEq, Clone)] pub struct Centroid { @@ -110,7 +99,7 @@ pub struct TDigest { centroids: Vec, max_size: usize, sum: f64, - count: u64, + count: f64, max: f64, min: f64, } @@ -120,8 +109,8 @@ impl TDigest { TDigest { centroids: Vec::new(), max_size, - sum: 0_f64, - count: 0, + sum: 0.0, + count: 0.0, max: f64::NAN, min: f64::NAN, } @@ -133,14 +122,14 @@ impl TDigest { centroids: vec![centroid.clone()], max_size, sum: centroid.mean * centroid.weight, - count: 1, + count: centroid.weight, max: centroid.mean, min: centroid.mean, } } #[inline] - pub fn count(&self) -> u64 { + pub fn count(&self) -> f64 { self.count } @@ -170,8 +159,8 @@ impl Default for TDigest { TDigest { centroids: Vec::new(), max_size: 100, - sum: 0_f64, - count: 0, + sum: 0.0, + count: 0.0, max: f64::NAN, min: f64::NAN, } @@ -216,12 +205,12 @@ impl TDigest { } let mut result = TDigest::new(self.max_size()); - result.count = self.count() + sorted_values.len() as u64; + result.count = self.count() + sorted_values.len() as f64; let maybe_min = *sorted_values.first().unwrap(); let maybe_max = *sorted_values.last().unwrap(); - if self.count() > 0 { + if self.count() > 0.0 { result.min = self.min.min(maybe_min); result.max = self.max.max(maybe_max); } else { @@ -233,7 +222,7 @@ impl TDigest { let mut k_limit: u64 = 1; let mut q_limit_times_count = - Self::k_to_q(k_limit, self.max_size) * result.count() as f64; + Self::k_to_q(k_limit, self.max_size) * result.count(); k_limit += 1; let mut iter_centroids = self.centroids.iter().peekable(); @@ -281,7 +270,7 @@ impl TDigest { compressed.push(curr.clone()); q_limit_times_count = - Self::k_to_q(k_limit, self.max_size) * result.count() as f64; + Self::k_to_q(k_limit, self.max_size) * result.count(); k_limit += 1; curr = next; } @@ -353,7 +342,7 @@ impl TDigest { let mut centroids: Vec = Vec::with_capacity(n_centroids); let mut starts: Vec = Vec::with_capacity(digests.len()); - let mut count = 0; + let mut count = 0.0; let mut min = f64::INFINITY; let mut max = f64::NEG_INFINITY; @@ -362,7 +351,7 @@ impl TDigest { starts.push(start); let curr_count = digest.count(); - if curr_count > 0 { + if curr_count > 0.0 { min = min.min(digest.min); max = max.max(digest.max); count += curr_count; @@ -373,6 +362,11 @@ impl TDigest { } } + // If no centroids were added (all digests had zero count), return default + if centroids.is_empty() { + return TDigest::default(); + } + let mut digests_per_block: usize = 1; while digests_per_block < starts.len() { for i in (0..starts.len()).step_by(digests_per_block * 2) { @@ -397,7 +391,7 @@ impl TDigest { let mut compressed: Vec = Vec::with_capacity(max_size); let mut k_limit = 1; - let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64; + let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count; let mut iter_centroids = centroids.iter_mut(); let mut curr = iter_centroids.next().unwrap(); @@ -416,7 +410,7 @@ impl TDigest { sums_to_merge = 0_f64; weights_to_merge = 0_f64; compressed.push(curr.clone()); - q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64; + q_limit_times_count = Self::k_to_q(k_limit, max_size) * count; k_limit += 1; curr = centroid; } @@ -440,7 +434,7 @@ impl TDigest { return 0.0; } - let rank = q * self.count as f64; + let rank = q * self.count; let mut pos: usize; let mut t; @@ -450,7 +444,7 @@ impl TDigest { } pos = 0; - t = self.count as f64; + t = self.count; for (k, centroid) in self.centroids.iter().enumerate().rev() { t -= centroid.weight(); @@ -563,7 +557,7 @@ impl TDigest { vec![ ScalarValue::UInt64(Some(self.max_size as u64)), ScalarValue::Float64(Some(self.sum)), - ScalarValue::UInt64(Some(self.count)), + ScalarValue::Float64(Some(self.count)), ScalarValue::Float64(Some(self.max)), ScalarValue::Float64(Some(self.min)), ScalarValue::List(arr), @@ -611,7 +605,7 @@ impl TDigest { Self { max_size, sum: cast_scalar_f64!(state[1]), - count: cast_scalar_u64!(&state[2]), + count: cast_scalar_f64!(state[2]), max, min, centroids, diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 739e333b54617..2205b009ecb27 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -110,7 +110,7 @@ impl AggregateUDFImpl for ApproxMedian { Ok(vec![ Field::new(format_state_name(args.name, "max_size"), UInt64, false), Field::new(format_state_name(args.name, "sum"), Float64, false), - Field::new(format_state_name(args.name, "count"), UInt64, false), + Field::new(format_state_name(args.name, "count"), Float64, false), Field::new(format_state_name(args.name, "max"), Float64, false), Field::new(format_state_name(args.name, "min"), Float64, false), Field::new_list( diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index b1e649ec029ff..392a044d01394 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -259,7 +259,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { ), Field::new( format_state_name(args.name, "count"), - DataType::UInt64, + DataType::Float64, false, ), Field::new( @@ -436,7 +436,7 @@ impl Accumulator for ApproxPercentileAccumulator { } fn evaluate(&mut self) -> Result { - if self.digest.count() == 0 { + if self.digest.count() == 0.0 { return ScalarValue::try_from(self.return_type.clone()); } let q = self.digest.estimate_quantile(self.percentile); @@ -513,8 +513,8 @@ mod tests { ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100); accumulator.merge_digests(&[t1]); - assert_eq!(accumulator.digest.count(), 50_000); + assert_eq!(accumulator.digest.count(), 50_000.0); accumulator.merge_digests(&[t2]); - assert_eq!(accumulator.digest.count(), 100_000); + assert_eq!(accumulator.digest.count(), 100_000.0); } } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index e911a16be75f5..07b93a16b7c17 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2029,11 +2029,12 @@ statement ok INSERT INTO t1 VALUES (TRUE); # ISSUE: https://github.com/apache/datafusion/issues/12716 -# This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN' and returns 'inf' +# This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN' +# With weight=0, the data point does not contribute, so result is NULL query R SELECT approx_percentile_cont_with_weight(0, 0) WITHIN GROUP (ORDER BY 'NaN'::DOUBLE) FROM t1 WHERE t1.v1; ---- -Infinity +NULL statement ok DROP TABLE t1; @@ -2352,21 +2353,21 @@ e 115 query TI SELECT c1, approx_percentile_cont_with_weight(c2, 0.95) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- -a 74 +a 65 b 68 -c 123 -d 124 -e 115 +c 122 +d 123 +e 110 # approx_percentile_cont_with_weight with centroids query TI SELECT c1, approx_percentile_cont_with_weight(c2, 0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- -a 74 +a 65 b 68 -c 123 -d 124 -e 115 +c 122 +d 123 +e 110 # csv_query_sum_crossjoin query TTI