Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6165,6 +6165,7 @@ void MoePermuteInferMeta(const MetaTensor& X,
const std::vector<int>& tokens_per_expert,
const int padding_alignment,
const bool do_gather,
const bool using_ue8m0_scale,
MetaTensor* X_unzipped,
MetaTensor* zipped_expertwise_rowmap,
MetaTensor* token_prob_unzipped,
Expand All @@ -6188,10 +6189,18 @@ void MoePermuteInferMeta(const MetaTensor& X,
common::errors::InvalidArgument(
"Input expert_prob_topk's dtype should be FLOAT32"));
if (XScale && do_gather) {
PADDLE_ENFORCE_EQ(XScale.dtype(),
DataType::FLOAT32,
common::errors::InvalidArgument(
"Input XScale's dtype should be FLOAT32"));
if (using_ue8m0_scale) {
PADDLE_ENFORCE_EQ(XScale.dtype(),
DataType::INT32,
common::errors::InvalidArgument(
"Input XScale's dtype should be INT32 if "
"using_ue8m0_scale is True"));
} else {
PADDLE_ENFORCE_EQ(XScale.dtype(),
DataType::FLOAT32,
common::errors::InvalidArgument(
"Input XScale's dtype should be FLOAT32"));
}
const int64_t quanted_cols = XScale.dims()[1];
XScale_unzipped->set_dims({-1, quanted_cols});
XScale_unzipped->set_dtype(XScale.dtype());
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ PADDLE_API void MoePermuteInferMeta(const MetaTensor& X,
const std::vector<int>& tokens_per_expert,
const int padding_alignment,
const bool do_gather,
const bool using_ue8m0_scale,
MetaTensor* X_unzipped,
MetaTensor* zipped_expertwise_rowmap,
MetaTensor* token_prob_unzipped,
Expand Down Expand Up @@ -876,6 +877,7 @@ PADDLE_API void MoePermuteInferMeta(const MetaTensor& X,
const std::vector<int>& tokens_per_expert,
const int padding_alignment,
const bool do_gather,
const bool using_ue8m0_scale,
MetaTensor* X_unzipped,
MetaTensor* zipped_expertwise_rowmap,
MetaTensor* token_prob_unzipped,
Expand Down
98 changes: 60 additions & 38 deletions paddle/phi/kernels/gpu/moe_permute_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,19 @@ struct expert_infos {
template <typename X_T,
typename routemap_T,
typename probs_T,
typename scale_T,
bool has_scale,
bool do_gather>
__global__ __launch_bounds__(512) void tokens_unzip_stable_kernel(
const X_T *__restrict__ X,
const routemap_T *__restrict__ routemap_topk,
const probs_T *__restrict__ probs_topk,
const float *__restrict__ XScale,
const scale_T *__restrict__ XScale,
const int *__restrict__ expert_base_offset,
X_T *__restrict__ X_unzipped,
int *__restrict__ zipped_expertwise_rowmap,
probs_T *__restrict__ probs_unzipped,
float *__restrict__ XScale_unzipped,
scale_T *__restrict__ XScale_unzipped,
int *global_expertwise_block_cumsum,
const int total_zipped_tokens_num,
const int token_length,
Expand Down Expand Up @@ -137,10 +138,11 @@ __global__ __launch_bounds__(512) void tokens_unzip_stable_kernel(
if constexpr (do_gather) {
// vec copy
if constexpr (has_scale) {
vectorized_memcpy(&XScale[(int64_t)row * (int64_t)scale_length],
&XScale_unzipped[(int64_t)proposed_row_idx *
(int64_t)scale_length],
scale_length);
// src or dst may be unaligned with 128bits
try_vectorized_memcpy(&XScale[(int64_t)row * (int64_t)scale_length],
&XScale_unzipped[(int64_t)proposed_row_idx *
(int64_t)scale_length],
scale_length);
}
vectorized_memcpy(
&X[(int64_t)row * (int64_t)token_length],
Expand All @@ -167,42 +169,50 @@ void dispatch_tokens_unzip_stable(const Context &dev_ctx,
const int topk, // deprecated
const int num_experts,
const int scale_length,
const bool do_gather) {
const bool do_gather,
const bool using_ue8m0_scale) {
dim3 grid, block;
grid.x =
(total_zipped_tokens_num + CUMSUM_BLOCK_SIZE - 1) / CUMSUM_BLOCK_SIZE;
block.x = 512;
#define DTYPE_CASE(dtype, type) dtype == phi::DataType::type
#define GET_DATA(tensor, type) tensor.data<type>()
#define GET_PTR_DATA(tensor, type) tensor->data<type>()
#define DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, DO_GATHER) \
auto kernel = tokens_unzip_stable_kernel<TOKEN_T, \
INT_T, \
PROB_T, \
HAS_SCALE, \
DO_GATHER>; \
kernel<<<grid, block, 0, dev_ctx.stream()>>>( \
GET_DATA(X, TOKEN_T), \
GET_DATA(expert_routemap_topk, INT_T), \
GET_DATA(expert_prob_topk, PROB_T), \
XScale ? XScale.get_ptr()->data<float>() : nullptr, \
GET_DATA(expert_offsets, int), \
GET_PTR_DATA(X_unzipped, TOKEN_T), \
GET_PTR_DATA(zipped_expertwise_rowmap, INT_T), \
GET_PTR_DATA(token_prob_unzipped, PROB_T), \
XScale_unzipped->data<float>(), \
global_expertwise_block_cumsum->data<int>(), \
total_zipped_tokens_num, \
token_length, \
scale_length, \
num_experts, \
#define DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, SCALE_T, HAS_SCALE, DO_GATHER) \
auto kernel = tokens_unzip_stable_kernel<TOKEN_T, \
INT_T, \
PROB_T, \
SCALE_T, \
HAS_SCALE, \
DO_GATHER>; \
kernel<<<grid, block, 0, dev_ctx.stream()>>>( \
GET_DATA(X, TOKEN_T), \
GET_DATA(expert_routemap_topk, INT_T), \
GET_DATA(expert_prob_topk, PROB_T), \
XScale ? GET_PTR_DATA(XScale.get_ptr(), SCALE_T) : nullptr, \
GET_DATA(expert_offsets, int), \
GET_PTR_DATA(X_unzipped, TOKEN_T), \
GET_PTR_DATA(zipped_expertwise_rowmap, INT_T), \
GET_PTR_DATA(token_prob_unzipped, PROB_T), \
GET_PTR_DATA(XScale_unzipped, SCALE_T), \
global_expertwise_block_cumsum->data<int>(), \
total_zipped_tokens_num, \
token_length, \
scale_length, \
num_experts, \
topk);

#define HANDLE_GATHER_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) \
if (do_gather) { \
DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, true) \
} else { \
DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, false) \
#define HANDLE_SCALE_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, DO_GATHER) \
if (using_ue8m0_scale) { \
DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, int32_t, HAS_SCALE, DO_GATHER) \
} else { \
DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, float, HAS_SCALE, DO_GATHER) \
}
#define HANDLE_GATHER_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) \
if (do_gather) { \
HANDLE_SCALE_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, true) \
} else { \
HANDLE_SCALE_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, false) \
}

#define HANDLE_TOKEN_TYPE(PROB_T, INT_T) \
Expand Down Expand Up @@ -241,6 +251,7 @@ void MoePermuteKernel(const Context &dev_ctx,
const std::vector<int> &tokens_per_expert,
const int padding_multiplex,
const bool do_gather,
const bool using_ue8m0_scale,
DenseTensor *X_unzipped,
DenseTensor *zipped_expertwise_rowmap,
DenseTensor *token_prob_unzipped,
Expand Down Expand Up @@ -317,14 +328,22 @@ void MoePermuteKernel(const Context &dev_ctx,
}
}
dev_ctx.template Alloc<T>(X_unzipped);
dev_ctx.template Alloc<float>(XScale_unzipped);
dev_ctx.template Alloc<int>(zipped_expertwise_rowmap);
dev_ctx.template Alloc<float>(token_prob_unzipped);
auto X_unzipped_ptr = reinterpret_cast<void *>(X_unzipped->data<T>());
auto token_prob_unzipped_ptr =
reinterpret_cast<void *>(token_prob_unzipped->data<float>());
auto XScale_unzipped_ptr =
reinterpret_cast<void *>(XScale_unzipped->data<float>());
void *XScale_unzipped_ptr = nullptr;
if (using_ue8m0_scale) {
// if using the ue8m0 scale, four ue8m0 scale will be packed into one int32
dev_ctx.template Alloc<int32_t>(XScale_unzipped);
XScale_unzipped_ptr =
reinterpret_cast<void *>(XScale_unzipped->data<int32_t>());
} else {
dev_ctx.template Alloc<float>(XScale_unzipped);
XScale_unzipped_ptr =
reinterpret_cast<void *>(XScale_unzipped->data<float>());
}

// -------- Memset all padding area to zero, with regard to do_gather
auto memset_invalid_rows =
Expand All @@ -345,7 +364,9 @@ void MoePermuteKernel(const Context &dev_ctx,
if (do_gather) { // no gather, no memset
memset_invalid_rows(X_unzipped_ptr, sizeof(T), cols);
if (XScale) {
memset_invalid_rows(XScale_unzipped_ptr, sizeof(float), quanted_cols);
memset_invalid_rows(XScale_unzipped_ptr,
using_ue8m0_scale ? sizeof(int32_t) : sizeof(float),
quanted_cols);
}
}
// Probs will be memset to zero whatsoever
Expand Down Expand Up @@ -377,7 +398,8 @@ void MoePermuteKernel(const Context &dev_ctx,
static_cast<int>(topk),
num_experts,
static_cast<int>(quanted_cols),
do_gather);
do_gather,
using_ue8m0_scale);
}
#undef CUMSUM_BLOCK_SIZE
#undef CUMSUM_INVALID_TAG
Expand Down
23 changes: 22 additions & 1 deletion paddle/phi/kernels/gpu/moe_permute_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,20 @@ struct alignas(16) VectorType<uint8_t, 16> {
uint8_t data[16];
};

template <typename T>
__device__ __forceinline__ void unrolled_memcpy(const T* src,
T* dst,
const int num_elements) {
#pragma unroll
for (int idx = threadIdx.x; idx < num_elements; idx += blockDim.x) {
dst[idx] = src[idx];
}
}
// Helper function to perform vectorized memory copy
template <typename T>
__device__ __forceinline__ void vectorized_memcpy(const T* src,
T* dst,
int num_elements) {
const int num_elements) {
constexpr int vector_size_in_bytes = 16;
const int elements_per_vector = vector_size_in_bytes / sizeof(T);

Expand All @@ -100,5 +109,17 @@ __device__ __forceinline__ void vectorized_memcpy(const T* src,
}
}
}
template <typename T>
__device__ __forceinline__ void try_vectorized_memcpy(const T* src,
T* dst,
const int num_elements) {
bool is_aligned_128bit =
((uintptr_t)src & 0xF) == 0 && ((uintptr_t)dst & 0xF) == 0;
if (is_aligned_128bit) {
vectorized_memcpy(src, dst, num_elements);
} else {
unrolled_memcpy(src, dst, num_elements);
}
}

} // namespace phi
1 change: 1 addition & 0 deletions paddle/phi/kernels/xpu/moe_permute_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ void MoePermuteKernel(const Context &dev_ctx,
const std::vector<int> &tokens_per_expert,
const int padding_multiplex,
const bool do_gather,
const bool using_ue8m0_scale,
DenseTensor *X_unzipped,
DenseTensor *zipped_expertwise_rowmap,
DenseTensor *token_prob_unzipped,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3895,7 +3895,7 @@
backward : moe_gate_dispatch_permute_grad

- op : moe_permute
args : (Tensor hidden_states, Tensor scale, Tensor expert_routemap_topk, Tensor expert_prob_topk, int num_experts, int[] tokens_per_expert, int padding_alignment, bool do_gather)
args : (Tensor hidden_states, Tensor scale, Tensor expert_routemap_topk, Tensor expert_prob_topk, int num_experts, int[] tokens_per_expert, int padding_alignment, bool do_gather, bool using_ue8m0_scale = false)
output : Tensor(hidden_states_unzipped), Tensor(zipped_expertwise_rowmap), Tensor(token_prob_unzipped), Tensor(scale_unzipped)
infer_meta:
func : MoePermuteInferMeta
Expand Down
12 changes: 8 additions & 4 deletions python/paddle/nn/functional/moe_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def moe_permute(
tokens_per_expert: list,
padding_alignment: int,
do_gather: bool = True,
using_ue8m0_scale: bool = False,
name: str | None = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
r"""
Expand All @@ -46,15 +47,16 @@ def moe_permute(
3. The padding_alignment parameter affects memory efficiency but not correctness.
4. Any output tokens can find an exact-match in the original input tokens.
5. This permute function has overcomed the aadiff issue, is deterministic.
6. If using_ue8m0_scale is True, then the data type of scale must be int32, and each int32 is packaged from 4 ue8m0 scaling factors.

Args:
hidden_states (Tensor): The input tensor containing tokens to be permuted, stored in row-major layout.
Supported data types: bfloat16 or float8_e4m3fn.
Shape: [sequence_length, token_dimension]
scale (Tensor|None): Scaling factors required when hidden_states is of float8 type.
For float8 inputs, this tensor provides the scaling factors for dequantization.
Shape: [sequence_length, ceil(token_dimension / 128)]
Data type: float32
Shape: [sequence_length, ceil(token_dimension / 128)]. If using_ue8m0_scale is True, the shape is [sequence_length, ceil(ceil(token_dimension / 128)/4)].
Data type: float32 or int32(Only when using_ue8m0_scale is True). If using_ue8m0_scale is True, the data type of scale is int32 which is packed of four ue8m0 scaling factors.
expert_routemap_topk (Tensor): Tensor indicating expert assignments for each token (top-k experts).
Each value represents the expert index the token is assigned to (-1 indicates not assigned).
Shape: [sequence_length, top_k_experts]
Expand All @@ -69,6 +71,7 @@ def moe_permute(
padding_alignment (int): Tokens alignment requirement for expert buffers (in bytes).
Must be a power of 2. Typical values are 16, 32 or 64 for optimal memory access.
do_gather(bool): Decide whether do actual tokens gather operation or not, default is True.
using_ue8m0_scale (bool): Whether to use the ue8m0 scaling for float8 inputs. Default is False.
name (str|None, optional): Name prefix for the operation (optional).
Default: None

Expand All @@ -84,8 +87,8 @@ def moe_permute(
Shape: [total_tokens_after_broadcast, 1]
Data type: float32
- scale_unzipped (Tensor): Broadcasted scale tensor (only valid for float8 inputs).
Shape: [total_tokens_after_broadcast, ceil(token_dimension / 128)]
Data type: float32
Shape: [total_tokens_after_broadcast, scale.shape[-1]]
Data type: float32 or int32. It is same as scale.

Examples:
.. code-block:: python
Expand Down Expand Up @@ -136,6 +139,7 @@ def moe_permute(
tokens_per_expert,
padding_alignment,
do_gather,
using_ue8m0_scale,
)
return (
hidden_states_unzipped,
Expand Down
Loading
Loading