-
Notifications
You must be signed in to change notification settings - Fork 744
[Optimization] enable trtllm_all_reduce fusion kernel in glm model #6660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
942fe2a
22c9356
fc50534
0f36a83
ed77444
5612a65
aae0b1d
c0790e1
b2be3f9
df0c96e
bf7df5c
dc0499d
155c363
9df022b
25e6615
4b462ea
23c5838
335527c
4edd889
11a1cab
771e5ad
912aab4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,209 @@ | ||
| """ | ||
| # Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
|
|
||
| from typing import Optional, Tuple | ||
|
|
||
| import paddle | ||
| import paddle.distributed as dist | ||
|
|
||
| from fastdeploy.config import FDConfig | ||
| from fastdeploy.model_executor.utils import has_flashinfer | ||
| from fastdeploy.utils import get_logger | ||
|
|
||
| logger = get_logger("flashinfer", "flashinfer.log") | ||
|
|
||
| _flashinfer_comm = None | ||
| _workspace_manager = None | ||
|
|
||
|
|
||
| def _get_flashinfer_comm(): | ||
| """Lazily import flashinfer.comm to avoid side effects at module load time.""" | ||
| global _flashinfer_comm | ||
| if _flashinfer_comm is not None: | ||
| return _flashinfer_comm | ||
| if has_flashinfer(): | ||
| try: | ||
| with paddle.use_compat_guard(enable=True, scope={"flashinfer"}): | ||
| import flashinfer.comm as comm | ||
|
|
||
| _flashinfer_comm = comm | ||
| except ImportError: | ||
| logger.warning("flashinfer.comm is not available, falling back to standard " "implementation") | ||
| return _flashinfer_comm | ||
|
|
||
|
|
||
| class FlashInferWorkspaceManager: | ||
| def __init__(self): | ||
| self.workspace_tensor = None | ||
| self.ipc_handles = None | ||
| self.world_size = None | ||
| self.rank = None | ||
| self.initialized = False | ||
|
|
||
| def initialize( | ||
| self, | ||
| world_size: int, | ||
| rank: int, | ||
| max_token_num: int, | ||
| hidden_dim: int, | ||
| group=None, | ||
| use_fp32_lamport: bool = False, | ||
| ): | ||
| """Initialize workspace""" | ||
| if self.initialized and self.world_size == world_size: | ||
| return | ||
|
|
||
| comm = _get_flashinfer_comm() | ||
| if comm is None: | ||
| logger.warning("FlashInfer comm not available, skipping workspace " "initialization") | ||
| return | ||
|
|
||
| self.cleanup() | ||
|
|
||
| self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( | ||
| rank, | ||
| world_size, | ||
| max_token_num, | ||
| hidden_dim, | ||
| group=group, | ||
| use_fp32_lamport=use_fp32_lamport, | ||
| ) | ||
|
|
||
| self.world_size = world_size | ||
| self.rank = rank | ||
| self.initialized = True | ||
|
|
||
| logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}") | ||
|
|
||
| def cleanup(self): | ||
| """Clean up workspace""" | ||
| if self.initialized and self.ipc_handles is not None: | ||
| try: | ||
| comm = _get_flashinfer_comm() | ||
| if comm is not None: | ||
| comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group()) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") | ||
| finally: | ||
| self.workspace_tensor = None | ||
| self.ipc_handles = None | ||
| self.initialized = False | ||
|
|
||
|
|
||
| _workspace_manager = FlashInferWorkspaceManager() | ||
|
|
||
|
|
||
| def ensure_workspace_initialized( | ||
| fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False | ||
| ): | ||
| """Ensure workspace is initialized""" | ||
| comm = _get_flashinfer_comm() | ||
| if not has_flashinfer() or comm is None: | ||
| return False | ||
|
|
||
| assert fd_config is not None | ||
| world_size = fd_config.parallel_config.tensor_parallel_size | ||
| if world_size <= 1: | ||
| return False | ||
|
|
||
| rank = dist.get_rank() | ||
|
|
||
| if not _workspace_manager.initialized or _workspace_manager.world_size != world_size: | ||
| _workspace_manager.initialize( | ||
| world_size=world_size, | ||
| rank=rank, | ||
| max_token_num=max_token_num, | ||
| hidden_dim=hidden_dim, | ||
| use_fp32_lamport=use_fp32_lamport, | ||
| ) | ||
|
|
||
| return _workspace_manager.initialized | ||
|
|
||
|
|
||
| def flashinfer_allreduce_residual_rmsnorm( | ||
| fd_config: FDConfig, | ||
| input_tensor: paddle.Tensor, | ||
| residual: paddle.Tensor, | ||
| weight: paddle.Tensor, | ||
| eps: float = 1e-6, | ||
| max_token_num: int = 2048, | ||
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fellow sglang |
||
| use_oneshot: Optional[bool] = None, | ||
| trigger_completion_at_end: bool = False, | ||
| fp32_acc: bool = False, | ||
| ) -> Tuple[paddle.Tensor, paddle.Tensor]: | ||
| """ | ||
| Use FlashInfer's fused allreduce + residual + RMS norm operation | ||
| """ | ||
| comm = _get_flashinfer_comm() | ||
| if not has_flashinfer() or comm is None: | ||
| logger.debug("FlashInfer not available, falling back to standard " "implementation") | ||
| return None, None | ||
|
|
||
| assert fd_config is not None | ||
| world_size = fd_config.parallel_config.tensor_parallel_size | ||
| if world_size <= 1: | ||
| logger.debug("Single GPU, no need for allreduce fusion") | ||
| return None, None | ||
|
|
||
| assert input_tensor.shape[0] <= max_token_num | ||
|
|
||
| if not ensure_workspace_initialized( | ||
| fd_config=fd_config, | ||
| max_token_num=max_token_num, | ||
| hidden_dim=input_tensor.shape[-1], | ||
| use_fp32_lamport=(input_tensor.dtype == paddle.float32), | ||
| ): | ||
| logger.debug("FlashInfer workspace not available") | ||
| return None, None | ||
|
|
||
| token_num, hidden_dim = input_tensor.shape | ||
|
|
||
| residual_out = paddle.empty_like(residual) | ||
| norm_out = paddle.empty_like(input_tensor) | ||
| # support empty tensor | ||
| if input_tensor.shape[0] == 0: | ||
| return norm_out, residual_out | ||
| comm.trtllm_allreduce_fusion( | ||
| allreduce_in=input_tensor, | ||
| world_size=world_size, | ||
| world_rank=dist.get_rank(), | ||
| token_num=token_num, | ||
| hidden_dim=hidden_dim, | ||
| workspace_ptrs=_workspace_manager.workspace_tensor, | ||
| launch_with_pdl=True, | ||
| use_oneshot=use_oneshot, | ||
| trigger_completion_at_end=trigger_completion_at_end, | ||
| fp32_acc=fp32_acc, | ||
| pattern_code=(comm.AllReduceFusionPattern.kARResidualRMSNorm), | ||
| allreduce_out=None, | ||
| residual_in=residual, | ||
| residual_out=residual_out, | ||
| norm_out=norm_out, | ||
| quant_out=None, | ||
| scale_out=None, | ||
| rms_gamma=weight, | ||
| rms_eps=eps, | ||
| scale_factor=None, | ||
| layout_code=None, | ||
| ) | ||
|
|
||
| return norm_out, residual_out | ||
|
|
||
|
|
||
| def cleanup_flashinfer_workspace(): | ||
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sglang也没有清理 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 cleanup 函数未被调用
建议在以下场景调用:
可参考: |
||
| global _workspace_manager | ||
| if _workspace_manager is not None: | ||
| _workspace_manager.cleanup() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -854,6 +854,7 @@ def __init__( | |
| skip_quant: bool = False, | ||
| weight_dtype: str = "", | ||
| layer_id: int = -1, | ||
| enable_all_reduce_fusion: bool = None, | ||
| ): | ||
| """ | ||
| Initialize a linear layer with additional parameters for inference and quantization. | ||
|
|
@@ -865,9 +866,17 @@ def __init__( | |
| input_size (int): Number of input features. Defaults to None. | ||
| output_size (int): Number of output features. Defaults to None. | ||
| with_bias (bool): Whether to include bias or not. Defaults to False. | ||
| skip_quant (bool): Whether to skip quantization. Defaults to False. | ||
| skip_quant (bool): Whether to skip quantization or not. Defaults to False. | ||
| enable_all_reduce_fusion (bool, optional): Whether to enable all-reduce fusion. | ||
| If None, it is determined by the config flag and prefix. Defaults to None. | ||
| """ | ||
| self.fd_config = fd_config | ||
| if enable_all_reduce_fusion is None: | ||
| self.enable_all_reduce_fusion = False | ||
| else: | ||
| self.enable_all_reduce_fusion = ( | ||
| fd_config.parallel_config.enable_flashinfer_allreduce_fusion and enable_all_reduce_fusion | ||
| ) | ||
| self.ep_size = fd_config.parallel_config.expert_parallel_size | ||
| self.tp_size = fd_config.parallel_config.tensor_parallel_size | ||
| self.tp_group = fd_config.parallel_config.tp_group | ||
|
|
@@ -945,7 +954,10 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: | |
|
|
||
| out = self.quant_method.apply(self, x) | ||
|
|
||
| if self.reduce_results and self.tp_size > 1: | ||
| need_tp_all_reduce = ( | ||
| self.reduce_results and self.tp_size > 1 and not (self.enable_all_reduce_fusion and out.shape[0] <= 2048) | ||
This comment was marked as outdated.
Sorry, something went wrong.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 最好增加对2048限制的解释
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我在下一个pr里面加上这里的注释 |
||
| ) | ||
| if need_tp_all_reduce: | ||
| out = tensor_model_parallel_all_reduce(out, self.tp_group) | ||
|
|
||
| return out | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,6 +35,7 @@ | |
| is_batch_invariant_mode_enabled, | ||
| rms_norm_batch_invariant, | ||
| ) | ||
| from .flashinfer_comm_fusion import flashinfer_allreduce_residual_rmsnorm | ||
| from .utils import get_tensor, modules_to_convert | ||
|
|
||
|
|
||
|
|
@@ -122,6 +123,10 @@ def __init__( | |
| self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank | ||
| self.tp_group = self.fd_config.parallel_config.tp_group | ||
| is_input_norm = prefix.endswith(".input_layernorm") | ||
| self.enable_all_reduce_fusion = ( | ||
| fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "post_attention_layernorm" in prefix | ||
| ) | ||
|
|
||
| self.is_last_norm = prefix.endswith(".norm") | ||
| self.split_x = ( | ||
| self.fd_config.parallel_config.use_sequence_parallel_moe | ||
|
|
@@ -240,6 +245,12 @@ def forward( | |
| norm_out = rms_norm(x, self.weight, self.eps) | ||
| return norm_out.astype(x_dtype), residual_out | ||
| norm_out = self.norm_func(x, residual_input, self.weight, self.eps) | ||
| # enable trtllm all reduce fusion | ||
| elif self.enable_all_reduce_fusion and x.shape[0] <= 2048: | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| norm_out = flashinfer_allreduce_residual_rmsnorm( | ||
| fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps | ||
| ) | ||
| assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!" | ||
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
BingooYang marked this conversation as resolved.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Bug fusion 失败时缺少降级机制 当 flashinfer 不可用、workspace 初始化失败等情况时, 建议修改为优雅降级: elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
)
if norm_out is None or residual_out is None:
# Fallback to standard all-reduce + RMSNorm
if is_batch_invariant_mode_enabled():
if residual_input is not None:
x = x + residual_input
norm_out = rms_norm_batch_invariant(x, self.weight, self.eps), x
else:
norm_out = self.norm_func(
x,
norm_weight=self.weight,
norm_bias=None,
epsilon=self.eps,
begin_norm_axis=self.begin_norm_axis,
bias=self.bias,
residual=residual_input,
quant_scale=(-1 if self.quant_scale is None else self.quant_scale),
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
) |
||
| else: | ||
| if is_batch_invariant_mode_enabled(): | ||
| # M-invariant path: per-row Triton kernel, no cross-row reduction | ||
|
|
||
This comment was marked as outdated.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.