Skip to content
13 changes: 8 additions & 5 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,6 @@ def triton_mrope(
return q, k


@torch._dynamo.disable()
def triton_mrope_wrapper(
query,
key,
Expand Down Expand Up @@ -1428,15 +1427,18 @@ def _forward_native(
dim=-1,
)

seq_len_q = query.shape[0]
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query = query.view(seq_len_q, -1, self.head_size)

query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

seq_len_k = key.shape[0]
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key = key.view(seq_len_k, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
Expand Down Expand Up @@ -1467,7 +1469,6 @@ def forward(
else:
return self._forward_native(positions, query, key)

@torch.compile(dynamic=True, backend=get_compiler_backend())
def _forward_triton(
self,
positions: torch.Tensor,
Expand Down Expand Up @@ -1502,7 +1503,9 @@ def _forward_triton(

return q.reshape(query_shape), k.reshape(key_shape)

query = query.view(num_tokens, -1, self.head_size)
seq_len_q = query.shape[0]
query = query.view(seq_len_q, -1, self.head_size)

query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
Expand Down
209 changes: 165 additions & 44 deletions python/sglang/srt/managers/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def get_embedding_and_mask(


def embed_mm_inputs(
forward_batch: ForwardBatch,
mm_inputs_list: List[MultimodalInputs],
extend_prefix_lens: List[int],
extend_seq_lens: List[int],
Expand All @@ -503,8 +504,6 @@ def embed_mm_inputs(
data_embedding_func_mapping: Dict[
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
use_deepstack: Dict[Modality, bool] = {},
) -> Optional[torch.Tensor]:
"""
Embed multimodal inputs and integrate them with text token embeddings.
Expand All @@ -515,12 +514,10 @@ def embed_mm_inputs(
extend_seq_lens: Sequence lengths for each request
input_ids: Input token IDs tensor
input_embedding: Embedding layer for text tokens
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)

Returns:
Combined embedding tensor with multimodal content integrated
"""
other_info = {}
if mm_inputs_list is None:
return None

Expand All @@ -530,8 +527,7 @@ def embed_mm_inputs(
for mm_inputs in mm_inputs_list:
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]

# deepstack_embeddings: per-modality
modalities, embeddings, masks, deepstack_embeddings = [], [], [], []
modalities, embeddings, masks = [], [], []

# 2. Get multimodal embedding separately
# Try get mm embedding if any
Expand Down Expand Up @@ -580,11 +576,6 @@ def embed_mm_inputs(
items_offset_list=items_offsets,
)

if use_deepstack.get(modality, None) and embedding is not None:
embedding, deepstack_embedding = (
multimodal_model.separate_deepstack_embeds(embedding)
)
deepstack_embeddings += [deepstack_embedding]
modalities += [modality]
embeddings += [embedding]
masks += [mask]
Expand All @@ -598,37 +589,29 @@ def embed_mm_inputs(
input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids)

# deepstack embedding
if use_deepstack:
num_deepstack_embeddings = len(multimodal_model.deepstack_visual_indexes)
indices = []
for mask in masks:
if mask is not None:
indices.append(torch.where(mask.squeeze(dim=-1))[0])
else:
indices.append(None)

deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
inputs_embeds.shape[-1] * num_deepstack_embeddings,
# only for qwen3vl right now, replace the original use_deepstack with this method.
if hasattr(multimodal_model, "post_process"):
embeddings, forward_batch = multimodal_model.post_process(
inputs_embeds, modalities, embeddings, indices, forward_batch
)
# a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size
input_deepstack_embeds = torch.zeros(
deepstack_embedding_shape,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype,
)

other_info["input_deepstack_embeds"] = input_deepstack_embeds

# 4. scatter embeddings into input embedding
for i, modality, embedding, mask in zip(
range(len(embeddings)), modalities, embeddings, masks
for i, modality, embedding, index in zip(
range(len(embeddings)), modalities, embeddings, indices
):
if embedding is None or mask is None:
if embedding is None or index is None:
continue
# in-place update
indices = torch.where(mask.squeeze(dim=-1))[0]
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
if use_deepstack.get(modality, None):
input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds[index] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)

return inputs_embeds, other_info
return inputs_embeds, forward_batch


def general_mm_embed_routine(
Expand All @@ -639,8 +622,7 @@ def general_mm_embed_routine(
data_embedding_funcs: Dict[
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
use_deepstack: Dict[Modality, bool] = {},
skip_llm_forward: bool = False,
**kwargs,
) -> torch.Tensor:
"""
Expand All @@ -651,8 +633,6 @@ def general_mm_embed_routine(
forward_batch: Batch information for model forward pass
language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders
use_deepstack: Whether to use deepstack embeddings for each modality, default False
**kwargs: Additional arguments passed to language model

Returns:
Expand All @@ -679,20 +659,17 @@ def general_mm_embed_routine(
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
if forward_batch.mm_inputs[i] is not None
]
inputs_embeds, other_info = embed_mm_inputs(

inputs_embeds, forward_batch = embed_mm_inputs(
forward_batch=forward_batch,
mm_inputs_list=mm_inputs_list,
extend_prefix_lens=extend_prefix_lens,
extend_seq_lens=extend_seq_lens,
input_ids=input_ids,
multimodal_model=multimodal_model,
input_embedding=embed_tokens,
data_embedding_func_mapping=data_embedding_funcs,
placeholder_tokens=placeholder_tokens,
use_deepstack=use_deepstack,
)
# add for qwen3_vl deepstack
if use_deepstack:
kwargs["input_deepstack_embeds"] = other_info["input_deepstack_embeds"]
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# just being defensive here
forward_batch.mm_inputs = None
Expand All @@ -701,6 +678,9 @@ def general_mm_embed_routine(
else:
inputs_embeds = None

if skip_llm_forward:
return inputs_embeds
Comment on lines +681 to +682
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Returning only 'inputs_embeds' when 'skip_llm_forward' is True may break downstream code expecting a tuple.

To prevent runtime errors, ensure the return type remains consistent, or update documentation to reflect the change.


hidden_states = language_model(
input_ids=None,
forward_batch=forward_batch,
Expand Down Expand Up @@ -816,3 +796,144 @@ def hash_feature(f):
reconstruct_t = f.reconstruct_on_target_device(torch.cuda.current_device())
return tensor_hash([reconstruct_t])
return data_hash(f)


def resolve_language_model(multimodal_model: nn.Module) -> Optional[nn.Module]:
# Qwen2-VL / Qwen3-VL Style
if hasattr(multimodal_model, "model"):
lm = getattr(multimodal_model, "model")
if hasattr(lm, "get_input_embeddings"):
return lm

# Llava / OneVision Style
if hasattr(multimodal_model, "language_model"):
lm = getattr(multimodal_model, "language_model")
if hasattr(lm, "get_input_embeddings"):
return lm

if hasattr(multimodal_model, "get_input_embeddings"):
return multimodal_model

return None


def should_use_external_mm_preprocess(multimodal_model: nn.Module) -> bool:
"""Decide whether we should use our generic "multimodal_preprocess_routine".

We are adapting VLM for piecewise CUDA graph. Since the encoder's forward
pass cannot be executed within the model's forward pass, we need to
precompute image embeddings using the encoder within the model runner.
For models that have already been adjusted, there is a member called
should_use_external_mm_preprocess, which is set to True. In practice,
the multimodal_preprocess_routine function will be called in the
model_runner.forward_extend to handle multimodal inputs.

For models that have not yet been adapted, the general_mm_embed_routine
will still be called in the model class's forward function for processing.

Current strategy:
- Llava family (models with vision_tower + multi_modal_projector):
Their forward already calls general_mm_embed_routine and includes
built-in multimodal processing. If we run it again in ModelRunner,
it will conflict with the internal logic, so we skip it here.
- Others (such as Qwen2-VL / Qwen2.5-VL / Qwen3-VL): use the
multimodal preprocessing.
"""

cls_name = multimodal_model.__class__.__name__

qwen_vl_classes = {
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"Qwen3OmniMoeForConditionalGeneration",
}

return cls_name in qwen_vl_classes
Comment on lines +820 to +853

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a hardcoded set of class names in should_use_external_mm_preprocess makes it difficult to extend this functionality to new models in the future. A more robust and maintainable approach would be to use feature detection on the model object itself. For instance, you could check for the existence of a specific attribute (e.g., model.supports_external_mm_preprocess = True) or a method (e.g., hasattr(model, "post_process")). This would make the logic more generic and decoupled from specific model implementations.



def multimodal_preprocess_routine(
forward_batch: ForwardBatch,
multimodal_model: Optional[nn.Module] = None,
data_embedding_funcs: Dict[
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
) -> torch.Tensor:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint for this function is torch.Tensor, but it actually returns a ForwardBatch object on lines 874 and 939. Please update the type hint to -> ForwardBatch: for correctness and clarity.

Suggested change
) -> torch.Tensor:
) -> ForwardBatch:

"""
Process multimodal inputs and forward through language model.
Args:
input_ids: Input token IDs tensor
forward_batch: Batch information for model forward pass
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
**kwargs: Additional arguments passed to language model
Returns:
Hidden states from language model forward pass
"""
if not should_use_external_mm_preprocess(multimodal_model):
return forward_batch

language_model = resolve_language_model(multimodal_model)
if language_model is None:
raise ValueError(
f"Cannot resolve language model from {type(multimodal_model).__name__}. "
f"Please ensure the model has 'model' or 'language_model' attribute."
)

assert hasattr(language_model, "get_input_embeddings")
embed_tokens = language_model.get_input_embeddings()
if not hasattr(language_model, "pp_group") or language_model.pp_group.is_first_rank:

input_ids = forward_batch.input_ids
if (
not forward_batch.forward_mode.is_decode()
and not forward_batch.forward_mode.is_target_verify()
and forward_batch.contains_mm_inputs()
):
mm_inputs_list = [
mm_input for mm_input in forward_batch.mm_inputs if mm_input is not None
]
extend_prefix_lens = [
prefix_len
for i, prefix_len in enumerate(forward_batch.extend_prefix_lens_cpu)
if forward_batch.mm_inputs[i] is not None
]
extend_seq_lens = [
seq_len
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
if forward_batch.mm_inputs[i] is not None
]
input_embeds, forward_batch = embed_mm_inputs(
forward_batch=forward_batch,
mm_inputs_list=mm_inputs_list,
extend_prefix_lens=extend_prefix_lens,
extend_seq_lens=extend_seq_lens,
input_ids=forward_batch.input_ids,
multimodal_model=multimodal_model,
input_embedding=embed_tokens,
data_embedding_func_mapping=data_embedding_funcs,
)
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# just being defensive here
forward_batch.mm_inputs = None
else:
# NOTE: This may reduce the performance for only-text inputs.
# Using a fixed-address buffer might be better, though it could be a bit dirty.
input_embeds = embed_tokens(input_ids)
# only for qwen3vl
if getattr(multimodal_model, "use_deepstack", False):
forward_batch.input_deepstack_embeds = torch.zeros(
(
len(input_ids),
multimodal_model.config.hidden_size
* len(multimodal_model.deepstack_visual_indexes),
),
device=input_embeds.device,
dtype=input_embeds.dtype,
)

forward_batch.input_embeds = input_embeds
else:
forward_batch.input_embeds = None

return forward_batch
4 changes: 4 additions & 0 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ def __init__(self, model_runner: ModelRunner):
# Graph inputs
with torch.device(self.device):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.input_embeds = torch.zeros(
(self.max_num_token, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.model_config.dtype,
)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ class ForwardBatch:
# For Qwen2-VL
mrope_positions: torch.Tensor = None

# For Qwen3-VL
input_deepstack_embeds: Optional[torch.Tensor] = None

# For two-batch overlap
tbo_split_seq_index: Optional[int] = None
tbo_parent_token_range: Optional[Tuple[int, int]] = None
Expand Down
Loading
Loading