diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index fb03afefb70..b09cea5f1a4 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -185,8 +185,9 @@ def moe_topk_select( probs_for_choice.reshape([seq_length, n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) ) # [seq_len, n_group] group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [seq_len, topk_group] - group_mask = paddle.zeros_like(group_scores).put_along_axis( - group_idx, paddle.to_tensor(1.0, dtype=group_scores.dtype), axis=-1 + group_mask = paddle.sum( + paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype), + axis=1, # Sum over topk_group dimension -> [seq_len, n_group] ) score_mask = ( group_mask.unsqueeze(-1).expand([seq_length, n_group, n_experts // n_group]).reshape([seq_length, -1]) @@ -194,7 +195,7 @@ def moe_topk_select( probs_for_choice = probs_for_choice.masked_fill(~score_mask.astype(paddle.bool), float("-inf")) _, topk_ids = paddle.topk(probs_for_choice, top_k, axis=-1) - topk_weights = paddle.take_along_axis(gate_probs, topk_ids, axis=-1) + topk_weights = paddle.index_sample(gate_probs, topk_ids) # normalize combine weights if renormalize: