diff --git a/.github/workflows/_build_linux.yml b/.github/workflows/_build_linux.yml index d17f210a12c..06d55801e27 100644 --- a/.github/workflows/_build_linux.yml +++ b/.github/workflows/_build_linux.yml @@ -164,7 +164,7 @@ jobs: python -m pip install -r requirements.txt python -m pip install wheel # 编译RDMA - export ENABLE_FD_RDMA=1 + export FD_ENABLE_RDMA_COMPILE=1 bash build.sh 1 python false [${COMPILE_ARCH}] ls ./dist/*.whl ' diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 83af4ebdd50..e7d26417d3d 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -902,6 +902,12 @@ def _set_cudagraph_sizes(self, max_capture_size: int = 0): draft_capture_sizes.append(max_capture_size) self.cudagraph_capture_sizes = sorted(draft_capture_sizes) + def filter_capture_size(self, tp_size: int = 1): + """When TSP is used, capture size must be divisible by tp size.""" + self.cudagraph_capture_sizes = [ + draft_size for draft_size in self.cudagraph_capture_sizes if (draft_size % tp_size == 0) + ] + def to_json_string(self): """ Convert speculative_config to json string. @@ -1628,7 +1634,15 @@ def postprocess(self): if self.device_config is not None and self.device_config.device_type != "cuda": self.graph_opt_config.use_cudagraph = False logger.info(f"CUDAGraph only support on GPU, current device type is {self.device_config.device_type}!") - + if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph: + if self.scheduler_config.max_num_seqs < self.parallel_config.tensor_parallel_size: + self.parallel_config.use_sequence_parallel_moe = False + logger.info( + "Warning: sequence parallel moe do not support max_num_seqs < tensor_parallel_size when cudagraph enabled. We set use_sequence_parallel_moe to False." + ) + else: + # It will hang when real batch_size < tp_size + self.graph_opt_config.filter_capture_size(tp_size=self.parallel_config.tensor_parallel_size) if self.model_config.enable_mm and self.graph_opt_config.use_cudagraph: self.cache_config.enable_prefix_caching = False logger.info("Multi-modal models do not support prefix caching when using CUDAGraph!") diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 23812e966a6..17a9944ffc8 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -512,8 +512,10 @@ def __post_init__(self): raise ValueError( "Please set --rdma_comm_ports argument when using " "rdma cache transfer protocol." ) - if len(self.rdma_comm_ports) != self.tensor_parallel_size: - raise ValueError("The number of rdma comm ports must be equal to tensor parallel size.") + if len(self.rdma_comm_ports) != self.tensor_parallel_size * self.data_parallel_size: + raise ValueError( + f"The number of rdma comm ports must be equal to number of ranks ({self.data_parallel_size=} * {self.tensor_parallel_size=} = {self.data_parallel_size * self.tensor_parallel_size}), but got {len(self.rdma_comm_ports)}." + ) if envs.ENABLE_V1_KVCACHE_SCHEDULER == 1: if "ipc" in self.cache_transfer_protocol: diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index a291db0e9a5..026d8c7d736 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -570,10 +570,11 @@ def __init__(self, fd_config: FDConfig): self.ernie = Ernie4_5_VLModel(fd_config=fd_config) # Persistent buffers for CUDA graphs. - self._input_embeddings = paddle.zeros( - [fd_config.model_config.max_model_len, fd_config.model_config.hidden_size], - dtype=fd_config.model_config.dtype, - ) + if fd_config.graph_opt_config.use_cudagraph: + self._decoder_input_embeddings = paddle.zeros( + [fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size], + dtype=fd_config.model_config.dtype, + ) self.ori_vocab_size = fd_config.model_config.ori_vocab_size @@ -783,10 +784,13 @@ def forward( image_features=image_features, image_token_num=vl_moe_meta.num_image_patch_id.item(), ) - self._input_embeddings.copy_(input_embeddings, False) + + if forward_meta.step_use_cudagraph: + self._decoder_input_embeddings.copy_(input_embeddings, False) + input_embeddings = self._decoder_input_embeddings hidden_states = self.ernie( - input_embeddings=self._input_embeddings, + input_embeddings=input_embeddings, ids_remove_padding=ids_remove_padding, forward_meta=forward_meta, vl_moe_meta=vl_moe_meta, diff --git a/fastdeploy/model_executor/models/ernie_vl_rm.py b/fastdeploy/model_executor/models/ernie_vl_rm.py index 86cddcb42c2..cfa29c84512 100644 --- a/fastdeploy/model_executor/models/ernie_vl_rm.py +++ b/fastdeploy/model_executor/models/ernie_vl_rm.py @@ -59,10 +59,11 @@ def __init__(self, fd_config: FDConfig): self.head_dtype = paddle.bfloat16 # Persistent buffers for CUDA graphs. - self._input_embeddings = paddle.zeros( - [fd_config.parallel_config.max_model_len, fd_config.model_config.hidden_size], - dtype=fd_config.model_config.dtype, - ) + if fd_config.graph_opt_config.use_cudagraph: + self._decoder_input_embeddings = paddle.zeros( + [fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size], + dtype=fd_config.model_config.dtype, + ) self.rm_head = nn.Sequential( ( @@ -112,10 +113,13 @@ def forward( image_features=image_features, image_token_num=vl_moe_meta.image_token_num.item(), ) - self._input_embeddings.copy_(input_embeddings, False) + + if forward_meta.step_use_cudagraph: + self._decoder_input_embeddings.copy_(input_embeddings, False) + input_embeddings = self._decoder_input_embeddings hidden_states = self.ernie( - input_embeddings=self._input_embeddings, + input_embeddings=input_embeddings, ids_remove_padding=ids_remove_padding, forward_meta=forward_meta, vl_moe_meta=vl_moe_meta, diff --git a/fastdeploy/model_executor/models/paddleocr_vl/paddleocr_vl.py b/fastdeploy/model_executor/models/paddleocr_vl/paddleocr_vl.py index 13afbe3c985..754af55dd87 100644 --- a/fastdeploy/model_executor/models/paddleocr_vl/paddleocr_vl.py +++ b/fastdeploy/model_executor/models/paddleocr_vl/paddleocr_vl.py @@ -132,10 +132,11 @@ def __init__(self, fd_config): ) # Persistent buffers for CUDA graphs. - self._decoder_input_embeddings = paddle.zeros( - [fd_config.scheduler_config.max_num_seqs, fd_config.model_config.hidden_size], - dtype=fd_config.model_config.dtype, - ) + if fd_config.graph_opt_config.use_cudagraph: + self._decoder_input_embeddings = paddle.zeros( + [fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size], + dtype=fd_config.model_config.dtype, + ) @paddle.no_grad() def load_weights(self, weights_iterator) -> None: @@ -242,15 +243,11 @@ def forward( if forward_meta.step_use_cudagraph: self._decoder_input_embeddings.copy_(input_embeddings, False) + input_embeddings = self._decoder_input_embeddings - hidden_states = self.model( - input_embeddings=self._decoder_input_embeddings, - forward_meta=forward_meta, - ) - else: - hidden_states = self.model( - input_embeddings=input_embeddings, - forward_meta=forward_meta, - ) + hidden_states = self.model( + input_embeddings=input_embeddings, + forward_meta=forward_meta, + ) return hidden_states diff --git a/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py b/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py index 531f530c449..0f17ec08f58 100644 --- a/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py +++ b/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py @@ -152,10 +152,11 @@ def __init__(self, fd_config: FDConfig): self.model = Qwen2_5_VLModel(fd_config=fd_config) # Persistent buffers for CUDA graphs. - self._input_embeddings = paddle.zeros( - [fd_config.model_config.max_model_len, fd_config.model_config.hidden_size], - dtype=fd_config.model_config.dtype, - ) + if fd_config.graph_opt_config.use_cudagraph: + self._decoder_input_embeddings = paddle.zeros( + [fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size], + dtype=fd_config.model_config.dtype, + ) self.ori_vocab_size = fd_config.model_config.ori_vocab_size @@ -290,10 +291,13 @@ def forward( input_embeddings = self.get_input_embeddings( ids_remove_padding=ids_remove_padding, image_features=image_features ) - self._input_embeddings.copy_(input_embeddings, False) + + if forward_meta.step_use_cudagraph: + self._decoder_input_embeddings.copy_(input_embeddings, False) + input_embeddings = self._decoder_input_embeddings hidden_states = self.model( - input_embeddings=self._input_embeddings, + input_embeddings=input_embeddings, ids_remove_padding=ids_remove_padding, image_features=image_features, forward_meta=forward_meta, diff --git a/setup.py b/setup.py index c40b006670e..94cbdc8178b 100644 --- a/setup.py +++ b/setup.py @@ -14,10 +14,12 @@ # limitations under the License. """ +import glob import os import re import subprocess import sys +from functools import lru_cache from pathlib import Path import paddle @@ -180,6 +182,68 @@ def get_device_type(): return "cpu" +def check_header(header_path): + return os.path.exists(header_path) + + +def check_library(lib_name): + # search /usr/lib /usr/lib64 /lib /lib64 .etc + paths = [ + "/usr/lib", + "/usr/lib32", + "/usr/lib64", + "/usr/lib/x86_64-linux-gnu", + "/lib", + "/lib32", + "/lib64", + "/usr/local/lib", + "/usr/local/lib64", + ] + for p in paths: + if glob.glob(os.path.join(p, lib_name)): + return True + return False + + +def check_rdma_packages(): + results = {} + + # libibverbs-dev + results["libibverbs header"] = check_header("/usr/include/infiniband/verbs.h") + results["libibverbs library"] = check_library("libibverbs.so*") or check_library("libibverbs.so") + + # librdmacm-dev + results["librdmacm header"] = check_header("/usr/include/rdma/rdma_cma.h") + results["librdmacm library"] = check_library("librdmacm.so*") or check_library("librdmacm.so") + + print("===== RDMA Library Check Results =====") + for k, v in results.items(): + status = "FOUND" if v else "NOT FOUND" + print(f"{k:25}: {status}") + + print("\n== Summary ==") + if all(results.values()): + print("All required RDMA libraries are installed.") + return True + else: + print("Some RDMA libraries are missing. Suggested commands:") + print("\nUbuntu/Debian:") + print(" sudo apt-get install -y libibverbs-dev librdmacm-dev") + print("\nCentOS/RHEL:") + print(" sudo yum install -y libibverbs-devel librdmacm-devel") + return False + + +@lru_cache(maxsize=1) +def rdma_comm_supported(): + supported = ( + get_device_type() in ["gpu", "xpu"] + and check_rdma_packages() + and os.getenv("FD_ENABLE_RDMA_COMPILE", "1") == "1" + ) + return supported + + def get_name(): """get package name""" return "fastdeploy-" + get_device_type() @@ -237,10 +301,10 @@ def write_version_to_file(): version=None, ) ] - if os.getenv("ENABLE_FD_RDMA", "0") == "1" + if rdma_comm_supported() else [] ), - cmdclass=cmdclass_dict if os.getenv("ENABLE_FD_RDMA", "0") == "1" else {}, + cmdclass=cmdclass_dict if rdma_comm_supported() else {}, zip_safe=False, classifiers=[ "Programming Language :: Python :: 3",