@@ -900,6 +900,12 @@ def _set_cudagraph_sizes(self, max_capture_size: int = 0):
900900 draft_capture_sizes .append (max_capture_size )
901901 self .cudagraph_capture_sizes = sorted (draft_capture_sizes )
902902
903+ def filter_capture_size (self , tp_size : int = 1 ):
904+ """When TSP is used, capture size must be divisible by tp size."""
905+ self .cudagraph_capture_sizes = [
906+ draft_size for draft_size in self .cudagraph_capture_sizes if (draft_size % tp_size == 0 )
907+ ]
908+
903909 def to_json_string (self ):
904910 """
905911 Convert speculative_config to json string.
@@ -1617,6 +1623,8 @@ def postprocess(self):
16171623 self .cache_config .max_encoder_cache = 0
16181624
16191625 # Adjustment GraphOptConfig
1626+ if self .parallel_config .use_sequence_parallel_moe :
1627+ self .graph_opt_config .filter_capture_size (tp_size = self .parallel_config .tensor_parallel_size )
16201628 if self .scheduler_config is not None and self .scheduler_config .splitwise_role == "prefill" :
16211629 self .graph_opt_config .use_cudagraph = self .graph_opt_config .cudagraph_only_prefill
16221630 if self .load_config is not None and self .load_config .dynamic_load_weight is True :
@@ -1633,16 +1641,10 @@ def postprocess(self):
16331641 logger .info ("Multi-modal models do not support prefix caching when using CUDAGraph!" )
16341642
16351643 if self .scheduler_config .splitwise_role == "mixed" :
1636- # Sequence parallel MoE is incompatible with CUDA graph now. It will hang.
1637- if self .graph_opt_config .use_cudagraph :
1638- self .parallel_config .use_sequence_parallel_moe = False
16391644 self .model_config .moe_phase = MoEPhase (phase = "prefill" )
16401645 elif self .scheduler_config .splitwise_role == "prefill" :
16411646 self .model_config .moe_phase = MoEPhase (phase = "prefill" )
16421647 elif self .scheduler_config .splitwise_role == "decode" :
1643- # Sequence parallel MoE is incompatible with CUDA graph now. It will hang.
1644- if self .graph_opt_config .use_cudagraph :
1645- self .parallel_config .use_sequence_parallel_moe = False
16461648 self .model_config .moe_phase = MoEPhase (phase = "decode" )
16471649 else :
16481650 raise NotImplementedError
0 commit comments