Skip to content

Commit 69976bc

Browse files
Test passing a custom policy to DispatchAdjacentDifference, DispatchMergeSort, DispatchScan, DispatchBatchMemcpy (#7289)
1 parent e9f0971 commit 69976bc

4 files changed

+335
-0
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#include "insert_nested_NVTX_range_guard.h"
5+
6+
#include <cub/device/device_adjacent_difference.cuh>
7+
8+
#include <thrust/detail/raw_pointer_cast.h>
9+
10+
#include <cuda/std/functional>
11+
#include <cuda/std/numeric>
12+
13+
#include <c2h/catch2_test_helper.h>
14+
15+
using namespace cub;
16+
17+
// TODO(bgruber): drop this test with CCCL 4.0 when we drop the adjacent difference dispatcher after publishing the
18+
// tuning API
19+
20+
template <typename InputIteratorT>
21+
struct my_policy_hub
22+
{
23+
using ValueT = cub::detail::it_value_t<InputIteratorT>;
24+
25+
// from Policy500 of the CUB adjacent difference tunings
26+
struct MaxPolicy : ChainedPolicy<500, MaxPolicy, MaxPolicy>
27+
{
28+
using AdjacentDifferencePolicy =
29+
AgentAdjacentDifferencePolicy<128,
30+
Nominal8BItemsToItems<ValueT>(7),
31+
BLOCK_LOAD_WARP_TRANSPOSE,
32+
LOAD_LDG,
33+
BLOCK_STORE_WARP_TRANSPOSE>;
34+
};
35+
};
36+
37+
C2H_TEST("DispatchAdjacentDifference::Dispatch: custom policy hub", "[device][adjacent_difference]")
38+
{
39+
using value_t = int;
40+
using offset_t = unsigned;
41+
using difference_op_t = cuda::std::minus<>;
42+
const offset_t num_items = 12345;
43+
44+
c2h::device_vector<value_t> in_items(num_items);
45+
c2h::device_vector<value_t> out_items(num_items);
46+
c2h::gen(C2H_SEED(1), in_items);
47+
48+
c2h::host_vector<value_t> host_in(in_items);
49+
c2h::host_vector<value_t> expected(num_items);
50+
cuda::std::adjacent_difference(host_in.begin(), host_in.end(), expected.begin(), cuda::std::minus<value_t>{});
51+
52+
using policy_hub_t = my_policy_hub<value_t*>;
53+
using dispatch_t =
54+
DispatchAdjacentDifference<value_t*, value_t*, difference_op_t, offset_t, MayAlias::No, ReadOption::Left, policy_hub_t>;
55+
size_t temp_size = 0;
56+
dispatch_t::Dispatch(
57+
nullptr,
58+
temp_size,
59+
thrust::raw_pointer_cast(in_items.data()),
60+
thrust::raw_pointer_cast(out_items.data()),
61+
num_items,
62+
difference_op_t{},
63+
/* stream */ nullptr);
64+
c2h::device_vector<std::uint8_t> temp_storage(temp_size, thrust::no_init);
65+
dispatch_t::Dispatch(
66+
thrust::raw_pointer_cast(temp_storage.data()),
67+
temp_size,
68+
thrust::raw_pointer_cast(in_items.data()),
69+
thrust::raw_pointer_cast(out_items.data()),
70+
num_items,
71+
difference_op_t{},
72+
/* stream */ nullptr);
73+
74+
REQUIRE(out_items == expected);
75+
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#include "insert_nested_NVTX_range_guard.h"
5+
6+
#include <cub/device/dispatch/dispatch_batch_memcpy.cuh>
7+
8+
#include <thrust/detail/raw_pointer_cast.h>
9+
10+
#include <cuda/std/array>
11+
#include <cuda/std/cstdint>
12+
13+
#include <c2h/catch2_test_helper.h>
14+
15+
using namespace cub;
16+
17+
// TODO(bgruber): drop this test with CCCL 4.0 when we drop the batch memcpy dispatcher after publishing the tuning API
18+
19+
template <class BufferOffsetT, class BlockOffsetT>
20+
struct my_policy_hub
21+
{
22+
static constexpr uint32_t BLOCK_THREADS = 128U;
23+
static constexpr uint32_t BUFFERS_PER_THREAD = 4U;
24+
static constexpr uint32_t TLEV_BYTES_PER_THREAD = 8U;
25+
26+
static constexpr uint32_t LARGE_BUFFER_BLOCK_THREADS = 256U;
27+
static constexpr uint32_t LARGE_BUFFER_BYTES_PER_THREAD = 32U;
28+
29+
static constexpr uint32_t WARP_LEVEL_THRESHOLD = 128;
30+
static constexpr uint32_t BLOCK_LEVEL_THRESHOLD = 8 * 1024;
31+
32+
using buff_delay_constructor_t = cub::detail::default_delay_constructor_t<BufferOffsetT>;
33+
using block_delay_constructor_t = cub::detail::default_delay_constructor_t<BlockOffsetT>;
34+
35+
// from Policy500 of the CUB batch memcpy tunings
36+
struct MaxPolicy : ChainedPolicy<500, MaxPolicy, MaxPolicy>
37+
{
38+
using AgentSmallBufferPolicyT = cub::detail::batch_memcpy::AgentBatchMemcpyPolicy<
39+
BLOCK_THREADS,
40+
BUFFERS_PER_THREAD,
41+
TLEV_BYTES_PER_THREAD,
42+
/* PREFER_POW2_BITS */ true,
43+
LARGE_BUFFER_BLOCK_THREADS * LARGE_BUFFER_BYTES_PER_THREAD,
44+
WARP_LEVEL_THRESHOLD,
45+
BLOCK_LEVEL_THRESHOLD,
46+
buff_delay_constructor_t,
47+
block_delay_constructor_t>;
48+
49+
using AgentLargeBufferPolicyT =
50+
cub::detail::batch_memcpy::agent_large_buffer_policy<LARGE_BUFFER_BLOCK_THREADS, LARGE_BUFFER_BYTES_PER_THREAD>;
51+
};
52+
};
53+
54+
C2H_TEST("DispatchBatchMemcpy::Dispatch: custom policy hub", "[device][memcpy]")
55+
{
56+
using value_t = cuda::std::uint8_t;
57+
using buffer_size_t = cuda::std::uint32_t;
58+
using block_offset_t = cuda::std::uint32_t;
59+
using buffer_offset_t = cub::detail::batch_memcpy::per_invocation_buffer_offset_t;
60+
61+
const cuda::std::array<buffer_size_t, 5> buffer_sizes{3, 128, 512, 4096, 9000};
62+
63+
c2h::host_vector<c2h::device_vector<value_t>> in_buffers(buffer_sizes.size());
64+
c2h::host_vector<c2h::device_vector<value_t>> out_buffers(buffer_sizes.size());
65+
66+
c2h::host_vector<value_t*> h_in_ptrs(buffer_sizes.size());
67+
c2h::host_vector<value_t*> h_out_ptrs(buffer_sizes.size());
68+
c2h::host_vector<buffer_size_t> h_sizes(buffer_sizes.size());
69+
70+
for (buffer_size_t i = 0; i < buffer_sizes.size(); ++i)
71+
{
72+
const auto bytes = buffer_sizes[i];
73+
in_buffers[i].resize(bytes);
74+
out_buffers[i].resize(bytes);
75+
c2h::gen(C2H_SEED(1), in_buffers[i]);
76+
77+
h_in_ptrs[i] = thrust::raw_pointer_cast(in_buffers[i].data());
78+
h_out_ptrs[i] = thrust::raw_pointer_cast(out_buffers[i].data());
79+
h_sizes[i] = bytes;
80+
}
81+
82+
c2h::device_vector<value_t*> d_in_ptrs = h_in_ptrs;
83+
c2h::device_vector<value_t*> d_out_ptrs = h_out_ptrs;
84+
c2h::device_vector<buffer_size_t> d_sizes = h_sizes;
85+
86+
using policy_hub_t = my_policy_hub<buffer_offset_t, block_offset_t>;
87+
using dispatch_t =
88+
cub::detail::DispatchBatchMemcpy<value_t**, value_t**, buffer_size_t*, block_offset_t, CopyAlg::Memcpy, policy_hub_t>;
89+
90+
size_t temp_size = 0;
91+
dispatch_t::Dispatch(
92+
nullptr,
93+
temp_size,
94+
thrust::raw_pointer_cast(d_in_ptrs.data()),
95+
thrust::raw_pointer_cast(d_out_ptrs.data()),
96+
thrust::raw_pointer_cast(d_sizes.data()),
97+
static_cast<cuda::std::int64_t>(buffer_sizes.size()),
98+
/* stream */ nullptr);
99+
c2h::device_vector<::cuda::std::uint8_t> temp_storage(temp_size, thrust::no_init);
100+
dispatch_t::Dispatch(
101+
thrust::raw_pointer_cast(temp_storage.data()),
102+
temp_size,
103+
thrust::raw_pointer_cast(d_in_ptrs.data()),
104+
thrust::raw_pointer_cast(d_out_ptrs.data()),
105+
thrust::raw_pointer_cast(d_sizes.data()),
106+
static_cast<cuda::std::int64_t>(buffer_sizes.size()),
107+
/* stream */ nullptr);
108+
109+
for (size_t i = 0; i < buffer_sizes.size(); ++i)
110+
{
111+
c2h::host_vector<value_t> host_in(in_buffers[i]);
112+
c2h::host_vector<value_t> host_out(out_buffers[i]);
113+
REQUIRE(host_out == host_in);
114+
}
115+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#include "insert_nested_NVTX_range_guard.h"
5+
6+
#include <cub/device/device_merge_sort.cuh>
7+
8+
#include <thrust/detail/raw_pointer_cast.h>
9+
10+
#include <algorithm>
11+
12+
#include "catch2_test_device_merge_sort_common.cuh"
13+
#include <c2h/catch2_test_helper.h>
14+
15+
using namespace cub;
16+
17+
// TODO(bgruber): drop this test with CCCL 4.0 when we drop the merge sort dispatcher after publishing the tuning API
18+
19+
template <typename KeyIteratorT>
20+
struct my_policy_hub
21+
{
22+
using KeyT = cub::detail::it_value_t<KeyIteratorT>;
23+
24+
// from Policy500 of the CUB merge sort tunings
25+
struct MaxPolicy : ChainedPolicy<500, MaxPolicy, MaxPolicy>
26+
{
27+
using MergeSortPolicy =
28+
AgentMergeSortPolicy<256,
29+
Nominal4BItemsToItems<KeyT>(11),
30+
BLOCK_LOAD_WARP_TRANSPOSE,
31+
LOAD_LDG,
32+
BLOCK_STORE_WARP_TRANSPOSE>;
33+
};
34+
};
35+
36+
C2H_TEST("DispatchMergeSort::Dispatch: custom policy hub", "[merge][sort][device]")
37+
{
38+
using key_t = int;
39+
using offset_t = unsigned;
40+
const offset_t num_items = 12345;
41+
42+
c2h::device_vector<key_t> in_keys(num_items);
43+
c2h::device_vector<key_t> out_keys(num_items);
44+
c2h::gen(C2H_SEED(1), in_keys);
45+
46+
using policy_hub_t = my_policy_hub<key_t*>;
47+
using dispatch_t = DispatchMergeSort<key_t*, NullType*, key_t*, NullType*, offset_t, custom_less_op_t, policy_hub_t>;
48+
size_t temp_size = 0;
49+
dispatch_t::Dispatch(
50+
nullptr,
51+
temp_size,
52+
thrust::raw_pointer_cast(in_keys.data()),
53+
nullptr,
54+
thrust::raw_pointer_cast(out_keys.data()),
55+
nullptr,
56+
num_items,
57+
custom_less_op_t{},
58+
/* stream */ nullptr);
59+
c2h::device_vector<uint8_t> temp_storage(temp_size, thrust::no_init);
60+
dispatch_t::Dispatch(
61+
thrust::raw_pointer_cast(temp_storage.data()),
62+
temp_size,
63+
thrust::raw_pointer_cast(in_keys.data()),
64+
nullptr,
65+
thrust::raw_pointer_cast(out_keys.data()),
66+
nullptr,
67+
num_items,
68+
custom_less_op_t{},
69+
/* stream */ nullptr);
70+
71+
c2h::host_vector<key_t> ref_keys = in_keys;
72+
std::stable_sort(ref_keys.begin(), ref_keys.end(), custom_less_op_t{});
73+
REQUIRE(ref_keys == out_keys);
74+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#include "insert_nested_NVTX_range_guard.h"
5+
6+
#include <cub/device/device_scan.cuh>
7+
8+
#include <thrust/detail/raw_pointer_cast.h>
9+
10+
#include <cuda/std/functional>
11+
12+
#include "catch2_test_device_scan.cuh"
13+
#include <c2h/catch2_test_helper.h>
14+
15+
using namespace cub;
16+
17+
// TODO(bgruber): drop this test with CCCL 4.0 when we drop the scan dispatcher after publishing the tuning API
18+
19+
template <typename InputValueT, typename OutputValueT, typename AccumT, typename OffsetT, typename ScanOpT>
20+
struct my_policy_hub
21+
{
22+
// from Policy500 of the CUB scan tunings
23+
struct MaxPolicy : ChainedPolicy<500, MaxPolicy, MaxPolicy>
24+
{
25+
using ScanPolicyT =
26+
AgentScanPolicy<128, 12, AccumT, BLOCK_LOAD_DIRECT, LOAD_CA, BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED, BLOCK_SCAN_RAKING>;
27+
};
28+
};
29+
30+
C2H_TEST("DispatchScan::Dispatch: custom policy hub", "[scan][device]")
31+
{
32+
using value_t = int;
33+
using offset_t = unsigned;
34+
using scan_op_t = cuda::std::plus<>;
35+
using accum_t = cuda::std::__accumulator_t<scan_op_t, value_t, value_t>;
36+
const offset_t num_items = 12345;
37+
38+
c2h::device_vector<value_t> in_items(num_items);
39+
c2h::device_vector<value_t> out_items(num_items, thrust::no_init);
40+
c2h::gen(C2H_SEED(1), in_items);
41+
42+
c2h::host_vector<value_t> expected(num_items);
43+
c2h::host_vector<value_t> host_items(in_items);
44+
compute_inclusive_scan_reference(host_items.cbegin(), host_items.cend(), expected.begin(), scan_op_t{}, value_t{});
45+
46+
using policy_hub_t = my_policy_hub<value_t, value_t, accum_t, offset_t, scan_op_t>;
47+
using dispatch_t =
48+
DispatchScan<value_t*, value_t*, scan_op_t, NullType, offset_t, accum_t, ForceInclusive::No, policy_hub_t>;
49+
size_t temp_size = 0;
50+
dispatch_t::Dispatch(
51+
nullptr,
52+
temp_size,
53+
thrust::raw_pointer_cast(in_items.data()),
54+
thrust::raw_pointer_cast(out_items.data()),
55+
scan_op_t{},
56+
NullType{},
57+
num_items,
58+
/* stream */ nullptr);
59+
c2h::device_vector<uint8_t> temp_storage(temp_size, thrust::no_init);
60+
dispatch_t::Dispatch(
61+
thrust::raw_pointer_cast(temp_storage.data()),
62+
temp_size,
63+
thrust::raw_pointer_cast(in_items.data()),
64+
thrust::raw_pointer_cast(out_items.data()),
65+
scan_op_t{},
66+
NullType{},
67+
num_items,
68+
/* stream */ nullptr);
69+
70+
REQUIRE(out_items == expected);
71+
}

0 commit comments

Comments
 (0)