diff --git a/paddle/fluid/pir/serialize_deserialize/CMakeLists.txt b/paddle/fluid/pir/serialize_deserialize/CMakeLists.txt index 0d3d180d97c823..432c8a700c642b 100644 --- a/paddle/fluid/pir/serialize_deserialize/CMakeLists.txt +++ b/paddle/fluid/pir/serialize_deserialize/CMakeLists.txt @@ -13,7 +13,7 @@ endif() file(GLOB_RECURSE YAML_PATCH_FILES "*.yaml") # change pir version when new patches are added -add_definitions(-DDEVELOP_VERSION=4) +add_definitions(-DDEVELOP_VERSION=0) add_definitions(-DRELEASE_VERSION=4) set(TEMPLATE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/patch/template.h.in) set(PATCH_HEADER ${CMAKE_CURRENT_BINARY_DIR}/patch/patch.h) diff --git a/paddle/fluid/pir/serialize_deserialize/patch/0.yaml b/paddle/fluid/pir/serialize_deserialize/patch/0.yaml new file mode 100644 index 00000000000000..1c5d8ba0134148 --- /dev/null +++ b/paddle/fluid/pir/serialize_deserialize/patch/0.yaml @@ -0,0 +1,7 @@ +op_patches: + - op_name : pd_op.moe_permute + actions: + - action : add_attr + object : using_ue8m0_scale + type : pir::BoolAttribute + data : "false" diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index dc5e905f499908..6d687b0e889618 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -6165,6 +6165,7 @@ void MoePermuteInferMeta(const MetaTensor& X, const std::vector& tokens_per_expert, const int padding_alignment, const bool do_gather, + const bool using_ue8m0_scale, MetaTensor* X_unzipped, MetaTensor* zipped_expertwise_rowmap, MetaTensor* token_prob_unzipped, @@ -6188,10 +6189,18 @@ void MoePermuteInferMeta(const MetaTensor& X, common::errors::InvalidArgument( "Input expert_prob_topk's dtype should be FLOAT32")); if (XScale && do_gather) { - PADDLE_ENFORCE_EQ(XScale.dtype(), - DataType::FLOAT32, - common::errors::InvalidArgument( - "Input XScale's dtype should be FLOAT32")); + if (using_ue8m0_scale) { + PADDLE_ENFORCE_EQ(XScale.dtype(), + DataType::INT32, + common::errors::InvalidArgument( + "Input XScale's dtype should be INT32 if " + "using_ue8m0_scale is True")); + } else { + PADDLE_ENFORCE_EQ(XScale.dtype(), + DataType::FLOAT32, + common::errors::InvalidArgument( + "Input XScale's dtype should be FLOAT32")); + } const int64_t quanted_cols = XScale.dims()[1]; XScale_unzipped->set_dims({-1, quanted_cols}); XScale_unzipped->set_dtype(XScale.dtype()); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index c288c330b437db..d26085b56d9707 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -569,6 +569,7 @@ PADDLE_API void MoePermuteInferMeta(const MetaTensor& X, const std::vector& tokens_per_expert, const int padding_alignment, const bool do_gather, + const bool using_ue8m0_scale, MetaTensor* X_unzipped, MetaTensor* zipped_expertwise_rowmap, MetaTensor* token_prob_unzipped, @@ -876,6 +877,7 @@ PADDLE_API void MoePermuteInferMeta(const MetaTensor& X, const std::vector& tokens_per_expert, const int padding_alignment, const bool do_gather, + const bool using_ue8m0_scale, MetaTensor* X_unzipped, MetaTensor* zipped_expertwise_rowmap, MetaTensor* token_prob_unzipped, diff --git a/paddle/phi/kernels/gpu/moe_permute_kernel.cu b/paddle/phi/kernels/gpu/moe_permute_kernel.cu index d66820c20e9549..6f7f1909b974e4 100644 --- a/paddle/phi/kernels/gpu/moe_permute_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_permute_kernel.cu @@ -46,18 +46,19 @@ struct expert_infos { template __global__ __launch_bounds__(512) void tokens_unzip_stable_kernel( const X_T *__restrict__ X, const routemap_T *__restrict__ routemap_topk, const probs_T *__restrict__ probs_topk, - const float *__restrict__ XScale, + const scale_T *__restrict__ XScale, const int *__restrict__ expert_base_offset, X_T *__restrict__ X_unzipped, int *__restrict__ zipped_expertwise_rowmap, probs_T *__restrict__ probs_unzipped, - float *__restrict__ XScale_unzipped, + scale_T *__restrict__ XScale_unzipped, int *global_expertwise_block_cumsum, const int total_zipped_tokens_num, const int token_length, @@ -137,10 +138,11 @@ __global__ __launch_bounds__(512) void tokens_unzip_stable_kernel( if constexpr (do_gather) { // vec copy if constexpr (has_scale) { - vectorized_memcpy(&XScale[(int64_t)row * (int64_t)scale_length], - &XScale_unzipped[(int64_t)proposed_row_idx * - (int64_t)scale_length], - scale_length); + // src or dst may be unaligned with 128bits + try_vectorized_memcpy(&XScale[(int64_t)row * (int64_t)scale_length], + &XScale_unzipped[(int64_t)proposed_row_idx * + (int64_t)scale_length], + scale_length); } vectorized_memcpy( &X[(int64_t)row * (int64_t)token_length], @@ -167,7 +169,8 @@ void dispatch_tokens_unzip_stable(const Context &dev_ctx, const int topk, // deprecated const int num_experts, const int scale_length, - const bool do_gather) { + const bool do_gather, + const bool using_ue8m0_scale) { dim3 grid, block; grid.x = (total_zipped_tokens_num + CUMSUM_BLOCK_SIZE - 1) / CUMSUM_BLOCK_SIZE; @@ -175,34 +178,41 @@ void dispatch_tokens_unzip_stable(const Context &dev_ctx, #define DTYPE_CASE(dtype, type) dtype == phi::DataType::type #define GET_DATA(tensor, type) tensor.data() #define GET_PTR_DATA(tensor, type) tensor->data() -#define DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, DO_GATHER) \ - auto kernel = tokens_unzip_stable_kernel; \ - kernel<<>>( \ - GET_DATA(X, TOKEN_T), \ - GET_DATA(expert_routemap_topk, INT_T), \ - GET_DATA(expert_prob_topk, PROB_T), \ - XScale ? XScale.get_ptr()->data() : nullptr, \ - GET_DATA(expert_offsets, int), \ - GET_PTR_DATA(X_unzipped, TOKEN_T), \ - GET_PTR_DATA(zipped_expertwise_rowmap, INT_T), \ - GET_PTR_DATA(token_prob_unzipped, PROB_T), \ - XScale_unzipped->data(), \ - global_expertwise_block_cumsum->data(), \ - total_zipped_tokens_num, \ - token_length, \ - scale_length, \ - num_experts, \ +#define DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, SCALE_T, HAS_SCALE, DO_GATHER) \ + auto kernel = tokens_unzip_stable_kernel; \ + kernel<<>>( \ + GET_DATA(X, TOKEN_T), \ + GET_DATA(expert_routemap_topk, INT_T), \ + GET_DATA(expert_prob_topk, PROB_T), \ + XScale ? GET_PTR_DATA(XScale.get_ptr(), SCALE_T) : nullptr, \ + GET_DATA(expert_offsets, int), \ + GET_PTR_DATA(X_unzipped, TOKEN_T), \ + GET_PTR_DATA(zipped_expertwise_rowmap, INT_T), \ + GET_PTR_DATA(token_prob_unzipped, PROB_T), \ + GET_PTR_DATA(XScale_unzipped, SCALE_T), \ + global_expertwise_block_cumsum->data(), \ + total_zipped_tokens_num, \ + token_length, \ + scale_length, \ + num_experts, \ topk); -#define HANDLE_GATHER_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) \ - if (do_gather) { \ - DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, true) \ - } else { \ - DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, false) \ +#define HANDLE_SCALE_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, DO_GATHER) \ + if (using_ue8m0_scale) { \ + DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, int32_t, HAS_SCALE, DO_GATHER) \ + } else { \ + DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, float, HAS_SCALE, DO_GATHER) \ + } +#define HANDLE_GATHER_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) \ + if (do_gather) { \ + HANDLE_SCALE_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, true) \ + } else { \ + HANDLE_SCALE_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, false) \ } #define HANDLE_TOKEN_TYPE(PROB_T, INT_T) \ @@ -241,6 +251,7 @@ void MoePermuteKernel(const Context &dev_ctx, const std::vector &tokens_per_expert, const int padding_multiplex, const bool do_gather, + const bool using_ue8m0_scale, DenseTensor *X_unzipped, DenseTensor *zipped_expertwise_rowmap, DenseTensor *token_prob_unzipped, @@ -317,14 +328,22 @@ void MoePermuteKernel(const Context &dev_ctx, } } dev_ctx.template Alloc(X_unzipped); - dev_ctx.template Alloc(XScale_unzipped); dev_ctx.template Alloc(zipped_expertwise_rowmap); dev_ctx.template Alloc(token_prob_unzipped); auto X_unzipped_ptr = reinterpret_cast(X_unzipped->data()); auto token_prob_unzipped_ptr = reinterpret_cast(token_prob_unzipped->data()); - auto XScale_unzipped_ptr = - reinterpret_cast(XScale_unzipped->data()); + void *XScale_unzipped_ptr = nullptr; + if (using_ue8m0_scale) { + // if using the ue8m0 scale, four ue8m0 scale will be packed into one int32 + dev_ctx.template Alloc(XScale_unzipped); + XScale_unzipped_ptr = + reinterpret_cast(XScale_unzipped->data()); + } else { + dev_ctx.template Alloc(XScale_unzipped); + XScale_unzipped_ptr = + reinterpret_cast(XScale_unzipped->data()); + } // -------- Memset all padding area to zero, with regard to do_gather auto memset_invalid_rows = @@ -345,7 +364,9 @@ void MoePermuteKernel(const Context &dev_ctx, if (do_gather) { // no gather, no memset memset_invalid_rows(X_unzipped_ptr, sizeof(T), cols); if (XScale) { - memset_invalid_rows(XScale_unzipped_ptr, sizeof(float), quanted_cols); + memset_invalid_rows(XScale_unzipped_ptr, + using_ue8m0_scale ? sizeof(int32_t) : sizeof(float), + quanted_cols); } } // Probs will be memset to zero whatsoever @@ -377,7 +398,8 @@ void MoePermuteKernel(const Context &dev_ctx, static_cast(topk), num_experts, static_cast(quanted_cols), - do_gather); + do_gather, + using_ue8m0_scale); } #undef CUMSUM_BLOCK_SIZE #undef CUMSUM_INVALID_TAG diff --git a/paddle/phi/kernels/gpu/moe_permute_utils.h b/paddle/phi/kernels/gpu/moe_permute_utils.h index 9f1a41f7f83c4c..8e4219f31e704d 100644 --- a/paddle/phi/kernels/gpu/moe_permute_utils.h +++ b/paddle/phi/kernels/gpu/moe_permute_utils.h @@ -73,11 +73,20 @@ struct alignas(16) VectorType { uint8_t data[16]; }; +template +__device__ __forceinline__ void unrolled_memcpy(const T* src, + T* dst, + const int num_elements) { +#pragma unroll + for (int idx = threadIdx.x; idx < num_elements; idx += blockDim.x) { + dst[idx] = src[idx]; + } +} // Helper function to perform vectorized memory copy template __device__ __forceinline__ void vectorized_memcpy(const T* src, T* dst, - int num_elements) { + const int num_elements) { constexpr int vector_size_in_bytes = 16; const int elements_per_vector = vector_size_in_bytes / sizeof(T); @@ -100,5 +109,17 @@ __device__ __forceinline__ void vectorized_memcpy(const T* src, } } } +template +__device__ __forceinline__ void try_vectorized_memcpy(const T* src, + T* dst, + const int num_elements) { + bool is_aligned_128bit = + ((uintptr_t)src & 0xF) == 0 && ((uintptr_t)dst & 0xF) == 0; + if (is_aligned_128bit) { + vectorized_memcpy(src, dst, num_elements); + } else { + unrolled_memcpy(src, dst, num_elements); + } +} } // namespace phi diff --git a/paddle/phi/kernels/xpu/moe_permute_kernel.cc b/paddle/phi/kernels/xpu/moe_permute_kernel.cc index b2d6aa195817ee..dd98731ca5a0bc 100644 --- a/paddle/phi/kernels/xpu/moe_permute_kernel.cc +++ b/paddle/phi/kernels/xpu/moe_permute_kernel.cc @@ -132,6 +132,7 @@ void MoePermuteKernel(const Context &dev_ctx, const std::vector &tokens_per_expert, const int padding_multiplex, const bool do_gather, + const bool using_ue8m0_scale, DenseTensor *X_unzipped, DenseTensor *zipped_expertwise_rowmap, DenseTensor *token_prob_unzipped, diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 0200bc487a55b8..8c8595d0f4fb5d 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3895,7 +3895,7 @@ backward : moe_gate_dispatch_permute_grad - op : moe_permute - args : (Tensor hidden_states, Tensor scale, Tensor expert_routemap_topk, Tensor expert_prob_topk, int num_experts, int[] tokens_per_expert, int padding_alignment, bool do_gather) + args : (Tensor hidden_states, Tensor scale, Tensor expert_routemap_topk, Tensor expert_prob_topk, int num_experts, int[] tokens_per_expert, int padding_alignment, bool do_gather, bool using_ue8m0_scale = false) output : Tensor(hidden_states_unzipped), Tensor(zipped_expertwise_rowmap), Tensor(token_prob_unzipped), Tensor(scale_unzipped) infer_meta: func : MoePermuteInferMeta diff --git a/python/paddle/nn/functional/moe_permute.py b/python/paddle/nn/functional/moe_permute.py index 5809e8af90c046..aea04536f03a71 100644 --- a/python/paddle/nn/functional/moe_permute.py +++ b/python/paddle/nn/functional/moe_permute.py @@ -32,6 +32,7 @@ def moe_permute( tokens_per_expert: list, padding_alignment: int, do_gather: bool = True, + using_ue8m0_scale: bool = False, name: str | None = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: r""" @@ -46,6 +47,7 @@ def moe_permute( 3. The padding_alignment parameter affects memory efficiency but not correctness. 4. Any output tokens can find an exact-match in the original input tokens. 5. This permute function has overcomed the aadiff issue, is deterministic. + 6. If using_ue8m0_scale is True, then the data type of scale must be int32, and each int32 is packaged from 4 ue8m0 scaling factors. Args: hidden_states (Tensor): The input tensor containing tokens to be permuted, stored in row-major layout. @@ -53,8 +55,8 @@ def moe_permute( Shape: [sequence_length, token_dimension] scale (Tensor|None): Scaling factors required when hidden_states is of float8 type. For float8 inputs, this tensor provides the scaling factors for dequantization. - Shape: [sequence_length, ceil(token_dimension / 128)] - Data type: float32 + Shape: [sequence_length, ceil(token_dimension / 128)]. If using_ue8m0_scale is True, the shape is [sequence_length, ceil(ceil(token_dimension / 128)/4)]. + Data type: float32 or int32(Only when using_ue8m0_scale is True). If using_ue8m0_scale is True, the data type of scale is int32 which is packed of four ue8m0 scaling factors. expert_routemap_topk (Tensor): Tensor indicating expert assignments for each token (top-k experts). Each value represents the expert index the token is assigned to (-1 indicates not assigned). Shape: [sequence_length, top_k_experts] @@ -69,6 +71,7 @@ def moe_permute( padding_alignment (int): Tokens alignment requirement for expert buffers (in bytes). Must be a power of 2. Typical values are 16, 32 or 64 for optimal memory access. do_gather(bool): Decide whether do actual tokens gather operation or not, default is True. + using_ue8m0_scale (bool): Whether to use the ue8m0 scaling for float8 inputs. Default is False. name (str|None, optional): Name prefix for the operation (optional). Default: None @@ -84,8 +87,8 @@ def moe_permute( Shape: [total_tokens_after_broadcast, 1] Data type: float32 - scale_unzipped (Tensor): Broadcasted scale tensor (only valid for float8 inputs). - Shape: [total_tokens_after_broadcast, ceil(token_dimension / 128)] - Data type: float32 + Shape: [total_tokens_after_broadcast, scale.shape[-1]] + Data type: float32 or int32. It is same as scale. Examples: .. code-block:: python @@ -136,6 +139,7 @@ def moe_permute( tokens_per_expert, padding_alignment, do_gather, + using_ue8m0_scale, ) return ( hidden_states_unzipped, diff --git a/test/legacy_test/test_moe_permute_unpermute.py b/test/legacy_test/test_moe_permute_unpermute.py index 6e2378adc60805..f73acaead6b876 100644 --- a/test/legacy_test/test_moe_permute_unpermute.py +++ b/test/legacy_test/test_moe_permute_unpermute.py @@ -28,14 +28,23 @@ def fabricate_dispatch_result( num_experts, data_type="bfloat16", broadcast_ratio=0.5, + using_ue8m0_scale=False, ): """Helper function to generate test data.""" hidden_states = paddle.randn([seqlen, token_length]).astype(data_type) scale = paddle.empty([0]) if data_type == "float8_e4m3fn": - scale_cols = (token_length + 127) // 128 - scale = paddle.randn([seqlen, scale_cols], dtype="float32") + if using_ue8m0_scale: + scale_cols = (token_length + 127) // 128 + # if using_ue8m0_scale, four ue8m0 scales will be packed into one int32 + scale_cols = (scale_cols + 3) // 4 + scale = paddle.randn([seqlen, scale_cols], dtype="float32").astype( + paddle.int32 + ) + else: + scale_cols = (token_length + 127) // 128 + scale = paddle.randn([seqlen, scale_cols], dtype="float32") # Calculate expert counts with normal distribution expected_experts = max(1, min(broadcast_ratio * num_experts, topk)) @@ -201,6 +210,65 @@ def test_permute_unpermute_consistency(self): err_msg="no_gather's unzipped_probs do not match", ) + def test_permute_unpermute_consistency_for_ue8m0_scale(self): + """Test that permute + unpermute recovers original tensors for ue8m0 scale.""" + DTYPES = ["float8_e4m3fn"] + EXPERT_NUMS = [4, 8, 16] + TOPKS = [4, 8, 16] + for dt, expert_num, topk in itertools.product( + DTYPES, EXPERT_NUMS, TOPKS + ): + with self.subTest(dtype=dt, expert_num=expert_num, topk=topk): + ( + hidden_states, + scale, + expert_routemap_topk, + expert_prob_topk, + tokens_per_expert, + ) = fabricate_dispatch_result( + self.SEQLEN, + self.TOKEN_LEN, + topk, + expert_num, + data_type=dt, + broadcast_ratio=0.5, + using_ue8m0_scale=True, + ) + if dt == "bfloat16": + scale = None + + # Permute step + ( + unzipped_tokens, + zipped_expertwise_rowmap, + unzipped_probs, + unzipped_scales, + ) = moe_permute( + hidden_states, + scale, + expert_routemap_topk, + expert_prob_topk, + num_experts=expert_num, + tokens_per_expert=tokens_per_expert, + padding_alignment=128, + using_ue8m0_scale=True, + ) + # test the unzipped_scales is correct or not + zipped_expertwise_rowmap_np = zipped_expertwise_rowmap.numpy() + scale_np = scale.numpy() + unzipped_scales_np = unzipped_scales.numpy() + assert zipped_expertwise_rowmap_np.ndim == 2 + for i in range(zipped_expertwise_rowmap_np.shape[0]): + valid_indices = zipped_expertwise_rowmap_np[i][ + zipped_expertwise_rowmap_np[i] != -1 + ] + for index in valid_indices: + np.testing.assert_equal( + scale_np[i], + unzipped_scales_np[index], + err_msg="unzipped_scales[{i}] is not correct", + ) + if __name__ == "__main__": unittest.main()