Fix moe topk select bug in cudagraph#7069
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
Fixes a MoE top-k selection issue when running under CUDA Graph by adjusting how the group mask is built and how top-k weights are gathered.
Changes:
- Replace
put_along_axis-basedgroup_maskconstruction with aone_hot+sumapproach. - Replace
take_along_axiswithindex_sampleto gathertopk_weights.
| group_mask = paddle.sum( | ||
| paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype), | ||
| axis=1, # Sum over topk_group dimension -> [seq_len, n_group] | ||
| ) |
There was a problem hiding this comment.
Building group_mask via one_hot(group_idx, num_classes=n_group) materializes a dense tensor of shape [seq_len, topk_group, n_group], which can significantly increase memory traffic (and potentially hurt latency) for large seq_len/n_group. Consider using a scatter/put-along-axis style update that writes ones directly into a [seq_len, n_group] buffer (e.g., initialize zeros then scatter indices with a ones tensor shaped like group_idx), which avoids the large intermediate.
| group_mask = paddle.sum( | |
| paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype), | |
| axis=1, # Sum over topk_group dimension -> [seq_len, n_group] | |
| ) | |
| # Build group_mask of shape [seq_len, n_group] without materializing a large one-hot tensor | |
| group_mask = paddle.zeros([seq_length, n_group], dtype=group_scores.dtype) | |
| updates = paddle.ones_like(group_idx, dtype=group_scores.dtype) | |
| group_mask = paddle.put_along_axis(group_mask, group_idx, updates, axis=1, reduce="assign") |
|
|
||
| _, topk_ids = paddle.topk(probs_for_choice, top_k, axis=-1) | ||
| topk_weights = paddle.take_along_axis(gate_probs, topk_ids, axis=-1) | ||
| topk_weights = paddle.index_sample(gate_probs, topk_ids) |
There was a problem hiding this comment.
paddle.index_sample has a narrower contract than take_along_axis: it samples along a fixed dimension (commonly axis=1) and (in many Paddle versions) assumes x is 2-D [N, M] and index is [N, K]. If gate_probs can ever be higher-rank or if the intended gather axis is not the second dimension, this will return incorrect results or error at runtime. If you need general axis=-1 semantics, consider switching to an index construction using gather_nd (batch indices + topk_ids) or explicitly reshape/validate gate_probs to 2-D before calling index_sample.
| topk_weights = paddle.index_sample(gate_probs, topk_ids) | |
| # Use paddle.index_sample with its 2-D [N, M] / [N, K] contract by flattening | |
| # all leading dimensions into a single batch dimension, and gather along the | |
| # last axis. This is equivalent to take_along_axis(..., axis=-1) but robust | |
| # to higher-rank gate_probs/topk_ids. | |
| last_expert_dim = gate_probs.shape[-1] | |
| flat_batch = paddle.numel(gate_probs) // last_expert_dim | |
| gate_probs_2d = gate_probs.reshape([flat_batch, last_expert_dim]) | |
| topk_last_dim = topk_ids.shape[-1] | |
| topk_ids_2d = topk_ids.reshape([flat_batch, topk_last_dim]) | |
| topk_weights_2d = paddle.index_sample(gate_probs_2d, topk_ids_2d) | |
| topk_weights = topk_weights_2d.reshape(topk_ids.shape) |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #7069 +/- ##
==========================================
Coverage ? 73.83%
==========================================
Files ? 399
Lines ? 56412
Branches ? 8919
==========================================
Hits ? 41651
Misses ? 11809
Partials ? 2952
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Motivation
Fix moe topk select bug in cudagraph.
Modifications
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.