Skip to content

Fix moe topk select bug in cudagraph#7069

Merged
zoooo0820 merged 1 commit intoPaddlePaddle:developfrom
zhangbo9674:dev/fix_topk
Mar 30, 2026
Merged

Fix moe topk select bug in cudagraph#7069
zoooo0820 merged 1 commit intoPaddlePaddle:developfrom
zhangbo9674:dev/fix_topk

Conversation

@zhangbo9674
Copy link
Copy Markdown
Contributor

Motivation

Fix moe topk select bug in cudagraph.

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings March 30, 2026 02:44
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 30, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes a MoE top-k selection issue when running under CUDA Graph by adjusting how the group mask is built and how top-k weights are gathered.

Changes:

  • Replace put_along_axis-based group_mask construction with a one_hot + sum approach.
  • Replace take_along_axis with index_sample to gather topk_weights.

Comment on lines +248 to 251
group_mask = paddle.sum(
paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype),
axis=1, # Sum over topk_group dimension -> [seq_len, n_group]
)
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Building group_mask via one_hot(group_idx, num_classes=n_group) materializes a dense tensor of shape [seq_len, topk_group, n_group], which can significantly increase memory traffic (and potentially hurt latency) for large seq_len/n_group. Consider using a scatter/put-along-axis style update that writes ones directly into a [seq_len, n_group] buffer (e.g., initialize zeros then scatter indices with a ones tensor shaped like group_idx), which avoids the large intermediate.

Suggested change
group_mask = paddle.sum(
paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype),
axis=1, # Sum over topk_group dimension -> [seq_len, n_group]
)
# Build group_mask of shape [seq_len, n_group] without materializing a large one-hot tensor
group_mask = paddle.zeros([seq_length, n_group], dtype=group_scores.dtype)
updates = paddle.ones_like(group_idx, dtype=group_scores.dtype)
group_mask = paddle.put_along_axis(group_mask, group_idx, updates, axis=1, reduce="assign")

Copilot uses AI. Check for mistakes.

_, topk_ids = paddle.topk(probs_for_choice, top_k, axis=-1)
topk_weights = paddle.take_along_axis(gate_probs, topk_ids, axis=-1)
topk_weights = paddle.index_sample(gate_probs, topk_ids)
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle.index_sample has a narrower contract than take_along_axis: it samples along a fixed dimension (commonly axis=1) and (in many Paddle versions) assumes x is 2-D [N, M] and index is [N, K]. If gate_probs can ever be higher-rank or if the intended gather axis is not the second dimension, this will return incorrect results or error at runtime. If you need general axis=-1 semantics, consider switching to an index construction using gather_nd (batch indices + topk_ids) or explicitly reshape/validate gate_probs to 2-D before calling index_sample.

Suggested change
topk_weights = paddle.index_sample(gate_probs, topk_ids)
# Use paddle.index_sample with its 2-D [N, M] / [N, K] contract by flattening
# all leading dimensions into a single batch dimension, and gather along the
# last axis. This is equivalent to take_along_axis(..., axis=-1) but robust
# to higher-rank gate_probs/topk_ids.
last_expert_dim = gate_probs.shape[-1]
flat_batch = paddle.numel(gate_probs) // last_expert_dim
gate_probs_2d = gate_probs.reshape([flat_batch, last_expert_dim])
topk_last_dim = topk_ids.shape[-1]
topk_ids_2d = topk_ids.reshape([flat_batch, topk_last_dim])
topk_weights_2d = paddle.index_sample(gate_probs_2d, topk_ids_2d)
topk_weights = topk_weights_2d.reshape(topk_ids.shape)

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

@zoooo0820 zoooo0820 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@a7cbe3f). Learn more about missing BASE report.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7069   +/-   ##
==========================================
  Coverage           ?   73.83%           
==========================================
  Files              ?      399           
  Lines              ?    56412           
  Branches           ?     8919           
==========================================
  Hits               ?    41651           
  Misses             ?    11809           
  Partials           ?     2952           
Flag Coverage Δ
GPU 73.83% <100.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@zoooo0820 zoooo0820 merged commit 5c60e2f into PaddlePaddle:develop Mar 30, 2026
37 of 41 checks passed
mattheliu pushed a commit to mattheliu/FastDeploy that referenced this pull request Apr 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants