Skip to content

Commit bdc64ff

Browse files
authored
Fix non default constructible input types test for cub::FindIf (#7447)
* Fix non default constructible input types test for cub::FindIf - make NotDefaultConstructible non default constructible by removing init in input param - make iterator raw so that vectorized path is taken * Use thrust::tabulate to initialized input array and other nits * Make functor non generic * Shortcircuit nigation train
1 parent 427af39 commit bdc64ff

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

cub/test/catch2_test_device_find_if.cu

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <cub/device/device_find.cuh>
77

88
#include <thrust/detail/raw_pointer_cast.h>
9+
#include <thrust/tabulate.h>
910

1011
#include <cuda/iterator>
1112

@@ -207,7 +208,7 @@ struct NotDefaultConstructible
207208
{
208209
int value_;
209210

210-
__host__ __device__ constexpr explicit NotDefaultConstructible(int value = 0)
211+
__host__ __device__ constexpr explicit NotDefaultConstructible(int value)
211212
: value_(value)
212213
{}
213214

@@ -228,18 +229,21 @@ struct NotDefaultConstructible
228229
}
229230
};
230231

231-
struct converter
232+
struct index_to_value
232233
{
233-
__host__ __device__ constexpr NotDefaultConstructible operator()(const int val) const noexcept
234+
__host__ __device__ NotDefaultConstructible operator()(int i)
234235
{
235-
return NotDefaultConstructible{val};
236+
return NotDefaultConstructible{static_cast<int>(i)};
236237
}
237238
};
238239

240+
static_assert(!cuda::std::is_default_constructible_v<NotDefaultConstructible>,
241+
"NotDefaultConstructible should not be default constructible");
242+
239243
C2H_TEST("Device find_if works with non default constructible types", "[device][find_if]")
240244
{
241-
using input_t = int32_t;
242-
using offset_t = int32_t;
245+
using input_t = NotDefaultConstructible;
246+
using offset_t = int;
243247

244248
constexpr offset_t min_items = 1;
245249
constexpr offset_t max_items = 10'000; // 10k items for reasonable test time
@@ -251,16 +255,19 @@ C2H_TEST("Device find_if works with non default constructible types", "[device][
251255
min_items,
252256
max_items,
253257
}));
254-
const input_t val_to_find = static_cast<input_t>(num_items - 1);
258+
const auto val_to_find = static_cast<int>(num_items - 1);
255259

256260
CAPTURE(num_items, val_to_find);
257261

258-
// counting_iterator input
259-
auto c_it = cuda::make_transform_iterator(cuda::make_counting_iterator(input_t{0}), converter{});
260-
{
261-
c2h::device_vector<offset_t> out_result(1, thrust::no_init);
262-
auto predicate = thrust::detail::equal_to_value<NotDefaultConstructible>{NotDefaultConstructible{val_to_find}};
263-
find_if(c_it, thrust::raw_pointer_cast(out_result.data()), predicate, num_items);
264-
REQUIRE(val_to_find == out_result[0]);
265-
}
262+
// raw device iterator to some device vector so that vectorized path is taken
263+
c2h::device_vector<input_t> d_vec(num_items, NotDefaultConstructible(0));
264+
265+
// fill with arbitrary values dont use c2h gen because NotDefaultConstructible is not default constructible
266+
thrust::tabulate(c2h::device_policy, d_vec.begin(), d_vec.end(), index_to_value{});
267+
268+
auto it = thrust::raw_pointer_cast(d_vec.data());
269+
c2h::device_vector<offset_t> out_result(1, thrust::no_init);
270+
auto predicate = thrust::detail::equal_to_value<NotDefaultConstructible>{NotDefaultConstructible{val_to_find}};
271+
find_if(it, thrust::raw_pointer_cast(out_result.data()), predicate, num_items);
272+
REQUIRE(val_to_find == out_result[0]);
266273
}

0 commit comments

Comments
 (0)