From 90e3933022d96d428ff016a4a36688f7434a7062 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Thu, 19 Mar 2026 14:09:28 +0800 Subject: [PATCH 01/11] conflict --- custom_ops/gpu_ops/cpp_extensions.cc | 7 + custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu | 166 +++++++++++++ custom_ops/setup_ops.py | 2 + .../layers/moe/fused_cast_sigmoid_bias.py | 44 ++++ .../layers/moe/fused_moe_cutlass_backend.py | 3 +- fastdeploy/model_executor/layers/moe/moe.py | 18 +- tests/layers/test_fused_cast_sigmoid_bias.py | 224 ++++++++++++++++++ 7 files changed, 461 insertions(+), 3 deletions(-) create mode 100644 custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu create mode 100644 fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py create mode 100644 tests/layers/test_fused_cast_sigmoid_bias.py diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 0f8343bd24b..0d01fab6b4e 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -691,6 +691,9 @@ std::vector NoauxTc(paddle::Tensor& scores, bool renormalize, float routed_scaling_factor); +std::vector FusedCastSigmoidBias(const paddle::Tensor& input, + const paddle::Tensor& bias); + std::vector NoauxTcRedundant( paddle::Tensor& scores, paddle::Tensor& scores_with_bias, @@ -1667,6 +1670,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); + m.def("fused_cast_sigmoid_bias", + &FusedCastSigmoidBias, + "Fused cast+sigmoid+bias for MoE gating scores"); + m.def("noaux_tc_redundant", &NoauxTcRedundant, "noaux_tc_redundant for MoE compute"); diff --git a/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu new file mode 100644 index 00000000000..392788582ff --- /dev/null +++ b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu @@ -0,0 +1,166 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +// Fused kernel: cast(input, fp32) -> sigmoid -> scores, scores + bias -> +// scores_with_bias +// +// For each element (token i, expert j): +// scores[i][j] = sigmoid(float(input[i][j])) +// scores_with_bias[i][j] = scores[i][j] + bias[j] +// +// Input: input [num_tokens, num_experts] bf16/fp16/fp32 +// bias [num_experts] or [1, num_experts] fp32 +// Output: scores [num_tokens, num_experts] fp32 +// scores_with_bias [num_tokens, num_experts] fp32 + +template +__global__ void fused_cast_sigmoid_bias_kernel( + const InT* __restrict__ input, + const float* __restrict__ bias, + float* __restrict__ scores, + float* __restrict__ scores_with_bias, + const int num_experts) { + const int64_t token_idx = blockIdx.x; + const int64_t offset = token_idx * num_experts; + + for (int j = threadIdx.x; j < num_experts; j += blockDim.x) { + float val = static_cast(input[offset + j]); + // sigmoid: 1 / (1 + exp(-x)) + float s = 1.0f / (1.0f + expf(-val)); + scores[offset + j] = s; + scores_with_bias[offset + j] = s + bias[j]; + } +} + +// Vectorized version for better memory throughput +template +__global__ void fused_cast_sigmoid_bias_vec_kernel( + const InT* __restrict__ input, + const float* __restrict__ bias, + float* __restrict__ scores, + float* __restrict__ scores_with_bias, + const int num_experts) { + const int64_t token_idx = blockIdx.x; + const int64_t offset = token_idx * num_experts; + + using in_vec_t = AlignedVector; + using out_vec_t = AlignedVector; + + const int vec_count = num_experts / kVecSize; + for (int idx = threadIdx.x; idx < vec_count; idx += blockDim.x) { + const int base = idx * kVecSize; + in_vec_t in_vec; + out_vec_t bias_vec; + Load(input + offset + base, &in_vec); + Load(bias + base, &bias_vec); + + out_vec_t s_vec, sb_vec; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + float val = static_cast(in_vec[i]); + float s = 1.0f / (1.0f + expf(-val)); + s_vec[i] = s; + sb_vec[i] = s + bias_vec[i]; + } + + Store(s_vec, scores + offset + base); + Store(sb_vec, scores_with_bias + offset + base); + } + + // Handle remaining elements + const int remaining_start = vec_count * kVecSize; + for (int j = remaining_start + threadIdx.x; j < num_experts; + j += blockDim.x) { + float val = static_cast(input[offset + j]); + float s = 1.0f / (1.0f + expf(-val)); + scores[offset + j] = s; + scores_with_bias[offset + j] = s + bias[j]; + } +} + +std::vector FusedCastSigmoidBias(const paddle::Tensor& input, + const paddle::Tensor& bias) { + auto input_shape = input.shape(); + PD_CHECK(input_shape.size() == 2, + "input must be 2D [num_tokens, num_experts]"); + auto bias_shape = bias.shape(); + // Support both [num_experts] and [1, num_experts] bias shapes + PD_CHECK( + bias_shape.size() == 1 || (bias_shape.size() == 2 && bias_shape[0] == 1), + "bias must be 1D [num_experts] or 2D [1, num_experts]"); + + int64_t num_tokens = input_shape[0]; + int64_t num_experts = input_shape[1]; + int64_t bias_numel = (bias_shape.size() == 1) ? bias_shape[0] : bias_shape[1]; + PD_CHECK(bias_numel == num_experts, "bias size must match num_experts"); + + auto place = input.place(); + auto stream = input.stream(); + + auto scores = paddle::empty( + {num_tokens, num_experts}, paddle::DataType::FLOAT32, place); + auto scores_with_bias = paddle::empty( + {num_tokens, num_experts}, paddle::DataType::FLOAT32, place); + + if (num_tokens == 0) { + return {scores, scores_with_bias}; + } + + dim3 grid(num_tokens); + int block_size = std::min(static_cast(1024), num_experts); + // Round up to warp size + block_size = ((block_size + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(block_size); + + DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, { + constexpr int kVecSize = 16 / sizeof(scalar_t); + if (num_experts % kVecSize == 0 && num_experts >= kVecSize) { + fused_cast_sigmoid_bias_vec_kernel + <<>>(input.data(), + bias.data(), + scores.data(), + scores_with_bias.data(), + num_experts); + } else { + fused_cast_sigmoid_bias_kernel + <<>>(input.data(), + bias.data(), + scores.data(), + scores_with_bias.data(), + num_experts); + } + }); + + return {scores, scores_with_bias}; +} + +std::vector FusedCastSigmoidBiasInferDtype( + const paddle::DataType& input_dtype, const paddle::DataType& bias_dtype) { + return {paddle::DataType::FLOAT32, paddle::DataType::FLOAT32}; +} + +std::vector> FusedCastSigmoidBiasInferShape( + const std::vector& input_shape, + const std::vector& bias_shape) { + return {input_shape, input_shape}; +} + +PD_BUILD_STATIC_OP(fused_cast_sigmoid_bias) + .Inputs({"input", "bias"}) + .Outputs({"scores", "scores_with_bias"}) + .SetKernelFn(PD_KERNEL(FusedCastSigmoidBias)) + .SetInferShapeFn(PD_INFER_SHAPE(FusedCastSigmoidBiasInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FusedCastSigmoidBiasInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 2e53012bb3b..958e50d313c 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -304,6 +304,7 @@ def find_end_files(directory, end_str): "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/custom_all_reduce/all_reduce.cu", "gpu_ops/merge_prefill_decode_output.cu", "gpu_ops/limit_thinking_content_length.cu", @@ -639,6 +640,7 @@ def find_end_files(directory, end_str): "gpu_ops/recover_decode_task.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/text_image_gather_scatter.cu", "gpu_ops/text_image_index_out.cu", diff --git a/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py new file mode 100644 index 00000000000..45953f9c34b --- /dev/null +++ b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py @@ -0,0 +1,44 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.model_executor.ops.gpu import ( + fused_cast_sigmoid_bias as _fused_cast_sigmoid_bias_cuda, +) + + +def fused_cast_sigmoid_bias( + gate_out: paddle.Tensor, + e_score_correction_bias: paddle.Tensor, +) -> tuple: + """ + Fused operation: cast gate_out to float32, apply sigmoid, and add bias. + + This fuses three separate operations: + 1. gate_out = gate_out.cast("float32") + 2. scores = sigmoid(gate_out) + 3. scores_with_bias = scores + e_score_correction_bias + + Args: + gate_out: [num_tokens, num_experts], bf16/fp16/fp32 - raw gate output + e_score_correction_bias: [num_experts], fp32 - correction bias + + Returns: + scores: [num_tokens, num_experts], fp32 - sigmoid(gate_out) + scores_with_bias: [num_tokens, num_experts], fp32 - scores + bias + """ + return _fused_cast_sigmoid_bias_cuda(gate_out, e_score_correction_bias) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 0c86270c630..d53501d3f98 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -275,7 +275,6 @@ def apply_tp( Paddle Cutlass compute Fused MoE. """ gate_out = gate(x) - gate_out = gate_out.cast("float32") if layer.topk_method == "noaux_tc": gate_out, topk_weights, topk_idx = get_moe_scores( gate_out, @@ -285,6 +284,7 @@ def apply_tp( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + use_fused_cast=True, ) ( permute_input, @@ -308,6 +308,7 @@ def apply_tp( topk_only_mode=True, ) else: + gate_out = gate_out.cast("float32") ( permute_input, token_nums_per_expert, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 1fecaec6cd4..34f9a58985d 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -39,6 +39,16 @@ from fastdeploy.model_executor.ops.gpu import noaux_tc, noaux_tc_redundant except: logger.warning("import noaux_tc Failed!") + +try: + from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( + fused_cast_sigmoid_bias, + ) + + _FUSED_CAST_SIGMOID_BIAS_AVAILABLE = True +except Exception: + _FUSED_CAST_SIGMOID_BIAS_AVAILABLE = False + import numpy as np @@ -90,13 +100,17 @@ def get_moe_scores( expert_in_rank_num_list: paddle.Tensor = None, tokens_per_expert_stats_list: paddle.Tensor = None, redundant_ep_rank_num_plus_one: int = 1, + use_fused_cast: bool = False, ) -> paddle.Tensor: """ compute moe scores using e_score_correction_bias. """ - scores = paddle.nn.functional.sigmoid(gating_output) assert e_score_correction_bias is not None, "e_score_correction_bias is none!" - scores_with_bias = scores + e_score_correction_bias + if use_fused_cast and _FUSED_CAST_SIGMOID_BIAS_AVAILABLE: + scores, scores_with_bias = fused_cast_sigmoid_bias(gating_output, e_score_correction_bias) + else: + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias if expert_id_to_ep_rank_array is None: scores, topk_values, topk_idx = noaux_tc( scores, diff --git a/tests/layers/test_fused_cast_sigmoid_bias.py b/tests/layers/test_fused_cast_sigmoid_bias.py new file mode 100644 index 00000000000..7f2b993fe7f --- /dev/null +++ b/tests/layers/test_fused_cast_sigmoid_bias.py @@ -0,0 +1,224 @@ +""" +Test for fused_cast_sigmoid_bias CUDA custom op. +Tests: functionality, accuracy, and performance. + +Usage: + conda activate fd_py12_bin + export PYTHONPATH="/workspace2/bingoo/code/FastDeploy" + python /workspace2/bingoo/code/FastDeploy/tests/model_executor/layers/moe/test_fused_cast_sigmoid_bias.py +""" + + +import paddle +import paddle.nn.functional as F + +from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( + fused_cast_sigmoid_bias, +) + + +def reference_cast_sigmoid_bias(gate_out, bias): + """Reference implementation: 3 separate ops.""" + gate_fp32 = gate_out.cast("float32") + scores = F.sigmoid(gate_fp32) + scores_with_bias = scores + bias + return scores, scores_with_bias + + +def test_functionality(): + """Test basic functionality: correct shapes and dtypes.""" + print("=" * 60) + print("Test 1: Functionality") + print("=" * 60) + + for dtype_name in ["float16", "bfloat16", "float32"]: + for num_tokens in [1, 7, 128, 1024]: + for num_experts in [8, 64, 128, 256]: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + scores, scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias) + + assert scores.shape == [ + num_tokens, + num_experts, + ], f"scores shape mismatch: {scores.shape} vs {[num_tokens, num_experts]}" + assert scores_with_bias.shape == [ + num_tokens, + num_experts, + ], f"scores_with_bias shape mismatch: {scores_with_bias.shape}" + assert scores.dtype == paddle.float32, f"scores dtype mismatch: {scores.dtype}" + assert ( + scores_with_bias.dtype == paddle.float32 + ), f"scores_with_bias dtype mismatch: {scores_with_bias.dtype}" + + # Sigmoid output should be in [0, 1] + assert paddle.all(scores >= 0.0) and paddle.all(scores <= 1.0), "scores out of [0,1] range" + + print(f" [PASS] dtype={dtype_name}") + + print(" All functionality tests passed.\n") + + +def test_accuracy(): + """Test numerical accuracy against reference implementation.""" + print("=" * 60) + print("Test 2: Accuracy") + print("=" * 60) + + test_cases = [ + ("float16", 1, 8), + ("float16", 128, 256), + ("float16", 1024, 256), + ("bfloat16", 1, 8), + ("bfloat16", 128, 256), + ("bfloat16", 1024, 256), + ("float32", 1, 8), + ("float32", 128, 256), + ("float32", 1024, 256), + ] + + for dtype_name, num_tokens, num_experts in test_cases: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + # Fused kernel + fused_scores, fused_scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias) + + # Reference + ref_scores, ref_scores_with_bias = reference_cast_sigmoid_bias(gate_out, bias) + + # Compare + scores_diff = paddle.abs(fused_scores - ref_scores).max().item() + scores_bias_diff = paddle.abs(fused_scores_with_bias - ref_scores_with_bias).max().item() + + atol = 1e-6 if dtype_name == "float32" else 1e-3 + passed = scores_diff < atol and scores_bias_diff < atol + + status = "PASS" if passed else "FAIL" + print( + f" [{status}] dtype={dtype_name}, tokens={num_tokens}, experts={num_experts} | " + f"scores_max_diff={scores_diff:.2e}, scores_with_bias_max_diff={scores_bias_diff:.2e}" + ) + + if not passed: + raise AssertionError( + f"Accuracy test failed for dtype={dtype_name}, tokens={num_tokens}, experts={num_experts}. " + f"scores_diff={scores_diff}, scores_bias_diff={scores_bias_diff}, atol={atol}" + ) + + print(" All accuracy tests passed.\n") + + +def test_accuracy_extreme_values(): + """Test accuracy with extreme input values.""" + print("=" * 60) + print("Test 3: Accuracy with extreme values") + print("=" * 60) + + num_tokens, num_experts = 64, 256 + + for dtype_name in ["float16", "bfloat16"]: + # Large positive values -> sigmoid ~ 1.0 + gate_out = paddle.full([num_tokens, num_experts], 10.0, dtype=dtype_name) + bias = paddle.zeros([num_experts], dtype="float32") + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + print(f" [{'PASS' if diff < 1e-5 else 'FAIL'}] dtype={dtype_name}, large positive: max_diff={diff:.2e}") + + # Large negative values -> sigmoid ~ 0.0 + gate_out = paddle.full([num_tokens, num_experts], -10.0, dtype=dtype_name) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + print(f" [{'PASS' if diff < 1e-5 else 'FAIL'}] dtype={dtype_name}, large negative: max_diff={diff:.2e}") + + # Zero values -> sigmoid = 0.5 + gate_out = paddle.zeros([num_tokens, num_experts], dtype=dtype_name) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + assert diff < 1e-6, f"Zero input test failed: diff={diff}" + print(f" [PASS] dtype={dtype_name}, zeros: max_diff={diff:.2e}") + + print(" All extreme value tests passed.\n") + + +def test_performance(): + """Benchmark fused kernel vs reference implementation using CUDA events.""" + print("=" * 60) + print("Test 4: Performance (CUDA event timing)") + print("=" * 60) + + configs = [ + ("bfloat16", 1, 256), # single token decode + ("bfloat16", 8, 256), # small batch decode + ("bfloat16", 64, 256), # medium batch + ("bfloat16", 256, 256), # typical DeepSeek-V3 config + ("bfloat16", 1024, 256), # large prefill + ("bfloat16", 4096, 256), # very large prefill + ] + + warmup_iters = 100 + bench_iters = 500 + + for dtype_name, num_tokens, num_experts in configs: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + # Warmup fused + for _ in range(warmup_iters): + fused_cast_sigmoid_bias(gate_out, bias) + paddle.device.synchronize() + + # Benchmark fused with CUDA events + start_event = paddle.device.cuda.Event(enable_timing=True) + end_event = paddle.device.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(bench_iters): + fused_cast_sigmoid_bias(gate_out, bias) + end_event.record() + paddle.device.synchronize() + fused_time = start_event.elapsed_time(end_event) / bench_iters * 1e3 # us + + # Warmup reference + for _ in range(warmup_iters): + reference_cast_sigmoid_bias(gate_out, bias) + paddle.device.synchronize() + + # Benchmark reference with CUDA events + start_event = paddle.device.cuda.Event(enable_timing=True) + end_event = paddle.device.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(bench_iters): + reference_cast_sigmoid_bias(gate_out, bias) + end_event.record() + paddle.device.synchronize() + ref_time = start_event.elapsed_time(end_event) / bench_iters * 1e3 # us + + speedup = ref_time / fused_time if fused_time > 0 else float("inf") + print( + f" tokens={num_tokens:5d}, experts={num_experts:3d} | " + f"ref={ref_time:8.1f}us, fused={fused_time:8.1f}us, speedup={speedup:.2f}x" + ) + + print() + print(" Note: The CUDA custom op fuses cast+sigmoid+bias into a single kernel,") + print(" eliminating 2 intermediate tensors and reducing kernel launches from 3 to 1.") + print(" Expected speedup: ~3x over the reference 3-op implementation.") + print(" Performance benchmark complete.\n") + + +if __name__ == "__main__": + paddle.set_device("gpu") + print("Running fused_cast_sigmoid_bias tests...\n") + + test_functionality() + test_accuracy() + test_accuracy_extreme_values() + test_performance() + + print("=" * 60) + print("All tests passed!") + print("=" * 60) From 4588d1d940c33893d10b455f241447aa7b01df10 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Mon, 16 Mar 2026 11:15:08 +0800 Subject: [PATCH 02/11] add cast_sgmoid_add fusion and enable in glm4.5 --- tests/layers/test_fused_cast_sigmoid_bias.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/layers/test_fused_cast_sigmoid_bias.py b/tests/layers/test_fused_cast_sigmoid_bias.py index 7f2b993fe7f..f9528b2cbd8 100644 --- a/tests/layers/test_fused_cast_sigmoid_bias.py +++ b/tests/layers/test_fused_cast_sigmoid_bias.py @@ -8,7 +8,6 @@ python /workspace2/bingoo/code/FastDeploy/tests/model_executor/layers/moe/test_fused_cast_sigmoid_bias.py """ - import paddle import paddle.nn.functional as F From 49e4323c98d79003427f16b60ec557ff07e1bbe7 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Tue, 17 Mar 2026 11:28:27 +0800 Subject: [PATCH 03/11] support more cast type --- custom_ops/gpu_ops/cpp_extensions.cc | 8 +- custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu | 102 +++++---- .../layers/moe/fused_cast_sigmoid_bias.py | 18 +- tests/layers/test_fused_cast_sigmoid_bias.py | 193 ++++++++++++++++-- 4 files changed, 254 insertions(+), 67 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 0d01fab6b4e..269d4b0520a 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -692,7 +692,8 @@ std::vector NoauxTc(paddle::Tensor& scores, float routed_scaling_factor); std::vector FusedCastSigmoidBias(const paddle::Tensor& input, - const paddle::Tensor& bias); + const paddle::Tensor& bias, + std::string cast_type); std::vector NoauxTcRedundant( paddle::Tensor& scores, @@ -1672,7 +1673,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("fused_cast_sigmoid_bias", &FusedCastSigmoidBias, - "Fused cast+sigmoid+bias for MoE gating scores"); + "Fused cast+sigmoid+bias for MoE gating scores", + py::arg("input"), + py::arg("bias"), + py::arg("cast_type") = std::string("float32")); m.def("noaux_tc_redundant", &NoauxTcRedundant, diff --git a/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu index 392788582ff..bd835faa3bf 100644 --- a/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu +++ b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu @@ -14,24 +14,24 @@ #include "helper.h" -// Fused kernel: cast(input, fp32) -> sigmoid -> scores, scores + bias -> +// Fused kernel: cast(input, cast_type) -> sigmoid -> scores, scores + bias -> // scores_with_bias // // For each element (token i, expert j): -// scores[i][j] = sigmoid(float(input[i][j])) -// scores_with_bias[i][j] = scores[i][j] + bias[j] +// scores[i][j] = OutT(sigmoid(float(input[i][j]))) +// scores_with_bias[i][j] = OutT(sigmoid(float(input[i][j])) + bias[j]) // // Input: input [num_tokens, num_experts] bf16/fp16/fp32 // bias [num_experts] or [1, num_experts] fp32 -// Output: scores [num_tokens, num_experts] fp32 -// scores_with_bias [num_tokens, num_experts] fp32 +// Output: scores [num_tokens, num_experts] cast_type (fp32/fp16/bf16) +// scores_with_bias [num_tokens, num_experts] cast_type (fp32/fp16/bf16) -template +template __global__ void fused_cast_sigmoid_bias_kernel( const InT* __restrict__ input, const float* __restrict__ bias, - float* __restrict__ scores, - float* __restrict__ scores_with_bias, + OutT* __restrict__ scores, + OutT* __restrict__ scores_with_bias, const int num_experts) { const int64_t token_idx = blockIdx.x; const int64_t offset = token_idx * num_experts; @@ -40,30 +40,31 @@ __global__ void fused_cast_sigmoid_bias_kernel( float val = static_cast(input[offset + j]); // sigmoid: 1 / (1 + exp(-x)) float s = 1.0f / (1.0f + expf(-val)); - scores[offset + j] = s; - scores_with_bias[offset + j] = s + bias[j]; + scores[offset + j] = static_cast(s); + scores_with_bias[offset + j] = static_cast(s + bias[j]); } } // Vectorized version for better memory throughput -template +template __global__ void fused_cast_sigmoid_bias_vec_kernel( const InT* __restrict__ input, const float* __restrict__ bias, - float* __restrict__ scores, - float* __restrict__ scores_with_bias, + OutT* __restrict__ scores, + OutT* __restrict__ scores_with_bias, const int num_experts) { const int64_t token_idx = blockIdx.x; const int64_t offset = token_idx * num_experts; using in_vec_t = AlignedVector; - using out_vec_t = AlignedVector; + using out_vec_t = AlignedVector; + using bias_vec_t = AlignedVector; const int vec_count = num_experts / kVecSize; for (int idx = threadIdx.x; idx < vec_count; idx += blockDim.x) { const int base = idx * kVecSize; in_vec_t in_vec; - out_vec_t bias_vec; + bias_vec_t bias_vec; Load(input + offset + base, &in_vec); Load(bias + base, &bias_vec); @@ -72,8 +73,8 @@ __global__ void fused_cast_sigmoid_bias_vec_kernel( for (int i = 0; i < kVecSize; ++i) { float val = static_cast(in_vec[i]); float s = 1.0f / (1.0f + expf(-val)); - s_vec[i] = s; - sb_vec[i] = s + bias_vec[i]; + s_vec[i] = static_cast(s); + sb_vec[i] = static_cast(s + bias_vec[i]); } Store(s_vec, scores + offset + base); @@ -86,13 +87,22 @@ __global__ void fused_cast_sigmoid_bias_vec_kernel( j += blockDim.x) { float val = static_cast(input[offset + j]); float s = 1.0f / (1.0f + expf(-val)); - scores[offset + j] = s; - scores_with_bias[offset + j] = s + bias[j]; + scores[offset + j] = static_cast(s); + scores_with_bias[offset + j] = static_cast(s + bias[j]); } } +static paddle::DataType ParseCastType(const std::string& cast_type) { + if (cast_type == "float32") return paddle::DataType::FLOAT32; + if (cast_type == "float16") return paddle::DataType::FLOAT16; + if (cast_type == "bfloat16") return paddle::DataType::BFLOAT16; + PD_THROW("Unsupported cast_type: " + cast_type + + ". Only float32, float16, bfloat16 are supported."); +} + std::vector FusedCastSigmoidBias(const paddle::Tensor& input, - const paddle::Tensor& bias) { + const paddle::Tensor& bias, + std::string cast_type) { auto input_shape = input.shape(); PD_CHECK(input_shape.size() == 2, "input must be 2D [num_tokens, num_experts]"); @@ -109,11 +119,11 @@ std::vector FusedCastSigmoidBias(const paddle::Tensor& input, auto place = input.place(); auto stream = input.stream(); + auto out_dtype = ParseCastType(cast_type); - auto scores = paddle::empty( - {num_tokens, num_experts}, paddle::DataType::FLOAT32, place); - auto scores_with_bias = paddle::empty( - {num_tokens, num_experts}, paddle::DataType::FLOAT32, place); + auto scores = paddle::empty({num_tokens, num_experts}, out_dtype, place); + auto scores_with_bias = + paddle::empty({num_tokens, num_experts}, out_dtype, place); if (num_tokens == 0) { return {scores, scores_with_bias}; @@ -125,31 +135,36 @@ std::vector FusedCastSigmoidBias(const paddle::Tensor& input, block_size = ((block_size + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; dim3 block(block_size); - DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, { - constexpr int kVecSize = 16 / sizeof(scalar_t); - if (num_experts % kVecSize == 0 && num_experts >= kVecSize) { - fused_cast_sigmoid_bias_vec_kernel - <<>>(input.data(), - bias.data(), - scores.data(), - scores_with_bias.data(), - num_experts); - } else { - fused_cast_sigmoid_bias_kernel - <<>>(input.data(), - bias.data(), - scores.data(), - scores_with_bias.data(), - num_experts); - } + DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), in_scalar_t, { + DISPATCH_FLOAT_FP6_DTYPE(out_dtype, out_scalar_t, { + constexpr int kVecSize = 16 / sizeof(in_scalar_t); + if (num_experts % kVecSize == 0 && num_experts >= kVecSize) { + fused_cast_sigmoid_bias_vec_kernel + <<>>(input.data(), + bias.data(), + scores.data(), + scores_with_bias.data(), + num_experts); + } else { + fused_cast_sigmoid_bias_kernel + <<>>(input.data(), + bias.data(), + scores.data(), + scores_with_bias.data(), + num_experts); + } + }); }); return {scores, scores_with_bias}; } std::vector FusedCastSigmoidBiasInferDtype( - const paddle::DataType& input_dtype, const paddle::DataType& bias_dtype) { - return {paddle::DataType::FLOAT32, paddle::DataType::FLOAT32}; + const paddle::DataType& input_dtype, + const paddle::DataType& bias_dtype, + std::string cast_type) { + auto out_dtype = ParseCastType(cast_type); + return {out_dtype, out_dtype}; } std::vector> FusedCastSigmoidBiasInferShape( @@ -161,6 +176,7 @@ std::vector> FusedCastSigmoidBiasInferShape( PD_BUILD_STATIC_OP(fused_cast_sigmoid_bias) .Inputs({"input", "bias"}) .Outputs({"scores", "scores_with_bias"}) + .Attrs({"cast_type: std::string"}) .SetKernelFn(PD_KERNEL(FusedCastSigmoidBias)) .SetInferShapeFn(PD_INFER_SHAPE(FusedCastSigmoidBiasInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(FusedCastSigmoidBiasInferDtype)); diff --git a/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py index 45953f9c34b..6d01138c06c 100644 --- a/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py +++ b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py @@ -24,21 +24,23 @@ def fused_cast_sigmoid_bias( gate_out: paddle.Tensor, e_score_correction_bias: paddle.Tensor, + cast_type: str = "float32", ) -> tuple: """ - Fused operation: cast gate_out to float32, apply sigmoid, and add bias. + 融合操作:将gate_out转换为指定类型,应用sigmoid函数,并添加偏置。 - This fuses three separate operations: - 1. gate_out = gate_out.cast("float32") + 该函数融合了以下三个独立操作: + 1. gate_out = gate_out.cast(cast_type) 2. scores = sigmoid(gate_out) 3. scores_with_bias = scores + e_score_correction_bias Args: - gate_out: [num_tokens, num_experts], bf16/fp16/fp32 - raw gate output - e_score_correction_bias: [num_experts], fp32 - correction bias + gate_out: [num_tokens, num_experts],bf16/fp16/fp32类型 - 原始gate输出 + e_score_correction_bias: [num_experts],fp32类型 - 修正偏置 + cast_type: 输出数据类型字符串,支持"float32"、"float16"、"bfloat16" Returns: - scores: [num_tokens, num_experts], fp32 - sigmoid(gate_out) - scores_with_bias: [num_tokens, num_experts], fp32 - scores + bias + scores: [num_tokens, num_experts],cast_type类型 - sigmoid(gate_out)的结果 + scores_with_bias: [num_tokens, num_experts],cast_type类型 - 加上偏置后的分数 """ - return _fused_cast_sigmoid_bias_cuda(gate_out, e_score_correction_bias) + return _fused_cast_sigmoid_bias_cuda(gate_out, e_score_correction_bias, cast_type) diff --git a/tests/layers/test_fused_cast_sigmoid_bias.py b/tests/layers/test_fused_cast_sigmoid_bias.py index f9528b2cbd8..2609036a7f7 100644 --- a/tests/layers/test_fused_cast_sigmoid_bias.py +++ b/tests/layers/test_fused_cast_sigmoid_bias.py @@ -3,31 +3,57 @@ Tests: functionality, accuracy, and performance. Usage: - conda activate fd_py12_bin - export PYTHONPATH="/workspace2/bingoo/code/FastDeploy" - python /workspace2/bingoo/code/FastDeploy/tests/model_executor/layers/moe/test_fused_cast_sigmoid_bias.py + conda activate fd_fused_cast_test + python /ssd2/bingoo/code/fastdeploy/FastDeploy/tests/layers/test_fused_cast_sigmoid_bias.py """ +import os + import paddle import paddle.nn.functional as F - -from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( - fused_cast_sigmoid_bias, +from paddle.utils.cpp_extension import load + +DTYPE_MAP = { + "float16": paddle.float16, + "bfloat16": paddle.bfloat16, + "float32": paddle.float32, +} + +# Load the custom op directly via paddle JIT compilation +_basedir = os.path.join(os.path.dirname(__file__), "../../custom_ops") +_basedir = os.path.abspath(_basedir) +_ops = load( + name="fused_cast_sigmoid_bias_op", + sources=[os.path.join(_basedir, "gpu_ops/fused_cast_sigmoid_bias.cu")], + extra_include_paths=[ + os.path.join(_basedir, "gpu_ops"), + os.path.join(_basedir, "third_party/nlohmann_json/include"), + os.path.join(_basedir, "third_party/cutlass/include"), + ], + extra_cuda_cflags=["-gencode", "arch=compute_80,code=sm_80", "-DPADDLE_DEV"], + build_directory="/tmp/fused_cast_build_test", ) -def reference_cast_sigmoid_bias(gate_out, bias): - """Reference implementation: 3 separate ops.""" +def fused_cast_sigmoid_bias(gate_out, bias, cast_type="float32"): + """Wrapper for the custom op.""" + return _ops.static_op_fused_cast_sigmoid_bias(gate_out, bias, cast_type) + + +def reference_cast_sigmoid_bias(gate_out, bias, cast_type="float32"): + """Reference implementation: compute in fp32, cast output to cast_type.""" gate_fp32 = gate_out.cast("float32") - scores = F.sigmoid(gate_fp32) - scores_with_bias = scores + bias + scores_fp32 = F.sigmoid(gate_fp32) + scores_with_bias_fp32 = scores_fp32 + bias + scores = scores_fp32.cast(cast_type) + scores_with_bias = scores_with_bias_fp32.cast(cast_type) return scores, scores_with_bias def test_functionality(): - """Test basic functionality: correct shapes and dtypes.""" + """Test basic functionality: correct shapes and dtypes (default cast_type=float32).""" print("=" * 60) - print("Test 1: Functionality") + print("Test 1: Functionality (default cast_type=float32)") print("=" * 60) for dtype_name in ["float16", "bfloat16", "float32"]: @@ -59,10 +85,43 @@ def test_functionality(): print(" All functionality tests passed.\n") +def test_functionality_cast_types(): + """Test functionality with different cast_type values.""" + print("=" * 60) + print("Test 1b: Functionality with different cast_type") + print("=" * 60) + + for input_dtype in ["float16", "bfloat16", "float32"]: + for cast_type in ["float16", "bfloat16", "float32"]: + expected_paddle_dtype = DTYPE_MAP[cast_type] + for num_tokens in [1, 64, 256]: + for num_experts in [8, 64, 256]: + gate_out = paddle.randn([num_tokens, num_experts], dtype=input_dtype) + bias = paddle.randn([num_experts], dtype="float32") + + scores, scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + + assert scores.shape == [num_tokens, num_experts], f"scores shape mismatch: {scores.shape}" + assert scores_with_bias.shape == [ + num_tokens, + num_experts, + ], f"scores_with_bias shape mismatch: {scores_with_bias.shape}" + assert ( + scores.dtype == expected_paddle_dtype + ), f"scores dtype mismatch: got {scores.dtype}, expected {expected_paddle_dtype}" + assert ( + scores_with_bias.dtype == expected_paddle_dtype + ), f"scores_with_bias dtype mismatch: got {scores_with_bias.dtype}, expected {expected_paddle_dtype}" + + print(f" [PASS] input_dtype={input_dtype}, cast_type={cast_type}") + + print(" All cast_type functionality tests passed.\n") + + def test_accuracy(): - """Test numerical accuracy against reference implementation.""" + """Test numerical accuracy against reference implementation (default cast_type=float32).""" print("=" * 60) - print("Test 2: Accuracy") + print("Test 2: Accuracy (default cast_type=float32)") print("=" * 60) test_cases = [ @@ -109,6 +168,76 @@ def test_accuracy(): print(" All accuracy tests passed.\n") +def test_accuracy_cast_types(): + """Test numerical accuracy with different cast_type values.""" + print("=" * 60) + print("Test 2b: Accuracy with different cast_type") + print("=" * 60) + + # (input_dtype, cast_type, num_tokens, num_experts) + test_cases = [ + # cast to float32 (original behavior) + ("float16", "float32", 128, 256), + ("bfloat16", "float32", 128, 256), + ("float32", "float32", 128, 256), + # cast to float16 + ("float16", "float16", 128, 256), + ("bfloat16", "float16", 128, 256), + ("float32", "float16", 128, 256), + # cast to bfloat16 + ("float16", "bfloat16", 128, 256), + ("bfloat16", "bfloat16", 128, 256), + ("float32", "bfloat16", 128, 256), + # different shapes + ("bfloat16", "float16", 1, 8), + ("bfloat16", "float16", 1024, 256), + ("float16", "bfloat16", 1, 8), + ("float16", "bfloat16", 1024, 256), + ] + + for input_dtype, cast_type, num_tokens, num_experts in test_cases: + gate_out = paddle.randn([num_tokens, num_experts], dtype=input_dtype) + bias = paddle.randn([num_experts], dtype="float32") + + # Fused kernel + fused_scores, fused_scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + + # Reference + ref_scores, ref_scores_with_bias = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + + # Compare in float32 for stable diff computation + scores_diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + scores_bias_diff = ( + paddle.abs(fused_scores_with_bias.cast("float32") - ref_scores_with_bias.cast("float32")).max().item() + ) + + # Tolerance depends on cast_type precision + if cast_type == "float32": + atol = 1e-6 + elif cast_type == "bfloat16": + atol = 1e-2 # bfloat16 has fewer mantissa bits + else: # float16 + atol = 1e-3 + + passed = scores_diff < atol and scores_bias_diff < atol + + status = "PASS" if passed else "FAIL" + print( + f" [{status}] input={input_dtype}, cast_type={cast_type}, " + f"tokens={num_tokens}, experts={num_experts} | " + f"scores_diff={scores_diff:.2e}, bias_diff={scores_bias_diff:.2e}" + ) + + if not passed: + raise AssertionError( + f"Accuracy test failed for input={input_dtype}, cast_type={cast_type}, " + f"tokens={num_tokens}, experts={num_experts}. " + f"scores_diff={scores_diff}, bias_diff={scores_bias_diff}, atol={atol}" + ) + + print(" All cast_type accuracy tests passed.\n") + + def test_accuracy_extreme_values(): """Test accuracy with extreme input values.""" print("=" * 60) @@ -144,6 +273,39 @@ def test_accuracy_extreme_values(): print(" All extreme value tests passed.\n") +def test_accuracy_extreme_values_cast_types(): + """Test accuracy with extreme values across different cast_type values.""" + print("=" * 60) + print("Test 3b: Accuracy with extreme values + different cast_type") + print("=" * 60) + + num_tokens, num_experts = 64, 256 + + for input_dtype in ["float16", "bfloat16"]: + for cast_type in ["float16", "bfloat16", "float32"]: + bias = paddle.zeros([num_experts], dtype="float32") + + # Large positive + gate_out = paddle.full([num_tokens, num_experts], 10.0, dtype=input_dtype) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + atol = 1e-2 if cast_type == "bfloat16" else 1e-5 + status = "PASS" if diff < atol else "FAIL" + print(f" [{status}] input={input_dtype}, cast={cast_type}, " f"large positive: diff={diff:.2e}") + + # Zero values + gate_out = paddle.zeros([num_tokens, num_experts], dtype=input_dtype) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + atol = 1e-2 if cast_type == "bfloat16" else 1e-5 + assert diff < atol, f"Zero input test failed: input={input_dtype}, cast={cast_type}, diff={diff}" + print(f" [PASS] input={input_dtype}, cast={cast_type}, " f"zeros: diff={diff:.2e}") + + print(" All extreme value cast_type tests passed.\n") + + def test_performance(): """Benchmark fused kernel vs reference implementation using CUDA events.""" print("=" * 60) @@ -214,8 +376,11 @@ def test_performance(): print("Running fused_cast_sigmoid_bias tests...\n") test_functionality() + test_functionality_cast_types() test_accuracy() + test_accuracy_cast_types() test_accuracy_extreme_values() + test_accuracy_extreme_values_cast_types() test_performance() print("=" * 60) From 7743979c3cb0ab4dbee8c7e5ab64e1a5dcb3f072 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Tue, 17 Mar 2026 14:12:12 +0800 Subject: [PATCH 04/11] modify test --- tests/layers/test_fused_cast_sigmoid_bias.py | 46 ++++++++------------ 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/tests/layers/test_fused_cast_sigmoid_bias.py b/tests/layers/test_fused_cast_sigmoid_bias.py index 2609036a7f7..9a1c16a42e8 100644 --- a/tests/layers/test_fused_cast_sigmoid_bias.py +++ b/tests/layers/test_fused_cast_sigmoid_bias.py @@ -1,17 +1,25 @@ """ -Test for fused_cast_sigmoid_bias CUDA custom op. -Tests: functionality, accuracy, and performance. - -Usage: - conda activate fd_fused_cast_test - python /ssd2/bingoo/code/fastdeploy/FastDeploy/tests/layers/test_fused_cast_sigmoid_bias.py +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ -import os - import paddle import paddle.nn.functional as F -from paddle.utils.cpp_extension import load + +from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( + fused_cast_sigmoid_bias, +) DTYPE_MAP = { "float16": paddle.float16, @@ -19,26 +27,6 @@ "float32": paddle.float32, } -# Load the custom op directly via paddle JIT compilation -_basedir = os.path.join(os.path.dirname(__file__), "../../custom_ops") -_basedir = os.path.abspath(_basedir) -_ops = load( - name="fused_cast_sigmoid_bias_op", - sources=[os.path.join(_basedir, "gpu_ops/fused_cast_sigmoid_bias.cu")], - extra_include_paths=[ - os.path.join(_basedir, "gpu_ops"), - os.path.join(_basedir, "third_party/nlohmann_json/include"), - os.path.join(_basedir, "third_party/cutlass/include"), - ], - extra_cuda_cflags=["-gencode", "arch=compute_80,code=sm_80", "-DPADDLE_DEV"], - build_directory="/tmp/fused_cast_build_test", -) - - -def fused_cast_sigmoid_bias(gate_out, bias, cast_type="float32"): - """Wrapper for the custom op.""" - return _ops.static_op_fused_cast_sigmoid_bias(gate_out, bias, cast_type) - def reference_cast_sigmoid_bias(gate_out, bias, cast_type="float32"): """Reference implementation: compute in fp32, cast output to cast_type.""" From 6d9459747e280bcd1d8e4aad1b72f1b9715a2f2b Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Wed, 18 Mar 2026 10:09:42 +0800 Subject: [PATCH 05/11] add type check --- custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu index bd835faa3bf..820992c7091 100644 --- a/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu +++ b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu @@ -116,6 +116,9 @@ std::vector FusedCastSigmoidBias(const paddle::Tensor& input, int64_t num_experts = input_shape[1]; int64_t bias_numel = (bias_shape.size() == 1) ? bias_shape[0] : bias_shape[1]; PD_CHECK(bias_numel == num_experts, "bias size must match num_experts"); + PD_CHECK(bias.dtype() == paddle::DataType::FLOAT32, + "bias must be float32, got ", + bias.dtype()); auto place = input.place(); auto stream = input.stream(); From 62efce98a1ee7953dc8cf389ce43f87741b49320 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Thu, 19 Mar 2026 10:29:35 +0800 Subject: [PATCH 06/11] fix config issues --- tests/layers/test_fused_moe_cutlass_backend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/layers/test_fused_moe_cutlass_backend.py b/tests/layers/test_fused_moe_cutlass_backend.py index 2e8ea281daa..0f91ac323a2 100644 --- a/tests/layers/test_fused_moe_cutlass_backend.py +++ b/tests/layers/test_fused_moe_cutlass_backend.py @@ -388,7 +388,9 @@ def combine(self, ffn_out, topk_idx, topk_weights, handle, quant_group_size=-1): np.testing.assert_allclose(out.numpy(), np.full((1, 2), 5.0)) def test_apply_tp_with_dispatch_and_reduce(self, monkeypatch): - def fake_get_moe_scores(gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize): + def fake_get_moe_scores( + gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize, **kwargs + ): return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]]) def fake_dispatch(*args, **kwargs): From 77529febd4c2c8a28229b0832cf063ecd6d9d3b8 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Wed, 25 Mar 2026 22:41:35 +0800 Subject: [PATCH 07/11] enable more backend --- .../layers/moe/fused_moe_cutlass_backend.py | 5 ++++- .../layers/moe/fused_moe_deepgemm_backend.py | 8 ++++++-- .../layers/moe/fused_moe_marlin_backend.py | 6 +++++- .../layers/moe/fused_moe_triton_backend.py | 6 +++++- fastdeploy/model_executor/layers/moe/moe.py | 16 ++++++---------- 5 files changed, 26 insertions(+), 15 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index d53501d3f98..5eaba781a6a 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -276,6 +276,9 @@ def apply_tp( """ gate_out = gate(x) if layer.topk_method == "noaux_tc": + if layer.dynamic_load_weight: + gate_out = gate_out.cast("float32") + use_fused = not layer.dynamic_load_weight gate_out, topk_weights, topk_idx = get_moe_scores( gate_out, layer.n_group, @@ -284,7 +287,7 @@ def apply_tp( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), - use_fused_cast=True, + use_fused_cast=use_fused, ) ( permute_input, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index cf38ec57b0d..06a5cc88fca 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -816,11 +816,12 @@ def apply_tp( below is TP compute method. """ gate_out = gate(x) - gate_out = gate_out.cast("float32") - if layer.topk_method == "noaux_tc": if not fastdeploy.envs.FD_USE_PHI_MOE_TOPK: + if layer.dynamic_load_weight: + gate_out = gate_out.cast("float32") + use_fused = not layer.dynamic_load_weight _, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores( gate_out, layer.n_group, @@ -829,8 +830,10 @@ def apply_tp( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") topk_weights, topk_ids = moe_topk_select( gate_out, layer.n_group, @@ -842,6 +845,7 @@ def apply_tp( ) else: + gate_out = gate_out.cast("float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py index 4e4101e9726..d1c26b59e81 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py @@ -248,7 +248,6 @@ def apply( Marlin compute Fused MoE. """ gate_out = gate(x) - gate_out = gate_out.cast("float32") token_num = x.shape[0] top_k = layer.top_k top_k = layer.top_k @@ -260,6 +259,9 @@ def apply( if topk_method == "noaux_tc": from fastdeploy.model_executor.layers.moe.moe import get_moe_scores + if layer.dynamic_load_weight: + gate_out = gate_out.cast("float32") + use_fused = not layer.dynamic_load_weight _, topk_weights, topk_ids = get_moe_scores( gate_out, layer.n_group, @@ -268,8 +270,10 @@ def apply( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index d1db43a3241..58a0c1a1e91 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -299,7 +299,6 @@ def apply( if token_num == 0: return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) gate_out = gate(x) - gate_out = gate_out.cast("float32") top_k = layer.top_k num_local_experts = layer.num_local_experts top_k = layer.top_k @@ -307,6 +306,9 @@ def apply( hidden_size = layer.hidden_size if layer.topk_method == "noaux_tc": + if layer.dynamic_load_weight: + gate_out = gate_out.cast("float32") + use_fused = not layer.dynamic_load_weight gate_out, topk_weights, topk_ids = get_moe_scores( gate_out, layer.n_group, @@ -315,8 +317,10 @@ def apply( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index bfd1e550de1..76790152b69 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -40,17 +40,12 @@ except: logger.warning("import noaux_tc Failed!") -try: - from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( - fused_cast_sigmoid_bias, - ) - - _FUSED_CAST_SIGMOID_BIAS_AVAILABLE = True -except Exception: - _FUSED_CAST_SIGMOID_BIAS_AVAILABLE = False - import numpy as np +from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( + fused_cast_sigmoid_bias, +) + def get_moe_method(layer=None): """ @@ -106,7 +101,7 @@ def get_moe_scores( compute moe scores using e_score_correction_bias. """ assert e_score_correction_bias is not None, "e_score_correction_bias is none!" - if use_fused_cast and _FUSED_CAST_SIGMOID_BIAS_AVAILABLE: + if use_fused_cast: scores, scores_with_bias = fused_cast_sigmoid_bias(gating_output, e_score_correction_bias) else: scores = paddle.nn.functional.sigmoid(gating_output) @@ -177,6 +172,7 @@ def __init__( super().__init__() self.fd_config = fd_config + self.dynamic_load_weight = fd_config.load_config.dynamic_load_weight self.layer_idx = layer_idx self.reduce_results = reduce_results self.renormalize = renormalize From 68079ad5bc9f1bdeb74b425a00fa61f45aa3434f Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Wed, 25 Mar 2026 22:48:35 +0800 Subject: [PATCH 08/11] modify 2025->2026 --- custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu | 2 +- fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu index 820992c7091..19c2da14ec7 100644 --- a/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu +++ b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py index 6d01138c06c..a573c2aa486 100644 --- a/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py +++ b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py @@ -1,5 +1,5 @@ """ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 218e50f3f5bcd9253aa89490c3ac50537fe8715e Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Thu, 26 Mar 2026 21:12:43 +0800 Subject: [PATCH 09/11] only support gpu backend and fix test issues --- .../model_executor/layers/moe/fused_moe_cutlass_backend.py | 4 ++-- .../model_executor/layers/moe/fused_moe_deepgemm_backend.py | 4 ++-- .../model_executor/layers/moe/fused_moe_marlin_backend.py | 6 +++--- .../model_executor/layers/moe/fused_moe_triton_backend.py | 5 +++-- tests/layers/test_deepgemm_fused_moe.py | 1 + tests/layers/test_fused_moe_cutlass_backend.py | 1 + tests/layers/test_fused_moe_triton_backend.py | 1 + 7 files changed, 13 insertions(+), 9 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 5eaba781a6a..913a905c34c 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -276,9 +276,9 @@ def apply_tp( """ gate_out = gate(x) if layer.topk_method == "noaux_tc": - if layer.dynamic_load_weight: + use_fused = not layer.dynamic_load_weight and current_platform.is_cuda() + if not use_fused: gate_out = gate_out.cast("float32") - use_fused = not layer.dynamic_load_weight gate_out, topk_weights, topk_idx = get_moe_scores( gate_out, layer.n_group, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 06a5cc88fca..7bea64a8c5d 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -819,9 +819,9 @@ def apply_tp( if layer.topk_method == "noaux_tc": if not fastdeploy.envs.FD_USE_PHI_MOE_TOPK: - if layer.dynamic_load_weight: + use_fused = not layer.dynamic_load_weight and current_platform.is_cuda() + if not use_fused: gate_out = gate_out.cast("float32") - use_fused = not layer.dynamic_load_weight _, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores( gate_out, layer.n_group, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py index d1c26b59e81..c7d6dd1d53f 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py @@ -20,6 +20,7 @@ from paddle import nn import fastdeploy +from fastdeploy.platforms import current_platform from fastdeploy.model_executor.ops.gpu import ( MoeWna16MarlinGemmApi, tritonmoe_preprocess_func, @@ -258,10 +259,9 @@ def apply( if topk_method == "noaux_tc": from fastdeploy.model_executor.layers.moe.moe import get_moe_scores - - if layer.dynamic_load_weight: + use_fused = not layer.dynamic_load_weight and current_platform.is_cuda() + if not use_fused: gate_out = gate_out.cast("float32") - use_fused = not layer.dynamic_load_weight _, topk_weights, topk_ids = get_moe_scores( gate_out, layer.n_group, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 58a0c1a1e91..7c64ffc1323 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -20,6 +20,7 @@ from paddle import nn import fastdeploy +from fastdeploy.platforms import current_platform from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import ( TensorTracker, @@ -306,9 +307,9 @@ def apply( hidden_size = layer.hidden_size if layer.topk_method == "noaux_tc": - if layer.dynamic_load_weight: + use_fused = not layer.dynamic_load_weight and current_platform.is_cuda() + if not use_fused: gate_out = gate_out.cast("float32") - use_fused = not layer.dynamic_load_weight gate_out, topk_weights, topk_ids = get_moe_scores( gate_out, layer.n_group, diff --git a/tests/layers/test_deepgemm_fused_moe.py b/tests/layers/test_deepgemm_fused_moe.py index 5381ee866a3..a30f67e21e0 100644 --- a/tests/layers/test_deepgemm_fused_moe.py +++ b/tests/layers/test_deepgemm_fused_moe.py @@ -132,6 +132,7 @@ def __init__(self): self.topk_group = 1 self.routed_scaling_factor = 1.0 self.renormalize = True + self.dynamic_load_weight = False self.gate_correction_bias = paddle.zeros([E], dtype="float32") self.topk_method = "noaux_tc" self.fd_config = DummyFDConfig() diff --git a/tests/layers/test_fused_moe_cutlass_backend.py b/tests/layers/test_fused_moe_cutlass_backend.py index 0f91ac323a2..fbf355984d7 100644 --- a/tests/layers/test_fused_moe_cutlass_backend.py +++ b/tests/layers/test_fused_moe_cutlass_backend.py @@ -84,6 +84,7 @@ def __init__( self.routed_scaling_factor = 1.0 self.gate_correction_bias = None self.is_quantized = False + self.dynamic_load_weight = False self.moe_quant_config = types.SimpleNamespace(moe_dynamic_quant=False, hadamard_block_size=128) self.weight_key_map = { "up_gate_proj_expert_weight_key": "up_gate_{}", diff --git a/tests/layers/test_fused_moe_triton_backend.py b/tests/layers/test_fused_moe_triton_backend.py index 1140cf72b16..0fd16dc3109 100644 --- a/tests/layers/test_fused_moe_triton_backend.py +++ b/tests/layers/test_fused_moe_triton_backend.py @@ -102,6 +102,7 @@ def __init__( } self._up_weights = None self._down_weights = None + self.dynamic_load_weight = False def extract_moe_ffn_weights(self, state_dict): return self._up_weights, self._down_weights, None, None From 91fa928f2eea04a5aa08f7f14425600ee8798dc6 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Thu, 26 Mar 2026 21:20:36 +0800 Subject: [PATCH 10/11] support gpu backend --- .../layers/moe/fused_cast_sigmoid_bias.py | 10 ++++++---- fastdeploy/model_executor/layers/moe/moe.py | 9 +++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py index a573c2aa486..606fba5d384 100644 --- a/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py +++ b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py @@ -16,10 +16,12 @@ import paddle -from fastdeploy.model_executor.ops.gpu import ( - fused_cast_sigmoid_bias as _fused_cast_sigmoid_bias_cuda, -) - +try: + from fastdeploy.model_executor.ops.gpu import ( + fused_cast_sigmoid_bias as _fused_cast_sigmoid_bias_cuda, + ) +except: + assert False, "fused_cast_sigmoid_bias not support!" def fused_cast_sigmoid_bias( gate_out: paddle.Tensor, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 76790152b69..b84492b80aa 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -42,9 +42,10 @@ import numpy as np -from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( - fused_cast_sigmoid_bias, -) +if current_platform.is_cuda(): + from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( + fused_cast_sigmoid_bias, + ) def get_moe_method(layer=None): @@ -101,7 +102,7 @@ def get_moe_scores( compute moe scores using e_score_correction_bias. """ assert e_score_correction_bias is not None, "e_score_correction_bias is none!" - if use_fused_cast: + if use_fused_cast and current_platform.is_cuda(): scores, scores_with_bias = fused_cast_sigmoid_bias(gating_output, e_score_correction_bias) else: scores = paddle.nn.functional.sigmoid(gating_output) From 8fc7e00714ecc466b25162196790a7837255fcef Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Thu, 26 Mar 2026 22:05:45 +0800 Subject: [PATCH 11/11] modify format --- .../model_executor/layers/moe/fused_cast_sigmoid_bias.py | 1 + .../model_executor/layers/moe/fused_moe_marlin_backend.py | 3 ++- .../model_executor/layers/moe/fused_moe_triton_backend.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py index 606fba5d384..6568655a329 100644 --- a/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py +++ b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py @@ -23,6 +23,7 @@ except: assert False, "fused_cast_sigmoid_bias not support!" + def fused_cast_sigmoid_bias( gate_out: paddle.Tensor, e_score_correction_bias: paddle.Tensor, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py index c7d6dd1d53f..692840fe4cb 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py @@ -20,11 +20,11 @@ from paddle import nn import fastdeploy -from fastdeploy.platforms import current_platform from fastdeploy.model_executor.ops.gpu import ( MoeWna16MarlinGemmApi, tritonmoe_preprocess_func, ) +from fastdeploy.platforms import current_platform from ..quantization.quant_base import QuantMethodBase @@ -259,6 +259,7 @@ def apply( if topk_method == "noaux_tc": from fastdeploy.model_executor.layers.moe.moe import get_moe_scores + use_fused = not layer.dynamic_load_weight and current_platform.is_cuda() if not use_fused: gate_out = gate_out.cast("float32") diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 7c64ffc1323..03fbc71584a 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -20,7 +20,6 @@ from paddle import nn import fastdeploy -from fastdeploy.platforms import current_platform from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import ( TensorTracker, @@ -29,6 +28,7 @@ set_weight_attrs, weight_fully_copied, ) +from fastdeploy.platforms import current_platform from fastdeploy.utils import ceil_div, register_custom_python_op from ..quantization.quant_base import QuantMethodBase