[BugFix] Fix zero workspace returned by CUB size query under CUDA Graph in MoE dispatch #5087
+12
−2
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
Under CUDA Graph capture, CUB’s workspace size query for radix sort in MoE dispatch may return 0 even when num_items > 0. This makes the allocated workspace size 0, and the subsequent self-check in run() fails with:
This PR adds a defensive guard so that the first (size-query) call never produces a 0-size workspace in this scenario, preventing the error.
Modifications
In custom_ops/gpu_ops/moe/moe_dispatch.cu, we call sorter_.getWorkspaceSize(moe_topk * num_rows) (line:77)before allocating the CUB radix sort workspace. We observed that, when CUDA Graph is enabled, temp_storage_bytes may remain 0 after the first size query, although a subsequent size query (performed inside run() as a sanity check) may report 1.
This behavior is consistent with using CUB’s DeviceRadixSort::SortPairs size query with all data pointers set to nullptr. While this often works, it’s not guaranteed across CUDA/CUB versions or under graph capture; some paths may not write temp_storage_bytes in that case.
Neither enqueuing getWorkspaceSize onto the stream nor passing actual non-null device pointers resolves this issue. At this point, it’s essentially confirmed that cub::DeviceRadixSort::SortPairs is not writing required_storage to 0; rather, it is not updating required_storage at all.
So we initialize required_storage to 1 to ensure that when SortPairs exhibits the anomalous behavior we observed (which, based on current evidence, only happens in the case where SortPairs should have returned 1) the code path can still proceed safely. It is important to state that when SortPairs behaves normally, the initial value does not matter, because SortPairs will always overwrite required_storage with the correct size.
Usage or Command
Pass
Accuracy Tests
The issue that occurred occasionally has been verified and fixed.
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.