@@ -44,25 +44,26 @@ void deterministic_sum(nvbench::state& state, nvbench::type_list<T>)
4444 state.add_global_memory_writes <T>(out.size ());
4545
4646 std::size_t temp_storage_bytes{};
47- #if !TUNE_BASE
48- cub::detail::rfa::dispatch<input_it_t , output_it_t , int , init_t , transform_t , accum_t , policy_selector_t >(
49- nullptr , temp_storage_bytes, d_in, d_out, elements, {}, 0 );
50- #else
47+ // we explicitly provide template arguments to override accum_t
5148 cub::detail::rfa::dispatch<input_it_t , output_it_t , int , init_t , transform_t , accum_t >(
52- nullptr , temp_storage_bytes, d_in, d_out, elements, {}, 0 );
53- #endif
49+ nullptr , temp_storage_bytes, d_in, d_out, elements, init_t {}, /* stream */ 0 );
5450
5551 thrust::device_vector<nvbench::uint8_t > temp_storage (temp_storage_bytes);
5652 auto * d_temp_storage = thrust::raw_pointer_cast (temp_storage.data ());
5753
5854 state.exec (nvbench::exec_tag::no_batch | nvbench::exec_tag::sync, [&](nvbench::launch& launch) {
55+ cub::detail::rfa::dispatch<
56+ input_it_t ,
57+ output_it_t ,
58+ int ,
59+ init_t ,
60+ transform_t ,
61+ accum_t
5962#if !TUNE_BASE
60- cub::detail::rfa::dispatch<input_it_t , output_it_t , int , init_t , transform_t , accum_t , policy_selector_t >(
61- d_temp_storage, temp_storage_bytes, d_in, d_out, elements, {}, launch.get_stream ());
62- #else
63- cub::detail::rfa::dispatch<input_it_t , output_it_t , int , init_t , transform_t , accum_t >(
64- d_temp_storage, temp_storage_bytes, d_in, d_out, elements, {}, launch.get_stream ());
65- #endif
63+ ,
64+ policy_selector_t
65+ #endif // !TUNE_BASE
66+ >(d_temp_storage, temp_storage_bytes, d_in, d_out, elements, {}, launch.get_stream ());
6667 });
6768}
6869
0 commit comments