diff --git a/paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.cpp b/paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.cpp index 0f1b9c55db0f1c..da6383273dce59 100644 --- a/paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.cpp +++ b/paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.cpp @@ -240,34 +240,44 @@ Buffer::intranode_dispatch( const std::optional& num_tokens_per_rank, const deep_ep::detail::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, - int cached_num_recv_tokens, - const std::optional& cached_rank_prefix_matrix, - const std::optional& cached_channel_prefix_matrix, + int cached_num_recv_tokens, // num_experts in cache mode + const std::optional& + cached_rank_prefix_matrix, // topk_idx in cache mode + const std::optional& + cached_channel_prefix_matrix, // took_weights in cache mode int expert_alignment, const Config& config, std::optional& previous_event, // NOLINT bool async, bool allocate_on_comm_stream) { + int curr_num_experts; + std::optional curr_topk_idx; + std::optional curr_topk_weights; + if (topk_idx.has_value()) { - EP_HOST_ASSERT(topk_idx.has_value() && topk_weights.has_value() && - num_tokens_per_rank.has_value() && + EP_HOST_ASSERT(topk_weights.has_value() && num_tokens_per_expert.has_value()); - last_topk_idx = ConvertPaddleTensorToDetailTensor( + curr_topk_idx = ConvertPaddleTensorToDetailTensor( cast_ad_func(topk_idx->raw_tensor(), phi::DataType::INT32)); - last_topk_weights = ConvertPaddleTensorToDetailTensor( + curr_topk_weights = ConvertPaddleTensorToDetailTensor( assign_ad_func(topk_weights->raw_tensor())); - last_num_experts = static_cast(num_tokens_per_expert->size(0)); + curr_num_experts = static_cast(num_tokens_per_expert->size(0)); + } else { // cache mode - EP_HOST_ASSERT(last_topk_idx.has_value() && last_topk_weights.has_value() && - last_num_experts != 0); + EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value() && + cached_channel_prefix_matrix.has_value()); + curr_topk_idx = cached_rank_prefix_matrix; + curr_topk_weights = cached_channel_prefix_matrix; + curr_num_experts = cached_num_recv_tokens; } + EP_HOST_ASSERT(curr_num_experts != 0); // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous()); auto num_tokens = static_cast(x.size(0)); int hidden_size = static_cast(x.size(1)); - int num_topk = static_cast(last_topk_idx->size(1)); - auto num_local_experts = last_num_experts / num_ranks; + int num_topk = static_cast(curr_topk_idx->size(1)); + auto num_local_experts = curr_num_experts / num_ranks; int ret = 0; // For int8 dispatch, the corresponding combine would be bf16, @@ -307,17 +317,14 @@ Buffer::intranode_dispatch( {num_local_experts}, phi::DataType::INT32, x.place())); auto h_num_recv_tokens_per_expert_list = std::vector(num_local_experts, 0); - auto rank_prefix_matrix = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_ranks, num_ranks}, phi::DataType::INT32, x.place())); - auto channel_prefix_matrix = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_ranks, 12}, phi::DataType::INT32, x.place())); + auto rank_prefix_matrix = curr_topk_idx; + auto channel_prefix_matrix = curr_topk_weights; auto recv_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( {num_ranks, 12}, phi::DataType::INT32, x.place())); - auto recv_src_idx = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({10}, phi::DataType::INT32, x.place())); + auto recv_src_idx = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {curr_num_experts}, phi::DataType::INT32, x.place())); auto send_head = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( {num_tokens, num_ranks}, phi::DataType::INT32, x.place())); @@ -326,13 +333,13 @@ Buffer::intranode_dispatch( bkcl_notify_dispatch_standard_with_num_recv_tokens_per_expert_list_cpu( comm_ctx->GetBKCLComm(), x.data_ptr(), - last_topk_idx->data_ptr(), - last_topk_weights->data_ptr(), + curr_topk_idx->data_ptr(), + curr_topk_weights->data_ptr(), num_scales, hidden_size, num_tokens, num_topk, - last_num_experts, + curr_num_experts, d_num_recv_tokens_per_expert_list .data_ptr(), // should not be nullptr h_num_recv_tokens_per_expert_list.data(), @@ -349,13 +356,13 @@ Buffer::intranode_dispatch( std::optional recv_topk_idx = ConvertPaddleTensorToDetailTensor( paddle::experimental::empty({num_recv_tokens, num_topk}, - last_topk_idx->dtype(), - last_topk_idx->place())); + curr_topk_idx->dtype(), + curr_topk_idx->place())); std::optional recv_topk_weights = ConvertPaddleTensorToDetailTensor( paddle::experimental::empty({num_recv_tokens, num_topk}, - last_topk_weights->dtype(), - last_topk_weights->place())); + curr_topk_weights->dtype(), + curr_topk_weights->place())); auto recv_x_scales = std::optional(); float* x_scales_ptr = nullptr; @@ -372,18 +379,18 @@ Buffer::intranode_dispatch( VLOG(3) << "DeepEP intranode_dispatch num_local_experts " << num_local_experts << " num_scales " << num_scales << " hidden_size " << hidden_size - << " num_tokens " << num_tokens << " last_num_experts " - << last_num_experts << " num_recv_tokens " << num_recv_tokens; + << " num_tokens " << num_tokens << " curr_num_experts " + << curr_num_experts << " num_recv_tokens " << num_recv_tokens; VLOG(3) << "DeepEP intranode_dispatch x dim " << x.dim() - << " last_topk_idx dim " << last_topk_idx->dim() - << " last_topk_weights dim " << last_topk_weights->dim(); + << " curr_topk_idx dim " << curr_topk_idx->dim() + << " curr_topk_weights dim " << curr_topk_weights->dim(); ret = bkcl_normal_dispatch_standard( comm_ctx->GetBKCLComm(), x.data_ptr(), // sendbuf x_scales_ptr, - last_topk_idx->data_ptr(), - last_topk_weights->data_ptr(), + curr_topk_idx->data_ptr(), + curr_topk_weights->data_ptr(), recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx->data_ptr(), @@ -393,7 +400,7 @@ Buffer::intranode_dispatch( hidden_size, num_tokens, num_topk, - last_num_experts, + curr_num_experts, ToBKCLDataType(x.dtype()), use_int8, async ? reinterpret_cast(comm_stream) @@ -433,10 +440,10 @@ Buffer::intranode_dispatch( recv_topk_idx, recv_topk_weights, h_num_recv_tokens_per_expert_list, - rank_prefix_matrix, - channel_prefix_matrix, + rank_prefix_matrix.value(), // topk_idx in cache mode + channel_prefix_matrix.value(), // topk_weights in cache mode recv_channel_prefix_matrix, - recv_src_idx, + recv_src_idx, // num_experts in cache mode send_head, event}; } @@ -577,41 +584,48 @@ Buffer::internode_dispatch( const std::optional& num_tokens_per_rdma_rank, const deep_ep::detail::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, - int cached_num_recv_tokens, + int cached_num_recv_tokens, // num_experts in cache mode int cached_num_rdma_recv_tokens, const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& - cached_recv_rdma_rank_prefix_sum, + cached_recv_rdma_rank_prefix_sum, // topk_weights in cache mode const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& - cached_recv_gbl_rank_prefix_sum, + cached_recv_gbl_rank_prefix_sum, // topk_idx in cache mode int expert_alignment, const Config& config, std::optional& previous_event, // NOLINT bool async, bool allocate_on_comm_stream) { + int curr_num_experts; + std::optional curr_topk_idx; + std::optional curr_topk_weights; + if (topk_idx.has_value()) { - EP_HOST_ASSERT(topk_idx.has_value() && topk_weights.has_value() && - num_tokens_per_rank.has_value() && + EP_HOST_ASSERT(topk_weights.has_value() && num_tokens_per_expert.has_value()); - last_topk_idx = ConvertPaddleTensorToDetailTensor( + curr_topk_idx = ConvertPaddleTensorToDetailTensor( cast_ad_func(topk_idx->raw_tensor(), phi::DataType::INT32)); - last_topk_weights = ConvertPaddleTensorToDetailTensor( + curr_topk_weights = ConvertPaddleTensorToDetailTensor( assign_ad_func(topk_weights->raw_tensor())); - last_num_experts = static_cast(num_tokens_per_expert->size(0)); + curr_num_experts = static_cast(num_tokens_per_expert->size(0)); } else { // cache mode - EP_HOST_ASSERT(last_topk_idx.has_value() && last_topk_weights.has_value() && - last_num_experts != 0); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value() && + cached_recv_rdma_rank_prefix_sum.has_value()); + curr_topk_idx = cached_recv_gbl_rank_prefix_sum; + curr_topk_weights = cached_recv_rdma_rank_prefix_sum; + curr_num_experts = cached_num_recv_tokens; } + EP_HOST_ASSERT(curr_num_experts != 0); // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous()); auto num_tokens = static_cast(x.size(0)); int hidden_size = static_cast(x.size(1)); - int num_topk = static_cast(last_topk_idx->size(1)); - auto num_local_experts = last_num_experts / num_ranks; + int num_topk = static_cast(curr_topk_idx->size(1)); + auto num_local_experts = curr_num_experts / num_ranks; int ret = 0; // For int8 dispatch, the corresponding combine would be bf16, @@ -657,22 +671,20 @@ Buffer::internode_dispatch( paddle::experimental::empty({10}, phi::DataType::INT32, x.place())); auto gbl_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( paddle::experimental::empty({10}, phi::DataType::INT32, x.place())); - auto recv_rdma_rank_prefix_sum = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({10}, phi::DataType::INT32, x.place())); - auto recv_gbl_rank_prefix_sum = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({10}, phi::DataType::INT32, x.place())); + auto recv_rdma_rank_prefix_sum = curr_topk_weights; + auto recv_gbl_rank_prefix_sum = curr_topk_idx; int num_recv_tokens = bkcl_notify_dispatch_standard_with_num_recv_tokens_per_expert_list_cpu( comm_ctx->GetBKCLComm(), x.data_ptr(), // x - last_topk_idx->data_ptr(), - last_topk_weights->data_ptr(), // topk_weight + curr_topk_idx->data_ptr(), + curr_topk_weights->data_ptr(), // topk_weight num_scales, hidden_size, num_tokens, num_topk, - last_num_experts, + curr_num_experts, d_num_recv_tokens_per_expert_list .data_ptr(), // should not be nullptr h_num_recv_tokens_per_expert_list.data(), @@ -686,32 +698,32 @@ Buffer::internode_dispatch( std::optional recv_rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {1, 1}, phi::DataType::INT32, last_topk_idx->place())); + {1, 1}, phi::DataType::INT32, curr_topk_idx->place())); std::optional recv_gbl_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {1, 1}, phi::DataType::INT32, last_topk_idx->place())); + {1, 1}, phi::DataType::INT32, curr_topk_idx->place())); std::optional recv_src_meta = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_recv_tokens, 1}, phi::DataType::INT32, last_topk_idx->place())); + {curr_num_experts, 1}, phi::DataType::INT32, curr_topk_idx->place())); std::optional send_rdma_head = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {1, 1}, phi::DataType::INT32, last_topk_idx->place())); + {1, 1}, phi::DataType::INT32, curr_topk_idx->place())); std::optional send_nvl_head = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {1, 1}, phi::DataType::INT32, last_topk_idx->place())); + {1, 1}, phi::DataType::INT32, curr_topk_idx->place())); auto recv_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( {num_recv_tokens, hidden_size}, x.dtype(), x.place())); std::optional recv_topk_idx = ConvertPaddleTensorToDetailTensor( paddle::experimental::empty({num_recv_tokens, num_topk}, - last_topk_idx->dtype(), - last_topk_idx->place())); + curr_topk_idx->dtype(), + curr_topk_idx->place())); std::optional recv_topk_weights = ConvertPaddleTensorToDetailTensor( paddle::experimental::empty({num_recv_tokens, num_topk}, - last_topk_weights->dtype(), - last_topk_weights->place())); + curr_topk_weights->dtype(), + curr_topk_weights->place())); auto recv_x_scales = std::optional(); float* x_scales_ptr = nullptr; @@ -728,15 +740,15 @@ Buffer::internode_dispatch( VLOG(3) << "DeepEP internode_dispatch num_local_experts " << num_local_experts << " num_scales " << num_scales << " hidden_size " << hidden_size - << " num_tokens " << num_tokens << " last_num_experts " - << last_num_experts << " num_recv_tokens " << num_recv_tokens; + << " num_tokens " << num_tokens << " curr_num_experts " + << curr_num_experts << " num_recv_tokens " << num_recv_tokens; ret = bkcl_normal_dispatch_standard( comm_ctx->GetBKCLComm(), x.data_ptr(), // sendbuf x_scales_ptr, - last_topk_idx->data_ptr(), - last_topk_weights->data_ptr(), + curr_topk_idx->data_ptr(), + curr_topk_weights->data_ptr(), recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx->data_ptr(), @@ -746,7 +758,7 @@ Buffer::internode_dispatch( hidden_size, num_tokens, num_topk, - last_num_experts, + curr_num_experts, ToBKCLDataType(x.dtype()), use_int8, async ? reinterpret_cast(comm_stream) @@ -789,10 +801,10 @@ Buffer::internode_dispatch( rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix, - recv_rdma_rank_prefix_sum, + recv_rdma_rank_prefix_sum.value(), // topk_weights in cache mode recv_gbl_channel_prefix_matrix, - recv_gbl_rank_prefix_sum, - recv_src_meta, + recv_gbl_rank_prefix_sum.value(), // topk_idx in cache mode + recv_src_meta, // num_experts in cache mode send_rdma_head, send_nvl_head, event}; diff --git a/paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.hpp b/paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.hpp index e310544956a794..9d13ad58f73f74 100644 --- a/paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.hpp +++ b/paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.hpp @@ -105,10 +105,6 @@ struct Buffer { volatile int* moe_recv_rdma_counter = nullptr; int* moe_recv_rdma_counter_mapped = nullptr; - std::optional last_topk_idx = std::nullopt; - std::optional last_topk_weights = std::nullopt; - int last_num_experts = 0; - public: Buffer(int rank, int num_ranks,