-
Notifications
You must be signed in to change notification settings - Fork 732
[Optimization] Elemenwise fusion #6880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
BingooYang
wants to merge
13
commits into
PaddlePaddle:develop
Choose a base branch
from
BingooYang:ele_fusion
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 7 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
90e3933
conflict
BingooYang 4588d1d
add cast_sgmoid_add fusion and enable in glm4.5
BingooYang 49e4323
support more cast type
BingooYang 7743979
modify test
BingooYang 6d94597
add type check
BingooYang 62efce9
fix config issues
BingooYang 31177a9
Merge branch 'develop' into ele_fusion
BingooYang 77529fe
enable more backend
BingooYang 68079ad
modify 2025->2026
BingooYang 218e50f
only support gpu backend and fix test issues
BingooYang 91fa928
support gpu backend
BingooYang 8fc7e00
modify format
BingooYang 043b748
Merge branch 'develop' into ele_fusion
EmmonsCurse File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include "helper.h" | ||
|
|
||
| // Fused kernel: cast(input, cast_type) -> sigmoid -> scores, scores + bias -> | ||
| // scores_with_bias | ||
| // | ||
| // For each element (token i, expert j): | ||
| // scores[i][j] = OutT(sigmoid(float(input[i][j]))) | ||
| // scores_with_bias[i][j] = OutT(sigmoid(float(input[i][j])) + bias[j]) | ||
| // | ||
| // Input: input [num_tokens, num_experts] bf16/fp16/fp32 | ||
| // bias [num_experts] or [1, num_experts] fp32 | ||
| // Output: scores [num_tokens, num_experts] cast_type (fp32/fp16/bf16) | ||
| // scores_with_bias [num_tokens, num_experts] cast_type (fp32/fp16/bf16) | ||
|
|
||
| template <typename InT, typename OutT> | ||
| __global__ void fused_cast_sigmoid_bias_kernel( | ||
| const InT* __restrict__ input, | ||
| const float* __restrict__ bias, | ||
| OutT* __restrict__ scores, | ||
| OutT* __restrict__ scores_with_bias, | ||
| const int num_experts) { | ||
| const int64_t token_idx = blockIdx.x; | ||
| const int64_t offset = token_idx * num_experts; | ||
|
|
||
| for (int j = threadIdx.x; j < num_experts; j += blockDim.x) { | ||
| float val = static_cast<float>(input[offset + j]); | ||
| // sigmoid: 1 / (1 + exp(-x)) | ||
| float s = 1.0f / (1.0f + expf(-val)); | ||
| scores[offset + j] = static_cast<OutT>(s); | ||
| scores_with_bias[offset + j] = static_cast<OutT>(s + bias[j]); | ||
| } | ||
| } | ||
|
|
||
| // Vectorized version for better memory throughput | ||
| template <typename InT, typename OutT, int kVecSize> | ||
| __global__ void fused_cast_sigmoid_bias_vec_kernel( | ||
| const InT* __restrict__ input, | ||
| const float* __restrict__ bias, | ||
| OutT* __restrict__ scores, | ||
| OutT* __restrict__ scores_with_bias, | ||
| const int num_experts) { | ||
| const int64_t token_idx = blockIdx.x; | ||
| const int64_t offset = token_idx * num_experts; | ||
|
|
||
| using in_vec_t = AlignedVector<InT, kVecSize>; | ||
| using out_vec_t = AlignedVector<OutT, kVecSize>; | ||
| using bias_vec_t = AlignedVector<float, kVecSize>; | ||
|
|
||
| const int vec_count = num_experts / kVecSize; | ||
| for (int idx = threadIdx.x; idx < vec_count; idx += blockDim.x) { | ||
| const int base = idx * kVecSize; | ||
| in_vec_t in_vec; | ||
| bias_vec_t bias_vec; | ||
| Load(input + offset + base, &in_vec); | ||
| Load(bias + base, &bias_vec); | ||
|
|
||
| out_vec_t s_vec, sb_vec; | ||
| #pragma unroll | ||
| for (int i = 0; i < kVecSize; ++i) { | ||
| float val = static_cast<float>(in_vec[i]); | ||
| float s = 1.0f / (1.0f + expf(-val)); | ||
| s_vec[i] = static_cast<OutT>(s); | ||
| sb_vec[i] = static_cast<OutT>(s + bias_vec[i]); | ||
| } | ||
|
|
||
| Store(s_vec, scores + offset + base); | ||
| Store(sb_vec, scores_with_bias + offset + base); | ||
| } | ||
|
|
||
| // Handle remaining elements | ||
| const int remaining_start = vec_count * kVecSize; | ||
| for (int j = remaining_start + threadIdx.x; j < num_experts; | ||
| j += blockDim.x) { | ||
| float val = static_cast<float>(input[offset + j]); | ||
| float s = 1.0f / (1.0f + expf(-val)); | ||
| scores[offset + j] = static_cast<OutT>(s); | ||
| scores_with_bias[offset + j] = static_cast<OutT>(s + bias[j]); | ||
| } | ||
| } | ||
|
|
||
| static paddle::DataType ParseCastType(const std::string& cast_type) { | ||
| if (cast_type == "float32") return paddle::DataType::FLOAT32; | ||
| if (cast_type == "float16") return paddle::DataType::FLOAT16; | ||
| if (cast_type == "bfloat16") return paddle::DataType::BFLOAT16; | ||
| PD_THROW("Unsupported cast_type: " + cast_type + | ||
| ". Only float32, float16, bfloat16 are supported."); | ||
| } | ||
|
|
||
| std::vector<paddle::Tensor> FusedCastSigmoidBias(const paddle::Tensor& input, | ||
| const paddle::Tensor& bias, | ||
| std::string cast_type) { | ||
| auto input_shape = input.shape(); | ||
| PD_CHECK(input_shape.size() == 2, | ||
| "input must be 2D [num_tokens, num_experts]"); | ||
| auto bias_shape = bias.shape(); | ||
| // Support both [num_experts] and [1, num_experts] bias shapes | ||
| PD_CHECK( | ||
| bias_shape.size() == 1 || (bias_shape.size() == 2 && bias_shape[0] == 1), | ||
| "bias must be 1D [num_experts] or 2D [1, num_experts]"); | ||
|
|
||
| int64_t num_tokens = input_shape[0]; | ||
| int64_t num_experts = input_shape[1]; | ||
| int64_t bias_numel = (bias_shape.size() == 1) ? bias_shape[0] : bias_shape[1]; | ||
| PD_CHECK(bias_numel == num_experts, "bias size must match num_experts"); | ||
| PD_CHECK(bias.dtype() == paddle::DataType::FLOAT32, | ||
| "bias must be float32, got ", | ||
| bias.dtype()); | ||
|
|
||
| auto place = input.place(); | ||
| auto stream = input.stream(); | ||
| auto out_dtype = ParseCastType(cast_type); | ||
|
|
||
| auto scores = paddle::empty({num_tokens, num_experts}, out_dtype, place); | ||
| auto scores_with_bias = | ||
| paddle::empty({num_tokens, num_experts}, out_dtype, place); | ||
|
|
||
| if (num_tokens == 0) { | ||
| return {scores, scores_with_bias}; | ||
| } | ||
|
|
||
| dim3 grid(num_tokens); | ||
| int block_size = std::min(static_cast<int64_t>(1024), num_experts); | ||
| // Round up to warp size | ||
| block_size = ((block_size + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; | ||
| dim3 block(block_size); | ||
|
|
||
| DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), in_scalar_t, { | ||
| DISPATCH_FLOAT_FP6_DTYPE(out_dtype, out_scalar_t, { | ||
| constexpr int kVecSize = 16 / sizeof(in_scalar_t); | ||
| if (num_experts % kVecSize == 0 && num_experts >= kVecSize) { | ||
| fused_cast_sigmoid_bias_vec_kernel<in_scalar_t, out_scalar_t, kVecSize> | ||
| <<<grid, block, 0, stream>>>(input.data<in_scalar_t>(), | ||
| bias.data<float>(), | ||
| scores.data<out_scalar_t>(), | ||
| scores_with_bias.data<out_scalar_t>(), | ||
| num_experts); | ||
| } else { | ||
| fused_cast_sigmoid_bias_kernel<in_scalar_t, out_scalar_t> | ||
| <<<grid, block, 0, stream>>>(input.data<in_scalar_t>(), | ||
| bias.data<float>(), | ||
| scores.data<out_scalar_t>(), | ||
| scores_with_bias.data<out_scalar_t>(), | ||
| num_experts); | ||
| } | ||
| }); | ||
| }); | ||
|
|
||
| return {scores, scores_with_bias}; | ||
| } | ||
|
|
||
| std::vector<paddle::DataType> FusedCastSigmoidBiasInferDtype( | ||
| const paddle::DataType& input_dtype, | ||
| const paddle::DataType& bias_dtype, | ||
| std::string cast_type) { | ||
| auto out_dtype = ParseCastType(cast_type); | ||
| return {out_dtype, out_dtype}; | ||
| } | ||
|
|
||
| std::vector<std::vector<int64_t>> FusedCastSigmoidBiasInferShape( | ||
| const std::vector<int64_t>& input_shape, | ||
| const std::vector<int64_t>& bias_shape) { | ||
| return {input_shape, input_shape}; | ||
| } | ||
|
|
||
| PD_BUILD_STATIC_OP(fused_cast_sigmoid_bias) | ||
| .Inputs({"input", "bias"}) | ||
| .Outputs({"scores", "scores_with_bias"}) | ||
| .Attrs({"cast_type: std::string"}) | ||
| .SetKernelFn(PD_KERNEL(FusedCastSigmoidBias)) | ||
| .SetInferShapeFn(PD_INFER_SHAPE(FusedCastSigmoidBiasInferShape)) | ||
| .SetInferDtypeFn(PD_INFER_DTYPE(FusedCastSigmoidBiasInferDtype)); | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 46 additions & 0 deletions
46
fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| """ | ||
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
|
|
||
| import paddle | ||
|
|
||
| from fastdeploy.model_executor.ops.gpu import ( | ||
| fused_cast_sigmoid_bias as _fused_cast_sigmoid_bias_cuda, | ||
| ) | ||
|
|
||
|
|
||
| def fused_cast_sigmoid_bias( | ||
| gate_out: paddle.Tensor, | ||
| e_score_correction_bias: paddle.Tensor, | ||
| cast_type: str = "float32", | ||
| ) -> tuple: | ||
| """ | ||
| 融合操作:将gate_out转换为指定类型,应用sigmoid函数,并添加偏置。 | ||
|
|
||
| 该函数融合了以下三个独立操作: | ||
| 1. gate_out = gate_out.cast(cast_type) | ||
| 2. scores = sigmoid(gate_out) | ||
| 3. scores_with_bias = scores + e_score_correction_bias | ||
|
|
||
| Args: | ||
| gate_out: [num_tokens, num_experts],bf16/fp16/fp32类型 - 原始gate输出 | ||
| e_score_correction_bias: [num_experts],fp32类型 - 修正偏置 | ||
| cast_type: 输出数据类型字符串,支持"float32"、"float16"、"bfloat16" | ||
|
|
||
| Returns: | ||
| scores: [num_tokens, num_experts],cast_type类型 - sigmoid(gate_out)的结果 | ||
| scores_with_bias: [num_tokens, num_experts],cast_type类型 - 加上偏置后的分数 | ||
| """ | ||
| return _fused_cast_sigmoid_bias_cuda(gate_out, e_score_correction_bias, cast_type) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.