Skip to content

Commit e57dad5

Browse files
Refactor and fix
1 parent adcb7fd commit e57dad5

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

cub/benchmarks/bench/reduce/deterministic.cu

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,26 @@ void deterministic_sum(nvbench::state& state, nvbench::type_list<T>)
4444
state.add_global_memory_writes<T>(out.size());
4545

4646
std::size_t temp_storage_bytes{};
47-
#if !TUNE_BASE
48-
cub::detail::rfa::dispatch<input_it_t, output_it_t, int, init_t, transform_t, accum_t, policy_selector_t>(
49-
nullptr, temp_storage_bytes, d_in, d_out, elements, {}, 0);
50-
#else
47+
// we explicitly provide template arguments to override accum_t
5148
cub::detail::rfa::dispatch<input_it_t, output_it_t, int, init_t, transform_t, accum_t>(
52-
nullptr, temp_storage_bytes, d_in, d_out, elements, {}, 0);
53-
#endif
49+
nullptr, temp_storage_bytes, d_in, d_out, elements, init_t{}, /* stream */ 0);
5450

5551
thrust::device_vector<nvbench::uint8_t> temp_storage(temp_storage_bytes);
5652
auto* d_temp_storage = thrust::raw_pointer_cast(temp_storage.data());
5753

5854
state.exec(nvbench::exec_tag::no_batch | nvbench::exec_tag::sync, [&](nvbench::launch& launch) {
55+
cub::detail::rfa::dispatch<
56+
input_it_t,
57+
output_it_t,
58+
int,
59+
init_t,
60+
transform_t,
61+
accum_t
5962
#if !TUNE_BASE
60-
cub::detail::rfa::dispatch<input_it_t, output_it_t, int, init_t, transform_t, accum_t, policy_selector_t>(
61-
d_temp_storage, temp_storage_bytes, d_in, d_out, elements, {}, launch.get_stream());
62-
#else
63-
cub::detail::rfa::dispatch<input_it_t, output_it_t, int, init_t, transform_t, accum_t>(
64-
d_temp_storage, temp_storage_bytes, d_in, d_out, elements, {}, launch.get_stream());
65-
#endif
63+
,
64+
policy_selector_t
65+
#endif // !TUNE_BASE
66+
>(d_temp_storage, temp_storage_bytes, d_in, d_out, elements, {}, launch.get_stream());
6667
});
6768
}
6869

0 commit comments

Comments
 (0)