Skip to content

Commit 89f0abe

Browse files
authored
Use radix sort for float/double types (#19137)
Use a custom decomposer on the CUB radix sort APIs to influence NaN behavior for sorting floats in libcudf. This is for `cudf::sorted_order()`. A follow-on PR will address `cudf::sort`. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Bradley Dice (https://github.com/bdice) - Paul Mattione (https://github.com/pmattione-nvidia) URL: #19137
1 parent ba20837 commit 89f0abe

File tree

5 files changed

+82
-33
lines changed

5 files changed

+82
-33
lines changed

cpp/src/sort/faster_sort_column.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,8 @@ void faster_sorted_order<sort_method::UNSTABLE>(column_view const& input,
3232
bool ascending,
3333
rmm::cuda_stream_view stream)
3434
{
35-
auto col_temp = column(input, stream);
36-
auto d_col = col_temp.mutable_view();
37-
thrust::sequence(
38-
rmm::exec_policy_nosync(stream), indices.begin<size_type>(), indices.end<size_type>(), 0);
35+
auto col_temp = column(input, stream);
36+
auto d_col = col_temp.mutable_view();
3937
auto dispatch_fn = faster_sorted_order_fn<sort_method::UNSTABLE>{};
4038
cudf::type_dispatcher<dispatch_storage_type>(
4139
input.type(), dispatch_fn, d_col, indices, ascending, stream);

cpp/src/sort/faster_sort_column_impl.cuh

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,31 @@
3030
namespace cudf {
3131
namespace detail {
3232

33+
template <typename F>
34+
struct float_pair {
35+
size_type s;
36+
F f;
37+
};
38+
39+
template <typename F>
40+
struct float_decomposer {
41+
__device__ cuda::std::tuple<size_type&, F&> operator()(float_pair<F>& key) const
42+
{
43+
return {key.s, key.f};
44+
}
45+
};
46+
47+
template <typename F>
48+
struct float_to_pair_and_seq {
49+
F* fs;
50+
__device__ cuda::std::pair<float_pair<F>, size_type> operator()(cudf::size_type idx)
51+
{
52+
auto const f = fs[idx];
53+
auto const s = (isnan(f) * (idx + 1)); // multiplier helps keep the sort stable for NaNs
54+
return {float_pair<F>{s, f}, idx};
55+
}
56+
};
57+
3358
/**
3459
* @brief Sort indices of a single column.
3560
*
@@ -50,23 +75,7 @@ void faster_sorted_order(column_view const& input,
5075
template <sort_method method>
5176
struct faster_sorted_order_fn {
5277
/**
53-
* @brief Compile time check for allowing faster sort.
54-
*
55-
* Faster sort is defined for fixed-width types where only
56-
* the primitive comparators cuda::std::greater or cuda::std::less
57-
* are needed.
58-
*
59-
* Floating point is removed here for special handling of NaNs
60-
* which require the row-comparator.
61-
*/
62-
template <typename T>
63-
static constexpr bool is_supported()
64-
{
65-
return cudf::is_fixed_width<T>() && !cudf::is_floating_point<T>();
66-
}
67-
68-
/**
69-
* @brief Sorts fixed-width columns using faster thrust sort.
78+
* @brief Sorts fixed-width columns using faster thrust sort
7079
*
7180
* Should not be called if `input.has_nulls()==true`
7281
*
@@ -85,6 +94,8 @@ struct faster_sorted_order_fn {
8594
// For other fixed-width types, thrust may use merge-sort.
8695
// The API sorts inplace so it requires making a copy of the input data
8796
// and creating the input indices sequence.
97+
thrust::sequence(
98+
rmm::exec_policy_nosync(stream), indices.begin<size_type>(), indices.end<size_type>(), 0);
8899

89100
auto const do_sort = [&](auto const comp) {
90101
if constexpr (method == sort_method::STABLE) {
@@ -114,9 +125,43 @@ struct faster_sorted_order_fn {
114125
mutable_column_view& indices,
115126
bool ascending,
116127
rmm::cuda_stream_view stream)
117-
requires(is_supported<T>() and !cudf::is_chrono<T>())
128+
requires(cudf::is_floating_point<T>())
118129
{
119-
faster_sort<T>(input, indices, ascending, stream);
130+
auto pair_in = rmm::device_uvector<float_pair<T>>(input.size(), stream);
131+
auto d_in = pair_in.begin();
132+
// pair_out/d_out is not returned to the caller but used as an intermediate
133+
auto pair_out = rmm::device_uvector<float_pair<T>>(input.size(), stream);
134+
auto d_out = pair_out.begin();
135+
auto vals = rmm::device_uvector<size_type>(indices.size(), stream);
136+
auto dv_in = vals.begin();
137+
auto dv_out = indices.begin<cudf::size_type>();
138+
139+
auto zip_out = thrust::make_zip_iterator(d_in, dv_in);
140+
thrust::transform(rmm::exec_policy_nosync(stream),
141+
thrust::counting_iterator<size_type>(0),
142+
thrust::counting_iterator<size_type>(input.size()),
143+
zip_out,
144+
float_to_pair_and_seq<T>{input.begin<T>()});
145+
146+
auto const decomposer = float_decomposer<T>{};
147+
auto const end_bit = sizeof(float_pair<T>) * 8;
148+
auto const sv = stream.value();
149+
auto const n = input.size();
150+
// cub radix sort implementation is always stable
151+
std::size_t tmp_bytes = 0;
152+
if (ascending) {
153+
cub::DeviceRadixSort::SortPairs(
154+
nullptr, tmp_bytes, d_in, d_out, dv_in, dv_out, n, decomposer, 0, end_bit, sv);
155+
auto tmp_stg = rmm::device_buffer(tmp_bytes, stream);
156+
cub::DeviceRadixSort::SortPairs(
157+
tmp_stg.data(), tmp_bytes, d_in, d_out, dv_in, dv_out, n, decomposer, 0, end_bit, sv);
158+
} else {
159+
cub::DeviceRadixSort::SortPairsDescending(
160+
nullptr, tmp_bytes, d_in, d_out, dv_in, dv_out, n, decomposer, 0, end_bit, sv);
161+
auto tmp_stg = rmm::device_buffer(tmp_bytes, stream);
162+
cub::DeviceRadixSort::SortPairsDescending(
163+
tmp_stg.data(), tmp_bytes, d_in, d_out, dv_in, dv_out, n, decomposer, 0, end_bit, sv);
164+
}
120165
}
121166

122167
template <typename T>
@@ -130,9 +175,19 @@ struct faster_sorted_order_fn {
130175
faster_sort<rep_type>(input, indices, ascending, stream);
131176
}
132177

178+
template <typename T>
179+
void operator()(mutable_column_view& input,
180+
mutable_column_view& indices,
181+
bool ascending,
182+
rmm::cuda_stream_view stream)
183+
requires(cudf::is_fixed_width<T>() and !cudf::is_chrono<T>() and !cudf::is_floating_point<T>())
184+
{
185+
faster_sort<T>(input, indices, ascending, stream);
186+
}
187+
133188
template <typename T>
134189
void operator()(mutable_column_view&, mutable_column_view&, bool, rmm::cuda_stream_view)
135-
requires(not is_supported<T>())
190+
requires(not cudf::is_fixed_width<T>())
136191
{
137192
CUDF_UNREACHABLE("invalid type for faster sort");
138193
}

cpp/src/sort/faster_stable_sort_column.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,8 @@ void faster_sorted_order<sort_method::STABLE>(column_view const& input,
3232
bool ascending,
3333
rmm::cuda_stream_view stream)
3434
{
35-
auto col_temp = column(input, stream);
36-
auto d_col = col_temp.mutable_view();
37-
thrust::sequence(
38-
rmm::exec_policy_nosync(stream), indices.begin<size_type>(), indices.end<size_type>(), 0);
35+
auto col_temp = column(input, stream);
36+
auto d_col = col_temp.mutable_view();
3937
auto dispatch_fn = faster_sorted_order_fn<sort_method::STABLE>{};
4038
cudf::type_dispatcher<dispatch_storage_type>(
4139
input.type(), dispatch_fn, d_col, indices, ascending, stream);

cpp/src/sort/sort_column.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ std::unique_ptr<column> sorted_order<sort_method::UNSTABLE>(column_view const& i
4040
auto sorted_indices = cudf::make_numeric_column(
4141
data_type(type_to_id<size_type>()), input.size(), mask_state::UNALLOCATED, stream, mr);
4242
mutable_column_view indices_view = sorted_indices->mutable_view();
43-
if (!input.has_nulls() && cudf::is_fixed_width(input.type()) &&
44-
!cudf::is_floating_point(input.type())) {
43+
if (!input.has_nulls() && cudf::is_fixed_width(input.type())) {
4544
faster_sorted_order<sort_method::UNSTABLE>(
4645
input, indices_view, column_order == order::ASCENDING, stream);
4746
} else {

cpp/src/sort/stable_sort_column.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace detail {
2828

2929
/**
3030
* @copydoc
31-
* sorted_order(column_view&,order,null_order,rmm::cuda_stream_view,rmm::device_async_resource_ref)
31+
* stable_sorted_order(column_view&,order,null_order,rmm::cuda_stream_view,rmm::device_async_resource_ref)
3232
*/
3333
template <>
3434
std::unique_ptr<column> sorted_order<sort_method::STABLE>(column_view const& input,
@@ -40,8 +40,7 @@ std::unique_ptr<column> sorted_order<sort_method::STABLE>(column_view const& inp
4040
auto sorted_indices = cudf::make_numeric_column(
4141
data_type(type_to_id<size_type>()), input.size(), mask_state::UNALLOCATED, stream, mr);
4242
mutable_column_view indices_view = sorted_indices->mutable_view();
43-
if (!input.has_nulls() && cudf::is_fixed_width(input.type()) &&
44-
!cudf::is_floating_point(input.type())) {
43+
if (!input.has_nulls() && cudf::is_fixed_width(input.type())) {
4544
faster_sorted_order<sort_method::STABLE>(
4645
input, indices_view, column_order == order::ASCENDING, stream);
4746
} else {

0 commit comments

Comments
 (0)