2222#include < cudf/utilities/type_dispatcher.hpp>
2323
2424#include < cuda/std/limits>
25+ #include < cuda/std/type_traits>
2526#include < thrust/iterator/transform_iterator.h>
2627
2728#include < memory>
@@ -38,6 +39,8 @@ namespace detail::row::hash {
3839template <template <typename > class hash_function , typename Nullate>
3940class element_hasher {
4041 public:
42+ using result_type = cuda::std::invoke_result_t <hash_function<int32_t >, int32_t >;
43+
4144 /* *
4245 * @brief Constructs an element_hasher object.
4346 *
@@ -47,8 +50,8 @@ class element_hasher {
4750 */
4851 __device__ element_hasher (
4952 Nullate nulls,
50- uint32_t seed = DEFAULT_HASH_SEED,
51- hash_value_type null_hash = cuda::std::numeric_limits<hash_value_type >::max()) noexcept
53+ result_type seed = DEFAULT_HASH_SEED,
54+ result_type null_hash = cuda::std::numeric_limits<result_type >::max()) noexcept
5255 : _check_nulls(nulls), _seed(seed), _null_hash(null_hash)
5356 {
5457 }
@@ -62,8 +65,8 @@ class element_hasher {
6265 * @return The hash value of the given element
6366 */
6467 template <typename T>
65- __device__ hash_value_type operator ()(column_device_view const & col,
66- size_type row_index) const noexcept
68+ __device__ result_type operator ()(column_device_view const & col,
69+ size_type row_index) const noexcept
6770 requires(column_device_view::has_element_accessor<T>())
6871 {
6972 if (_check_nulls && col.is_null (row_index)) { return _null_hash; }
@@ -79,16 +82,17 @@ class element_hasher {
7982 * @return The hash value of the given element
8083 */
8184 template <typename T>
82- __device__ hash_value_type operator ()(column_device_view const & col,
83- size_type row_index) const noexcept
85+ __device__ result_type operator ()(column_device_view const & col,
86+ size_type row_index) const noexcept
8487 requires(not column_device_view::has_element_accessor<T>())
8588 {
8689 CUDF_UNREACHABLE (" Unsupported type in hash." );
8790 }
8891
8992 Nullate _check_nulls;
90- uint32_t _seed;
91- hash_value_type _null_hash;
93+ // Assumes seeds are the same as the result type of the hash function
94+ result_type _seed;
95+ result_type _null_hash;
9296};
9397
9498/* *
@@ -102,21 +106,20 @@ class device_row_hasher {
102106 friend class row_hasher ;
103107
104108 public:
109+ using result_type = cuda::std::invoke_result_t <hash_function<int32_t >, int32_t >;
110+
105111 /* *
106112 * @brief Return the hash value of a row in the given table.
107113 *
108114 * @param row_index The row index to compute the hash value of
109115 * @return The hash value of the row
110116 */
111- __device__ auto operator ()(size_type row_index) const noexcept
117+ __device__ result_type operator ()(size_type row_index) const noexcept
112118 {
113119 auto it =
114120 thrust::make_transform_iterator (_table.begin (), [row_index, this ](auto const & column) {
115121 return cudf::type_dispatcher<dispatch_storage_type>(
116- column.type (),
117- element_hasher_adapter<hash_function>{_check_nulls, _seed},
118- column,
119- row_index);
122+ column.type (), element_hasher_adapter{_check_nulls, _seed}, column, row_index);
120123 });
121124
122125 return detail::accumulate (it, it + _table.num_columns (), _seed, [](auto hash, auto h) {
@@ -132,31 +135,30 @@ class device_row_hasher {
132135 * When the column is nested, this uses the element_hasher to hash the shape and values of the
133136 * column.
134137 */
135- template <template <typename > class hash_fn >
136138 class element_hasher_adapter {
137- static constexpr hash_value_type NULL_HASH = cuda::std::numeric_limits<hash_value_type >::max();
138- static constexpr hash_value_type NON_NULL_HASH = 0 ;
139+ static constexpr result_type NULL_HASH = cuda::std::numeric_limits<result_type >::max();
140+ static constexpr result_type NON_NULL_HASH = 0 ;
139141
140142 public:
141- __device__ element_hasher_adapter (Nullate check_nulls, uint32_t seed) noexcept
143+ __device__ element_hasher_adapter (Nullate check_nulls, result_type seed) noexcept
142144 : _element_hasher(check_nulls, seed), _check_nulls(check_nulls)
143145 {
144146 }
145147
146148 template <typename T>
147- __device__ hash_value_type operator ()(column_device_view const & col,
148- size_type row_index) const noexcept
149+ __device__ result_type operator ()(column_device_view const & col,
150+ size_type row_index) const noexcept
149151 requires(not cudf::is_nested<T>())
150152 {
151153 return _element_hasher.template operator ()<T>(col, row_index);
152154 }
153155
154156 template <typename T>
155- __device__ hash_value_type operator ()(column_device_view const & col,
156- size_type row_index) const noexcept
157+ __device__ result_type operator ()(column_device_view const & col,
158+ size_type row_index) const noexcept
157159 requires(cudf::is_nested<T>())
158160 {
159- auto hash = hash_value_type {0 };
161+ auto hash = result_type {0 };
160162 column_device_view curr_col = col.slice (row_index, 1 );
161163 while (curr_col.type ().id () == type_id::STRUCT || curr_col.type ().id () == type_id::LIST) {
162164 if (_check_nulls) {
@@ -175,7 +177,7 @@ class device_row_hasher {
175177 auto list_sizes = make_list_size_iterator (list_col);
176178 hash = detail::accumulate (
177179 list_sizes, list_sizes + list_col.size (), hash, [](auto hash, auto size) {
178- return cudf::hashing::detail::hash_combine (hash, hash_fn <size_type>{}(size));
180+ return cudf::hashing::detail::hash_combine (hash, hash_function <size_type>{}(size));
179181 });
180182 curr_col = list_col.get_sliced_child ();
181183 }
@@ -188,20 +190,21 @@ class device_row_hasher {
188190 return hash;
189191 }
190192
191- element_hasher<hash_fn , Nullate> const _element_hasher;
193+ element_hasher<hash_function , Nullate> const _element_hasher;
192194 Nullate const _check_nulls;
193195 };
194196
195197 CUDF_HOST_DEVICE device_row_hasher (Nullate check_nulls,
196198 table_device_view t,
197- uint32_t seed = DEFAULT_HASH_SEED) noexcept
199+ result_type seed = DEFAULT_HASH_SEED) noexcept
198200 : _check_nulls{check_nulls}, _table{t}, _seed(seed)
199201 {
200202 }
201203
202204 Nullate const _check_nulls;
203205 table_device_view const _table;
204- uint32_t const _seed;
206+ // Assumes seeds are the same as the result type of the hash function
207+ result_type const _seed;
205208};
206209
207210/* *
@@ -236,11 +239,12 @@ class row_hasher {
236239 /* *
237240 * @brief Get the hash operator to use on the device
238241 *
239- * Returns a unary callable, `F`, with signature `hash_function::hash_value_type F(size_type)`.
240- *
241- * `F(i)` returns the hash of row i.
242+ * Returns a unary callable, `F`, where `F(i)` returns the hash value of row i.
242243 *
244+ * @tparam hash_function Hash functor to use for hashing elements
245+ * @tparam DeviceRowHasher The device row hasher type to use
243246 * @tparam Nullate A cudf::nullate type describing whether to check for nulls
247+ *
244248 * @param nullate Indicates if any input column contains nulls
245249 * @param seed The seed to use for the hash function
246250 * @return A hash operator to use on the device
@@ -249,8 +253,9 @@ class row_hasher {
249253 template <typename > class hash_function = cudf::hashing::detail::default_hash,
250254 template <template <typename > class , typename > class DeviceRowHasher = device_row_hasher,
251255 typename Nullate>
252- DeviceRowHasher<hash_function, Nullate> device_hasher (Nullate nullate = {},
253- uint32_t seed = DEFAULT_HASH_SEED) const
256+ DeviceRowHasher<hash_function, Nullate> device_hasher (
257+ Nullate nullate = {},
258+ cuda::std::invoke_result_t <hash_function<int32_t >, int32_t > seed = DEFAULT_HASH_SEED) const
254259 {
255260 return DeviceRowHasher<hash_function, Nullate>(nullate, *d_t , seed);
256261 }
0 commit comments