Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,10 +1633,14 @@ def postprocess(self):
logger.info("Multi-modal models do not support prefix caching when using CUDAGraph!")

if self.scheduler_config.splitwise_role == "mixed":
if self.graph_opt_config.use_cudagraph:
self.parallel_config.use_sequence_parallel_moe = False
self.model_config.moe_phase = MoEPhase(phase="prefill")
elif self.scheduler_config.splitwise_role == "prefill":
self.model_config.moe_phase = MoEPhase(phase="prefill")
elif self.scheduler_config.splitwise_role == "decode":
if self.graph_opt_config.use_cudagraph:
self.parallel_config.use_sequence_parallel_moe = False
self.model_config.moe_phase = MoEPhase(phase="decode")
else:
raise NotImplementedError
Expand Down
6 changes: 4 additions & 2 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,10 @@ def __post_init__(self):
raise ValueError(
"Please set --rdma_comm_ports argument when using " "rdma cache transfer protocol."
)
if len(self.rdma_comm_ports) != self.tensor_parallel_size:
raise ValueError("The number of rdma comm ports must be equal to tensor parallel size.")
if len(self.rdma_comm_ports) != self.tensor_parallel_size * self.data_parallel_size:
raise ValueError(
f"The number of rdma comm ports must be equal to number of ranks(dp_size: {self.data_parallel_size} * tp_size: {self.tensor_parallel_size})."
)

if envs.ENABLE_V1_KVCACHE_SCHEDULER == 1:
if "ipc" in self.cache_transfer_protocol:
Expand Down
11 changes: 7 additions & 4 deletions fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,8 @@ def __init__(self, fd_config: FDConfig):
self.ernie = Ernie4_5_VLModel(fd_config=fd_config)

# Persistent buffers for CUDA graphs.
self._input_embeddings = paddle.zeros(
[fd_config.model_config.max_model_len, fd_config.model_config.hidden_size],
self._decoder_input_embeddings = paddle.zeros(
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)

Expand Down Expand Up @@ -783,10 +783,13 @@ def forward(
image_features=image_features,
image_token_num=vl_moe_meta.num_image_patch_id.item(),
)
self._input_embeddings.copy_(input_embeddings, False)

if forward_meta.step_use_cudagraph:
self._decoder_input_embeddings.copy_(input_embeddings, False)
input_embeddings = self._decoder_input_embeddings

hidden_states = self.ernie(
input_embeddings=self._input_embeddings,
input_embeddings=input_embeddings,
ids_remove_padding=ids_remove_padding,
forward_meta=forward_meta,
vl_moe_meta=vl_moe_meta,
Expand Down
11 changes: 7 additions & 4 deletions fastdeploy/model_executor/models/ernie_vl_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def __init__(self, fd_config: FDConfig):
self.head_dtype = paddle.bfloat16

# Persistent buffers for CUDA graphs.
self._input_embeddings = paddle.zeros(
[fd_config.parallel_config.max_model_len, fd_config.model_config.hidden_size],
self._decoder_input_embeddings = paddle.zeros(
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)

Expand Down Expand Up @@ -112,10 +112,13 @@ def forward(
image_features=image_features,
image_token_num=vl_moe_meta.image_token_num.item(),
)
self._input_embeddings.copy_(input_embeddings, False)

if forward_meta.step_use_cudagraph:
self._decoder_input_embeddings.copy_(input_embeddings, False)
input_embeddings = self._decoder_input_embeddings

hidden_states = self.ernie(
input_embeddings=self._input_embeddings,
input_embeddings=input_embeddings,
ids_remove_padding=ids_remove_padding,
forward_meta=forward_meta,
vl_moe_meta=vl_moe_meta,
Expand Down
16 changes: 6 additions & 10 deletions fastdeploy/model_executor/models/paddleocr_vl/paddleocr_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(self, fd_config):

# Persistent buffers for CUDA graphs.
self._decoder_input_embeddings = paddle.zeros(
[fd_config.scheduler_config.max_num_seqs, fd_config.model_config.hidden_size],
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)

Expand Down Expand Up @@ -242,15 +242,11 @@ def forward(

if forward_meta.step_use_cudagraph:
self._decoder_input_embeddings.copy_(input_embeddings, False)
input_embeddings = self._decoder_input_embeddings

hidden_states = self.model(
input_embeddings=self._decoder_input_embeddings,
forward_meta=forward_meta,
)
else:
hidden_states = self.model(
input_embeddings=input_embeddings,
forward_meta=forward_meta,
)
hidden_states = self.model(
input_embeddings=input_embeddings,
forward_meta=forward_meta,
)

return hidden_states
11 changes: 7 additions & 4 deletions fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def __init__(self, fd_config: FDConfig):
self.model = Qwen2_5_VLModel(fd_config=fd_config)

# Persistent buffers for CUDA graphs.
self._input_embeddings = paddle.zeros(
[fd_config.model_config.max_model_len, fd_config.model_config.hidden_size],
self._decoder_input_embeddings = paddle.zeros(
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)

Expand Down Expand Up @@ -290,10 +290,13 @@ def forward(
input_embeddings = self.get_input_embeddings(
ids_remove_padding=ids_remove_padding, image_features=image_features
)
self._input_embeddings.copy_(input_embeddings, False)

if forward_meta.step_use_cudagraph:
self._decoder_input_embeddings.copy_(input_embeddings, False)
input_embeddings = self._decoder_input_embeddings

hidden_states = self.model(
input_embeddings=self._input_embeddings,
input_embeddings=input_embeddings,
ids_remove_padding=ids_remove_padding,
image_features=image_features,
forward_meta=forward_meta,
Expand Down
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,8 @@ def write_version_to_file():
version=None,
)
]
if os.getenv("ENABLE_FD_RDMA", "0") == "1"
else []
),
cmdclass=cmdclass_dict if os.getenv("ENABLE_FD_RDMA", "0") == "1" else {},
cmdclass=cmdclass_dict,
zip_safe=False,
classifiers=[
"Programming Language :: Python :: 3",
Expand Down
Loading