Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 24 additions & 30 deletions datafusion/functions-aggregate-common/src/tdigest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,6 @@ macro_rules! cast_scalar_f64 {
};
}

// Cast a non-null [`ScalarValue::UInt64`] to an [`u64`], or
// panic.
macro_rules! cast_scalar_u64 {
($value:expr ) => {
match &$value {
ScalarValue::UInt64(Some(v)) => *v,
v => panic!("invalid type {}", v),
}
};
}

/// Centroid implementation to the cluster mentioned in the paper.
#[derive(Debug, PartialEq, Clone)]
pub struct Centroid {
Expand Down Expand Up @@ -110,7 +99,7 @@ pub struct TDigest {
centroids: Vec<Centroid>,
max_size: usize,
sum: f64,
count: u64,
count: f64,
max: f64,
min: f64,
}
Expand All @@ -120,8 +109,8 @@ impl TDigest {
TDigest {
centroids: Vec::new(),
max_size,
sum: 0_f64,
count: 0,
sum: 0.0,
count: 0.0,
max: f64::NAN,
min: f64::NAN,
}
Expand All @@ -133,14 +122,14 @@ impl TDigest {
centroids: vec![centroid.clone()],
max_size,
sum: centroid.mean * centroid.weight,
count: 1,
count: centroid.weight,
max: centroid.mean,
min: centroid.mean,
}
}

#[inline]
pub fn count(&self) -> u64 {
pub fn count(&self) -> f64 {
self.count
}

Expand Down Expand Up @@ -170,8 +159,8 @@ impl Default for TDigest {
TDigest {
centroids: Vec::new(),
max_size: 100,
sum: 0_f64,
count: 0,
sum: 0.0,
count: 0.0,
max: f64::NAN,
min: f64::NAN,
}
Expand Down Expand Up @@ -216,12 +205,12 @@ impl TDigest {
}

let mut result = TDigest::new(self.max_size());
result.count = self.count() + sorted_values.len() as u64;
result.count = self.count() + sorted_values.len() as f64;

let maybe_min = *sorted_values.first().unwrap();
let maybe_max = *sorted_values.last().unwrap();

if self.count() > 0 {
if self.count() > 0.0 {
result.min = self.min.min(maybe_min);
result.max = self.max.max(maybe_max);
} else {
Expand All @@ -233,7 +222,7 @@ impl TDigest {

let mut k_limit: u64 = 1;
let mut q_limit_times_count =
Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
Self::k_to_q(k_limit, self.max_size) * result.count();
k_limit += 1;

let mut iter_centroids = self.centroids.iter().peekable();
Expand Down Expand Up @@ -281,7 +270,7 @@ impl TDigest {

compressed.push(curr.clone());
q_limit_times_count =
Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
Self::k_to_q(k_limit, self.max_size) * result.count();
k_limit += 1;
curr = next;
}
Expand Down Expand Up @@ -353,7 +342,7 @@ impl TDigest {
let mut centroids: Vec<Centroid> = Vec::with_capacity(n_centroids);
let mut starts: Vec<usize> = Vec::with_capacity(digests.len());

let mut count = 0;
let mut count = 0.0;
let mut min = f64::INFINITY;
let mut max = f64::NEG_INFINITY;

Expand All @@ -362,7 +351,7 @@ impl TDigest {
starts.push(start);

let curr_count = digest.count();
if curr_count > 0 {
if curr_count > 0.0 {
min = min.min(digest.min);
max = max.max(digest.max);
count += curr_count;
Expand All @@ -373,6 +362,11 @@ impl TDigest {
}
}

// If no centroids were added (all digests had zero count), return default
if centroids.is_empty() {
return TDigest::default();
}

let mut digests_per_block: usize = 1;
while digests_per_block < starts.len() {
for i in (0..starts.len()).step_by(digests_per_block * 2) {
Expand All @@ -397,7 +391,7 @@ impl TDigest {
let mut compressed: Vec<Centroid> = Vec::with_capacity(max_size);

let mut k_limit = 1;
let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count;

let mut iter_centroids = centroids.iter_mut();
let mut curr = iter_centroids.next().unwrap();
Expand All @@ -416,7 +410,7 @@ impl TDigest {
sums_to_merge = 0_f64;
weights_to_merge = 0_f64;
compressed.push(curr.clone());
q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
q_limit_times_count = Self::k_to_q(k_limit, max_size) * count;
k_limit += 1;
curr = centroid;
}
Expand All @@ -440,7 +434,7 @@ impl TDigest {
return 0.0;
}

let rank = q * self.count as f64;
let rank = q * self.count;

let mut pos: usize;
let mut t;
Expand All @@ -450,7 +444,7 @@ impl TDigest {
}

pos = 0;
t = self.count as f64;
t = self.count;

for (k, centroid) in self.centroids.iter().enumerate().rev() {
t -= centroid.weight();
Expand Down Expand Up @@ -563,7 +557,7 @@ impl TDigest {
vec![
ScalarValue::UInt64(Some(self.max_size as u64)),
ScalarValue::Float64(Some(self.sum)),
ScalarValue::UInt64(Some(self.count)),
ScalarValue::Float64(Some(self.count)),
ScalarValue::Float64(Some(self.max)),
ScalarValue::Float64(Some(self.min)),
ScalarValue::List(arr),
Expand Down Expand Up @@ -611,7 +605,7 @@ impl TDigest {
Self {
max_size,
sum: cast_scalar_f64!(state[1]),
count: cast_scalar_u64!(&state[2]),
count: cast_scalar_f64!(state[2]),
max,
min,
centroids,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/approx_median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl AggregateUDFImpl for ApproxMedian {
Ok(vec![
Field::new(format_state_name(args.name, "max_size"), UInt64, false),
Field::new(format_state_name(args.name, "sum"), Float64, false),
Field::new(format_state_name(args.name, "count"), UInt64, false),
Field::new(format_state_name(args.name, "count"), Float64, false),
Field::new(format_state_name(args.name, "max"), Float64, false),
Field::new(format_state_name(args.name, "min"), Float64, false),
Field::new_list(
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ impl AggregateUDFImpl for ApproxPercentileCont {
),
Field::new(
format_state_name(args.name, "count"),
DataType::UInt64,
DataType::Float64,
false,
),
Field::new(
Expand Down Expand Up @@ -436,7 +436,7 @@ impl Accumulator for ApproxPercentileAccumulator {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
if self.digest.count() == 0 {
if self.digest.count() == 0.0 {
return ScalarValue::try_from(self.return_type.clone());
}
let q = self.digest.estimate_quantile(self.percentile);
Expand Down
21 changes: 11 additions & 10 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2029,11 +2029,12 @@ statement ok
INSERT INTO t1 VALUES (TRUE);

# ISSUE: https://github.com/apache/datafusion/issues/12716
# This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN' and returns 'inf'
# This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN'
# With weight=0, the data point does not contribute, so result is NULL
query R
SELECT approx_percentile_cont_with_weight(0, 0) WITHIN GROUP (ORDER BY 'NaN'::DOUBLE) FROM t1 WHERE t1.v1;
----
Infinity
NULL

statement ok
DROP TABLE t1;
Expand Down Expand Up @@ -2352,21 +2353,21 @@ e 115
query TI
SELECT c1, approx_percentile_cont_with_weight(c2, 0.95) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1
----
a 74
a 65
b 68
c 123
d 124
e 115
c 122
d 123
e 110

# approx_percentile_cont_with_weight with centroids
query TI
SELECT c1, approx_percentile_cont_with_weight(c2, 0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1
----
a 74
a 65
b 68
c 123
d 124
e 115
c 122
d 123
e 110

# csv_query_sum_crossjoin
query TTI
Expand Down
Loading