File tree Expand file tree Collapse file tree 6 files changed +23
-17
lines changed Expand file tree Collapse file tree 6 files changed +23
-17
lines changed Original file line number Diff line number Diff line change @@ -1633,12 +1633,14 @@ def postprocess(self):
16331633 logger .info ("Multi-modal models do not support prefix caching when using CUDAGraph!" )
16341634
16351635 if self .scheduler_config .splitwise_role == "mixed" :
1636+ # Sequence parallel MoE is incompatible with CUDA graph now. It will hang.
16361637 if self .graph_opt_config .use_cudagraph :
16371638 self .parallel_config .use_sequence_parallel_moe = False
16381639 self .model_config .moe_phase = MoEPhase (phase = "prefill" )
16391640 elif self .scheduler_config .splitwise_role == "prefill" :
16401641 self .model_config .moe_phase = MoEPhase (phase = "prefill" )
16411642 elif self .scheduler_config .splitwise_role == "decode" :
1643+ # Sequence parallel MoE is incompatible with CUDA graph now. It will hang.
16421644 if self .graph_opt_config .use_cudagraph :
16431645 self .parallel_config .use_sequence_parallel_moe = False
16441646 self .model_config .moe_phase = MoEPhase (phase = "decode" )
Original file line number Diff line number Diff line change @@ -509,7 +509,7 @@ def __post_init__(self):
509509 )
510510 if len (self .rdma_comm_ports ) != self .tensor_parallel_size * self .data_parallel_size :
511511 raise ValueError (
512- f"The number of rdma comm ports must be equal to number of ranks(dp_size: { self .data_parallel_size } * tp_size: { self .tensor_parallel_size } )."
512+ f"The number of rdma comm ports must be equal to number of ranks ( { self .data_parallel_size = } * { self . tensor_parallel_size = } = { self .data_parallel_size * self . tensor_parallel_size } ), but got { len ( self . rdma_comm_ports ) } ."
513513 )
514514
515515 if envs .ENABLE_V1_KVCACHE_SCHEDULER == 1 :
Original file line number Diff line number Diff line change @@ -570,10 +570,11 @@ def __init__(self, fd_config: FDConfig):
570570 self .ernie = Ernie4_5_VLModel (fd_config = fd_config )
571571
572572 # Persistent buffers for CUDA graphs.
573- self ._decoder_input_embeddings = paddle .zeros (
574- [fd_config .graph_opt_config .max_capture_size , fd_config .model_config .hidden_size ],
575- dtype = fd_config .model_config .dtype ,
576- )
573+ if fd_config .graph_opt_config .use_cudagraph :
574+ self ._decoder_input_embeddings = paddle .zeros (
575+ [fd_config .graph_opt_config .max_capture_size , fd_config .model_config .hidden_size ],
576+ dtype = fd_config .model_config .dtype ,
577+ )
577578
578579 self .ori_vocab_size = fd_config .model_config .ori_vocab_size
579580
Original file line number Diff line number Diff line change @@ -59,10 +59,11 @@ def __init__(self, fd_config: FDConfig):
5959 self .head_dtype = paddle .bfloat16
6060
6161 # Persistent buffers for CUDA graphs.
62- self ._decoder_input_embeddings = paddle .zeros (
63- [fd_config .graph_opt_config .max_capture_size , fd_config .model_config .hidden_size ],
64- dtype = fd_config .model_config .dtype ,
65- )
62+ if fd_config .graph_opt_config .use_cudagraph :
63+ self ._decoder_input_embeddings = paddle .zeros (
64+ [fd_config .graph_opt_config .max_capture_size , fd_config .model_config .hidden_size ],
65+ dtype = fd_config .model_config .dtype ,
66+ )
6667
6768 self .rm_head = nn .Sequential (
6869 (
Original file line number Diff line number Diff line change @@ -132,10 +132,11 @@ def __init__(self, fd_config):
132132 )
133133
134134 # Persistent buffers for CUDA graphs.
135- self ._decoder_input_embeddings = paddle .zeros (
136- [fd_config .graph_opt_config .max_capture_size , fd_config .model_config .hidden_size ],
137- dtype = fd_config .model_config .dtype ,
138- )
135+ if fd_config .graph_opt_config .use_cudagraph :
136+ self ._decoder_input_embeddings = paddle .zeros (
137+ [fd_config .graph_opt_config .max_capture_size , fd_config .model_config .hidden_size ],
138+ dtype = fd_config .model_config .dtype ,
139+ )
139140
140141 @paddle .no_grad ()
141142 def load_weights (self , weights_iterator ) -> None :
Original file line number Diff line number Diff line change @@ -152,10 +152,11 @@ def __init__(self, fd_config: FDConfig):
152152 self .model = Qwen2_5_VLModel (fd_config = fd_config )
153153
154154 # Persistent buffers for CUDA graphs.
155- self ._decoder_input_embeddings = paddle .zeros (
156- [fd_config .graph_opt_config .max_capture_size , fd_config .model_config .hidden_size ],
157- dtype = fd_config .model_config .dtype ,
158- )
155+ if fd_config .graph_opt_config .use_cudagraph :
156+ self ._decoder_input_embeddings = paddle .zeros (
157+ [fd_config .graph_opt_config .max_capture_size , fd_config .model_config .hidden_size ],
158+ dtype = fd_config .model_config .dtype ,
159+ )
159160
160161 self .ori_vocab_size = fd_config .model_config .ori_vocab_size
161162
You can’t perform that action at this time.
0 commit comments