Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ def __init__(
self.pod_ip: str = None
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
self.disable_custom_all_reduce: bool = False
self.enable_flashinfer_allreduce_fusion: bool = False
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
Expand Down
11 changes: 11 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ class EngineArgs:
Flag to disable the custom all-reduce kernel.
"""

enable_flashinfer_allreduce_fusion: bool = False
"""
Flag to enable all reduce fusion kernel in flashinfer.
"""

use_internode_ll_two_stage: bool = False
"""
Flag to use the internode_ll_two_stage kernel.
Expand Down Expand Up @@ -977,6 +982,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.disable_custom_all_reduce,
help="Flag to disable custom all-reduce.",
)
parallel_group.add_argument(
"--enable-flashinfer-allreduce-fusion",
action="store_true",
default=EngineArgs.enable_flashinfer_allreduce_fusion,
help="Flag to enable all reduce fusion kernel in flashinfer.",
)
parallel_group.add_argument(
"--use-internode-ll-two-stage",
action="store_true",
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2301,6 +2301,7 @@ def _start_worker_service(self):
"moe_gate_fp32": self.cfg.model_config.moe_gate_fp32,
"enable_entropy": self.cfg.model_config.enable_entropy,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
}
for worker_flag, value in worker_store_true_flag.items():
if value:
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ def _start_worker_service(self):
"enable_entropy": self.cfg.model_config.enable_entropy,
"ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
}
for worker_flag, value in worker_store_true_flag.items():
if value:
Expand Down
212 changes: 212 additions & 0 deletions fastdeploy/model_executor/layers/flashinfer_comm_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from typing import Optional, Tuple

import paddle
import paddle.distributed as dist

# from sglang.srt.distributed import get_tensor_model_parallel_world_size
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.utils import has_flashinfer
from fastdeploy.utils import get_logger

logger = get_logger("flashinfer", "flashinfer.log")

_flashinfer_comm = None
_workspace_manager = None

# fd_config.parallel_config.tensor_parallel_size

if has_flashinfer():
try:
paddle.compat.enable_torch_proxy(scope={"flashinfer"})
import flashinfer.comm as comm

_flashinfer_comm = comm
except ImportError:
logger.warning("flashinfer.comm is not available, falling back to standard " "implementation")


class FlashInferWorkspaceManager:
def __init__(self):
self.workspace_tensor = None
self.ipc_handles = None
self.world_size = None
self.rank = None
self.initialized = False

def initialize(
self,
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
group=None,
use_fp32_lamport: bool = False,
):
"""Initialize workspace"""
if self.initialized and self.world_size == world_size:
return

if _flashinfer_comm is None:
logger.warning("FlashInfer comm not available, skipping workspace " "initialization")
return

self.cleanup()

self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
rank,
world_size,
max_token_num,
hidden_dim,
group=group,
use_fp32_lamport=use_fp32_lamport,
)

self.world_size = world_size
self.rank = rank
self.initialized = True

logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}")

def cleanup(self):
"""Clean up workspace"""
if self.initialized and self.ipc_handles is not None:
try:
_flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group())
except Exception as e:
logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
finally:
self.workspace_tensor = None
self.ipc_handles = None
self.initialized = False


_workspace_manager = FlashInferWorkspaceManager()


def ensure_workspace_initialized(
fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
):
"""Ensure workspace is initialized"""
if not has_flashinfer() or _flashinfer_comm is None:
return False

assert fd_config is not None
world_size = fd_config.parallel_config.tensor_parallel_size
if world_size <= 1:
return False

rank = dist.get_rank()

if not _workspace_manager.initialized or _workspace_manager.world_size != world_size:
_workspace_manager.initialize(
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
use_fp32_lamport=use_fp32_lamport,
)

return _workspace_manager.initialized


def flashinfer_allreduce_residual_rmsnorm(
fd_config: FDConfig,
input_tensor: paddle.Tensor,
residual: paddle.Tensor,
weight: paddle.Tensor,
eps: float = 1e-6,
max_token_num: int = 2048,
use_oneshot: Optional[bool] = None,
trigger_completion_at_end: bool = False,
fp32_acc: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""
Use FlashInfer's fused allreduce + residual + RMS norm operation

Args:
input_tensor: Input tensor that needs allreduce
residual: Residual tensor
weight: RMS norm weight
eps: RMS norm epsilon
max_token_num: Maximum token number
use_oneshot: Whether to use oneshot mode
trigger_completion_at_end: Whether to trigger completion at end
fp32_acc: Whether to use fp32 precision

Returns:
Tuple[paddle.Tensor, paddle.Tensor]: (norm_output, residual_output)
"""
if not has_flashinfer() or _flashinfer_comm is None:
logger.debug("FlashInfer not available, falling back to standard " "implementation")
return None, None

assert fd_config is not None
world_size = fd_config.parallel_config.tensor_parallel_size
if world_size <= 1:
logger.debug("Single GPU, no need for allreduce fusion")
return None, None

assert input_tensor.shape[0] <= max_token_num

if not ensure_workspace_initialized(
fd_config=fd_config,
max_token_num=max_token_num,
hidden_dim=input_tensor.shape[-1],
use_fp32_lamport=(input_tensor.dtype == paddle.float32),
):
logger.debug("FlashInfer workspace not available")
return None, None

token_num, hidden_dim = input_tensor.shape

residual_out = paddle.empty_like(residual)
norm_out = paddle.empty_like(input_tensor)
# support empty tensor
if input_tensor.shape[0] == 0:
return norm_out, residual_out
_flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
world_size=world_size,
world_rank=dist.get_rank(),
token_num=token_num,
hidden_dim=hidden_dim,
workspace_ptrs=_workspace_manager.workspace_tensor,
launch_with_pdl=True,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm),
allreduce_out=None,
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=eps,
scale_factor=None,
layout_code=None,
)

return norm_out, residual_out


def fake_flashinfer_allreduce_residual_rmsnorm(
input_tensor: paddle.Tensor,
residual: paddle.Tensor,
weight: paddle.Tensor,
eps: float = 1e-6,
max_token_num: int = 16384,
use_oneshot: Optional[bool] = None,
trigger_completion_at_end: bool = False,
fp32_acc: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
residual_out = paddle.empty_like(residual)
norm_out = paddle.empty_like(input_tensor)
return norm_out, residual_out


def cleanup_flashinfer_workspace():
global _workspace_manager
if _workspace_manager is not None:
_workspace_manager.cleanup()
8 changes: 7 additions & 1 deletion fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,9 @@ def __init__(
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
self.fd_config = fd_config
self.enable_all_reduce_fusion = (
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "enable_all_reduce" in prefix
)
self.ep_size = fd_config.parallel_config.expert_parallel_size
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.tp_group = fd_config.parallel_config.tp_group
Expand Down Expand Up @@ -937,7 +940,10 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:

out = self.quant_method.apply(self, x)

if self.reduce_results and self.tp_size > 1:
need_tp_all_reduce = (
self.reduce_results and self.tp_size > 1 and not (self.enable_all_reduce_fusion and out.shape[0] <= 2048)
)
if need_tp_all_reduce:
out = tensor_model_parallel_all_reduce(out, self.tp_group)

return out
Expand Down
10 changes: 10 additions & 0 deletions fastdeploy/model_executor/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_batch_invariant_mode_enabled,
rms_norm_batch_invariant,
)
from .flashinfer_comm_fusion import flashinfer_allreduce_residual_rmsnorm
from .utils import get_tensor, modules_to_convert


Expand Down Expand Up @@ -122,6 +123,10 @@ def __init__(
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
self.tp_group = self.fd_config.parallel_config.tp_group
is_input_norm = prefix.endswith(".input_layernorm")
self.enable_all_reduce_fusion = (
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "post_attention_layernorm" in prefix
)

self.is_last_norm = prefix.endswith(".norm")
self.split_x = (
self.fd_config.parallel_config.use_sequence_parallel_moe
Expand Down Expand Up @@ -240,6 +245,11 @@ def forward(
norm_out = rms_norm(x, self.weight, self.eps)
return norm_out.astype(x_dtype), residual_out
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
# enable trtllm all reduce fusion
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
norm_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
)
else:
if is_batch_invariant_mode_enabled():
# M-invariant path: per-row Triton kernel, no cross-row reduction
Expand Down
9 changes: 2 additions & 7 deletions fastdeploy/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.
"""

import importlib
import importlib.util
import math
from enum import Enum
from typing import Callable, Optional
Expand All @@ -25,11 +23,12 @@

from fastdeploy import envs
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
from fastdeploy.model_executor.utils import set_weight_attrs
from fastdeploy.model_executor.utils import has_flashinfer, set_weight_attrs
from fastdeploy.platforms import current_platform

if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch

from fastdeploy.utils import get_logger

from ..moe import FusedMoE
Expand Down Expand Up @@ -59,10 +58,6 @@ def check_device_capability(num):
return False


def has_flashinfer():
return importlib.util.find_spec("flashinfer") is not None


def round_up(a, b):
return ((a + b - 1) // b) * b

Expand Down
8 changes: 3 additions & 5 deletions fastdeploy/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def __init__(
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group

self.use_ep = self.expert_parallel_size > 1
self.use_tp = self.tensor_parallel_size > 1

Expand Down Expand Up @@ -189,7 +188,6 @@ def forward(self, x, forward_meta: ForwardMeta = None):
if self.n_shared_experts > 0:
shared_experts_out = self.shared_experts(x)
out = out + shared_experts_out

return out


Expand All @@ -213,7 +211,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None

self.o_proj = RowParallelLinear(
fd_config,
prefix=f"{prefix}.o_proj",
prefix=f"{prefix}.enable_all_reduce.o_proj",
input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim,
output_size=fd_config.model_config.hidden_size,
layer_id=layer_id,
Expand Down Expand Up @@ -288,15 +286,15 @@ def __init__(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm",
prefix=f"{prefix}.enable_all_reduce_fusion.input_layernorm",
layer_id=layer_id,
)

self.post_attention_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
prefix=f"{prefix}.enable_all_reduce_fusion.post_attention_layernorm",
layer_id=layer_id,
)

Expand Down
Loading
Loading