Skip to content

Commit 251847f

Browse files
Add a test for cub::DispatchSegmentedReduce (#7311)
1 parent 70ced93 commit 251847f

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#include "insert_nested_NVTX_range_guard.h"
5+
6+
#include <cub/device/dispatch/dispatch_segmented_reduce.cuh>
7+
8+
#include <thrust/detail/raw_pointer_cast.h>
9+
10+
#include "catch2_test_device_reduce.cuh"
11+
#include <c2h/catch2_test_helper.h>
12+
13+
using namespace cub;
14+
15+
// TODO(bgruber): drop this test with CCCL 4.0 when we drop the segmented reduce dispatcher after publishing the
16+
// tuning API
17+
18+
template <typename AccumT, typename OffsetT, typename ReductionOpT>
19+
struct my_policy_hub
20+
{
21+
// from Policy500 of the CUB segmented reduce tunings
22+
struct MaxPolicy : ChainedPolicy<500, MaxPolicy, MaxPolicy>
23+
{
24+
using ReducePolicy = AgentReducePolicy<256, 20, AccumT, 4, BLOCK_REDUCE_WARP_REDUCTIONS, LOAD_LDG>;
25+
using SingleTilePolicy = ReducePolicy;
26+
using SegmentedReducePolicy = ReducePolicy;
27+
};
28+
};
29+
30+
C2H_TEST("DispatchSegmentedReduce::Dispatch: custom policy hub", "[segmented][reduce][device]")
31+
{
32+
using input_t = int;
33+
using output_t = int;
34+
using offset_t = int;
35+
using reduction_t = ::cuda::std::plus<>;
36+
using accum_t = ::cuda::std::__accumulator_t<reduction_t, input_t, output_t>;
37+
38+
c2h::device_vector<offset_t> offsets{0, 3, 3, 7, 9, 15};
39+
c2h::device_vector<input_t> in_items{
40+
8, 6, 7, 5, 3, 0, 9, 25, 24, 6, 7, 2, 46, 8, 123, 2, 5, 3, 76, 48,
41+
};
42+
const auto num_segments = static_cast<::cuda::std::int64_t>(offsets.size() - 1);
43+
44+
c2h::device_vector<output_t> out_result(num_segments, thrust::no_init);
45+
46+
c2h::host_vector<output_t> expected_result(num_segments, thrust::no_init);
47+
compute_segmented_problem_reference(in_items, offsets, reduction_t{}, accum_t{}, expected_result.begin());
48+
49+
using policy_hub_t = my_policy_hub<accum_t, offset_t, reduction_t>;
50+
using dispatch_t = DispatchSegmentedReduce<
51+
input_t*,
52+
output_t*,
53+
const offset_t*,
54+
const offset_t*,
55+
offset_t,
56+
reduction_t,
57+
output_t,
58+
accum_t,
59+
policy_hub_t>;
60+
61+
size_t temp_size = 0;
62+
dispatch_t::Dispatch(
63+
nullptr,
64+
temp_size,
65+
thrust::raw_pointer_cast(in_items.data()),
66+
thrust::raw_pointer_cast(out_result.data()),
67+
num_segments,
68+
thrust::raw_pointer_cast(offsets.data()),
69+
thrust::raw_pointer_cast(offsets.data()) + 1,
70+
reduction_t{},
71+
output_t{},
72+
/* stream */ nullptr);
73+
c2h::device_vector<unsigned char> temp_storage(temp_size, thrust::no_init);
74+
dispatch_t::Dispatch(
75+
thrust::raw_pointer_cast(temp_storage.data()),
76+
temp_size,
77+
thrust::raw_pointer_cast(in_items.data()),
78+
thrust::raw_pointer_cast(out_result.data()),
79+
num_segments,
80+
thrust::raw_pointer_cast(offsets.data()),
81+
thrust::raw_pointer_cast(offsets.data()) + 1,
82+
reduction_t{},
83+
output_t{},
84+
/* stream */ nullptr);
85+
86+
REQUIRE(out_result == expected_result);
87+
}

0 commit comments

Comments
 (0)