|
| 1 | +/* |
| 2 | + * SPDX-FileCopyrightText: Copyright (c) 2025 DeepSeek |
| 3 | + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 4 | + * |
| 5 | + * This file incorporates material from the DeepSeek project, licensed under the MIT License. |
| 6 | + * The modifications made by NVIDIA are licensed under the Apache License, Version 2.0. |
| 7 | + * |
| 8 | + * SPDX-License-Identifier: MIT AND Apache-2.0 |
| 9 | + * |
| 10 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 11 | + * you may not use this file except in compliance with the License. |
| 12 | + * You may obtain a copy of the License at |
| 13 | + * |
| 14 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 15 | + * |
| 16 | + * Unless required by applicable law or agreed to in writing, software |
| 17 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 18 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 19 | + * See the License for the specific language governing permissions and |
| 20 | + * limitations under the License. |
| 21 | + */ |
| 22 | + |
| 23 | +#pragma once |
| 24 | + |
| 25 | +#include "kernels/api.cuh" |
| 26 | +#include "kernels/exception.cuh" |
| 27 | + |
| 28 | +namespace nixl_ep { |
| 29 | + |
| 30 | +template <typename dtype_t> |
| 31 | +dtype_t ceil_div(dtype_t a, dtype_t b) { |
| 32 | + return (a + b - 1) / b; |
| 33 | +} |
| 34 | + |
| 35 | +template <typename dtype_t> |
| 36 | +dtype_t align(dtype_t a, dtype_t b) { |
| 37 | + return ceil_div<dtype_t>(a, b) * b; |
| 38 | +} |
| 39 | + |
| 40 | +struct EPBuffer { |
| 41 | + int num_clean_int = 0; |
| 42 | + |
| 43 | + void* dispatch_rdma_send_buffer = nullptr; |
| 44 | + void* dispatch_rdma_recv_data_buffer = nullptr; |
| 45 | + int* dispatch_rdma_recv_count_buffer = nullptr; |
| 46 | + |
| 47 | + void* combine_rdma_send_buffer = nullptr; |
| 48 | + void* combine_rdma_recv_data_buffer = nullptr; |
| 49 | + int* combine_rdma_recv_flag_buffer = nullptr; |
| 50 | + |
| 51 | + void* combine_rdma_send_buffer_data_start = nullptr; |
| 52 | + size_t num_bytes_per_combine_msg = 0; |
| 53 | + |
| 54 | + std::pair<int*, int> clean_meta() { |
| 55 | + EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); |
| 56 | + return {dispatch_rdma_recv_count_buffer, num_clean_int}; |
| 57 | + } |
| 58 | +}; |
| 59 | + |
| 60 | +struct EPLayout { |
| 61 | + size_t total_bytes = 0; |
| 62 | + EPBuffer buffers[2]; |
| 63 | + |
| 64 | + template <typename out_ptr_t = void*, typename count_ptr_t = uint8_t*, typename in_ptr_t = void*> |
| 65 | + out_ptr_t advance(const in_ptr_t& ptr, size_t count) { |
| 66 | + return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count); |
| 67 | + } |
| 68 | + |
| 69 | + EPLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { |
| 70 | + const int num_scales = hidden / 128; |
| 71 | + |
| 72 | + // Dispatch and combine layout: |
| 73 | + // - 2 symmetric odd/even send buffer |
| 74 | + // - 2 symmetric odd/even receive buffers |
| 75 | + // - 2 symmetric odd/even signaling buffers |
| 76 | + |
| 77 | + // Message sizes |
| 78 | + // NOTES: you should add a control `int4` for combine messages if you want to do data transformation |
| 79 | + // NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-128-channel min/max |
| 80 | + EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); |
| 81 | + size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); |
| 82 | + size_t num_bytes_per_combine_msg = num_scales * sizeof(nv_bfloat162) + hidden * sizeof(nv_bfloat16); |
| 83 | + |
| 84 | + // Send buffer |
| 85 | + size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; |
| 86 | + size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; |
| 87 | + size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes); |
| 88 | + EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0); |
| 89 | + total_bytes += send_buffer_bytes * 2; |
| 90 | + |
| 91 | + // Symmetric receive buffers |
| 92 | + // TODO: optimize memory usages |
| 93 | + size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; |
| 94 | + size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; |
| 95 | + size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes); |
| 96 | + EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0); |
| 97 | + total_bytes += recv_buffer_bytes * 2; |
| 98 | + |
| 99 | + // Symmetric signaling buffers |
| 100 | + size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); |
| 101 | + size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; |
| 102 | + size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); |
| 103 | + size_t signaling_buffer_bytes_aligned = align<size_t>(signaling_buffer_bytes, 128); |
| 104 | + total_bytes += signaling_buffer_bytes_aligned * 2; |
| 105 | + |
| 106 | + // Assign pointers |
| 107 | + // NOTES: we still leave some space for distinguishing dispatch/combine buffer, |
| 108 | + // so you may see some parameters are duplicated |
| 109 | + for (int i = 0; i < 2; ++ i) { |
| 110 | + buffers[i] = { |
| 111 | + static_cast<int>(signaling_buffer_bytes / sizeof(int)), |
| 112 | + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), |
| 113 | + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i), |
| 114 | + advance<int*>(rdma_buffer, signaling_buffer_bytes_aligned * i), |
| 115 | + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), |
| 116 | + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i), |
| 117 | + advance<int*>(rdma_buffer, signaling_buffer_bytes_aligned * i), |
| 118 | + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), |
| 119 | + num_bytes_per_combine_msg |
| 120 | + }; |
| 121 | + } |
| 122 | + } |
| 123 | +}; |
| 124 | + |
| 125 | +size_t get_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { |
| 126 | + auto num_bytes = EPLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes; |
| 127 | + return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; |
| 128 | +} |
| 129 | + |
| 130 | +} // namespace nixl_ep |
0 commit comments