Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 85 additions & 73 deletions paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,34 +240,44 @@ Buffer::intranode_dispatch(
const std::optional<deep_ep::detail::Tensor>& num_tokens_per_rank,
const deep_ep::detail::Tensor& is_token_in_rank,
const std::optional<deep_ep::detail::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens,
const std::optional<deep_ep::detail::Tensor>& cached_rank_prefix_matrix,
const std::optional<deep_ep::detail::Tensor>& cached_channel_prefix_matrix,
int cached_num_recv_tokens, // num_experts in cache mode
const std::optional<deep_ep::detail::Tensor>&
cached_rank_prefix_matrix, // topk_idx in cache mode
const std::optional<deep_ep::detail::Tensor>&
cached_channel_prefix_matrix, // took_weights in cache mode
int expert_alignment,
const Config& config,
std::optional<EventHandle>& previous_event, // NOLINT
bool async,
bool allocate_on_comm_stream) {
int curr_num_experts;
std::optional<deep_ep::detail::Tensor> curr_topk_idx;
std::optional<deep_ep::detail::Tensor> curr_topk_weights;

if (topk_idx.has_value()) {
EP_HOST_ASSERT(topk_idx.has_value() && topk_weights.has_value() &&
num_tokens_per_rank.has_value() &&
EP_HOST_ASSERT(topk_weights.has_value() &&
num_tokens_per_expert.has_value());
last_topk_idx = ConvertPaddleTensorToDetailTensor(
curr_topk_idx = ConvertPaddleTensorToDetailTensor(
cast_ad_func(topk_idx->raw_tensor(), phi::DataType::INT32));
last_topk_weights = ConvertPaddleTensorToDetailTensor(
curr_topk_weights = ConvertPaddleTensorToDetailTensor(
assign_ad_func(topk_weights->raw_tensor()));
last_num_experts = static_cast<int>(num_tokens_per_expert->size(0));
curr_num_experts = static_cast<int>(num_tokens_per_expert->size(0));

} else { // cache mode
EP_HOST_ASSERT(last_topk_idx.has_value() && last_topk_weights.has_value() &&
last_num_experts != 0);
EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value() &&
cached_channel_prefix_matrix.has_value());
curr_topk_idx = cached_rank_prefix_matrix;
curr_topk_weights = cached_channel_prefix_matrix;
curr_num_experts = cached_num_recv_tokens;
}
EP_HOST_ASSERT(curr_num_experts != 0);

// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous());
auto num_tokens = static_cast<int>(x.size(0));
int hidden_size = static_cast<int>(x.size(1));
int num_topk = static_cast<int>(last_topk_idx->size(1));
auto num_local_experts = last_num_experts / num_ranks;
int num_topk = static_cast<int>(curr_topk_idx->size(1));
auto num_local_experts = curr_num_experts / num_ranks;
int ret = 0;

// For int8 dispatch, the corresponding combine would be bf16,
Expand Down Expand Up @@ -307,17 +317,14 @@ Buffer::intranode_dispatch(
{num_local_experts}, phi::DataType::INT32, x.place()));
auto h_num_recv_tokens_per_expert_list =
std::vector<int>(num_local_experts, 0);
auto rank_prefix_matrix =
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{num_ranks, num_ranks}, phi::DataType::INT32, x.place()));
auto channel_prefix_matrix =
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{num_ranks, 12}, phi::DataType::INT32, x.place()));
auto rank_prefix_matrix = curr_topk_idx;
auto channel_prefix_matrix = curr_topk_weights;
auto recv_channel_prefix_matrix =
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{num_ranks, 12}, phi::DataType::INT32, x.place()));
auto recv_src_idx = ConvertPaddleTensorToDetailTensor(
paddle::experimental::empty({10}, phi::DataType::INT32, x.place()));
auto recv_src_idx =
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{curr_num_experts}, phi::DataType::INT32, x.place()));
auto send_head =
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{num_tokens, num_ranks}, phi::DataType::INT32, x.place()));
Expand All @@ -326,13 +333,13 @@ Buffer::intranode_dispatch(
bkcl_notify_dispatch_standard_with_num_recv_tokens_per_expert_list_cpu(
comm_ctx->GetBKCLComm(),
x.data_ptr(),
last_topk_idx->data_ptr<int>(),
last_topk_weights->data_ptr<float>(),
curr_topk_idx->data_ptr<int>(),
curr_topk_weights->data_ptr<float>(),
num_scales,
hidden_size,
num_tokens,
num_topk,
last_num_experts,
curr_num_experts,
d_num_recv_tokens_per_expert_list
.data_ptr<int>(), // should not be nullptr
h_num_recv_tokens_per_expert_list.data(),
Expand All @@ -349,13 +356,13 @@ Buffer::intranode_dispatch(
std::optional<deep_ep::detail::Tensor> recv_topk_idx =
ConvertPaddleTensorToDetailTensor(
paddle::experimental::empty({num_recv_tokens, num_topk},
last_topk_idx->dtype(),
last_topk_idx->place()));
curr_topk_idx->dtype(),
curr_topk_idx->place()));
std::optional<deep_ep::detail::Tensor> recv_topk_weights =
ConvertPaddleTensorToDetailTensor(
paddle::experimental::empty({num_recv_tokens, num_topk},
last_topk_weights->dtype(),
last_topk_weights->place()));
curr_topk_weights->dtype(),
curr_topk_weights->place()));

auto recv_x_scales = std::optional<deep_ep::detail::Tensor>();
float* x_scales_ptr = nullptr;
Expand All @@ -372,18 +379,18 @@ Buffer::intranode_dispatch(

VLOG(3) << "DeepEP intranode_dispatch num_local_experts " << num_local_experts
<< " num_scales " << num_scales << " hidden_size " << hidden_size
<< " num_tokens " << num_tokens << " last_num_experts "
<< last_num_experts << " num_recv_tokens " << num_recv_tokens;
<< " num_tokens " << num_tokens << " curr_num_experts "
<< curr_num_experts << " num_recv_tokens " << num_recv_tokens;
VLOG(3) << "DeepEP intranode_dispatch x dim " << x.dim()
<< " last_topk_idx dim " << last_topk_idx->dim()
<< " last_topk_weights dim " << last_topk_weights->dim();
<< " curr_topk_idx dim " << curr_topk_idx->dim()
<< " curr_topk_weights dim " << curr_topk_weights->dim();

ret = bkcl_normal_dispatch_standard(
comm_ctx->GetBKCLComm(),
x.data_ptr(), // sendbuf
x_scales_ptr,
last_topk_idx->data_ptr<int>(),
last_topk_weights->data_ptr<float>(),
curr_topk_idx->data_ptr<int>(),
curr_topk_weights->data_ptr<float>(),
recv_x.data_ptr(),
recv_x_scales_ptr,
recv_topk_idx->data_ptr<int>(),
Expand All @@ -393,7 +400,7 @@ Buffer::intranode_dispatch(
hidden_size,
num_tokens,
num_topk,
last_num_experts,
curr_num_experts,
ToBKCLDataType(x.dtype()),
use_int8,
async ? reinterpret_cast<XPUStream>(comm_stream)
Expand Down Expand Up @@ -433,10 +440,10 @@ Buffer::intranode_dispatch(
recv_topk_idx,
recv_topk_weights,
h_num_recv_tokens_per_expert_list,
rank_prefix_matrix,
channel_prefix_matrix,
rank_prefix_matrix.value(), // topk_idx in cache mode
channel_prefix_matrix.value(), // topk_weights in cache mode
recv_channel_prefix_matrix,
recv_src_idx,
recv_src_idx, // num_experts in cache mode
send_head,
event};
}
Expand Down Expand Up @@ -577,41 +584,48 @@ Buffer::internode_dispatch(
const std::optional<deep_ep::detail::Tensor>& num_tokens_per_rdma_rank,
const deep_ep::detail::Tensor& is_token_in_rank,
const std::optional<deep_ep::detail::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens,
int cached_num_recv_tokens, // num_experts in cache mode
int cached_num_rdma_recv_tokens,
const std::optional<deep_ep::detail::Tensor>&
cached_rdma_channel_prefix_matrix,
const std::optional<deep_ep::detail::Tensor>&
cached_recv_rdma_rank_prefix_sum,
cached_recv_rdma_rank_prefix_sum, // topk_weights in cache mode
const std::optional<deep_ep::detail::Tensor>&
cached_gbl_channel_prefix_matrix,
const std::optional<deep_ep::detail::Tensor>&
cached_recv_gbl_rank_prefix_sum,
cached_recv_gbl_rank_prefix_sum, // topk_idx in cache mode
int expert_alignment,
const Config& config,
std::optional<EventHandle>& previous_event, // NOLINT
bool async,
bool allocate_on_comm_stream) {
int curr_num_experts;
std::optional<deep_ep::detail::Tensor> curr_topk_idx;
std::optional<deep_ep::detail::Tensor> curr_topk_weights;

if (topk_idx.has_value()) {
EP_HOST_ASSERT(topk_idx.has_value() && topk_weights.has_value() &&
num_tokens_per_rank.has_value() &&
EP_HOST_ASSERT(topk_weights.has_value() &&
num_tokens_per_expert.has_value());
last_topk_idx = ConvertPaddleTensorToDetailTensor(
curr_topk_idx = ConvertPaddleTensorToDetailTensor(
cast_ad_func(topk_idx->raw_tensor(), phi::DataType::INT32));
last_topk_weights = ConvertPaddleTensorToDetailTensor(
curr_topk_weights = ConvertPaddleTensorToDetailTensor(
assign_ad_func(topk_weights->raw_tensor()));
last_num_experts = static_cast<int>(num_tokens_per_expert->size(0));
curr_num_experts = static_cast<int>(num_tokens_per_expert->size(0));
} else { // cache mode
EP_HOST_ASSERT(last_topk_idx.has_value() && last_topk_weights.has_value() &&
last_num_experts != 0);
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value() &&
cached_recv_rdma_rank_prefix_sum.has_value());
curr_topk_idx = cached_recv_gbl_rank_prefix_sum;
curr_topk_weights = cached_recv_rdma_rank_prefix_sum;
curr_num_experts = cached_num_recv_tokens;
}
EP_HOST_ASSERT(curr_num_experts != 0);

// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous());
auto num_tokens = static_cast<int>(x.size(0));
int hidden_size = static_cast<int>(x.size(1));
int num_topk = static_cast<int>(last_topk_idx->size(1));
auto num_local_experts = last_num_experts / num_ranks;
int num_topk = static_cast<int>(curr_topk_idx->size(1));
auto num_local_experts = curr_num_experts / num_ranks;
int ret = 0;

// For int8 dispatch, the corresponding combine would be bf16,
Expand Down Expand Up @@ -657,22 +671,20 @@ Buffer::internode_dispatch(
paddle::experimental::empty({10}, phi::DataType::INT32, x.place()));
auto gbl_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor(
paddle::experimental::empty({10}, phi::DataType::INT32, x.place()));
auto recv_rdma_rank_prefix_sum = ConvertPaddleTensorToDetailTensor(
paddle::experimental::empty({10}, phi::DataType::INT32, x.place()));
auto recv_gbl_rank_prefix_sum = ConvertPaddleTensorToDetailTensor(
paddle::experimental::empty({10}, phi::DataType::INT32, x.place()));
auto recv_rdma_rank_prefix_sum = curr_topk_weights;
auto recv_gbl_rank_prefix_sum = curr_topk_idx;

int num_recv_tokens =
bkcl_notify_dispatch_standard_with_num_recv_tokens_per_expert_list_cpu(
comm_ctx->GetBKCLComm(),
x.data_ptr(), // x
last_topk_idx->data_ptr<int>(),
last_topk_weights->data_ptr<float>(), // topk_weight
curr_topk_idx->data_ptr<int>(),
curr_topk_weights->data_ptr<float>(), // topk_weight
num_scales,
hidden_size,
num_tokens,
num_topk,
last_num_experts,
curr_num_experts,
d_num_recv_tokens_per_expert_list
.data_ptr<int>(), // should not be nullptr
h_num_recv_tokens_per_expert_list.data(),
Expand All @@ -686,32 +698,32 @@ Buffer::internode_dispatch(

std::optional<deep_ep::detail::Tensor> recv_rdma_channel_prefix_matrix =
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{1, 1}, phi::DataType::INT32, last_topk_idx->place()));
{1, 1}, phi::DataType::INT32, curr_topk_idx->place()));
std::optional<deep_ep::detail::Tensor> recv_gbl_channel_prefix_matrix =
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{1, 1}, phi::DataType::INT32, last_topk_idx->place()));
{1, 1}, phi::DataType::INT32, curr_topk_idx->place()));
std::optional<deep_ep::detail::Tensor> recv_src_meta =
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{num_recv_tokens, 1}, phi::DataType::INT32, last_topk_idx->place()));
{curr_num_experts, 1}, phi::DataType::INT32, curr_topk_idx->place()));
std::optional<deep_ep::detail::Tensor> send_rdma_head =
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{1, 1}, phi::DataType::INT32, last_topk_idx->place()));
{1, 1}, phi::DataType::INT32, curr_topk_idx->place()));
std::optional<deep_ep::detail::Tensor> send_nvl_head =
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{1, 1}, phi::DataType::INT32, last_topk_idx->place()));
{1, 1}, phi::DataType::INT32, curr_topk_idx->place()));

auto recv_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{num_recv_tokens, hidden_size}, x.dtype(), x.place()));
std::optional<deep_ep::detail::Tensor> recv_topk_idx =
ConvertPaddleTensorToDetailTensor(
paddle::experimental::empty({num_recv_tokens, num_topk},
last_topk_idx->dtype(),
last_topk_idx->place()));
curr_topk_idx->dtype(),
curr_topk_idx->place()));
std::optional<deep_ep::detail::Tensor> recv_topk_weights =
ConvertPaddleTensorToDetailTensor(
paddle::experimental::empty({num_recv_tokens, num_topk},
last_topk_weights->dtype(),
last_topk_weights->place()));
curr_topk_weights->dtype(),
curr_topk_weights->place()));

auto recv_x_scales = std::optional<deep_ep::detail::Tensor>();
float* x_scales_ptr = nullptr;
Expand All @@ -728,15 +740,15 @@ Buffer::internode_dispatch(

VLOG(3) << "DeepEP internode_dispatch num_local_experts " << num_local_experts
<< " num_scales " << num_scales << " hidden_size " << hidden_size
<< " num_tokens " << num_tokens << " last_num_experts "
<< last_num_experts << " num_recv_tokens " << num_recv_tokens;
<< " num_tokens " << num_tokens << " curr_num_experts "
<< curr_num_experts << " num_recv_tokens " << num_recv_tokens;

ret = bkcl_normal_dispatch_standard(
comm_ctx->GetBKCLComm(),
x.data_ptr(), // sendbuf
x_scales_ptr,
last_topk_idx->data_ptr<int>(),
last_topk_weights->data_ptr<float>(),
curr_topk_idx->data_ptr<int>(),
curr_topk_weights->data_ptr<float>(),
recv_x.data_ptr(),
recv_x_scales_ptr,
recv_topk_idx->data_ptr<int>(),
Expand All @@ -746,7 +758,7 @@ Buffer::internode_dispatch(
hidden_size,
num_tokens,
num_topk,
last_num_experts,
curr_num_experts,
ToBKCLDataType(x.dtype()),
use_int8,
async ? reinterpret_cast<XPUStream>(comm_stream)
Expand Down Expand Up @@ -789,10 +801,10 @@ Buffer::internode_dispatch(
rdma_channel_prefix_matrix,
gbl_channel_prefix_matrix,
recv_rdma_channel_prefix_matrix,
recv_rdma_rank_prefix_sum,
recv_rdma_rank_prefix_sum.value(), // topk_weights in cache mode
recv_gbl_channel_prefix_matrix,
recv_gbl_rank_prefix_sum,
recv_src_meta,
recv_gbl_rank_prefix_sum.value(), // topk_idx in cache mode
recv_src_meta, // num_experts in cache mode
send_rdma_head,
send_nvl_head,
event};
Expand Down
4 changes: 0 additions & 4 deletions paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,6 @@ struct Buffer {
volatile int* moe_recv_rdma_counter = nullptr;
int* moe_recv_rdma_counter_mapped = nullptr;

std::optional<deep_ep::detail::Tensor> last_topk_idx = std::nullopt;
std::optional<deep_ep::detail::Tensor> last_topk_weights = std::nullopt;
int last_num_experts = 0;

public:
Buffer(int rank,
int num_ranks,
Expand Down
Loading