diff --git a/custom_ops/gpu_ops/moe/fused_moe_imp_op.h b/custom_ops/gpu_ops/moe/fused_moe_imp_op.h index 254f80e670d..aec98d2592d 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_imp_op.h +++ b/custom_ops/gpu_ops/moe/fused_moe_imp_op.h @@ -16,8 +16,8 @@ */ #pragma once -#include #include +#include #include "cub/cub.cuh" namespace phi { @@ -45,7 +45,10 @@ class CubKeyValueSorter { size_t getWorkspaceSize(const size_t num_key_value_pairs, bool descending = false) { num_key_value_pairs_ = num_key_value_pairs; - size_t required_storage = 0; + // Initialize to 1 as workaround: under CUDA Graph capture, CUB may not + // write to required_storage, and 1 is the minimum expected size in that + // scenario. + size_t required_storage = 1; int* null_int = nullptr; if (descending) { cub::DeviceRadixSort::SortPairsDescending(NULL, diff --git a/custom_ops/gpu_ops/moe/moe_dispatch.cu b/custom_ops/gpu_ops/moe/moe_dispatch.cu index 6bdb4c73e8f..7ec887a3091 100644 --- a/custom_ops/gpu_ops/moe/moe_dispatch.cu +++ b/custom_ops/gpu_ops/moe/moe_dispatch.cu @@ -87,6 +87,13 @@ void MoeDispatchKernel( int8_t *sorter_ws_ptr = reinterpret_cast(ws_ptr + bytes); int *permuted_experts_ = reinterpret_cast(sorter_ws_ptr + sorter_ws_size_bytes); + // If expected_ws_size > workspace_size ever occurs in sorter_.run (which + // should be practically impossible), there is a contiguous, currently unused + // region (permuted_experts_) right after sorter_ws_ptr. In practice, this + // region is larger than what cub::DeviceRadixSort::SortPairs requires. + // However, relying on this to “work” after canceling the assertion is unsafe: + // it constitutes undefined behavior, and there is no guarantee it will remain + // correct across inputs, CUDA/CUB versions, or architectures. int *permuted_rows_ = permuted_experts_ + num_moe_inputs; int *topk_idx_ptr = topk_idx->data();