diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 50e0ba08a37..9939c7e3680 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -671,6 +671,7 @@ def __init__( self.pod_ip: str = None # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). self.disable_custom_all_reduce: bool = False + self.enable_flashinfer_allreduce_fusion: bool = False for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index afb7095a449..21f032423d7 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -274,6 +274,11 @@ class EngineArgs: Flag to disable the custom all-reduce kernel. """ + enable_flashinfer_allreduce_fusion: bool = False + """ + Flag to enable all reduce fusion kernel in flashinfer. + """ + use_internode_ll_two_stage: bool = False """ Flag to use the internode_ll_two_stage kernel. @@ -977,6 +982,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.disable_custom_all_reduce, help="Flag to disable custom all-reduce.", ) + parallel_group.add_argument( + "--enable-flashinfer-allreduce-fusion", + action="store_true", + default=EngineArgs.enable_flashinfer_allreduce_fusion, + help="Flag to enable all reduce fusion kernel in flashinfer.", + ) parallel_group.add_argument( "--use-internode-ll-two-stage", action="store_true", diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 28776b53ede..7bc7cb5845c 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -2301,6 +2301,7 @@ def _start_worker_service(self): "moe_gate_fp32": self.cfg.model_config.moe_gate_fp32, "enable_entropy": self.cfg.model_config.enable_entropy, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, + "enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 9f78f8584ac..c27e29a03f7 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -651,6 +651,7 @@ def _start_worker_service(self): "enable_entropy": self.cfg.model_config.enable_entropy, "ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, + "enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py new file mode 100644 index 00000000000..b9b84938416 --- /dev/null +++ b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py @@ -0,0 +1,212 @@ +from typing import Optional, Tuple + +import paddle +import paddle.distributed as dist + +# from sglang.srt.distributed import get_tensor_model_parallel_world_size +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.utils import has_flashinfer +from fastdeploy.utils import get_logger + +logger = get_logger("flashinfer", "flashinfer.log") + +_flashinfer_comm = None +_workspace_manager = None + +# fd_config.parallel_config.tensor_parallel_size + +if has_flashinfer(): + try: + paddle.compat.enable_torch_proxy(scope={"flashinfer"}) + import flashinfer.comm as comm + + _flashinfer_comm = comm + except ImportError: + logger.warning("flashinfer.comm is not available, falling back to standard " "implementation") + + +class FlashInferWorkspaceManager: + def __init__(self): + self.workspace_tensor = None + self.ipc_handles = None + self.world_size = None + self.rank = None + self.initialized = False + + def initialize( + self, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + group=None, + use_fp32_lamport: bool = False, + ): + """Initialize workspace""" + if self.initialized and self.world_size == world_size: + return + + if _flashinfer_comm is None: + logger.warning("FlashInfer comm not available, skipping workspace " "initialization") + return + + self.cleanup() + + self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, + world_size, + max_token_num, + hidden_dim, + group=group, + use_fp32_lamport=use_fp32_lamport, + ) + + self.world_size = world_size + self.rank = rank + self.initialized = True + + logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}") + + def cleanup(self): + """Clean up workspace""" + if self.initialized and self.ipc_handles is not None: + try: + _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group()) + except Exception as e: + logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") + finally: + self.workspace_tensor = None + self.ipc_handles = None + self.initialized = False + + +_workspace_manager = FlashInferWorkspaceManager() + + +def ensure_workspace_initialized( + fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False +): + """Ensure workspace is initialized""" + if not has_flashinfer() or _flashinfer_comm is None: + return False + + assert fd_config is not None + world_size = fd_config.parallel_config.tensor_parallel_size + if world_size <= 1: + return False + + rank = dist.get_rank() + + if not _workspace_manager.initialized or _workspace_manager.world_size != world_size: + _workspace_manager.initialize( + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + use_fp32_lamport=use_fp32_lamport, + ) + + return _workspace_manager.initialized + + +def flashinfer_allreduce_residual_rmsnorm( + fd_config: FDConfig, + input_tensor: paddle.Tensor, + residual: paddle.Tensor, + weight: paddle.Tensor, + eps: float = 1e-6, + max_token_num: int = 2048, + use_oneshot: Optional[bool] = None, + trigger_completion_at_end: bool = False, + fp32_acc: bool = False, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """ + Use FlashInfer's fused allreduce + residual + RMS norm operation + + Args: + input_tensor: Input tensor that needs allreduce + residual: Residual tensor + weight: RMS norm weight + eps: RMS norm epsilon + max_token_num: Maximum token number + use_oneshot: Whether to use oneshot mode + trigger_completion_at_end: Whether to trigger completion at end + fp32_acc: Whether to use fp32 precision + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: (norm_output, residual_output) + """ + if not has_flashinfer() or _flashinfer_comm is None: + logger.debug("FlashInfer not available, falling back to standard " "implementation") + return None, None + + assert fd_config is not None + world_size = fd_config.parallel_config.tensor_parallel_size + if world_size <= 1: + logger.debug("Single GPU, no need for allreduce fusion") + return None, None + + assert input_tensor.shape[0] <= max_token_num + + if not ensure_workspace_initialized( + fd_config=fd_config, + max_token_num=max_token_num, + hidden_dim=input_tensor.shape[-1], + use_fp32_lamport=(input_tensor.dtype == paddle.float32), + ): + logger.debug("FlashInfer workspace not available") + return None, None + + token_num, hidden_dim = input_tensor.shape + + residual_out = paddle.empty_like(residual) + norm_out = paddle.empty_like(input_tensor) + # support empty tensor + if input_tensor.shape[0] == 0: + return norm_out, residual_out + _flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + world_size=world_size, + world_rank=dist.get_rank(), + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=_workspace_manager.workspace_tensor, + launch_with_pdl=True, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm), + allreduce_out=None, + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + quant_out=None, + scale_out=None, + rms_gamma=weight, + rms_eps=eps, + scale_factor=None, + layout_code=None, + ) + + return norm_out, residual_out + + +def fake_flashinfer_allreduce_residual_rmsnorm( + input_tensor: paddle.Tensor, + residual: paddle.Tensor, + weight: paddle.Tensor, + eps: float = 1e-6, + max_token_num: int = 16384, + use_oneshot: Optional[bool] = None, + trigger_completion_at_end: bool = False, + fp32_acc: bool = False, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + residual_out = paddle.empty_like(residual) + norm_out = paddle.empty_like(input_tensor) + return norm_out, residual_out + + +def cleanup_flashinfer_workspace(): + global _workspace_manager + if _workspace_manager is not None: + _workspace_manager.cleanup() diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 2bee885ff43..6028694966d 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -860,6 +860,9 @@ def __init__( skip_quant (bool): Whether to skip quantization. Defaults to False. """ self.fd_config = fd_config + self.enable_all_reduce_fusion = ( + fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "enable_all_reduce" in prefix + ) self.ep_size = fd_config.parallel_config.expert_parallel_size self.tp_size = fd_config.parallel_config.tensor_parallel_size self.tp_group = fd_config.parallel_config.tp_group @@ -937,7 +940,10 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: out = self.quant_method.apply(self, x) - if self.reduce_results and self.tp_size > 1: + need_tp_all_reduce = ( + self.reduce_results and self.tp_size > 1 and not (self.enable_all_reduce_fusion and out.shape[0] <= 2048) + ) + if need_tp_all_reduce: out = tensor_model_parallel_all_reduce(out, self.tp_group) return out diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index 14e248e0a72..a830575b76c 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -35,6 +35,7 @@ is_batch_invariant_mode_enabled, rms_norm_batch_invariant, ) +from .flashinfer_comm_fusion import flashinfer_allreduce_residual_rmsnorm from .utils import get_tensor, modules_to_convert @@ -122,6 +123,10 @@ def __init__( self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank self.tp_group = self.fd_config.parallel_config.tp_group is_input_norm = prefix.endswith(".input_layernorm") + self.enable_all_reduce_fusion = ( + fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "post_attention_layernorm" in prefix + ) + self.is_last_norm = prefix.endswith(".norm") self.split_x = ( self.fd_config.parallel_config.use_sequence_parallel_moe @@ -240,6 +245,11 @@ def forward( norm_out = rms_norm(x, self.weight, self.eps) return norm_out.astype(x_dtype), residual_out norm_out = self.norm_func(x, residual_input, self.weight, self.eps) + # enable trtllm all reduce fusion + elif self.enable_all_reduce_fusion and x.shape[0] <= 2048: + norm_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps + ) else: if is_batch_invariant_mode_enabled(): # M-invariant path: per-row Triton kernel, no cross-row reduction diff --git a/fastdeploy/model_executor/layers/quantization/mxfp4.py b/fastdeploy/model_executor/layers/quantization/mxfp4.py index 9fa02866210..508a5f891fc 100644 --- a/fastdeploy/model_executor/layers/quantization/mxfp4.py +++ b/fastdeploy/model_executor/layers/quantization/mxfp4.py @@ -14,8 +14,6 @@ # limitations under the License. """ -import importlib -import importlib.util import math from enum import Enum from typing import Callable, Optional @@ -25,11 +23,12 @@ from fastdeploy import envs from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase -from fastdeploy.model_executor.utils import set_weight_attrs +from fastdeploy.model_executor.utils import has_flashinfer, set_weight_attrs from fastdeploy.platforms import current_platform if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch + from fastdeploy.utils import get_logger from ..moe import FusedMoE @@ -59,10 +58,6 @@ def check_device_capability(num): return False -def has_flashinfer(): - return importlib.util.find_spec("flashinfer") is not None - - def round_up(a, b): return ((a + b - 1) // b) * b diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index 3f45e9df614..a468f344775 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -128,7 +128,6 @@ def __init__( self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank self.tp_group = fd_config.parallel_config.tp_group - self.use_ep = self.expert_parallel_size > 1 self.use_tp = self.tensor_parallel_size > 1 @@ -189,7 +188,6 @@ def forward(self, x, forward_meta: ForwardMeta = None): if self.n_shared_experts > 0: shared_experts_out = self.shared_experts(x) out = out + shared_experts_out - return out @@ -213,7 +211,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None self.o_proj = RowParallelLinear( fd_config, - prefix=f"{prefix}.o_proj", + prefix=f"{prefix}.enable_all_reduce.o_proj", input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim, output_size=fd_config.model_config.hidden_size, layer_id=layer_id, @@ -288,7 +286,7 @@ def __init__( fd_config, hidden_size=fd_config.model_config.hidden_size, eps=fd_config.model_config.rms_norm_eps, - prefix=f"{prefix}.input_layernorm", + prefix=f"{prefix}.enable_all_reduce_fusion.input_layernorm", layer_id=layer_id, ) @@ -296,7 +294,7 @@ def __init__( fd_config, hidden_size=fd_config.model_config.hidden_size, eps=fd_config.model_config.rms_norm_eps, - prefix=f"{prefix}.post_attention_layernorm", + prefix=f"{prefix}.enable_all_reduce_fusion.post_attention_layernorm", layer_id=layer_id, ) diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index b03fa480d3e..2f8a56e5db4 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -14,6 +14,8 @@ # limitations under the License. """ +import importlib +import importlib.util import os import re from collections.abc import Mapping @@ -552,6 +554,10 @@ def fn(loaded_weight_name, is_moe): return fn +def has_flashinfer(): + return importlib.util.find_spec("flashinfer") is not None + + @cache def get_sm_version(): if paddle.cuda.is_available(): diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 9e41428e42b..73cd27974ef 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -840,6 +840,12 @@ def parse_args(): default=None, help="Configuration of SpeculativeConfig.", ) + parser.add_argument( + "--enable_flashinfer_allreduce_fusion", + action="store_true", + default=False, + help="Flag to enable all reduce fusion kernel in flashinfer.", + ) parser.add_argument( "--max_num_batched_tokens", type=int, diff --git a/requirements.txt b/requirements.txt index 66ed714045b..8ad10fa48e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,7 +46,7 @@ setproctitle aistudio_sdk p2pstore py-cpuinfo -flashinfer-python-paddle +flashinfer-python-paddle @ https://xly-devops.bj.bcebos.com/flashinfer/flashinfer_python_paddle-0.4.1.2-py3-none-any.whl flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl arctic_inference @ https://paddle-qa.bj.bcebos.com/ernie/arctic_inference-0.1.3-cp310-cp310-linux_x86_64.whl transformers>=4.55.1,<5.0.0 diff --git a/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py b/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py index 121e74ee4b9..54ce40ca5d4 100644 --- a/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py +++ b/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py @@ -31,6 +31,7 @@ def _make_minimal_rmsnorm(hidden_size, eps=1e-5, dtype="float32"): layer.bias = None layer.split_x = False layer.allgather_out = False + layer.enable_all_reduce_fusion = False return layer diff --git a/tests/layers/test_trtllm_allreduce_rms_fusion.py b/tests/layers/test_trtllm_allreduce_rms_fusion.py new file mode 100644 index 00000000000..6699038d5fd --- /dev/null +++ b/tests/layers/test_trtllm_allreduce_rms_fusion.py @@ -0,0 +1,383 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import subprocess +import sys +import unittest +from unittest.mock import Mock, patch + +import paddle + + +def test_run_distributed(): + """Launch multi-GPU distributed test via paddle.distributed.launch as subprocess""" + + current_dir = os.path.dirname(os.path.abspath(__file__)) + run_script = os.path.join(current_dir, "trtllm_allreduce_rms_fusion.py") + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" + command = [ + sys.executable, + "-m", + "paddle.distributed.launch", + "--gpus", + "0,1", + run_script, + ] + + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + try: + stdout, stderr = process.communicate(timeout=400) + return_code = process.returncode + except subprocess.TimeoutExpired: + process.kill() + stdout, stderr = process.communicate() + return_code = -1 + assert return_code in (0, 250), f"Process exited with code {return_code}" + + +test_run_distributed() + + +class TestFlashInferWorkspaceManagerEdgeCases(unittest.TestCase): + """Test FlashInferWorkspaceManager edge cases and fallback paths""" + + def setUp(self): + """Initialize test fixtures""" + # Patch before importing to test fallback paths + self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer") + self.mock_has_flashinfer = self.patcher_has_flashinfer.start() + + def tearDown(self): + """Clean up patches""" + self.patcher_has_flashinfer.stop() + + def test_initialization_early_return_when_already_initialized(self): + """Test line 47: early return when already initialized with same world_size""" + # Patch _flashinfer_comm to be available + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm: + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + + # First initialization + manager.initialized = True + manager.world_size = 2 + + # Mock the comm functions + mock_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion = Mock(return_value=(Mock(), Mock())) + + # Second initialization with same world_size - should return early + manager.initialize( + world_size=2, + rank=0, + max_token_num=2048, + hidden_dim=4096, + ) + + def test_initialization_warning_when_comm_none(self): + """Test lines 50-51: warning when _flashinfer_comm is None""" + # Patch to ensure _flashinfer_comm is None + with patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm", + None, + ): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + + # Should not raise, just log warning and return + manager.initialize( + world_size=2, + rank=0, + max_token_num=2048, + hidden_dim=4096, + ) + + # Verify not initialized + self.assertFalse(manager.initialized) + + def test_cleanup_with_exception(self): + """Test lines 73-80: cleanup with exception handling""" + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm: + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + manager.initialized = True + manager.ipc_handles = Mock() + manager.workspace_tensor = Mock() + + # Mock the destroy function to raise exception + mock_comm.trtllm_destroy_ipc_workspace_for_all_reduce = Mock(side_effect=RuntimeError("Cleanup error")) + + # Should not raise, just log warning + manager.cleanup() + + # Verify cleanup happened + self.assertFalse(manager.initialized) + self.assertIsNone(manager.workspace_tensor) + self.assertIsNone(manager.ipc_handles) + + def test_cleanup_without_initialization(self): + """Test cleanup when not initialized""" + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + manager.initialized = False + + # Should not raise + manager.cleanup() + + # Verify state + self.assertFalse(manager.initialized) + + +class TestEnsureWorkspaceInitialized(unittest.TestCase): + """Test ensure_workspace_initialized fallback paths""" + + def setUp(self): + """Initialize test fixtures""" + self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer") + self.mock_has_flashinfer = self.patcher_has_flashinfer.start() + + def tearDown(self): + """Clean up patches""" + self.patcher_has_flashinfer.stop() + + def test_ensure_workspace_when_flashinfer_not_available(self): + """Test line 91: early return when flashinfer not available""" + self.mock_has_flashinfer.return_value = False + + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + ensure_workspace_initialized, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + result = ensure_workspace_initialized(fd_config) + + # Should return False (not initialized) + self.assertFalse(result) + + def test_ensure_workspace_when_comm_none(self): + """Test ensure_workspace_initialized when _flashinfer_comm is None""" + self.mock_has_flashinfer.return_value = True + + with patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm", + None, + ): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + ensure_workspace_initialized, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + result = ensure_workspace_initialized(fd_config) + + # Should return False + self.assertFalse(result) + + def test_ensure_workspace_single_gpu(self): + """Test line 96: early return when world_size <= 1""" + self.mock_has_flashinfer.return_value = True + + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm"): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + ensure_workspace_initialized, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 1 + + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.dist.get_rank", return_value=0): + result = ensure_workspace_initialized(fd_config) + + # Should return False for single GPU + self.assertFalse(result) + + +class TestFlashInferAllReduceResidualRMSNormFallbacks(unittest.TestCase): + """Test flashinfer_allreduce_residual_rmsnorm fallback paths""" + + def setUp(self): + """Initialize test fixtures""" + self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer") + self.mock_has_flashinfer = self.patcher_has_flashinfer.start() + + def tearDown(self): + """Clean up patches""" + self.patcher_has_flashinfer.stop() + + def test_flashinfer_not_available_fallback(self): + """Test lines 140-141: fallback when flashinfer not available""" + self.mock_has_flashinfer.return_value = False + + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + input_tensor = paddle.randn([128, 768]) + residual = paddle.randn([128, 768]) + weight = paddle.randn([768]) + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=1e-6, + max_token_num=2048, + ) + + # Should return None, None when flashinfer not available + self.assertIsNone(norm_out) + self.assertIsNone(residual_out) + + def test_single_gpu_fallback(self): + """Test lines 146-147: fallback for single GPU""" + self.mock_has_flashinfer.return_value = True + + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm"): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 1 + + input_tensor = paddle.randn([128, 768]) + residual = paddle.randn([128, 768]) + weight = paddle.randn([768]) + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=1e-6, + max_token_num=2048, + ) + + # Should return None, None for single GPU + self.assertIsNone(norm_out) + self.assertIsNone(residual_out) + + def test_empty_tensor_handling(self): + """Test line 166: empty tensor handling""" + self.mock_has_flashinfer.return_value = True + + with ( + patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm, + patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion.ensure_workspace_initialized", + return_value=True, + ), + ): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + # Empty tensor (0 tokens) + input_tensor = paddle.zeros([0, 768]) + residual = paddle.zeros([0, 768]) + weight = paddle.randn([768]) + + # Mock the trtllm_allreduce_fusion to not be called + mock_comm.trtllm_allreduce_fusion = Mock() + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=1e-6, + max_token_num=2048, + ) + + # Should return empty tensors, not call flashinfer + self.assertEqual(norm_out.shape[0], 0) + self.assertEqual(residual_out.shape[0], 0) + mock_comm.trtllm_allreduce_fusion.assert_not_called() + + +class TestFakeFlashInferFunction(unittest.TestCase): + """Test fake_flashinfer_allreduce_residual_rmsnorm function""" + + def test_fake_function_basic(self): + """Test lines 204-206: fake function basic functionality""" + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + fake_flashinfer_allreduce_residual_rmsnorm, + ) + + input_tensor = paddle.randn([128, 768]) + residual = paddle.randn([128, 768]) + weight = paddle.randn([768]) + + norm_out, residual_out = fake_flashinfer_allreduce_residual_rmsnorm( + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=1e-6, + max_token_num=16384, + use_oneshot=None, + trigger_completion_at_end=False, + fp32_acc=False, + ) + + # Should return empty-like tensors + self.assertEqual(norm_out.shape, input_tensor.shape) + self.assertEqual(residual_out.shape, residual.shape) + + +class TestCleanupFlashInferWorkspace(unittest.TestCase): + """Test cleanup_flashinfer_workspace function""" + + def test_cleanup_workspace_function(self): + """Test lines 211-212: cleanup function""" + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._workspace_manager") as mock_manager: + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + cleanup_flashinfer_workspace, + ) + + mock_manager.cleanup = Mock() + + cleanup_flashinfer_workspace() + + mock_manager.cleanup.assert_called_once() diff --git a/tests/layers/trtllm_allreduce_rms_fusion.py b/tests/layers/trtllm_allreduce_rms_fusion.py new file mode 100644 index 00000000000..5770900df95 --- /dev/null +++ b/tests/layers/trtllm_allreduce_rms_fusion.py @@ -0,0 +1,247 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import time +import unittest +from unittest.mock import Mock + +import numpy as np +import paddle +import paddle.distributed as dist + + +class TestFlashInferAllReduceResidualRMSNorm(unittest.TestCase): + """Test FlashInfer AllReduce + Residual + RMSNorm fused operator""" + + @classmethod + def setUpClass(cls): + """Set up test environment""" + if paddle.is_compiled_with_cuda(): + paddle.set_device("gpu") + else: + paddle.set_device("cpu") + dist.init_parallel_env() + + def setUp(self): + """Initialize each test case""" + # Fix random seed for reproducibility + paddle.seed(42) + np.random.seed(42) + + self.dtype = paddle.float32 + self.token_num = 128 + self.hidden_dim = 768 + self.eps = 1e-6 + self.epsilon = 1e-6 + self.max_token_num = 2048 + + # Create mock FDConfig + self.fd_config = Mock() + self.fd_config.parallel_config = Mock() + self.fd_config.parallel_config.tensor_parallel_size = dist.get_world_size() + self.begin_norm_axis = 1 + + # Performance test params - increase iterations for stability + self.warmup_iterations = 20 # Increase warmup + self.test_iterations = 200 # Increase test iterations + + def tearDown(self): + """Clean up resources""" + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.empty_cache() + paddle.device.cuda.synchronize() + + def create_test_tensors(self): + """Create test tensors""" + input_tensor = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype) + residual = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype) + weight = paddle.randn([self.hidden_dim], dtype=self.dtype) + return input_tensor, residual, weight + + def compute_reference_output(self, input_tensor, residual, weight, eps): + """Reference implementation: manually compute AllReduce + Residual + RMSNorm""" + # # Step 1: AllReduce (identity on single device) + # allreduce_out = input_tensor.clone() + # Apply all reduce operator + dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM) + # Step 2: Add residual + residual_out = input_tensor + residual + + # Step 3: RMSNorm + variance = residual_out.pow(2).mean(axis=-1, keepdim=True) + norm_out = residual_out * paddle.rsqrt(variance + eps) + norm_out = norm_out * weight + + # dist.all_reduce(residual_out, op=dist.ReduceOp.SUM) + return norm_out, residual_out + + def paddle_rms_fuse(self, input_tensor, residual, weight, eps): + from paddle.incubate.nn.functional import fused_rms_norm + + # Apply all reduce operator + dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM) + out_fused = fused_rms_norm( + input_tensor, + norm_weight=weight, + norm_bias=None, + epsilon=eps, + begin_norm_axis=self.begin_norm_axis, + bias=None, + residual=residual, + ) + + return out_fused[0], out_fused[1] + + def flashinfer_rms_fuse(self, input_tensor, residual, weight, eps): + """FlashInfer fused operator""" + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=self.fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=eps, + max_token_num=self.max_token_num, + use_oneshot=False, + ) + return norm_out, residual_out + + def benchmark_function(self, func, *args, name="", **kwargs): + """ + Improved performance benchmark + - Wait for GPU frequency stabilization + - Use median instead of mean (more stable) + - Filter outliers + """ + # Force GPU frequency stabilization + if paddle.is_compiled_with_cuda(): + for _ in range(5): + paddle.device.cuda.synchronize() + time.sleep(0.01) + + # Warmup - thorough warm-up + for _ in range(self.warmup_iterations): + result = func(*args, **kwargs) + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + + # Extra wait to ensure GPU stability + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + time.sleep(0.1) + + # Benchmark run + times = [] + for i in range(self.test_iterations): + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + + start = time.perf_counter() + result = func(*args, **kwargs) + + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + + end = time.perf_counter() + elapsed = (end - start) * 1000 # Convert to milliseconds + times.append(elapsed) + + times = np.array(times) + + # Filter outliers using IQR method + q1, q3 = np.percentile(times, [25, 75]) + iqr = q3 - q1 + lower_bound = q1 - 1.5 * iqr + upper_bound = q3 + 1.5 * iqr + filtered_times = times[(times >= lower_bound) & (times <= upper_bound)] + + # Fall back to raw data if too many samples filtered out + if len(filtered_times) < self.test_iterations * 0.5: + filtered_times = times + + # Statistics + avg_time = np.mean(filtered_times) + median_time = np.median(filtered_times) + std_time = np.std(filtered_times) + min_time = np.min(filtered_times) + max_time = np.max(filtered_times) + cv = (std_time / avg_time) * 100 # Coefficient of variation (%) + + print(f"\n{'='*70}") + print(f"Performance Benchmark: {name}") + print(f"{'='*70}") + print(f"Iterations: {len(filtered_times)}/{self.test_iterations} (after {self.warmup_iterations} warmup)") + print(f"Median: {median_time:.4f} ms (most stable metric)") + print(f"Average: {avg_time:.4f} ms") + print(f"Std Dev: {std_time:.4f} ms (CV: {cv:.2f}%)") + print(f"Min: {min_time:.4f} ms") + print(f"Max: {max_time:.4f} ms") + print(f"{'='*70}\n") + + # Return median (more stable) and result + return median_time, result + + def test_accuracy_fused_vs_reference(self): + """Test accuracy of fused operator vs reference implementation""" + input_tensor, residual, weight = self.create_test_tensors() + reference_output, ref_res = self.compute_reference_output( + input_tensor.clone(), residual.clone(), weight.clone(), self.eps + ) + fused_output, paddle_res = self.paddle_rms_fuse( + input_tensor.clone(), residual.clone(), weight.clone(), self.eps + ) + flashinfer_output, flashinfer_res = self.flashinfer_rms_fuse( + input_tensor.clone(), residual.clone(), weight.clone(), self.eps + ) + # Verify results + np.testing.assert_allclose(fused_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(ref_res.numpy(), paddle_res.numpy(), rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(flashinfer_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(ref_res.numpy(), flashinfer_res.numpy(), rtol=1e-5, atol=1e-5) + + +class TestFlashInferWorkspaceManager(unittest.TestCase): + """Test FlashInferWorkspaceManager""" + + def setUp(self): + """Initialize""" + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + self.manager = FlashInferWorkspaceManager() + + def test_initialization(self): + """Test initialization state""" + self.assertIsNone(self.manager.workspace_tensor) + self.assertIsNone(self.manager.ipc_handles) + self.assertIsNone(self.manager.world_size) + self.assertIsNone(self.manager.rank) + self.assertFalse(self.manager.initialized) + + def test_cleanup(self): + """Test cleanup functionality""" + self.manager.cleanup() + self.assertFalse(self.manager.initialized) + self.assertIsNone(self.manager.workspace_tensor) + + +if __name__ == "__main__": + """Run tests directly (called by subprocess after distributed launch)""" + unittest.main(verbosity=2) diff --git a/tests/model_executor/test_linear.py b/tests/model_executor/test_linear.py index 13f2bbe245e..aba98479303 100644 --- a/tests/model_executor/test_linear.py +++ b/tests/model_executor/test_linear.py @@ -58,6 +58,7 @@ def make_fd_config( expert_parallel_size=1, tp_group=None, use_sequence_parallel_moe=use_sequence_parallel_moe, + enable_flashinfer_allreduce_fusion=False, ), scheduler_config=SimpleNamespace(splitwise_role=splitwise_role, max_num_seqs=1), load_config=SimpleNamespace(