Skip to content

Commit 7ac2593

Browse files
authored
[Optimization] default compile rdma, reduce cudagraph buffer size in mm, fix some config bug (#5121)
* default compile rdma, reduce cudagraph buffer size in mm, fix some config logic * update * update * fix bug * enhance rdma compile * fix
1 parent 6fa3410 commit 7ac2593

File tree

8 files changed

+126
-37
lines changed

8 files changed

+126
-37
lines changed

.github/workflows/_build_linux.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ jobs:
164164
python -m pip install -r requirements.txt
165165
python -m pip install wheel
166166
# 编译RDMA
167-
export ENABLE_FD_RDMA=1
167+
export FD_ENABLE_RDMA_COMPILE=1
168168
bash build.sh 1 python false [${COMPILE_ARCH}]
169169
ls ./dist/*.whl
170170
'

fastdeploy/config.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,12 @@ def _set_cudagraph_sizes(self, max_capture_size: int = 0):
902902
draft_capture_sizes.append(max_capture_size)
903903
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
904904

905+
def filter_capture_size(self, tp_size: int = 1):
906+
"""When TSP is used, capture size must be divisible by tp size."""
907+
self.cudagraph_capture_sizes = [
908+
draft_size for draft_size in self.cudagraph_capture_sizes if (draft_size % tp_size == 0)
909+
]
910+
905911
def to_json_string(self):
906912
"""
907913
Convert speculative_config to json string.
@@ -1628,7 +1634,15 @@ def postprocess(self):
16281634
if self.device_config is not None and self.device_config.device_type != "cuda":
16291635
self.graph_opt_config.use_cudagraph = False
16301636
logger.info(f"CUDAGraph only support on GPU, current device type is {self.device_config.device_type}!")
1631-
1637+
if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph:
1638+
if self.scheduler_config.max_num_seqs < self.parallel_config.tensor_parallel_size:
1639+
self.parallel_config.use_sequence_parallel_moe = False
1640+
logger.info(
1641+
"Warning: sequence parallel moe do not support max_num_seqs < tensor_parallel_size when cudagraph enabled. We set use_sequence_parallel_moe to False."
1642+
)
1643+
else:
1644+
# It will hang when real batch_size < tp_size
1645+
self.graph_opt_config.filter_capture_size(tp_size=self.parallel_config.tensor_parallel_size)
16321646
if self.model_config.enable_mm and self.graph_opt_config.use_cudagraph:
16331647
self.cache_config.enable_prefix_caching = False
16341648
logger.info("Multi-modal models do not support prefix caching when using CUDAGraph!")

fastdeploy/engine/args_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,8 +512,10 @@ def __post_init__(self):
512512
raise ValueError(
513513
"Please set --rdma_comm_ports argument when using " "rdma cache transfer protocol."
514514
)
515-
if len(self.rdma_comm_ports) != self.tensor_parallel_size:
516-
raise ValueError("The number of rdma comm ports must be equal to tensor parallel size.")
515+
if len(self.rdma_comm_ports) != self.tensor_parallel_size * self.data_parallel_size:
516+
raise ValueError(
517+
f"The number of rdma comm ports must be equal to number of ranks ({self.data_parallel_size=} * {self.tensor_parallel_size=} = {self.data_parallel_size * self.tensor_parallel_size}), but got {len(self.rdma_comm_ports)}."
518+
)
517519

518520
if envs.ENABLE_V1_KVCACHE_SCHEDULER == 1:
519521
if "ipc" in self.cache_transfer_protocol:

fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -570,10 +570,11 @@ def __init__(self, fd_config: FDConfig):
570570
self.ernie = Ernie4_5_VLModel(fd_config=fd_config)
571571

572572
# Persistent buffers for CUDA graphs.
573-
self._input_embeddings = paddle.zeros(
574-
[fd_config.model_config.max_model_len, fd_config.model_config.hidden_size],
575-
dtype=fd_config.model_config.dtype,
576-
)
573+
if fd_config.graph_opt_config.use_cudagraph:
574+
self._decoder_input_embeddings = paddle.zeros(
575+
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
576+
dtype=fd_config.model_config.dtype,
577+
)
577578

578579
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
579580

@@ -783,10 +784,13 @@ def forward(
783784
image_features=image_features,
784785
image_token_num=vl_moe_meta.num_image_patch_id.item(),
785786
)
786-
self._input_embeddings.copy_(input_embeddings, False)
787+
788+
if forward_meta.step_use_cudagraph:
789+
self._decoder_input_embeddings.copy_(input_embeddings, False)
790+
input_embeddings = self._decoder_input_embeddings
787791

788792
hidden_states = self.ernie(
789-
input_embeddings=self._input_embeddings,
793+
input_embeddings=input_embeddings,
790794
ids_remove_padding=ids_remove_padding,
791795
forward_meta=forward_meta,
792796
vl_moe_meta=vl_moe_meta,

fastdeploy/model_executor/models/ernie_vl_rm.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,11 @@ def __init__(self, fd_config: FDConfig):
5959
self.head_dtype = paddle.bfloat16
6060

6161
# Persistent buffers for CUDA graphs.
62-
self._input_embeddings = paddle.zeros(
63-
[fd_config.parallel_config.max_model_len, fd_config.model_config.hidden_size],
64-
dtype=fd_config.model_config.dtype,
65-
)
62+
if fd_config.graph_opt_config.use_cudagraph:
63+
self._decoder_input_embeddings = paddle.zeros(
64+
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
65+
dtype=fd_config.model_config.dtype,
66+
)
6667

6768
self.rm_head = nn.Sequential(
6869
(
@@ -112,10 +113,13 @@ def forward(
112113
image_features=image_features,
113114
image_token_num=vl_moe_meta.image_token_num.item(),
114115
)
115-
self._input_embeddings.copy_(input_embeddings, False)
116+
117+
if forward_meta.step_use_cudagraph:
118+
self._decoder_input_embeddings.copy_(input_embeddings, False)
119+
input_embeddings = self._decoder_input_embeddings
116120

117121
hidden_states = self.ernie(
118-
input_embeddings=self._input_embeddings,
122+
input_embeddings=input_embeddings,
119123
ids_remove_padding=ids_remove_padding,
120124
forward_meta=forward_meta,
121125
vl_moe_meta=vl_moe_meta,

fastdeploy/model_executor/models/paddleocr_vl/paddleocr_vl.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,11 @@ def __init__(self, fd_config):
132132
)
133133

134134
# Persistent buffers for CUDA graphs.
135-
self._decoder_input_embeddings = paddle.zeros(
136-
[fd_config.scheduler_config.max_num_seqs, fd_config.model_config.hidden_size],
137-
dtype=fd_config.model_config.dtype,
138-
)
135+
if fd_config.graph_opt_config.use_cudagraph:
136+
self._decoder_input_embeddings = paddle.zeros(
137+
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
138+
dtype=fd_config.model_config.dtype,
139+
)
139140

140141
@paddle.no_grad()
141142
def load_weights(self, weights_iterator) -> None:
@@ -242,15 +243,11 @@ def forward(
242243

243244
if forward_meta.step_use_cudagraph:
244245
self._decoder_input_embeddings.copy_(input_embeddings, False)
246+
input_embeddings = self._decoder_input_embeddings
245247

246-
hidden_states = self.model(
247-
input_embeddings=self._decoder_input_embeddings,
248-
forward_meta=forward_meta,
249-
)
250-
else:
251-
hidden_states = self.model(
252-
input_embeddings=input_embeddings,
253-
forward_meta=forward_meta,
254-
)
248+
hidden_states = self.model(
249+
input_embeddings=input_embeddings,
250+
forward_meta=forward_meta,
251+
)
255252

256253
return hidden_states

fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,11 @@ def __init__(self, fd_config: FDConfig):
152152
self.model = Qwen2_5_VLModel(fd_config=fd_config)
153153

154154
# Persistent buffers for CUDA graphs.
155-
self._input_embeddings = paddle.zeros(
156-
[fd_config.model_config.max_model_len, fd_config.model_config.hidden_size],
157-
dtype=fd_config.model_config.dtype,
158-
)
155+
if fd_config.graph_opt_config.use_cudagraph:
156+
self._decoder_input_embeddings = paddle.zeros(
157+
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
158+
dtype=fd_config.model_config.dtype,
159+
)
159160

160161
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
161162

@@ -290,10 +291,13 @@ def forward(
290291
input_embeddings = self.get_input_embeddings(
291292
ids_remove_padding=ids_remove_padding, image_features=image_features
292293
)
293-
self._input_embeddings.copy_(input_embeddings, False)
294+
295+
if forward_meta.step_use_cudagraph:
296+
self._decoder_input_embeddings.copy_(input_embeddings, False)
297+
input_embeddings = self._decoder_input_embeddings
294298

295299
hidden_states = self.model(
296-
input_embeddings=self._input_embeddings,
300+
input_embeddings=input_embeddings,
297301
ids_remove_padding=ids_remove_padding,
298302
image_features=image_features,
299303
forward_meta=forward_meta,

setup.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
# limitations under the License.
1515
"""
1616

17+
import glob
1718
import os
1819
import re
1920
import subprocess
2021
import sys
22+
from functools import lru_cache
2123
from pathlib import Path
2224

2325
import paddle
@@ -180,6 +182,68 @@ def get_device_type():
180182
return "cpu"
181183

182184

185+
def check_header(header_path):
186+
return os.path.exists(header_path)
187+
188+
189+
def check_library(lib_name):
190+
# search /usr/lib /usr/lib64 /lib /lib64 .etc
191+
paths = [
192+
"/usr/lib",
193+
"/usr/lib32",
194+
"/usr/lib64",
195+
"/usr/lib/x86_64-linux-gnu",
196+
"/lib",
197+
"/lib32",
198+
"/lib64",
199+
"/usr/local/lib",
200+
"/usr/local/lib64",
201+
]
202+
for p in paths:
203+
if glob.glob(os.path.join(p, lib_name)):
204+
return True
205+
return False
206+
207+
208+
def check_rdma_packages():
209+
results = {}
210+
211+
# libibverbs-dev
212+
results["libibverbs header"] = check_header("/usr/include/infiniband/verbs.h")
213+
results["libibverbs library"] = check_library("libibverbs.so*") or check_library("libibverbs.so")
214+
215+
# librdmacm-dev
216+
results["librdmacm header"] = check_header("/usr/include/rdma/rdma_cma.h")
217+
results["librdmacm library"] = check_library("librdmacm.so*") or check_library("librdmacm.so")
218+
219+
print("===== RDMA Library Check Results =====")
220+
for k, v in results.items():
221+
status = "FOUND" if v else "NOT FOUND"
222+
print(f"{k:25}: {status}")
223+
224+
print("\n== Summary ==")
225+
if all(results.values()):
226+
print("All required RDMA libraries are installed.")
227+
return True
228+
else:
229+
print("Some RDMA libraries are missing. Suggested commands:")
230+
print("\nUbuntu/Debian:")
231+
print(" sudo apt-get install -y libibverbs-dev librdmacm-dev")
232+
print("\nCentOS/RHEL:")
233+
print(" sudo yum install -y libibverbs-devel librdmacm-devel")
234+
return False
235+
236+
237+
@lru_cache(maxsize=1)
238+
def rdma_comm_supported():
239+
supported = (
240+
get_device_type() in ["gpu", "xpu"]
241+
and check_rdma_packages()
242+
and os.getenv("FD_ENABLE_RDMA_COMPILE", "1") == "1"
243+
)
244+
return supported
245+
246+
183247
def get_name():
184248
"""get package name"""
185249
return "fastdeploy-" + get_device_type()
@@ -237,10 +301,10 @@ def write_version_to_file():
237301
version=None,
238302
)
239303
]
240-
if os.getenv("ENABLE_FD_RDMA", "0") == "1"
304+
if rdma_comm_supported()
241305
else []
242306
),
243-
cmdclass=cmdclass_dict if os.getenv("ENABLE_FD_RDMA", "0") == "1" else {},
307+
cmdclass=cmdclass_dict if rdma_comm_supported() else {},
244308
zip_safe=False,
245309
classifiers=[
246310
"Programming Language :: Python :: 3",

0 commit comments

Comments
 (0)