Skip to content

Commit 360bcd0

Browse files
committed
Add bias to Skew on hash_aggregate
1 parent 208fdfb commit 360bcd0

File tree

3 files changed

+26
-16
lines changed

3 files changed

+26
-16
lines changed

cpp/src/arrow/acero/hash_aggregate_test.cc

+11-4
Original file line numberDiff line numberDiff line change
@@ -1511,6 +1511,9 @@ TEST_P(GroupBy, VarianceOptionsAndSkewOptions) {
15111511
auto skew_keep_nulls_min_count = std::make_shared<SkewOptions>(
15121512
/*skip_nulls=*/false, /*bias=*/true, /*min_count=*/3);
15131513

1514+
auto skew_unbiased = std::make_shared<SkewOptions>(
1515+
/*skip_nulls=*/false, /*bias=*/false, /*min_count=*/0);
1516+
15141517
for (std::string value_column : {"argument", "argument1"}) {
15151518
for (bool use_threads : {false}) {
15161519
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
@@ -1554,26 +1557,30 @@ TEST_P(GroupBy, VarianceOptionsAndSkewOptions) {
15541557
{"hash_skew", skew_keep_nulls, value_column, "hash_skew"},
15551558
{"hash_skew", skew_min_count, value_column, "hash_skew"},
15561559
{"hash_skew", skew_keep_nulls_min_count, value_column, "hash_skew"},
1560+
{"hash_skew", skew_unbiased, value_column, "hash_skew"},
15571561
{"hash_kurtosis", skew_keep_nulls, value_column, "hash_kurtosis"},
15581562
{"hash_kurtosis", skew_min_count, value_column, "hash_kurtosis"},
15591563
{"hash_kurtosis", skew_keep_nulls_min_count, value_column,
15601564
"hash_kurtosis"},
1565+
{"hash_kurtosis", skew_unbiased, value_column, "hash_kurtosis"},
15611566
},
15621567
use_threads));
15631568
expected = ArrayFromJSON(struct_({
15641569
field("key", int64()),
15651570
field("hash_skew", float64()),
15661571
field("hash_skew", float64()),
15671572
field("hash_skew", float64()),
1573+
field("hash_skew", float64()),
1574+
field("hash_kurtosis", float64()),
15681575
field("hash_kurtosis", float64()),
15691576
field("hash_kurtosis", float64()),
15701577
field("hash_kurtosis", float64()),
15711578
}),
15721579
R"([
1573-
[1, null, 0.707106, null, null, -1.5, null ],
1574-
[2, 0.213833, 0.213833, 0.213833, -1.720164, -1.720164, -1.720164],
1575-
[3, 0.0, null, null, -2.0, null, null ],
1576-
[4, null, 0.707106, null, null, -1.5, null ]
1580+
[1, null, 0.707106, null, null, null, -1.5, null, null ],
1581+
[2, 0.213833, 0.213833, 0.213833, 0.37037, -1.720164, -1.720164, -1.720164, -3.90123],
1582+
[3, 0.0, null, null, -NaN, -2.0, null, null, -NaN ],
1583+
[4, null, 0.707106, null, null, null, -1.5, null, null ]
15771584
])");
15781585
ValidateOutput(actual);
15791586
AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);

cpp/src/arrow/compute/kernels/aggregate_var_std_internal.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ struct Moments {
103103
result = count * m4 / (m2 * m2) - 3;
104104
} else {
105105
result = 1.0 / (count - 2) / (count - 3) *
106-
((pow(count, 2) - 1.0) * (m4 / count) / pow((m2 / count), 2.0) -
107-
3 * pow((count - 1), 2.0));
106+
((pow(count, 2) - 1.0) * (m4 / count) / pow((m2 / count), 2) -
107+
3 * pow((count - 1), 2));
108108
}
109109
return result;
110110
}

cpp/src/arrow/compute/kernels/hash_aggregate_numeric.cc

+13-10
Original file line numberDiff line numberDiff line change
@@ -452,35 +452,37 @@ struct GroupedStatisticImpl : public GroupedAggregator {
452452
Status InitInternal(ExecContext* ctx, const KernelInitArgs& args,
453453
StatisticType stat_type, const VarianceOptions& options) {
454454
return InitInternal(ctx, args, stat_type, options.ddof, options.skip_nulls,
455-
options.min_count);
455+
/*bias=*/false, options.min_count);
456456
}
457457

458458
// Init helper for hash_skew and hash_kurtosis
459459
Status InitInternal(ExecContext* ctx, const KernelInitArgs& args,
460460
StatisticType stat_type, const SkewOptions& options) {
461461
return InitInternal(ctx, args, stat_type, /*ddof=*/0, options.skip_nulls,
462-
options.min_count);
462+
options.bias, options.min_count);
463463
}
464464

465465
Status InitInternal(ExecContext* ctx, const KernelInitArgs& args,
466-
StatisticType stat_type, int ddof, bool skip_nulls,
466+
StatisticType stat_type, int ddof, bool skip_nulls, bool bias,
467467
uint32_t min_count) {
468468
if constexpr (is_decimal_type<Type>::value) {
469469
int32_t decimal_scale =
470470
checked_cast<const DecimalType&>(*args.inputs[0].type).scale();
471-
return InitInternal(ctx, stat_type, decimal_scale, ddof, skip_nulls, min_count);
471+
return InitInternal(ctx, stat_type, decimal_scale, ddof, skip_nulls, bias,
472+
min_count);
472473
} else {
473-
return InitInternal(ctx, stat_type, /*decimal_scale=*/0, ddof, skip_nulls,
474+
return InitInternal(ctx, stat_type, /*decimal_scale=*/0, ddof, skip_nulls, bias,
474475
min_count);
475476
}
476477
}
477478

478479
Status InitInternal(ExecContext* ctx, StatisticType stat_type, int32_t decimal_scale,
479-
int ddof, bool skip_nulls, uint32_t min_count) {
480+
int ddof, bool skip_nulls, bool bias, uint32_t min_count) {
480481
stat_type_ = stat_type;
481482
moments_level_ = moments_level_for_statistic(stat_type_);
482483
decimal_scale_ = decimal_scale;
483484
skip_nulls_ = skip_nulls;
485+
bias_ = bias;
484486
min_count_ = min_count;
485487
ddof_ = ddof;
486488
ctx_ = ctx;
@@ -539,7 +541,7 @@ struct GroupedStatisticImpl : public GroupedAggregator {
539541
Status ConsumeGeneric(const ExecSpan& batch) {
540542
GroupedStatisticImpl<Type> state;
541543
RETURN_NOT_OK(state.InitInternal(ctx_, stat_type_, decimal_scale_, ddof_, skip_nulls_,
542-
min_count_));
544+
bias_, min_count_));
543545
RETURN_NOT_OK(state.Resize(num_groups_));
544546
int64_t* counts = state.counts_.mutable_data();
545547
double* means = state.means_.mutable_data();
@@ -612,7 +614,7 @@ struct GroupedStatisticImpl : public GroupedAggregator {
612614
var_std.resize(num_groups_);
613615
GroupedStatisticImpl<Type> state;
614616
RETURN_NOT_OK(state.InitInternal(ctx_, stat_type_, decimal_scale_, ddof_,
615-
skip_nulls_, min_count_));
617+
skip_nulls_, bias_, min_count_));
616618
RETURN_NOT_OK(state.Resize(num_groups_));
617619
int64_t* other_counts = state.counts_.mutable_data();
618620
double* other_means = state.means_.mutable_data();
@@ -749,10 +751,10 @@ struct GroupedStatisticImpl : public GroupedAggregator {
749751
results[i] = moments.Stddev(ddof_);
750752
break;
751753
case StatisticType::Skew:
752-
results[i] = moments.Skew();
754+
results[i] = moments.Skew(bias_);
753755
break;
754756
case StatisticType::Kurtosis:
755-
results[i] = moments.Kurtosis();
757+
results[i] = moments.Kurtosis(bias_);
756758
break;
757759
default:
758760
return Status::NotImplemented("Statistic type ",
@@ -809,6 +811,7 @@ struct GroupedStatisticImpl : public GroupedAggregator {
809811
int moments_level_;
810812
int32_t decimal_scale_;
811813
bool skip_nulls_;
814+
bool bias_;
812815
uint32_t min_count_;
813816
int ddof_;
814817
int64_t num_groups_ = 0;

0 commit comments

Comments
 (0)