Skip to content

Use radix sort for float/double types #19137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 24, 2025
6 changes: 2 additions & 4 deletions cpp/src/sort/faster_sort_column.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ void faster_sorted_order<sort_method::UNSTABLE>(column_view const& input,
bool ascending,
rmm::cuda_stream_view stream)
{
auto col_temp = column(input, stream);
auto d_col = col_temp.mutable_view();
thrust::sequence(
rmm::exec_policy_nosync(stream), indices.begin<size_type>(), indices.end<size_type>(), 0);
auto col_temp = column(input, stream);
auto d_col = col_temp.mutable_view();
auto dispatch_fn = faster_sorted_order_fn<sort_method::UNSTABLE>{};
cudf::type_dispatcher<dispatch_storage_type>(
input.type(), dispatch_fn, d_col, indices, ascending, stream);
Expand Down
95 changes: 75 additions & 20 deletions cpp/src/sort/faster_sort_column_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,31 @@
namespace cudf {
namespace detail {

template <typename F>
struct float_pair {
size_type s;
F f;
};

template <typename F>
struct float_decomposer {
__device__ cuda::std::tuple<size_type&, F&> operator()(float_pair<F>& key) const
{
return {key.s, key.f};
}
};

template <typename F>
struct float_to_pair_and_seq {
F* fs;
__device__ cuda::std::pair<float_pair<F>, size_type> operator()(cudf::size_type idx)
{
auto const f = fs[idx];
auto const s = (isnan(f) * (idx + 1)); // multiplier helps keep the sort stable for NaNs
return {float_pair<F>{s, f}, idx};
}
};

/**
* @brief Sort indices of a single column.
*
Expand All @@ -50,23 +75,7 @@ void faster_sorted_order(column_view const& input,
template <sort_method method>
struct faster_sorted_order_fn {
/**
* @brief Compile time check for allowing faster sort.
*
* Faster sort is defined for fixed-width types where only
* the primitive comparators cuda::std::greater or cuda::std::less
* are needed.
*
* Floating point is removed here for special handling of NaNs
* which require the row-comparator.
*/
template <typename T>
static constexpr bool is_supported()
{
return cudf::is_fixed_width<T>() && !cudf::is_floating_point<T>();
}

/**
* @brief Sorts fixed-width columns using faster thrust sort.
* @brief Sorts fixed-width columns using faster thrust sort
*
* Should not be called if `input.has_nulls()==true`
*
Expand All @@ -85,6 +94,8 @@ struct faster_sorted_order_fn {
// For other fixed-width types, thrust may use merge-sort.
// The API sorts inplace so it requires making a copy of the input data
// and creating the input indices sequence.
thrust::sequence(
rmm::exec_policy_nosync(stream), indices.begin<size_type>(), indices.end<size_type>(), 0);

auto const do_sort = [&](auto const comp) {
if constexpr (method == sort_method::STABLE) {
Expand Down Expand Up @@ -114,9 +125,43 @@ struct faster_sorted_order_fn {
mutable_column_view& indices,
bool ascending,
rmm::cuda_stream_view stream)
requires(is_supported<T>() and !cudf::is_chrono<T>())
requires(cudf::is_floating_point<T>())
{
faster_sort<T>(input, indices, ascending, stream);
auto pair_in = rmm::device_uvector<float_pair<T>>(input.size(), stream);
auto d_in = pair_in.begin();
// pair_out/d_out is not returned to the caller but used as an intermediate
auto pair_out = rmm::device_uvector<float_pair<T>>(input.size(), stream);
auto d_out = pair_out.begin();
auto vals = rmm::device_uvector<size_type>(indices.size(), stream);
auto dv_in = vals.begin();
auto dv_out = indices.begin<cudf::size_type>();

auto zip_out = thrust::make_zip_iterator(d_in, dv_in);
thrust::transform(rmm::exec_policy_nosync(stream),
thrust::counting_iterator<size_type>(0),
thrust::counting_iterator<size_type>(input.size()),
zip_out,
float_to_pair_and_seq<T>{input.begin<T>()});

auto const decomposer = float_decomposer<T>{};
auto const end_bit = sizeof(float_pair<T>) * 8;
auto const sv = stream.value();
auto const n = input.size();
// cub radix sort implementation is always stable
std::size_t tmp_bytes = 0;
if (ascending) {
cub::DeviceRadixSort::SortPairs(
nullptr, tmp_bytes, d_in, d_out, dv_in, dv_out, n, decomposer, 0, end_bit, sv);
auto tmp_stg = rmm::device_buffer(tmp_bytes, stream);
cub::DeviceRadixSort::SortPairs(
tmp_stg.data(), tmp_bytes, d_in, d_out, dv_in, dv_out, n, decomposer, 0, end_bit, sv);
} else {
cub::DeviceRadixSort::SortPairsDescending(
nullptr, tmp_bytes, d_in, d_out, dv_in, dv_out, n, decomposer, 0, end_bit, sv);
auto tmp_stg = rmm::device_buffer(tmp_bytes, stream);
cub::DeviceRadixSort::SortPairsDescending(
tmp_stg.data(), tmp_bytes, d_in, d_out, dv_in, dv_out, n, decomposer, 0, end_bit, sv);
}
}

template <typename T>
Expand All @@ -130,9 +175,19 @@ struct faster_sorted_order_fn {
faster_sort<rep_type>(input, indices, ascending, stream);
}

template <typename T>
void operator()(mutable_column_view& input,
mutable_column_view& indices,
bool ascending,
rmm::cuda_stream_view stream)
requires(cudf::is_fixed_width<T>() and !cudf::is_chrono<T>() and !cudf::is_floating_point<T>())
{
faster_sort<T>(input, indices, ascending, stream);
}

template <typename T>
void operator()(mutable_column_view&, mutable_column_view&, bool, rmm::cuda_stream_view)
requires(not is_supported<T>())
requires(not cudf::is_fixed_width<T>())
{
CUDF_UNREACHABLE("invalid type for faster sort");
}
Expand Down
6 changes: 2 additions & 4 deletions cpp/src/sort/faster_stable_sort_column.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ void faster_sorted_order<sort_method::STABLE>(column_view const& input,
bool ascending,
rmm::cuda_stream_view stream)
{
auto col_temp = column(input, stream);
auto d_col = col_temp.mutable_view();
thrust::sequence(
rmm::exec_policy_nosync(stream), indices.begin<size_type>(), indices.end<size_type>(), 0);
auto col_temp = column(input, stream);
auto d_col = col_temp.mutable_view();
auto dispatch_fn = faster_sorted_order_fn<sort_method::STABLE>{};
cudf::type_dispatcher<dispatch_storage_type>(
input.type(), dispatch_fn, d_col, indices, ascending, stream);
Expand Down
3 changes: 1 addition & 2 deletions cpp/src/sort/sort_column.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ std::unique_ptr<column> sorted_order<sort_method::UNSTABLE>(column_view const& i
auto sorted_indices = cudf::make_numeric_column(
data_type(type_to_id<size_type>()), input.size(), mask_state::UNALLOCATED, stream, mr);
mutable_column_view indices_view = sorted_indices->mutable_view();
if (!input.has_nulls() && cudf::is_fixed_width(input.type()) &&
!cudf::is_floating_point(input.type())) {
if (!input.has_nulls() && cudf::is_fixed_width(input.type())) {
faster_sorted_order<sort_method::UNSTABLE>(
input, indices_view, column_order == order::ASCENDING, stream);
} else {
Expand Down
5 changes: 2 additions & 3 deletions cpp/src/sort/stable_sort_column.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace detail {

/**
* @copydoc
* sorted_order(column_view&,order,null_order,rmm::cuda_stream_view,rmm::device_async_resource_ref)
* stable_sorted_order(column_view&,order,null_order,rmm::cuda_stream_view,rmm::device_async_resource_ref)
*/
template <>
std::unique_ptr<column> sorted_order<sort_method::STABLE>(column_view const& input,
Expand All @@ -40,8 +40,7 @@ std::unique_ptr<column> sorted_order<sort_method::STABLE>(column_view const& inp
auto sorted_indices = cudf::make_numeric_column(
data_type(type_to_id<size_type>()), input.size(), mask_state::UNALLOCATED, stream, mr);
mutable_column_view indices_view = sorted_indices->mutable_view();
if (!input.has_nulls() && cudf::is_fixed_width(input.type()) &&
!cudf::is_floating_point(input.type())) {
if (!input.has_nulls() && cudf::is_fixed_width(input.type())) {
faster_sorted_order<sort_method::STABLE>(
input, indices_view, column_order == order::ASCENDING, stream);
} else {
Expand Down