|
| 1 | +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 3 | + |
| 4 | +#include "insert_nested_NVTX_range_guard.h" |
| 5 | + |
| 6 | +#include <cub/device/dispatch/dispatch_batch_memcpy.cuh> |
| 7 | + |
| 8 | +#include <thrust/detail/raw_pointer_cast.h> |
| 9 | + |
| 10 | +#include <cuda/std/array> |
| 11 | +#include <cuda/std/cstdint> |
| 12 | + |
| 13 | +#include <c2h/catch2_test_helper.h> |
| 14 | + |
| 15 | +using namespace cub; |
| 16 | + |
| 17 | +// TODO(bgruber): drop this test with CCCL 4.0 when we drop the batch memcpy dispatcher after publishing the tuning API |
| 18 | + |
| 19 | +template <class BufferOffsetT, class BlockOffsetT> |
| 20 | +struct my_policy_hub |
| 21 | +{ |
| 22 | + static constexpr uint32_t BLOCK_THREADS = 128U; |
| 23 | + static constexpr uint32_t BUFFERS_PER_THREAD = 4U; |
| 24 | + static constexpr uint32_t TLEV_BYTES_PER_THREAD = 8U; |
| 25 | + |
| 26 | + static constexpr uint32_t LARGE_BUFFER_BLOCK_THREADS = 256U; |
| 27 | + static constexpr uint32_t LARGE_BUFFER_BYTES_PER_THREAD = 32U; |
| 28 | + |
| 29 | + static constexpr uint32_t WARP_LEVEL_THRESHOLD = 128; |
| 30 | + static constexpr uint32_t BLOCK_LEVEL_THRESHOLD = 8 * 1024; |
| 31 | + |
| 32 | + using buff_delay_constructor_t = cub::detail::default_delay_constructor_t<BufferOffsetT>; |
| 33 | + using block_delay_constructor_t = cub::detail::default_delay_constructor_t<BlockOffsetT>; |
| 34 | + |
| 35 | + // from Policy500 of the CUB batch memcpy tunings |
| 36 | + struct MaxPolicy : ChainedPolicy<500, MaxPolicy, MaxPolicy> |
| 37 | + { |
| 38 | + using AgentSmallBufferPolicyT = cub::detail::batch_memcpy::AgentBatchMemcpyPolicy< |
| 39 | + BLOCK_THREADS, |
| 40 | + BUFFERS_PER_THREAD, |
| 41 | + TLEV_BYTES_PER_THREAD, |
| 42 | + /* PREFER_POW2_BITS */ true, |
| 43 | + LARGE_BUFFER_BLOCK_THREADS * LARGE_BUFFER_BYTES_PER_THREAD, |
| 44 | + WARP_LEVEL_THRESHOLD, |
| 45 | + BLOCK_LEVEL_THRESHOLD, |
| 46 | + buff_delay_constructor_t, |
| 47 | + block_delay_constructor_t>; |
| 48 | + |
| 49 | + using AgentLargeBufferPolicyT = |
| 50 | + cub::detail::batch_memcpy::agent_large_buffer_policy<LARGE_BUFFER_BLOCK_THREADS, LARGE_BUFFER_BYTES_PER_THREAD>; |
| 51 | + }; |
| 52 | +}; |
| 53 | + |
| 54 | +C2H_TEST("DispatchBatchMemcpy::Dispatch: custom policy hub", "[device][memcpy]") |
| 55 | +{ |
| 56 | + using value_t = cuda::std::uint8_t; |
| 57 | + using buffer_size_t = cuda::std::uint32_t; |
| 58 | + using block_offset_t = cuda::std::uint32_t; |
| 59 | + using buffer_offset_t = cub::detail::batch_memcpy::per_invocation_buffer_offset_t; |
| 60 | + |
| 61 | + const cuda::std::array<buffer_size_t, 5> buffer_sizes{3, 128, 512, 4096, 9000}; |
| 62 | + |
| 63 | + c2h::host_vector<c2h::device_vector<value_t>> in_buffers(buffer_sizes.size()); |
| 64 | + c2h::host_vector<c2h::device_vector<value_t>> out_buffers(buffer_sizes.size()); |
| 65 | + |
| 66 | + c2h::host_vector<value_t*> h_in_ptrs(buffer_sizes.size()); |
| 67 | + c2h::host_vector<value_t*> h_out_ptrs(buffer_sizes.size()); |
| 68 | + c2h::host_vector<buffer_size_t> h_sizes(buffer_sizes.size()); |
| 69 | + |
| 70 | + for (buffer_size_t i = 0; i < buffer_sizes.size(); ++i) |
| 71 | + { |
| 72 | + const auto bytes = buffer_sizes[i]; |
| 73 | + in_buffers[i].resize(bytes); |
| 74 | + out_buffers[i].resize(bytes); |
| 75 | + c2h::gen(C2H_SEED(1), in_buffers[i]); |
| 76 | + |
| 77 | + h_in_ptrs[i] = thrust::raw_pointer_cast(in_buffers[i].data()); |
| 78 | + h_out_ptrs[i] = thrust::raw_pointer_cast(out_buffers[i].data()); |
| 79 | + h_sizes[i] = bytes; |
| 80 | + } |
| 81 | + |
| 82 | + c2h::device_vector<value_t*> d_in_ptrs = h_in_ptrs; |
| 83 | + c2h::device_vector<value_t*> d_out_ptrs = h_out_ptrs; |
| 84 | + c2h::device_vector<buffer_size_t> d_sizes = h_sizes; |
| 85 | + |
| 86 | + using policy_hub_t = my_policy_hub<buffer_offset_t, block_offset_t>; |
| 87 | + using dispatch_t = |
| 88 | + cub::detail::DispatchBatchMemcpy<value_t**, value_t**, buffer_size_t*, block_offset_t, CopyAlg::Memcpy, policy_hub_t>; |
| 89 | + |
| 90 | + size_t temp_size = 0; |
| 91 | + dispatch_t::Dispatch( |
| 92 | + nullptr, |
| 93 | + temp_size, |
| 94 | + thrust::raw_pointer_cast(d_in_ptrs.data()), |
| 95 | + thrust::raw_pointer_cast(d_out_ptrs.data()), |
| 96 | + thrust::raw_pointer_cast(d_sizes.data()), |
| 97 | + static_cast<cuda::std::int64_t>(buffer_sizes.size()), |
| 98 | + /* stream */ nullptr); |
| 99 | + c2h::device_vector<::cuda::std::uint8_t> temp_storage(temp_size, thrust::no_init); |
| 100 | + dispatch_t::Dispatch( |
| 101 | + thrust::raw_pointer_cast(temp_storage.data()), |
| 102 | + temp_size, |
| 103 | + thrust::raw_pointer_cast(d_in_ptrs.data()), |
| 104 | + thrust::raw_pointer_cast(d_out_ptrs.data()), |
| 105 | + thrust::raw_pointer_cast(d_sizes.data()), |
| 106 | + static_cast<cuda::std::int64_t>(buffer_sizes.size()), |
| 107 | + /* stream */ nullptr); |
| 108 | + |
| 109 | + for (size_t i = 0; i < buffer_sizes.size(); ++i) |
| 110 | + { |
| 111 | + c2h::host_vector<value_t> host_in(in_buffers[i]); |
| 112 | + c2h::host_vector<value_t> host_out(out_buffers[i]); |
| 113 | + REQUIRE(host_out == host_in); |
| 114 | + } |
| 115 | +} |
0 commit comments