Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions custom_ops/gpu_ops/moe/fused_moe_imp_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
*/

#pragma once
#include <string>
#include <sstream>
#include <string>
#include "cub/cub.cuh"

namespace phi {
Expand Down Expand Up @@ -45,7 +45,7 @@ class CubKeyValueSorter {
size_t getWorkspaceSize(const size_t num_key_value_pairs,
bool descending = false) {
num_key_value_pairs_ = num_key_value_pairs;
size_t required_storage = 0;
size_t required_storage = 1;
int* null_int = nullptr;
if (descending) {
cub::DeviceRadixSort::SortPairsDescending(NULL,
Expand Down
7 changes: 7 additions & 0 deletions custom_ops/gpu_ops/moe/moe_dispatch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ void MoeDispatchKernel(
int8_t *sorter_ws_ptr = reinterpret_cast<int8_t *>(ws_ptr + bytes);
int *permuted_experts_ =
reinterpret_cast<int *>(sorter_ws_ptr + sorter_ws_size_bytes);
// If expected_ws_size > workspace_size ever occurs in sorter_.run (which
// should be practically impossible), there is a contiguous, currently unused
// region (permuted_experts_) right after sorter_ws_ptr. In practice, this
// region is larger than what cub::DeviceRadixSort::SortPairs requires.
// However, relying on this to “work” after canceling the assertion is unsafe:
// it constitutes undefined behavior, and there is no guarantee it will remain
// correct across inputs, CUDA/CUB versions, or architectures.
int *permuted_rows_ = permuted_experts_ + num_moe_inputs;

int *topk_idx_ptr = topk_idx->data<int>();
Expand Down
Loading