30
30
namespace cudf {
31
31
namespace detail {
32
32
33
+ template <typename F>
34
+ struct float_pair {
35
+ size_type s;
36
+ F f;
37
+ };
38
+
39
+ template <typename F>
40
+ struct float_decomposer {
41
+ __device__ cuda::std::tuple<size_type&, F&> operator ()(float_pair<F>& key) const
42
+ {
43
+ return {key.s , key.f };
44
+ }
45
+ };
46
+
47
+ template <typename F>
48
+ struct float_to_pair_and_seq {
49
+ F* fs;
50
+ __device__ cuda::std::pair<float_pair<F>, size_type> operator ()(cudf::size_type idx)
51
+ {
52
+ auto const f = fs[idx];
53
+ auto const s = (isnan (f) * (idx + 1 )); // multiplier helps keep the sort stable for NaNs
54
+ return {float_pair<F>{s, f}, idx};
55
+ }
56
+ };
57
+
33
58
/* *
34
59
* @brief Sort indices of a single column.
35
60
*
@@ -50,23 +75,7 @@ void faster_sorted_order(column_view const& input,
50
75
template <sort_method method>
51
76
struct faster_sorted_order_fn {
52
77
/* *
53
- * @brief Compile time check for allowing faster sort.
54
- *
55
- * Faster sort is defined for fixed-width types where only
56
- * the primitive comparators cuda::std::greater or cuda::std::less
57
- * are needed.
58
- *
59
- * Floating point is removed here for special handling of NaNs
60
- * which require the row-comparator.
61
- */
62
- template <typename T>
63
- static constexpr bool is_supported ()
64
- {
65
- return cudf::is_fixed_width<T>() && !cudf::is_floating_point<T>();
66
- }
67
-
68
- /* *
69
- * @brief Sorts fixed-width columns using faster thrust sort.
78
+ * @brief Sorts fixed-width columns using faster thrust sort
70
79
*
71
80
* Should not be called if `input.has_nulls()==true`
72
81
*
@@ -85,6 +94,8 @@ struct faster_sorted_order_fn {
85
94
// For other fixed-width types, thrust may use merge-sort.
86
95
// The API sorts inplace so it requires making a copy of the input data
87
96
// and creating the input indices sequence.
97
+ thrust::sequence (
98
+ rmm::exec_policy_nosync (stream), indices.begin <size_type>(), indices.end <size_type>(), 0 );
88
99
89
100
auto const do_sort = [&](auto const comp) {
90
101
if constexpr (method == sort_method::STABLE) {
@@ -114,9 +125,43 @@ struct faster_sorted_order_fn {
114
125
mutable_column_view& indices,
115
126
bool ascending,
116
127
rmm::cuda_stream_view stream)
117
- requires(is_supported<T>() and ! cudf::is_chrono <T>())
128
+ requires(cudf::is_floating_point <T>())
118
129
{
119
- faster_sort<T>(input, indices, ascending, stream);
130
+ auto pair_in = rmm::device_uvector<float_pair<T>>(input.size (), stream);
131
+ auto d_in = pair_in.begin ();
132
+ // pair_out/d_out is not returned to the caller but used as an intermediate
133
+ auto pair_out = rmm::device_uvector<float_pair<T>>(input.size (), stream);
134
+ auto d_out = pair_out.begin ();
135
+ auto vals = rmm::device_uvector<size_type>(indices.size (), stream);
136
+ auto dv_in = vals.begin ();
137
+ auto dv_out = indices.begin <cudf::size_type>();
138
+
139
+ auto zip_out = thrust::make_zip_iterator (d_in, dv_in);
140
+ thrust::transform (rmm::exec_policy_nosync (stream),
141
+ thrust::counting_iterator<size_type>(0 ),
142
+ thrust::counting_iterator<size_type>(input.size ()),
143
+ zip_out,
144
+ float_to_pair_and_seq<T>{input.begin <T>()});
145
+
146
+ auto const decomposer = float_decomposer<T>{};
147
+ auto const end_bit = sizeof (float_pair<T>) * 8 ;
148
+ auto const sv = stream.value ();
149
+ auto const n = input.size ();
150
+ // cub radix sort implementation is always stable
151
+ std::size_t tmp_bytes = 0 ;
152
+ if (ascending) {
153
+ cub::DeviceRadixSort::SortPairs (
154
+ nullptr , tmp_bytes, d_in, d_out, dv_in, dv_out, n, decomposer, 0 , end_bit, sv);
155
+ auto tmp_stg = rmm::device_buffer (tmp_bytes, stream);
156
+ cub::DeviceRadixSort::SortPairs (
157
+ tmp_stg.data (), tmp_bytes, d_in, d_out, dv_in, dv_out, n, decomposer, 0 , end_bit, sv);
158
+ } else {
159
+ cub::DeviceRadixSort::SortPairsDescending (
160
+ nullptr , tmp_bytes, d_in, d_out, dv_in, dv_out, n, decomposer, 0 , end_bit, sv);
161
+ auto tmp_stg = rmm::device_buffer (tmp_bytes, stream);
162
+ cub::DeviceRadixSort::SortPairsDescending (
163
+ tmp_stg.data (), tmp_bytes, d_in, d_out, dv_in, dv_out, n, decomposer, 0 , end_bit, sv);
164
+ }
120
165
}
121
166
122
167
template <typename T>
@@ -130,9 +175,19 @@ struct faster_sorted_order_fn {
130
175
faster_sort<rep_type>(input, indices, ascending, stream);
131
176
}
132
177
178
+ template <typename T>
179
+ void operator ()(mutable_column_view& input,
180
+ mutable_column_view& indices,
181
+ bool ascending,
182
+ rmm::cuda_stream_view stream)
183
+ requires(cudf::is_fixed_width<T>() and !cudf::is_chrono<T>() and !cudf::is_floating_point<T>())
184
+ {
185
+ faster_sort<T>(input, indices, ascending, stream);
186
+ }
187
+
133
188
template <typename T>
134
189
void operator ()(mutable_column_view&, mutable_column_view&, bool , rmm::cuda_stream_view)
135
- requires(not is_supported <T>())
190
+ requires(not cudf::is_fixed_width <T>())
136
191
{
137
192
CUDF_UNREACHABLE (" invalid type for faster sort" );
138
193
}
0 commit comments