diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index a89a8a9246e..5f62b3f2ea3 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1301,7 +1301,6 @@ def triton_mrope( return q, k -@torch._dynamo.disable() def triton_mrope_wrapper( query, key, @@ -1428,15 +1427,18 @@ def _forward_native( dim=-1, ) + seq_len_q = query.shape[0] query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) + query = query.view(seq_len_q, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + seq_len_k = key.shape[0] key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) + key = key.view(seq_len_k, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) @@ -1467,7 +1469,6 @@ def forward( else: return self._forward_native(positions, query, key) - @torch.compile(dynamic=True, backend=get_compiler_backend()) def _forward_triton( self, positions: torch.Tensor, @@ -1502,7 +1503,9 @@ def _forward_triton( return q.reshape(query_shape), k.reshape(key_shape) - query = query.view(num_tokens, -1, self.head_size) + seq_len_q = query.shape[0] + query = query.view(seq_len_q, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 98b9371ab17..2c75a16f4d3 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -494,6 +494,7 @@ def get_embedding_and_mask( def embed_mm_inputs( + forward_batch: ForwardBatch, mm_inputs_list: List[MultimodalInputs], extend_prefix_lens: List[int], extend_seq_lens: List[int], @@ -503,8 +504,6 @@ def embed_mm_inputs( data_embedding_func_mapping: Dict[ Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] ] = None, - placeholder_tokens: dict[Modality, List[int]] = None, - use_deepstack: Dict[Modality, bool] = {}, ) -> Optional[torch.Tensor]: """ Embed multimodal inputs and integrate them with text token embeddings. @@ -515,12 +514,10 @@ def embed_mm_inputs( extend_seq_lens: Sequence lengths for each request input_ids: Input token IDs tensor input_embedding: Embedding layer for text tokens - placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None) Returns: Combined embedding tensor with multimodal content integrated """ - other_info = {} if mm_inputs_list is None: return None @@ -530,8 +527,7 @@ def embed_mm_inputs( for mm_inputs in mm_inputs_list: item_flatten_list += [item for item in mm_inputs.mm_items if item is not None] - # deepstack_embeddings: per-modality - modalities, embeddings, masks, deepstack_embeddings = [], [], [], [] + modalities, embeddings, masks = [], [], [] # 2. Get multimodal embedding separately # Try get mm embedding if any @@ -580,11 +576,6 @@ def embed_mm_inputs( items_offset_list=items_offsets, ) - if use_deepstack.get(modality, None) and embedding is not None: - embedding, deepstack_embedding = ( - multimodal_model.separate_deepstack_embeds(embedding) - ) - deepstack_embeddings += [deepstack_embedding] modalities += [modality] embeddings += [embedding] masks += [mask] @@ -598,37 +589,29 @@ def embed_mm_inputs( input_ids.clamp_(min=0, max=vocab_size - 1) inputs_embeds = input_embedding(input_ids) - # deepstack embedding - if use_deepstack: - num_deepstack_embeddings = len(multimodal_model.deepstack_visual_indexes) + indices = [] + for mask in masks: + if mask is not None: + indices.append(torch.where(mask.squeeze(dim=-1))[0]) + else: + indices.append(None) - deepstack_embedding_shape = inputs_embeds.shape[:-1] + ( - inputs_embeds.shape[-1] * num_deepstack_embeddings, + # only for qwen3vl right now, replace the original use_deepstack with this method. + if hasattr(multimodal_model, "post_process"): + embeddings, forward_batch = multimodal_model.post_process( + inputs_embeds, modalities, embeddings, indices, forward_batch ) - # a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size - input_deepstack_embeds = torch.zeros( - deepstack_embedding_shape, - device=inputs_embeds.device, - dtype=inputs_embeds.dtype, - ) - - other_info["input_deepstack_embeds"] = input_deepstack_embeds # 4. scatter embeddings into input embedding - for i, modality, embedding, mask in zip( - range(len(embeddings)), modalities, embeddings, masks + for i, modality, embedding, index in zip( + range(len(embeddings)), modalities, embeddings, indices ): - if embedding is None or mask is None: + if embedding is None or index is None: continue # in-place update - indices = torch.where(mask.squeeze(dim=-1))[0] - inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype) - if use_deepstack.get(modality, None): - input_deepstack_embeds[indices] = deepstack_embeddings[i].to( - inputs_embeds.device, inputs_embeds.dtype - ) + inputs_embeds[index] = embedding.to(inputs_embeds.device, inputs_embeds.dtype) - return inputs_embeds, other_info + return inputs_embeds, forward_batch def general_mm_embed_routine( @@ -639,8 +622,7 @@ def general_mm_embed_routine( data_embedding_funcs: Dict[ Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] ] = None, - placeholder_tokens: Optional[dict[Modality, List[int]]] = None, - use_deepstack: Dict[Modality, bool] = {}, + skip_llm_forward: bool = False, **kwargs, ) -> torch.Tensor: """ @@ -651,8 +633,6 @@ def general_mm_embed_routine( forward_batch: Batch information for model forward pass language_model: Base language model to use data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function. - placeholder_tokens: Token IDs for multimodal placeholders - use_deepstack: Whether to use deepstack embeddings for each modality, default False **kwargs: Additional arguments passed to language model Returns: @@ -679,7 +659,9 @@ def general_mm_embed_routine( for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu) if forward_batch.mm_inputs[i] is not None ] - inputs_embeds, other_info = embed_mm_inputs( + + inputs_embeds, forward_batch = embed_mm_inputs( + forward_batch=forward_batch, mm_inputs_list=mm_inputs_list, extend_prefix_lens=extend_prefix_lens, extend_seq_lens=extend_seq_lens, @@ -687,12 +669,7 @@ def general_mm_embed_routine( multimodal_model=multimodal_model, input_embedding=embed_tokens, data_embedding_func_mapping=data_embedding_funcs, - placeholder_tokens=placeholder_tokens, - use_deepstack=use_deepstack, ) - # add for qwen3_vl deepstack - if use_deepstack: - kwargs["input_deepstack_embeds"] = other_info["input_deepstack_embeds"] # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models # just being defensive here forward_batch.mm_inputs = None @@ -701,6 +678,9 @@ def general_mm_embed_routine( else: inputs_embeds = None + if skip_llm_forward: + return inputs_embeds + hidden_states = language_model( input_ids=None, forward_batch=forward_batch, @@ -816,3 +796,144 @@ def hash_feature(f): reconstruct_t = f.reconstruct_on_target_device(torch.cuda.current_device()) return tensor_hash([reconstruct_t]) return data_hash(f) + + +def resolve_language_model(multimodal_model: nn.Module) -> Optional[nn.Module]: + # Qwen2-VL / Qwen3-VL Style + if hasattr(multimodal_model, "model"): + lm = getattr(multimodal_model, "model") + if hasattr(lm, "get_input_embeddings"): + return lm + + # Llava / OneVision Style + if hasattr(multimodal_model, "language_model"): + lm = getattr(multimodal_model, "language_model") + if hasattr(lm, "get_input_embeddings"): + return lm + + if hasattr(multimodal_model, "get_input_embeddings"): + return multimodal_model + + return None + + +def should_use_external_mm_preprocess(multimodal_model: nn.Module) -> bool: + """Decide whether we should use our generic "multimodal_preprocess_routine". + + We are adapting VLM for piecewise CUDA graph. Since the encoder's forward + pass cannot be executed within the model's forward pass, we need to + precompute image embeddings using the encoder within the model runner. + For models that have already been adjusted, there is a member called + should_use_external_mm_preprocess, which is set to True. In practice, + the multimodal_preprocess_routine function will be called in the + model_runner.forward_extend to handle multimodal inputs. + + For models that have not yet been adapted, the general_mm_embed_routine + will still be called in the model class's forward function for processing. + + Current strategy: + - Llava family (models with vision_tower + multi_modal_projector): + Their forward already calls general_mm_embed_routine and includes + built-in multimodal processing. If we run it again in ModelRunner, + it will conflict with the internal logic, so we skip it here. + - Others (such as Qwen2-VL / Qwen2.5-VL / Qwen3-VL): use the + multimodal preprocessing. + """ + + cls_name = multimodal_model.__class__.__name__ + + qwen_vl_classes = { + "Qwen2VLForConditionalGeneration", + "Qwen2_5_VLForConditionalGeneration", + "Qwen3VLForConditionalGeneration", + "Qwen3VLMoeForConditionalGeneration", + "Qwen3OmniMoeForConditionalGeneration", + } + + return cls_name in qwen_vl_classes + + +def multimodal_preprocess_routine( + forward_batch: ForwardBatch, + multimodal_model: Optional[nn.Module] = None, + data_embedding_funcs: Dict[ + Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] + ] = None, +) -> torch.Tensor: + """ + Process multimodal inputs and forward through language model. + Args: + input_ids: Input token IDs tensor + forward_batch: Batch information for model forward pass + data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function. + **kwargs: Additional arguments passed to language model + Returns: + Hidden states from language model forward pass + """ + if not should_use_external_mm_preprocess(multimodal_model): + return forward_batch + + language_model = resolve_language_model(multimodal_model) + if language_model is None: + raise ValueError( + f"Cannot resolve language model from {type(multimodal_model).__name__}. " + f"Please ensure the model has 'model' or 'language_model' attribute." + ) + + assert hasattr(language_model, "get_input_embeddings") + embed_tokens = language_model.get_input_embeddings() + if not hasattr(language_model, "pp_group") or language_model.pp_group.is_first_rank: + + input_ids = forward_batch.input_ids + if ( + not forward_batch.forward_mode.is_decode() + and not forward_batch.forward_mode.is_target_verify() + and forward_batch.contains_mm_inputs() + ): + mm_inputs_list = [ + mm_input for mm_input in forward_batch.mm_inputs if mm_input is not None + ] + extend_prefix_lens = [ + prefix_len + for i, prefix_len in enumerate(forward_batch.extend_prefix_lens_cpu) + if forward_batch.mm_inputs[i] is not None + ] + extend_seq_lens = [ + seq_len + for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu) + if forward_batch.mm_inputs[i] is not None + ] + input_embeds, forward_batch = embed_mm_inputs( + forward_batch=forward_batch, + mm_inputs_list=mm_inputs_list, + extend_prefix_lens=extend_prefix_lens, + extend_seq_lens=extend_seq_lens, + input_ids=forward_batch.input_ids, + multimodal_model=multimodal_model, + input_embedding=embed_tokens, + data_embedding_func_mapping=data_embedding_funcs, + ) + # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models + # just being defensive here + forward_batch.mm_inputs = None + else: + # NOTE: This may reduce the performance for only-text inputs. + # Using a fixed-address buffer might be better, though it could be a bit dirty. + input_embeds = embed_tokens(input_ids) + # only for qwen3vl + if getattr(multimodal_model, "use_deepstack", False): + forward_batch.input_deepstack_embeds = torch.zeros( + ( + len(input_ids), + multimodal_model.config.hidden_size + * len(multimodal_model.deepstack_visual_indexes), + ), + device=input_embeds.device, + dtype=input_embeds.dtype, + ) + + forward_batch.input_embeds = input_embeds + else: + forward_batch.input_embeds = None + + return forward_batch diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 2c3da398f3e..a95a4400c29 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -318,6 +318,10 @@ def __init__(self, model_runner: ModelRunner): # Graph inputs with torch.device(self.device): self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.input_embeds = torch.zeros( + (self.max_num_token, self.model_runner.model_config.hidden_size), + dtype=self.model_runner.model_config.dtype, + ) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index a4f2e7025c0..95766138b24 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -343,6 +343,9 @@ class ForwardBatch: # For Qwen2-VL mrope_positions: torch.Tensor = None + # For Qwen3-VL + input_deepstack_embeds: Optional[torch.Tensor] = None + # For two-batch overlap tbo_split_seq_index: Optional[int] = None tbo_parent_token_range: Optional[Tuple[int, int]] = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1dfc36e9227..88c85ccc95b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -92,6 +92,7 @@ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_registry import LoRARef +from sglang.srt.managers.mm_utils import multimodal_preprocess_routine from sglang.srt.mem_cache.allocator import ( BaseTokenToKVPoolAllocator, PagedTokenToKVPoolAllocator, @@ -2138,6 +2139,13 @@ def forward_extend( skip_attn_backend_init: bool = False, pp_proxy_tensors=None, ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + + if self.is_multimodal: + forward_batch = multimodal_preprocess_routine( + forward_batch=forward_batch, + multimodal_model=self.model, + ) + kwargs = {} if self.support_pp: kwargs["pp_proxy_tensors"] = pp_proxy_tensors @@ -2274,6 +2282,7 @@ def _forward_raw( skip_attn_backend_init=skip_attn_backend_init, pp_proxy_tensors=pp_proxy_tensors, ) + forward_batch.input_deepstack_embeds = None elif forward_batch.forward_mode.is_idle(): ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors) else: diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index 2322ebb22f3..c1f1aa807ae 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -166,9 +166,18 @@ def __init__(self, model_runner: ModelRunner): self.max_num_tokens = max(self.capture_num_tokens) + self.use_input_embeds = model_runner.is_multimodal + + # The following is just for qwen3vl, maybe not ideal to place it here. + self.use_deepstack = getattr(model_runner.model, "use_deepstack", False) + # Graph inputs with torch.device(self.device): self.input_ids = torch.zeros((self.max_num_tokens,), dtype=torch.int64) + self.input_embeds = torch.zeros( + (self.max_num_tokens, self.model_runner.model_config.hidden_size), + dtype=self.model_runner.dtype, + ) self.out_cache_loc = torch.zeros( (self.max_num_tokens,), dtype=self._cache_loc_dtype() ) @@ -176,6 +185,18 @@ def __init__(self, model_runner: ModelRunner): (self.max_num_tokens,), dtype=self._cache_loc_dtype() ) self.positions = torch.zeros((self.max_num_tokens,), dtype=torch.int64) + self.mrope_positions = torch.zeros( + (3, self.max_num_tokens), dtype=torch.int64 + ) + if self.use_deepstack: + self.input_deepstack_embeds = torch.zeros( + ( + self.max_num_tokens, + self.model_runner.model_config.hidden_size + * len(self.model_runner.model.deepstack_visual_indexes), + ), + dtype=self.model_runner.dtype, + ) self.tbo_plugin = TboCudaGraphRunnerPlugin() self.attention_layers = self.model_runner.attention_layers @@ -216,7 +237,21 @@ def warmup_and_capture(self): forward_batch = ForwardBatch( forward_mode=ForwardMode.EXTEND, batch_size=1, - input_ids=torch.randint(0, 100, (num_tokens,), device=self.device), + input_ids=( + torch.randint(0, 100, (num_tokens,), device=self.device) + if not self.use_input_embeds + else None + ), + input_embeds=( + torch.randn( + num_tokens, + self.model_runner.model_config.hidden_size, + dtype=self.model_runner.dtype, + device=self.device, + ) + if self.use_input_embeds + else None + ), req_pool_indices=torch.arange(1, device=self.device), seq_lens=torch.tensor([num_tokens], device=self.device), next_token_logits_buffer=None, @@ -246,13 +281,25 @@ def warmup_and_capture(self): global_num_tokens_for_logprob_gpu=None, dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), global_dp_buffer_len=None, - mrope_positions=None, + mrope_positions=self.mrope_positions[:, :num_tokens], spec_algorithm=None, spec_info=None, capture_hidden_mode=CaptureHiddenMode.NULL, num_token_non_padded=None, global_forward_mode=ForwardMode.EXTEND, lora_ids=None, + input_deepstack_embeds=( + torch.zeros( + ( + num_tokens, + self.model_runner.model_config.hidden_size + * len(self.model_runner.model.deepstack_visual_indexes), + ), + dtype=self.model_runner.dtype, + ) + if self.use_deepstack + else None + ), ) # Attention backend @@ -326,10 +373,21 @@ def capture_one_batch_size(self, num_tokens: int): bs = 1 # Graph inputs - input_ids = self.input_ids[:num_tokens] + if self.use_input_embeds: + input_ids = None + input_embeds = self.input_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + input_embeds = None + + input_deepstack_embeds = None + if self.use_deepstack: + input_deepstack_embeds = self.input_deepstack_embeds[:num_tokens] + out_cache_loc = self.out_cache_loc[:num_tokens] out_cache_loc_swa = self.out_cache_loc_swa[:num_tokens] positions = self.positions[:num_tokens] + mrope_positions = self.mrope_positions[:, :num_tokens] # pipeline parallelism if self.pp_size > 1: @@ -351,6 +409,7 @@ def capture_one_batch_size(self, num_tokens: int): forward_mode=ForwardMode.EXTEND, batch_size=bs, input_ids=input_ids, + input_embeds=input_embeds, req_pool_indices=torch.arange(bs, device=self.device), seq_lens=torch.tensor([num_tokens], device=self.device), next_token_logits_buffer=None, @@ -376,13 +435,14 @@ def capture_one_batch_size(self, num_tokens: int): global_num_tokens_for_logprob_gpu=None, dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), global_dp_buffer_len=None, - mrope_positions=None, + mrope_positions=mrope_positions, spec_algorithm=None, spec_info=None, capture_hidden_mode=CaptureHiddenMode.NULL, num_token_non_padded=None, global_forward_mode=ForwardMode.EXTEND, lora_ids=None, + input_deepstack_embeds=input_deepstack_embeds, ) self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) @@ -428,7 +488,14 @@ def replay_prepare( forward_batch: ForwardBatch, **kwargs, ): - num_tokens = len(forward_batch.input_ids) + if self.use_input_embeds: + num_tokens = forward_batch.input_embeds.shape[0] + else: + num_tokens = len(forward_batch.input_ids) + + if self.use_deepstack: + self.input_deepstack_embeds.zero_() # may be removed. + index = bisect.bisect_left(self.capture_num_tokens, num_tokens) static_num_tokens = self.capture_num_tokens[index] self.raw_num_tokens = num_tokens @@ -437,7 +504,11 @@ def replay_prepare( self.out_cache_loc_swa.zero_() bs = forward_batch.batch_size - self.input_ids[:num_tokens].copy_(forward_batch.input_ids) + if self.use_input_embeds: + self.input_embeds[:num_tokens].copy_(forward_batch.input_embeds) + else: + self.input_ids[:num_tokens].copy_(forward_batch.input_ids) + self.positions[:num_tokens].copy_(forward_batch.positions) self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) if forward_batch.out_cache_loc_swa is not None: @@ -452,13 +523,39 @@ def replay_prepare( else None ) + if forward_batch.mrope_positions is not None: + self.mrope_positions[:, :num_tokens].copy_(forward_batch.mrope_positions) + + if self.use_input_embeds: + input_ids = None + input_embeds = self.input_embeds[:static_num_tokens] + else: + input_ids = self.input_ids[:static_num_tokens] + input_embeds = None + + positions = self.positions[:static_num_tokens] + out_cache_loc = self.out_cache_loc[:static_num_tokens] + + mrope_positions = ( + self.mrope_positions[:, :static_num_tokens] + if forward_batch.mrope_positions is not None + else None + ) + next_token_logits_buffer = None - mrope_positions = None + + input_deepstack_embeds = None + if self.use_deepstack: + self.input_deepstack_embeds[:num_tokens].copy_( + forward_batch.input_deepstack_embeds + ) + input_deepstack_embeds = self.input_deepstack_embeds[:static_num_tokens] static_forward_batch = ForwardBatch( forward_mode=forward_batch.forward_mode, batch_size=bs, input_ids=input_ids, + input_embeds=input_embeds, req_pool_indices=forward_batch.req_pool_indices, seq_lens=forward_batch.seq_lens, next_token_logits_buffer=next_token_logits_buffer, @@ -498,6 +595,7 @@ def replay_prepare( temperature=forward_batch.temperature, top_p_normalized_logprobs=forward_batch.top_p_normalized_logprobs, top_p=forward_batch.top_p, + input_deepstack_embeds=input_deepstack_embeds, ) return static_forward_batch diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 63c0a42f2a4..6b39825cb52 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -57,11 +57,12 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.managers.mm_utils import ( - MultiModalityDataPaddingPatternMultimodalTokens, - general_mm_embed_routine, +from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, ) -from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2Model @@ -566,6 +567,25 @@ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw) return video_embeds + def post_process( + self, + inputs_embeds, + modalities: List[Modality], + embeddings: List[torch.Tensor], + indices: List[torch.Tensor], + forward_batch: ForwardBatch, + ) -> torch.Tensor: + # Placeholder for post_process + new_embeddings = [] + for i, (modality, embedding, index) in enumerate( + zip(modalities, embeddings, indices) + ): + if embedding is None or index is None: + continue + + new_embeddings.append(embedding) + return new_embeddings, forward_batch + def get_input_embeddings(self): return self.model.embed_tokens @@ -575,6 +595,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, + input_embeds=None, get_embedding: bool = False, pp_proxy_tensors: Optional[PPProxyTensors] = None, ): @@ -603,11 +624,21 @@ def forward( f"(3, seq_len) positions, but got {positions.size()}" ) - hidden_states = general_mm_embed_routine( + input_embeds = forward_batch.input_embeds + # It may seem strange to assign input_embeds again even after passing it as an argument. + # This is for compatibility considerations. + # In the 'extend' scenario, this forward function is called from two places: + # 1. model_runner calls forward directly, + # 2. piece_wise_cuda_graph_runner calls forward and replay. + + # Currently, + # In 'extend', input_embeds is passed in. + # In 'decode', input_ids is passed in. + + hidden_states = self.model( input_ids=input_ids, forward_batch=forward_batch, - language_model=self.model, - multimodal_model=self, + input_embeds=input_embeds, positions=positions, pp_proxy_tensors=pp_proxy_tensors, ) diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 88d6d5bc9c2..84c15ee776b 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -39,10 +39,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.managers.mm_utils import ( - MultiModalityDataPaddingPatternMultimodalTokens, - general_mm_embed_routine, -) +from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -509,6 +506,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, + input_embeds=None, get_embedding: bool = False, ): """Run forward pass for Qwen2-VL. @@ -535,11 +533,22 @@ def forward( "multimodal section rotary embedding requires " f"(3, seq_len) positions, but got {positions.size()}" ) - hidden_states = general_mm_embed_routine( + + input_embeds = forward_batch.input_embeds + # It may seem strange to assign input_embeds again even after passing it as an argument. + # This is for compatibility considerations. + # In the 'extend' scenario, this forward function is called from two places: + # 1. model_runner calls forward directly, + # 2. piece_wise_cuda_graph_runner calls forward and replay. + + # Currently, + # In 'extend', input_embeds is passed in. + # In 'decode', input_ids is passed in. + + hidden_states = self.model( input_ids=input_ids, forward_batch=forward_batch, - language_model=self.model, - multimodal_model=self, + input_embeds=input_embeds, positions=positions, ) diff --git a/python/sglang/srt/models/qwen3_omni_moe.py b/python/sglang/srt/models/qwen3_omni_moe.py index 8663e5ac5a0..c4f1d564cf6 100644 --- a/python/sglang/srt/models/qwen3_omni_moe.py +++ b/python/sglang/srt/models/qwen3_omni_moe.py @@ -481,6 +481,14 @@ def __init__( self.enable_talker = False self.pad_input_ids = self.thinker.pad_input_ids self.forward = self.thinker.forward + self.get_audio_feature = self.thinker.get_audio_feature + self.get_image_feature = self.thinker.get_image_feature + self.get_video_feature = self.thinker.get_video_feature + self.get_input_embeddings = self.thinker.get_input_embeddings + self.post_process = self.thinker.post_process + + def get_input_embeddings(self): + return self.thinker.model.embed_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index c4d9456bc9e..71ee491f325 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -34,10 +34,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.managers.mm_utils import ( - MultiModalityDataPaddingPatternMultimodalTokens, - general_mm_embed_routine, -) +from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens from sglang.srt.managers.schedule_batch import ( Modality, MultimodalDataItem, @@ -667,6 +664,48 @@ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw) return video_embeds + def post_process( + self, + inputs_embeds, + modalities: List[Modality], + embeddings: List[torch.Tensor], + indices: List[torch.Tensor], + forward_batch: ForwardBatch, + ) -> torch.Tensor: + if not self.use_deepstack: + return embeddings, forward_batch + deepstack_embeddings = [] + new_embeddings = [] + + num_deepstack_embeddings = len(self.deepstack_visual_indexes) + deepstack_embedding_shape = inputs_embeds.shape[:-1] + ( + inputs_embeds.shape[-1] * num_deepstack_embeddings, + ) + input_deepstack_embeds = torch.zeros( + deepstack_embedding_shape, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, + ) + + for i, (modality, embedding, index) in enumerate( + zip(modalities, embeddings, indices) + ): + if embedding is None or index is None: + continue + if self.use_deepstack.get(modality, False): + embedding, deepstack_embedding = self.separate_deepstack_embeds( + embedding + ) + if index is not None: + input_deepstack_embeds[index] = deepstack_embedding.to( + inputs_embeds.device, inputs_embeds.dtype + ) + + new_embeddings.append(embedding) + + forward_batch.input_deepstack_embeds = input_deepstack_embeds + return new_embeddings, forward_batch + def get_input_embeddings(self): return self.model.embed_tokens @@ -682,6 +721,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, + input_embeds=None, get_embedding: bool = False, ): """Run forward pass for Qwen3-VL. @@ -709,13 +749,23 @@ def forward( f"(3, seq_len) positions, but got {positions.size()}" ) - hidden_states = general_mm_embed_routine( + input_embeds = forward_batch.input_embeds + # It may seem strange to assign input_embeds again even after passing it as an argument. + # This is for compatibility considerations. + # In the 'extend' scenario, this forward function is called from two places: + # 1. model_runner calls forward directly, + # 2. piece_wise_cuda_graph_runner calls forward and replay. + + # Currently, + # In 'extend', input_embeds is passed in. + # In 'decode', input_ids is passed in. + + hidden_states = self.model( input_ids=input_ids, forward_batch=forward_batch, - language_model=self.model, - multimodal_model=self, + input_embeds=input_embeds, positions=positions, - use_deepstack=self.use_deepstack, + input_deepstack_embeds=forward_batch.input_deepstack_embeds, ) if not get_embedding: diff --git a/test/srt/nightly/test_vlms_piecewise_cuda_graph.py b/test/srt/nightly/test_vlms_piecewise_cuda_graph.py new file mode 100644 index 00000000000..0001b917a4c --- /dev/null +++ b/test/srt/nightly/test_vlms_piecewise_cuda_graph.py @@ -0,0 +1,266 @@ +import argparse +import glob +import json +import os +import random +import subprocess +import sys +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, +) + +MODELS = [ + SimpleNamespace(model="Qwen/Qwen2.5-VL-7B-Instruct", mmmu_accuracy=0.60), +] + + +# Set default mem_fraction_static to 0.8 +DEFAULT_MEM_FRACTION_STATIC = 0.8 + + +class TestVLMPiecewiseCudaGraph(CustomTestCase): + parsed_args = None # Class variable to store args + + @classmethod + def setUpClass(cls): + # Removed argument parsing from here + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.time_out = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + + if cls.parsed_args is None: + cls.parsed_args = SimpleNamespace( + mem_fraction_static=DEFAULT_MEM_FRACTION_STATIC + ) + + # Set OpenAI API key and base URL environment variables. Needed for lmm-evals to work. + os.environ["OPENAI_API_KEY"] = cls.api_key + os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1" + + def run_mmmu_eval( + self, + model_version: str, + output_path: str, + *, + env: dict | None = None, + ): + """ + Evaluate a VLM on the MMMU validation set with lmms‑eval. + Only `model_version` (checkpoint) and `chat_template` vary; + We are focusing only on the validation set due to resource constraints. + """ + # -------- fixed settings -------- + model = "openai_compatible" + tp = 1 + tasks = "mmmu_val" + batch_size = 32 + log_suffix = "openai_compatible" + os.makedirs(output_path, exist_ok=True) + + # -------- compose --model_args -------- + model_args = f'model_version="{model_version}",' f"tp={tp}" + + # -------- build command list -------- + cmd = [ + "python3", + "-m", + "lmms_eval", + "--model", + model, + "--model_args", + model_args, + "--tasks", + tasks, + "--batch_size", + str(batch_size), + "--output_path", + str(output_path), + ] + + subprocess.run( + cmd, + check=True, + timeout=3600, + ) + + def _run_vlm_mmmu_test( + self, + model, + output_path, + test_name="", + custom_env=None, + log_level="info", + capture_output=False, + ): + """ + Common method to run VLM MMMU benchmark test. + Args: + model: Model to test + output_path: Path for output logs + test_name: Optional test name for logging + custom_env: Optional custom environment variables + log_level: Log level for server (default: "info") + capture_output: Whether to capture server stdout/stderr + """ + print(f"\nTesting model: {model.model}{test_name}") + + process = None + mmmu_accuracy = 0 # Initialize to handle potential exceptions + server_output = "" + + try: + # Prepare environment variables + process_env = os.environ.copy() + if custom_env: + process_env.update(custom_env) + # if test vlm with cuda_ipc feature, open this env_var + process_env["SGLANG_USE_CUDA_IPC_TRANSPORT"] = "1" + + # Prepare stdout/stderr redirection if needed + stdout_file = None + stderr_file = None + if capture_output: + stdout_file = open("/tmp/server_stdout.log", "w") + stderr_file = open("/tmp/server_stderr.log", "w") + + # Launch server for testing + process = popen_launch_server( + model.model, + base_url=self.base_url, + timeout=self.time_out, + api_key=self.api_key, + other_args=[ + "--trust-remote-code", + "--piecewise-cuda-graph-max-tokens", + "8192", + "--enable-piecewise-cuda-graph", + "--tp=8", + "--piecewise-cuda-graph-compiler=eager", + "--disable-radix-cache", + "--log-level", + log_level, + ], + env=process_env, + return_stdout_stderr=( + (stdout_file, stderr_file) if capture_output else None + ), + ) + + # Run evaluation + self.run_mmmu_eval(model.model, output_path) + + # Get the result file + # Search recursively for JSON result files (lmms-eval v0.4.1+ creates subdirectories) + result_files = glob.glob(f"{output_path}/**/*.json", recursive=True) + if not result_files: + result_files = glob.glob(f"{output_path}/*.json") + + if not result_files: + raise FileNotFoundError(f"No JSON result files found in {output_path}") + + result_file_path = result_files[0] + + with open(result_file_path, "r") as f: + result = json.load(f) + print(f"Result{test_name}\n: {result}") + + # Process the result + mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] + print( + f"Model {model.model} achieved accuracy{test_name}: {mmmu_accuracy:.4f}" + ) + + # Capture server output if requested + if capture_output and process: + server_output = self._read_output_from_files() + + # Assert performance meets expected threshold + self.assertGreaterEqual( + mmmu_accuracy, + model.mmmu_accuracy, + f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f}){test_name}", + ) + + return server_output + + except Exception as e: + print(f"Error testing {model.model}{test_name}: {e}") + self.fail(f"Test failed for {model.model}{test_name}: {e}") + + finally: + # Ensure process cleanup happens regardless of success/failure + if process is not None and process.poll() is None: + print(f"Cleaning up process {process.pid}") + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process: {e}") + + # clean up temporary files + if capture_output: + if stdout_file: + stdout_file.close() + if stderr_file: + stderr_file.close() + for filename in ["/tmp/server_stdout.log", "/tmp/server_stderr.log"]: + try: + if os.path.exists(filename): + os.remove(filename) + except Exception as e: + print(f"Error removing {filename}: {e}") + + def _read_output_from_files(self): + output_lines = [] + + log_files = [ + ("/tmp/server_stdout.log", "[STDOUT]"), + ("/tmp/server_stderr.log", "[STDERR]"), + ] + for filename, tag in log_files: + try: + if os.path.exists(filename): + with open(filename, "r") as f: + for line in f: + output_lines.append(f"{tag} {line.rstrip()}") + except Exception as e: + print(f"Error reading {tag.lower()} file: {e}") + + return "\n".join(output_lines) + + def test_vlm_mmmu_benchmark(self): + """Test VLM models against MMMU benchmark.""" + models_to_test = MODELS + + if is_in_ci(): + models_to_test = [random.choice(MODELS)] + + for model in models_to_test: + self._run_vlm_mmmu_test(model, "./logs") + + +if __name__ == "__main__": + # Define and parse arguments here, before unittest.main + parser = argparse.ArgumentParser(description="Test VLM models") + parser.add_argument( + "--mem-fraction-static", + type=float, + help="Static memory fraction for the model", + default=DEFAULT_MEM_FRACTION_STATIC, + ) + + # Parse args intended for unittest + args = parser.parse_args() + + # Store the parsed args object on the class + TestVLMPiecewiseCudaGraph.parsed_args = args + + # Pass args to unittest + unittest.main(argv=[sys.argv[0]]) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 56336f60b6d..e936222114a 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -94,7 +94,7 @@ TestFile("test_original_logprobs.py", 41), TestFile("test_page_size.py", 60), TestFile("test_penalty.py", 82), - TestFile("test_piecewise_cuda_graph.py", 600), + TestFile("test_piecewise_cuda_graph.py", 750), TestFile("test_priority_scheduling.py", 130), TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_radix_attention.py", 105), @@ -337,6 +337,7 @@ TestFile("nightly/test_text_models_perf.py"), TestFile("nightly/test_vlms_mmmu_eval.py"), TestFile("nightly/test_vlms_perf.py"), + TestFile("nightly/test_vlms_piecewise_cuda_graph.py"), TestFile("test_openai_adapter.py"), TestFile("test_openai_function_calling.py"), TestFile("test_openai_server.py"), diff --git a/test/srt/test_piecewise_cuda_graph.py b/test/srt/test_piecewise_cuda_graph.py index 3dc410ebf52..b1f07433d37 100644 --- a/test/srt/test_piecewise_cuda_graph.py +++ b/test/srt/test_piecewise_cuda_graph.py @@ -248,5 +248,48 @@ def test_mgsm_accuracy(self): print(f"MGSM Accuracy: {metrics['score']:.3f}") +class TestPiecewiseCudaGraphQwen25VL(CustomTestCase): + """Test piecewise CUDA graph with Qwen2.5-VL-7B-Instruct model""" + + @classmethod + def setUpClass(cls): + cls.model = "Qwen/Qwen2.5-VL-7B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--enable-piecewise-cuda-graph", + "--piecewise-cuda-graph-compiler", + "eager", + "--disable-radix-cache", + "--pp-size", + "2", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k_accuracy(self): + """Test GSM8K accuracy with 8-shot setting""" + num_examples = 2000 + + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=num_examples, + num_threads=min(num_examples, 1024), + ) + + metrics = run_eval(args) + print(f"GSM8K Accuracy: {metrics['score']:.3f}") + + self.assertGreaterEqual(metrics["score"], 0.70) + + if __name__ == "__main__": unittest.main()