Skip to content

Commit ae5dac1

Browse files
authored
[Cherry-Pick][Optimization] enable trtllm_all_reduce fusion kernel in glm model (#6660) (#7228)
* enable trtllm_all_reduce fusion kernel in glm model * update flashinfer paddle version * format update modify test modify test support empty tensor and modify test fix test_linear config issues modify test name add edge test case modify format fix conflict modify default max token num in trtllm_allreduce_fusion add max token num branch for trtllm_allreduce_fusion fix format fix rmsnorm config issue modify 2025 to 2026 enable trtllm_allreduce fusion Revert "[Cherry-Pick][CI] Use GPU-Build-RL runner for _build_linux_rl.yml (#7186) (#7195)" This reverts commit ca2f38b. Revert "[Cherry-Pick][BugFix] prevent requests from entering running state without a slot(#7141) (#7181)" This reverts commit 80f4a72. clean flashinfer cache and modify test fix dumpy patch issue fix some issues * remove redundent * enable moe reduce fusion * fix test * fix cuda context issue * update flashinfer version
1 parent fae4a8b commit ae5dac1

16 files changed

Lines changed: 903 additions & 14 deletions

fastdeploy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,7 @@ def __init__(
675675
self.pod_ip: str = None
676676
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
677677
self.disable_custom_all_reduce: bool = False
678+
self.enable_flashinfer_allreduce_fusion: bool = False
678679
for key, value in args.items():
679680
if hasattr(self, key):
680681
setattr(self, key, value)

fastdeploy/engine/args_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,11 @@ class EngineArgs:
274274
Flag to disable the custom all-reduce kernel.
275275
"""
276276

277+
enable_flashinfer_allreduce_fusion: bool = False
278+
"""
279+
Flag to enable all reduce fusion kernel in flashinfer.
280+
"""
281+
277282
use_internode_ll_two_stage: bool = False
278283
"""
279284
Flag to use the internode_ll_two_stage kernel.
@@ -1000,6 +1005,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
10001005
default=EngineArgs.disable_custom_all_reduce,
10011006
help="Flag to disable custom all-reduce.",
10021007
)
1008+
parallel_group.add_argument(
1009+
"--enable-flashinfer-allreduce-fusion",
1010+
action="store_true",
1011+
default=EngineArgs.enable_flashinfer_allreduce_fusion,
1012+
help="Flag to enable all reduce fusion kernel in flashinfer.",
1013+
)
10031014
parallel_group.add_argument(
10041015
"--use-internode-ll-two-stage",
10051016
action="store_true",

fastdeploy/engine/common_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2542,6 +2542,7 @@ def _start_worker_service(self):
25422542
"enable_entropy": self.cfg.model_config.enable_entropy,
25432543
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
25442544
"enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask,
2545+
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
25452546
}
25462547
for worker_flag, value in worker_store_true_flag.items():
25472548
if value:

fastdeploy/engine/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,7 @@ def _start_worker_service(self):
667667
"ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens,
668668
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
669669
"enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask,
670+
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
670671
}
671672
for worker_flag, value in worker_store_true_flag.items():
672673
if value:
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""
2+
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
from typing import Optional, Tuple
18+
19+
import paddle
20+
import paddle.distributed as dist
21+
22+
from fastdeploy.config import FDConfig
23+
from fastdeploy.model_executor.utils import has_flashinfer
24+
from fastdeploy.utils import get_logger
25+
26+
logger = get_logger("flashinfer", "flashinfer.log")
27+
28+
_flashinfer_comm = None
29+
_workspace_manager = None
30+
31+
32+
def _get_flashinfer_comm():
33+
"""Lazily import flashinfer.comm to avoid side effects at module load time."""
34+
global _flashinfer_comm
35+
if _flashinfer_comm is not None:
36+
return _flashinfer_comm
37+
if has_flashinfer():
38+
try:
39+
with paddle.use_compat_guard(enable=True, scope={"flashinfer"}):
40+
import flashinfer.comm as comm
41+
42+
_flashinfer_comm = comm
43+
except ImportError:
44+
logger.warning("flashinfer.comm is not available, falling back to standard " "implementation")
45+
return _flashinfer_comm
46+
47+
48+
class FlashInferWorkspaceManager:
49+
def __init__(self):
50+
self.workspace_tensor = None
51+
self.ipc_handles = None
52+
self.world_size = None
53+
self.rank = None
54+
self.initialized = False
55+
56+
def initialize(
57+
self,
58+
world_size: int,
59+
rank: int,
60+
max_token_num: int,
61+
hidden_dim: int,
62+
group=None,
63+
use_fp32_lamport: bool = False,
64+
):
65+
"""Initialize workspace"""
66+
if self.initialized and self.world_size == world_size:
67+
return
68+
69+
comm = _get_flashinfer_comm()
70+
if comm is None:
71+
logger.warning("FlashInfer comm not available, skipping workspace " "initialization")
72+
return
73+
74+
self.cleanup()
75+
76+
self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
77+
rank,
78+
world_size,
79+
max_token_num,
80+
hidden_dim,
81+
group=group,
82+
use_fp32_lamport=use_fp32_lamport,
83+
)
84+
85+
self.world_size = world_size
86+
self.rank = rank
87+
self.initialized = True
88+
89+
logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}")
90+
91+
def cleanup(self):
92+
"""Clean up workspace"""
93+
if self.initialized and self.ipc_handles is not None:
94+
try:
95+
comm = _get_flashinfer_comm()
96+
if comm is not None:
97+
comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group())
98+
except Exception as e:
99+
logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
100+
finally:
101+
self.workspace_tensor = None
102+
self.ipc_handles = None
103+
self.initialized = False
104+
105+
106+
_workspace_manager = FlashInferWorkspaceManager()
107+
108+
109+
def ensure_workspace_initialized(
110+
fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
111+
):
112+
"""Ensure workspace is initialized"""
113+
comm = _get_flashinfer_comm()
114+
if not has_flashinfer() or comm is None:
115+
return False
116+
117+
assert fd_config is not None
118+
world_size = fd_config.parallel_config.tensor_parallel_size
119+
if world_size <= 1:
120+
return False
121+
122+
rank = dist.get_rank()
123+
124+
if not _workspace_manager.initialized or _workspace_manager.world_size != world_size:
125+
_workspace_manager.initialize(
126+
world_size=world_size,
127+
rank=rank,
128+
max_token_num=max_token_num,
129+
hidden_dim=hidden_dim,
130+
use_fp32_lamport=use_fp32_lamport,
131+
)
132+
133+
return _workspace_manager.initialized
134+
135+
136+
def flashinfer_allreduce_residual_rmsnorm(
137+
fd_config: FDConfig,
138+
input_tensor: paddle.Tensor,
139+
residual: paddle.Tensor,
140+
weight: paddle.Tensor,
141+
eps: float = 1e-6,
142+
max_token_num: int = 2048,
143+
use_oneshot: Optional[bool] = None,
144+
trigger_completion_at_end: bool = False,
145+
fp32_acc: bool = False,
146+
) -> Tuple[paddle.Tensor, paddle.Tensor]:
147+
"""
148+
Use FlashInfer's fused allreduce + residual + RMS norm operation
149+
"""
150+
comm = _get_flashinfer_comm()
151+
if not has_flashinfer() or comm is None:
152+
logger.debug("FlashInfer not available, falling back to standard " "implementation")
153+
return None, None
154+
155+
assert fd_config is not None
156+
world_size = fd_config.parallel_config.tensor_parallel_size
157+
if world_size <= 1:
158+
logger.debug("Single GPU, no need for allreduce fusion")
159+
return None, None
160+
161+
assert input_tensor.shape[0] <= max_token_num
162+
163+
if not ensure_workspace_initialized(
164+
fd_config=fd_config,
165+
max_token_num=max_token_num,
166+
hidden_dim=input_tensor.shape[-1],
167+
use_fp32_lamport=(input_tensor.dtype == paddle.float32),
168+
):
169+
logger.debug("FlashInfer workspace not available")
170+
return None, None
171+
172+
token_num, hidden_dim = input_tensor.shape
173+
174+
residual_out = paddle.empty_like(residual)
175+
norm_out = paddle.empty_like(input_tensor)
176+
# support empty tensor
177+
if input_tensor.shape[0] == 0:
178+
return norm_out, residual_out
179+
comm.trtllm_allreduce_fusion(
180+
allreduce_in=input_tensor,
181+
world_size=world_size,
182+
world_rank=dist.get_rank(),
183+
token_num=token_num,
184+
hidden_dim=hidden_dim,
185+
workspace_ptrs=_workspace_manager.workspace_tensor,
186+
launch_with_pdl=True,
187+
use_oneshot=use_oneshot,
188+
trigger_completion_at_end=trigger_completion_at_end,
189+
fp32_acc=fp32_acc,
190+
pattern_code=(comm.AllReduceFusionPattern.kARResidualRMSNorm),
191+
allreduce_out=None,
192+
residual_in=residual,
193+
residual_out=residual_out,
194+
norm_out=norm_out,
195+
quant_out=None,
196+
scale_out=None,
197+
rms_gamma=weight,
198+
rms_eps=eps,
199+
scale_factor=None,
200+
layout_code=None,
201+
)
202+
203+
return norm_out, residual_out
204+
205+
206+
def cleanup_flashinfer_workspace():
207+
global _workspace_manager
208+
if _workspace_manager is not None:
209+
_workspace_manager.cleanup()

fastdeploy/model_executor/layers/linear.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,7 @@ def __init__(
853853
skip_quant: bool = False,
854854
weight_dtype: str = "",
855855
layer_id: int = -1,
856+
enable_all_reduce_fusion: bool = None,
856857
):
857858
"""
858859
Initialize a linear layer with additional parameters for inference and quantization.
@@ -864,9 +865,17 @@ def __init__(
864865
input_size (int): Number of input features. Defaults to None.
865866
output_size (int): Number of output features. Defaults to None.
866867
with_bias (bool): Whether to include bias or not. Defaults to False.
867-
skip_quant (bool): Whether to skip quantization. Defaults to False.
868+
skip_quant (bool): Whether to skip quantization or not. Defaults to False.
869+
enable_all_reduce_fusion (bool, optional): Whether to enable all-reduce fusion.
870+
If None, it is determined by the config flag and prefix. Defaults to None.
868871
"""
869872
self.fd_config = fd_config
873+
if enable_all_reduce_fusion is None:
874+
self.enable_all_reduce_fusion = False
875+
else:
876+
self.enable_all_reduce_fusion = (
877+
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and enable_all_reduce_fusion
878+
)
870879
self.ep_size = fd_config.parallel_config.expert_parallel_size
871880
self.tp_size = fd_config.parallel_config.tensor_parallel_size
872881
self.tp_group = fd_config.parallel_config.tp_group
@@ -944,7 +953,10 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
944953

945954
out = self.quant_method.apply(self, x)
946955

947-
if self.reduce_results and self.tp_size > 1:
956+
need_tp_all_reduce = (
957+
self.reduce_results and self.tp_size > 1 and not (self.enable_all_reduce_fusion and out.shape[0] <= 2048)
958+
)
959+
if need_tp_all_reduce:
948960
out = tensor_model_parallel_all_reduce(out, self.tp_group)
949961

950962
return out

fastdeploy/model_executor/layers/normalization.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
is_batch_invariant_mode_enabled,
3636
rms_norm_batch_invariant,
3737
)
38+
from .flashinfer_comm_fusion import flashinfer_allreduce_residual_rmsnorm
3839
from .utils import get_tensor, modules_to_convert
3940

4041

@@ -122,6 +123,10 @@ def __init__(
122123
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
123124
self.tp_group = self.fd_config.parallel_config.tp_group
124125
is_input_norm = prefix.endswith(".input_layernorm")
126+
self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion and (
127+
("post_attention_layernorm" in prefix) or (("input_layernorm" in prefix and layer_id != 0))
128+
)
129+
125130
self.is_last_norm = prefix.endswith(".norm")
126131
self.split_x = (
127132
self.fd_config.parallel_config.use_sequence_parallel_moe
@@ -240,6 +245,12 @@ def forward(
240245
norm_out = rms_norm(x, self.weight, self.eps)
241246
return norm_out.astype(x_dtype), residual_out
242247
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
248+
# enable trtllm all reduce fusion
249+
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
250+
norm_out = flashinfer_allreduce_residual_rmsnorm(
251+
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
252+
)
253+
assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!"
243254
else:
244255
if is_batch_invariant_mode_enabled():
245256
# M-invariant path: per-row Triton kernel, no cross-row reduction

fastdeploy/model_executor/layers/quantization/mxfp4.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
# limitations under the License.
1515
"""
1616

17-
import importlib
18-
import importlib.util
1917
import math
2018
from enum import Enum
2119
from typing import Callable, Optional
@@ -25,11 +23,12 @@
2523

2624
from fastdeploy import envs
2725
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
28-
from fastdeploy.model_executor.utils import set_weight_attrs
26+
from fastdeploy.model_executor.utils import has_flashinfer, set_weight_attrs
2927
from fastdeploy.platforms import current_platform
3028

3129
if current_platform.is_cuda():
3230
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch
31+
3332
from fastdeploy.utils import get_logger
3433

3534
from ..moe import FusedMoE
@@ -59,10 +58,6 @@ def check_device_capability(num):
5958
return False
6059

6160

62-
def has_flashinfer():
63-
return importlib.util.find_spec("flashinfer") is not None
64-
65-
6661
def round_up(a, b):
6762
return ((a + b - 1) // b) * b
6863

0 commit comments

Comments
 (0)