Skip to content

Commit 167656f

Browse files
committed
GH-45755: [C++][Compute] Add winsorize function
1 parent c3e399a commit 167656f

20 files changed

+480
-29
lines changed

cpp/src/arrow/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,7 @@ if(ARROW_COMPUTE)
783783
compute/kernels/vector_run_end_encode.cc
784784
compute/kernels/vector_select_k.cc
785785
compute/kernels/vector_sort.cc
786+
compute/kernels/vector_statistics.cc
786787
compute/kernels/vector_swizzle.cc
787788
compute/key_hash_internal.cc
788789
compute/key_map_internal.cc

cpp/src/arrow/compute/api_vector.cc

+9
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ static auto kSortOptionsType = GetFunctionOptionsType<SortOptions>(
142142
static auto kPartitionNthOptionsType = GetFunctionOptionsType<PartitionNthOptions>(
143143
DataMember("pivot", &PartitionNthOptions::pivot),
144144
DataMember("null_placement", &PartitionNthOptions::null_placement));
145+
static auto kWinsorizeOptionsType = GetFunctionOptionsType<WinsorizeOptions>(
146+
DataMember("lower_limit", &WinsorizeOptions::lower_limit),
147+
DataMember("upper_limit", &WinsorizeOptions::upper_limit));
145148
static auto kSelectKOptionsType = GetFunctionOptionsType<SelectKOptions>(
146149
DataMember("k", &SelectKOptions::k),
147150
DataMember("sort_keys", &SelectKOptions::sort_keys));
@@ -208,6 +211,11 @@ PartitionNthOptions::PartitionNthOptions(int64_t pivot, NullPlacement null_place
208211
null_placement(null_placement) {}
209212
constexpr char PartitionNthOptions::kTypeName[];
210213

214+
WinsorizeOptions::WinsorizeOptions(double lower_limit, double upper_limit)
215+
: FunctionOptions(internal::kWinsorizeOptionsType),
216+
lower_limit(lower_limit),
217+
upper_limit(upper_limit) {}
218+
211219
SelectKOptions::SelectKOptions(int64_t k, std::vector<SortKey> sort_keys)
212220
: FunctionOptions(internal::kSelectKOptionsType),
213221
k(k),
@@ -275,6 +283,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
275283
DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType));
276284
DCHECK_OK(registry->AddFunctionOptionsType(kInversePermutationOptionsType));
277285
DCHECK_OK(registry->AddFunctionOptionsType(kScatterOptionsType));
286+
DCHECK_OK(registry->AddFunctionOptionsType(kWinsorizeOptionsType));
278287
}
279288
} // namespace internal
280289

cpp/src/arrow/compute/api_vector.h

+19
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,25 @@ class ARROW_EXPORT PartitionNthOptions : public FunctionOptions {
228228
NullPlacement null_placement;
229229
};
230230

231+
class ARROW_EXPORT WinsorizeOptions : public FunctionOptions {
232+
public:
233+
WinsorizeOptions(double lower_limit, double upper_limit);
234+
WinsorizeOptions() : WinsorizeOptions(0, 1) {}
235+
static constexpr char const kTypeName[] = "WinsorizeOptions";
236+
237+
/// The quantile below which all values are replaced with the quantile's value.
238+
///
239+
/// For example, if lower_limit = 0.05, then all values in the lower 5% percentile
240+
/// will be replaced with the 5% percentile value.
241+
double lower_limit;
242+
243+
/// The quantile above which all values are replaced with the quantile's value.
244+
///
245+
/// For example, if upper_limit = 0.95, then all values in the upper 95% percentile
246+
/// will be replaced with the 95% percentile value.
247+
double upper_limit;
248+
};
249+
231250
/// \brief Options for cumulative functions
232251
/// \note Also aliased as CumulativeSumOptions for backward compatibility
233252
class ARROW_EXPORT CumulativeOptions : public FunctionOptions {

cpp/src/arrow/compute/kernels/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ add_arrow_compute_test(vector_test
105105
vector_nested_test.cc
106106
vector_replace_test.cc
107107
vector_run_end_encode_test.cc
108+
vector_statistics_test.cc
108109
select_k_test.cc
109110
EXTRA_LINK_LIBS
110111
arrow_compute_kernels_testing

cpp/src/arrow/compute/kernels/aggregate_mode.cc

+3-4
Original file line numberDiff line numberDiff line change
@@ -495,10 +495,9 @@ void RegisterScalarAggregateMode(FunctionRegistry* registry) {
495495
ModeExecutorChunked<StructType, BooleanType>::Exec)));
496496
for (const auto& type : NumericTypes()) {
497497
// TODO(wesm):
498-
DCHECK_OK(func->AddKernel(NewModeKernel(
499-
type, GenerateNumeric<ModeExecutor, StructType>(*type),
500-
GenerateNumeric<ModeExecutorChunked, StructType, VectorKernel::ChunkedExec>(
501-
*type))));
498+
DCHECK_OK(func->AddKernel(
499+
NewModeKernel(type, GenerateNumeric<ModeExecutor, StructType>(*type),
500+
GenerateNumeric<ModeExecutorChunked, StructType>(*type))));
502501
}
503502
// Type parameters are ignored
504503
DCHECK_OK(func->AddKernel(

cpp/src/arrow/compute/kernels/aggregate_quantile.cc

+11-15
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ template <typename T>
7676
double DataPointToDouble(T value, const DataType&) {
7777
return static_cast<double>(value);
7878
}
79+
double DataPointToDouble(const Decimal32& value, const DataType& ty) {
80+
return value.ToDouble(checked_cast<const DecimalType&>(ty).scale());
81+
}
82+
double DataPointToDouble(const Decimal64& value, const DataType& ty) {
83+
return value.ToDouble(checked_cast<const DecimalType&>(ty).scale());
84+
}
7985
double DataPointToDouble(const Decimal128& value, const DataType& ty) {
8086
return value.ToDouble(checked_cast<const DecimalType&>(ty).scale());
8187
}
@@ -524,23 +530,13 @@ void AddQuantileKernels(VectorFunction* func) {
524530
base.signature = KernelSignature::Make({InputType(ty)}, OutputType(ResolveOutput));
525531
// output type is determined at runtime, set template argument to nulltype
526532
base.exec = GenerateNumeric<QuantileExecutor, NullType>(*ty);
527-
base.exec_chunked =
528-
GenerateNumeric<QuantileExecutorChunked, NullType, VectorKernel::ChunkedExec>(
529-
*ty);
530-
DCHECK_OK(func->AddKernel(base));
531-
}
532-
{
533-
base.signature =
534-
KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ResolveOutput));
535-
base.exec = QuantileExecutor<NullType, Decimal128Type>::Exec;
536-
base.exec_chunked = QuantileExecutorChunked<NullType, Decimal128Type>::Exec;
533+
base.exec_chunked = GenerateNumeric<QuantileExecutorChunked, NullType>(*ty);
537534
DCHECK_OK(func->AddKernel(base));
538535
}
539-
{
540-
base.signature =
541-
KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ResolveOutput));
542-
base.exec = QuantileExecutor<NullType, Decimal256Type>::Exec;
543-
base.exec_chunked = QuantileExecutorChunked<NullType, Decimal256Type>::Exec;
536+
for (auto type_id : DecimalTypeIds()) {
537+
base.signature = KernelSignature::Make({type_id}, OutputType(ResolveOutput));
538+
base.exec = GenerateDecimal<QuantileExecutor, NullType>(type_id);
539+
base.exec_chunked = GenerateDecimal<QuantileExecutorChunked, NullType>(type_id);
544540
DCHECK_OK(func->AddKernel(base));
545541
}
546542
}

cpp/src/arrow/compute/kernels/aggregate_test.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -4258,7 +4258,8 @@ TEST(TestQuantileKernel, Decimal) {
42584258
ValidateOutput(*out_array);
42594259
AssertArraysEqual(*expected, *out_array, /*verbose=*/true);
42604260
};
4261-
for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
4261+
for (const auto& ty :
4262+
{decimal32(3, 2), decimal64(3, 2), decimal128(3, 2), decimal256(3, 2)}) {
42624263
check(ArrayFromJSON(ty, R"(["1.00", "5.00", null])"),
42634264
QuantileOptions(0.5, QuantileOptions::LINEAR),
42644265
ArrayFromJSON(float64(), R"([3.00])"));

cpp/src/arrow/compute/kernels/codegen_internal.h

+6-5
Original file line numberDiff line numberDiff line change
@@ -988,9 +988,9 @@ struct FailFunctor<VectorKernel::ChunkedExec> {
988988
};
989989

990990
// GD for numeric types (integer and floating point)
991-
template <template <typename...> class Generator, typename Type0,
992-
typename KernelType = ArrayKernelExec, typename... Args>
993-
KernelType GenerateNumeric(detail::GetTypeId get_id) {
991+
template <template <typename...> class Generator, typename Type0, typename... Args>
992+
auto GenerateNumeric(detail::GetTypeId get_id) {
993+
using KernelType = decltype(&Generator<Type0, Int8Type, Args...>::Exec);
994994
switch (get_id.id) {
995995
case Type::INT8:
996996
return Generator<Type0, Int8Type, Args...>::Exec;
@@ -1367,7 +1367,8 @@ ArrayKernelExec GenerateTemporal(detail::GetTypeId get_id) {
13671367
//
13681368
// See "Numeric" above for description of the generator functor
13691369
template <template <typename...> class Generator, typename Type0, typename... Args>
1370-
ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) {
1370+
auto GenerateDecimal(detail::GetTypeId get_id) {
1371+
using KernelType = decltype(&Generator<Type0, Decimal256Type, Args...>::Exec);
13711372
switch (get_id.id) {
13721373
case Type::DECIMAL32:
13731374
return Generator<Type0, Decimal32Type, Args...>::Exec;
@@ -1379,7 +1380,7 @@ ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) {
13791380
return Generator<Type0, Decimal256Type, Args...>::Exec;
13801381
default:
13811382
DCHECK(false);
1382-
return nullptr;
1383+
return KernelType(nullptr);
13831384
}
13841385
}
13851386

cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ std::vector<std::shared_ptr<CastFunction>> GetBooleanCasts() {
5454

5555
for (const auto& ty : NumericTypes()) {
5656
ArrayKernelExec exec =
57-
GenerateNumeric<applicator::ScalarUnary, BooleanType, ArrayKernelExec, IsNonZero>(
58-
*ty);
57+
GenerateNumeric<applicator::ScalarUnary, BooleanType, IsNonZero>(*ty);
5958
DCHECK_OK(func->AddKernel(ty->id(), {ty}, boolean(), exec));
6059
}
6160
for (const auto& ty : BaseBinaryTypes()) {

cpp/src/arrow/compute/kernels/scalar_cast_string.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -683,8 +683,7 @@ void AddNumberToStringCasts(CastFunction* func) {
683683
template <typename OutType>
684684
void AddDecimalToStringCasts(CastFunction* func) {
685685
auto out_ty = TypeTraits<OutType>::type_singleton();
686-
for (const auto& in_tid : std::vector<Type::type>{Type::DECIMAL32, Type::DECIMAL64,
687-
Type::DECIMAL128, Type::DECIMAL256}) {
686+
for (const auto& in_tid : DecimalTypeIds()) {
688687
DCHECK_OK(
689688
func->AddKernel(in_tid, {in_tid}, out_ty,
690689
GenerateDecimal<DecimalToStringCastFunctor, OutType>(in_tid),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include <functional>
19+
#include <memory>
20+
#include <optional>
21+
#include <utility>
22+
23+
#include "arrow/compute/api_aggregate.h"
24+
#include "arrow/compute/api_vector.h"
25+
#include "arrow/compute/exec.h"
26+
#include "arrow/compute/function.h"
27+
#include "arrow/compute/kernel.h"
28+
#include "arrow/compute/kernels/codegen_internal.h"
29+
#include "arrow/compute/registry.h"
30+
#include "arrow/result.h"
31+
#include "arrow/scalar.h"
32+
#include "arrow/status.h"
33+
#include "arrow/util/bit_run_reader.h"
34+
#include "arrow/util/checked_cast.h"
35+
#include "arrow/util/logging.h"
36+
37+
namespace arrow::compute::internal {
38+
39+
using ::arrow::internal::checked_cast;
40+
41+
namespace {
42+
43+
Status ValidateOptions(const WinsorizeOptions& options) {
44+
if (!(options.lower_limit >= 0 && options.lower_limit <= 1) ||
45+
!(options.upper_limit >= 0 && options.upper_limit <= 1)) {
46+
return Status::Invalid("winsorize limits must be between 0 and 1");
47+
}
48+
if (options.lower_limit > options.upper_limit) {
49+
return Status::Invalid(
50+
"winsorize upper limit must be equal or greater than lower limit");
51+
}
52+
return Status::OK();
53+
}
54+
55+
using WinsorizeState = internal::OptionsWrapper<WinsorizeOptions>;
56+
57+
// We have a first unused template parameter for compatibility with GenerateNumeric.
58+
template <typename Unused, typename Type>
59+
struct Winsorize {
60+
using ArrayType = typename TypeTraits<Type>::ArrayType;
61+
using CType = typename TypeTraits<Type>::CType;
62+
63+
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
64+
const auto& options = WinsorizeState::Get(ctx);
65+
RETURN_NOT_OK(ValidateOptions(options));
66+
ARROW_ASSIGN_OR_RAISE(auto maybe_quantiles,
67+
GetQuantileValues(ctx, batch.ToExecBatch(), options));
68+
auto data = batch.values[0].array.ToArrayData();
69+
auto out_data = out->array_data_mutable();
70+
if (!maybe_quantiles.has_value()) {
71+
// Only nulls and NaNs => return input as-is
72+
out_data->null_count = data->null_count.load();
73+
out_data->length = data->length;
74+
out_data->buffers = data->buffers;
75+
return Status::OK();
76+
}
77+
return ClipValues(*data, maybe_quantiles.value(), out_data, ctx);
78+
}
79+
80+
static Status ExecChunked(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
81+
const auto& options = WinsorizeState::Get(ctx);
82+
RETURN_NOT_OK(ValidateOptions(options));
83+
ARROW_ASSIGN_OR_RAISE(auto maybe_quantiles, GetQuantileValues(ctx, batch, options));
84+
const auto& chunked_array = batch.values[0].chunked_array();
85+
if (!maybe_quantiles.has_value()) {
86+
// Only nulls and NaNs => return input as-is
87+
*out = chunked_array;
88+
return Status::OK();
89+
}
90+
ArrayVector out_chunks;
91+
out_chunks.reserve(chunked_array->num_chunks());
92+
for (const auto& chunk : chunked_array->chunks()) {
93+
auto out_data = chunk->data()->Copy();
94+
RETURN_NOT_OK(
95+
ClipValues(*chunk->data(), maybe_quantiles.value(), out_data.get(), ctx));
96+
out_chunks.push_back(MakeArray(out_data));
97+
}
98+
return ChunkedArray::Make(std::move(out_chunks)).Value(out);
99+
}
100+
101+
struct QuantileValues {
102+
CType lower_bound, upper_bound;
103+
};
104+
105+
static Result<std::optional<QuantileValues>> GetQuantileValues(
106+
KernelContext* ctx, const ExecBatch& batch, const WinsorizeOptions& options) {
107+
// We use "nearest" to avoid the conversion of quantile values to double.
108+
QuantileOptions quantile_options(/*q=*/{options.lower_limit, options.upper_limit},
109+
QuantileOptions::NEAREST);
110+
ARROW_ASSIGN_OR_RAISE(
111+
auto quantile,
112+
CallFunction("quantile", batch, &quantile_options, ctx->exec_context()));
113+
auto quantile_array = quantile.array_as<ArrayType>();
114+
DCHECK_EQ(quantile_array->length(), 2);
115+
if (quantile_array->null_count() == 2) {
116+
return std::nullopt;
117+
}
118+
DCHECK_EQ(quantile_array->null_count(), 0);
119+
return QuantileValues{CType(quantile_array->Value(0)),
120+
CType(quantile_array->Value(1))};
121+
}
122+
123+
static Status ClipValues(const ArrayData& data, QuantileValues quantiles,
124+
ArrayData* out, KernelContext* ctx) {
125+
DCHECK_EQ(out->buffers.size(), data.buffers.size());
126+
out->null_count = data.null_count.load();
127+
out->length = data.length;
128+
out->buffers[0] = data.buffers[0];
129+
ARROW_ASSIGN_OR_RAISE(out->buffers[1], ctx->Allocate(out->length * sizeof(CType)));
130+
// Avoid leaving uninitialized memory under null entries
131+
std::memset(out->buffers[1]->mutable_data(), 0, out->length * sizeof(CType));
132+
133+
const CType* in_values = data.GetValues<CType>(1);
134+
CType* out_values = out->GetMutableValues<CType>(1);
135+
136+
auto visit = [&](int64_t position, int64_t length) {
137+
for (int64_t i = position; i < position + length; ++i) {
138+
if (in_values[i] < quantiles.lower_bound) {
139+
out_values[i] = quantiles.lower_bound;
140+
} else if (in_values[i] > quantiles.upper_bound) {
141+
out_values[i] = quantiles.upper_bound;
142+
} else {
143+
// NaNs also fall here
144+
out_values[i] = in_values[i];
145+
}
146+
}
147+
};
148+
arrow::internal::VisitSetBitRunsVoid(data.buffers[0], data.offset, data.length,
149+
visit);
150+
return Status::OK();
151+
}
152+
};
153+
154+
template <typename Unused, typename Type>
155+
struct WinsorizeChunked {
156+
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
157+
return Winsorize<Unused, Type>::ExecChunked(ctx, batch, out);
158+
}
159+
};
160+
161+
Result<TypeHolder> ResolveWinsorizeOutput(KernelContext* ctx,
162+
const std::vector<TypeHolder>& in_types) {
163+
DCHECK_EQ(in_types.size(), 1);
164+
return in_types[0];
165+
}
166+
167+
const FunctionDoc winsorize_doc(
168+
"Winsorize an array",
169+
("This function applies a winsorization transform to the input array\n"
170+
"so as to reduce the influence of potential outliers.\n"
171+
"NaNs and nulls in the input are ignored for the purpose of computing\n"
172+
"the lower and upper quantiles.\n"
173+
"The quantile limits can be changed in WinsorizeOptions."),
174+
{"array"}, "WinsorizeOptions", /*options_required=*/true);
175+
176+
} // namespace
177+
178+
void RegisterVectorStatistics(FunctionRegistry* registry) {
179+
const static auto default_winsorize_options = WinsorizeOptions();
180+
181+
auto winsorize = std::make_shared<VectorFunction>(
182+
"winsorize", Arity::Unary(), winsorize_doc, &default_winsorize_options);
183+
184+
VectorKernel base;
185+
base.init = WinsorizeState::Init;
186+
base.mem_allocation = MemAllocation::NO_PREALLOCATE;
187+
base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
188+
base.can_execute_chunkwise = false;
189+
// The variable is ill-named, but since we output a ChunkedArray ourselves,
190+
// the function execution logic shouldn't try to wrap it again.
191+
base.output_chunked = false;
192+
193+
for (const auto& ty : NumericTypes()) {
194+
base.signature = KernelSignature::Make({ty->id()}, &ResolveWinsorizeOutput);
195+
base.exec = GenerateNumeric<Winsorize, /*Unused*/ void>(ty->id());
196+
base.exec_chunked = GenerateNumeric<WinsorizeChunked, /*Unused*/ void>(ty->id());
197+
DCHECK_OK(winsorize->AddKernel(base));
198+
}
199+
for (auto type_id : DecimalTypeIds()) {
200+
base.signature = KernelSignature::Make({type_id}, &ResolveWinsorizeOutput);
201+
base.exec = GenerateDecimal<Winsorize, /*Unused*/ void>(type_id);
202+
base.exec_chunked = GenerateDecimal<WinsorizeChunked, /*Unused*/ void>(type_id);
203+
DCHECK_OK(winsorize->AddKernel(base));
204+
}
205+
DCHECK_OK(registry->AddFunction(std::move(winsorize)));
206+
}
207+
208+
} // namespace arrow::compute::internal

0 commit comments

Comments
 (0)