66#include < cub/device/device_find.cuh>
77
88#include < thrust/detail/raw_pointer_cast.h>
9+ #include < thrust/tabulate.h>
910
1011#include < cuda/iterator>
1112
@@ -207,7 +208,7 @@ struct NotDefaultConstructible
207208{
208209 int value_;
209210
210- __host__ __device__ constexpr explicit NotDefaultConstructible (int value = 0 )
211+ __host__ __device__ constexpr explicit NotDefaultConstructible (int value)
211212 : value_(value)
212213 {}
213214
@@ -228,18 +229,21 @@ struct NotDefaultConstructible
228229 }
229230};
230231
231- struct converter
232+ struct index_to_value
232233{
233- __host__ __device__ constexpr NotDefaultConstructible operator ()(const int val) const noexcept
234+ __host__ __device__ NotDefaultConstructible operator ()(int i)
234235 {
235- return NotDefaultConstructible{val };
236+ return NotDefaultConstructible{static_cast < int >(i) };
236237 }
237238};
238239
240+ static_assert (!cuda::std::is_default_constructible_v<NotDefaultConstructible>,
241+ " NotDefaultConstructible should not be default constructible" );
242+
239243C2H_TEST (" Device find_if works with non default constructible types" , " [device][find_if]" )
240244{
241- using input_t = int32_t ;
242- using offset_t = int32_t ;
245+ using input_t = NotDefaultConstructible ;
246+ using offset_t = int ;
243247
244248 constexpr offset_t min_items = 1 ;
245249 constexpr offset_t max_items = 10'000 ; // 10k items for reasonable test time
@@ -251,16 +255,19 @@ C2H_TEST("Device find_if works with non default constructible types", "[device][
251255 min_items,
252256 max_items,
253257 }));
254- const input_t val_to_find = static_cast <input_t >(num_items - 1 );
258+ const auto val_to_find = static_cast <int >(num_items - 1 );
255259
256260 CAPTURE (num_items, val_to_find);
257261
258- // counting_iterator input
259- auto c_it = cuda::make_transform_iterator (cuda::make_counting_iterator (input_t {0 }), converter{});
260- {
261- c2h::device_vector<offset_t > out_result (1 , thrust::no_init);
262- auto predicate = thrust::detail::equal_to_value<NotDefaultConstructible>{NotDefaultConstructible{val_to_find}};
263- find_if (c_it, thrust::raw_pointer_cast (out_result.data ()), predicate, num_items);
264- REQUIRE (val_to_find == out_result[0 ]);
265- }
262+ // raw device iterator to some device vector so that vectorized path is taken
263+ c2h::device_vector<input_t > d_vec (num_items, NotDefaultConstructible (0 ));
264+
265+ // fill with arbitrary values dont use c2h gen because NotDefaultConstructible is not default constructible
266+ thrust::tabulate (c2h::device_policy, d_vec.begin (), d_vec.end (), index_to_value{});
267+
268+ auto it = thrust::raw_pointer_cast (d_vec.data ());
269+ c2h::device_vector<offset_t > out_result (1 , thrust::no_init);
270+ auto predicate = thrust::detail::equal_to_value<NotDefaultConstructible>{NotDefaultConstructible{val_to_find}};
271+ find_if (it, thrust::raw_pointer_cast (out_result.data ()), predicate, num_items);
272+ REQUIRE (val_to_find == out_result[0 ]);
266273}
0 commit comments