[Cherry-Pick] Fix moe topk select bug in cudagraph(#7069)#7070
[Cherry-Pick] Fix moe topk select bug in cudagraph(#7069)#7070zoooo0820 merged 1 commit intoPaddlePaddle:release/2.5from
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
该 PR 从 develop 分支的 #7069 进行 cherry-pick,目标是修复 CUDA Graph(cudagraph)场景下 MoE 的 topk/group 选择逻辑在特定实现上的兼容性问题,避免路由选择在 cudagraph capture/replay 下出错。
Changes:
- 使用
one_hot + sum生成 group mask,替代put_along_axis路径。 - 使用
paddle.index_sample替代paddle.take_along_axis获取 topk 权重。
| group_mask = paddle.sum( | ||
| paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype), | ||
| axis=1, # Sum over topk_group dimension -> [seq_len, n_group] |
There was a problem hiding this comment.
当前 PR 标题不符合仓库 Cherry-Pick 约定:需要在 [Cherry-Pick] 后包含至少一个标签(如 [BugFix]),并在标题末尾追加原 develop PR 号(此处应为 (#7069)),否则可能触发 CI 的 Cherry-Pick 校验失败。建议按模板格式调整标题。
| paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype), | ||
| axis=1, # Sum over topk_group dimension -> [seq_len, n_group] | ||
| ) |
There was a problem hiding this comment.
PR 描述里 Modifications / Usage or Command / Accuracy Tests 等关键字段为空。由于该改动是修复 cudagraph 下的路由选择 bug,建议补充:复现方式/影响范围、修复原理,以及至少一个可运行的验证命令或准确性/回归结果,便于 release 分支风险评估。
| _, topk_ids = paddle.topk(probs_for_choice, top_k, axis=-1) | ||
| topk_weights = paddle.take_along_axis(gate_probs, topk_ids, axis=-1) | ||
| topk_weights = paddle.index_sample(gate_probs, topk_ids) |
There was a problem hiding this comment.
这里的修改是为了解决 cudagraph 场景下的 topk/group mask 选择问题,但当前单测(例如 tests/operators/test_noaux_tc_redundant.py)只覆盖了数值正确性,未覆盖 CUDA Graph capture/replay。建议新增或扩展单测:在 paddle.device.cuda.graphs.CUDAGraph 的 capture/replay 中运行 moe_topk_select(包含 n_group>1 && topk_group<n_group 分支),以避免该类回归再次出现。
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## release/2.5 #7070 +/- ##
==============================================
Coverage ? 68.97%
==============================================
Files ? 390
Lines ? 54086
Branches ? 8518
==============================================
Hits ? 37306
Misses ? 14079
Partials ? 2701
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Motivation
[Cherry-Pick] #7069
Modifications
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.