-
Notifications
You must be signed in to change notification settings - Fork 4
[DO NOT MERGE] Support Piecewise CUDA Graph for Qwen2.5-VL #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0c334a2
471a51f
fd1cd1a
db5c290
4e1a121
79a2da2
b0b5ea4
4ffe08a
7702eb0
bc54e4c
1184918
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -494,6 +494,7 @@ def get_embedding_and_mask( | |
|
|
||
|
|
||
| def embed_mm_inputs( | ||
| forward_batch: ForwardBatch, | ||
| mm_inputs_list: List[MultimodalInputs], | ||
| extend_prefix_lens: List[int], | ||
| extend_seq_lens: List[int], | ||
|
|
@@ -503,8 +504,6 @@ def embed_mm_inputs( | |
| data_embedding_func_mapping: Dict[ | ||
| Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] | ||
| ] = None, | ||
| placeholder_tokens: dict[Modality, List[int]] = None, | ||
| use_deepstack: Dict[Modality, bool] = {}, | ||
| ) -> Optional[torch.Tensor]: | ||
| """ | ||
| Embed multimodal inputs and integrate them with text token embeddings. | ||
|
|
@@ -515,12 +514,10 @@ def embed_mm_inputs( | |
| extend_seq_lens: Sequence lengths for each request | ||
| input_ids: Input token IDs tensor | ||
| input_embedding: Embedding layer for text tokens | ||
| placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None) | ||
|
|
||
| Returns: | ||
| Combined embedding tensor with multimodal content integrated | ||
| """ | ||
| other_info = {} | ||
| if mm_inputs_list is None: | ||
| return None | ||
|
|
||
|
|
@@ -530,8 +527,7 @@ def embed_mm_inputs( | |
| for mm_inputs in mm_inputs_list: | ||
| item_flatten_list += [item for item in mm_inputs.mm_items if item is not None] | ||
|
|
||
| # deepstack_embeddings: per-modality | ||
| modalities, embeddings, masks, deepstack_embeddings = [], [], [], [] | ||
| modalities, embeddings, masks = [], [], [] | ||
|
|
||
| # 2. Get multimodal embedding separately | ||
| # Try get mm embedding if any | ||
|
|
@@ -580,11 +576,6 @@ def embed_mm_inputs( | |
| items_offset_list=items_offsets, | ||
| ) | ||
|
|
||
| if use_deepstack.get(modality, None) and embedding is not None: | ||
| embedding, deepstack_embedding = ( | ||
| multimodal_model.separate_deepstack_embeds(embedding) | ||
| ) | ||
| deepstack_embeddings += [deepstack_embedding] | ||
| modalities += [modality] | ||
| embeddings += [embedding] | ||
| masks += [mask] | ||
|
|
@@ -598,37 +589,29 @@ def embed_mm_inputs( | |
| input_ids.clamp_(min=0, max=vocab_size - 1) | ||
| inputs_embeds = input_embedding(input_ids) | ||
|
|
||
| # deepstack embedding | ||
| if use_deepstack: | ||
| num_deepstack_embeddings = len(multimodal_model.deepstack_visual_indexes) | ||
| indices = [] | ||
| for mask in masks: | ||
| if mask is not None: | ||
| indices.append(torch.where(mask.squeeze(dim=-1))[0]) | ||
| else: | ||
| indices.append(None) | ||
|
|
||
| deepstack_embedding_shape = inputs_embeds.shape[:-1] + ( | ||
| inputs_embeds.shape[-1] * num_deepstack_embeddings, | ||
| # only for qwen3vl right now, replace the original use_deepstack with this method. | ||
| if hasattr(multimodal_model, "post_process"): | ||
| embeddings, forward_batch = multimodal_model.post_process( | ||
| inputs_embeds, modalities, embeddings, indices, forward_batch | ||
| ) | ||
| # a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size | ||
| input_deepstack_embeds = torch.zeros( | ||
| deepstack_embedding_shape, | ||
| device=inputs_embeds.device, | ||
| dtype=inputs_embeds.dtype, | ||
| ) | ||
|
|
||
| other_info["input_deepstack_embeds"] = input_deepstack_embeds | ||
|
|
||
| # 4. scatter embeddings into input embedding | ||
| for i, modality, embedding, mask in zip( | ||
| range(len(embeddings)), modalities, embeddings, masks | ||
| for i, modality, embedding, index in zip( | ||
| range(len(embeddings)), modalities, embeddings, indices | ||
| ): | ||
| if embedding is None or mask is None: | ||
| if embedding is None or index is None: | ||
| continue | ||
| # in-place update | ||
| indices = torch.where(mask.squeeze(dim=-1))[0] | ||
| inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype) | ||
| if use_deepstack.get(modality, None): | ||
| input_deepstack_embeds[indices] = deepstack_embeddings[i].to( | ||
| inputs_embeds.device, inputs_embeds.dtype | ||
| ) | ||
| inputs_embeds[index] = embedding.to(inputs_embeds.device, inputs_embeds.dtype) | ||
|
|
||
| return inputs_embeds, other_info | ||
| return inputs_embeds, forward_batch | ||
|
|
||
|
|
||
| def general_mm_embed_routine( | ||
|
|
@@ -639,8 +622,7 @@ def general_mm_embed_routine( | |
| data_embedding_funcs: Dict[ | ||
| Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] | ||
| ] = None, | ||
| placeholder_tokens: Optional[dict[Modality, List[int]]] = None, | ||
| use_deepstack: Dict[Modality, bool] = {}, | ||
| skip_llm_forward: bool = False, | ||
| **kwargs, | ||
| ) -> torch.Tensor: | ||
| """ | ||
|
|
@@ -651,8 +633,6 @@ def general_mm_embed_routine( | |
| forward_batch: Batch information for model forward pass | ||
| language_model: Base language model to use | ||
| data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function. | ||
| placeholder_tokens: Token IDs for multimodal placeholders | ||
| use_deepstack: Whether to use deepstack embeddings for each modality, default False | ||
| **kwargs: Additional arguments passed to language model | ||
|
|
||
| Returns: | ||
|
|
@@ -679,20 +659,17 @@ def general_mm_embed_routine( | |
| for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu) | ||
| if forward_batch.mm_inputs[i] is not None | ||
| ] | ||
| inputs_embeds, other_info = embed_mm_inputs( | ||
|
|
||
| inputs_embeds, forward_batch = embed_mm_inputs( | ||
| forward_batch=forward_batch, | ||
| mm_inputs_list=mm_inputs_list, | ||
| extend_prefix_lens=extend_prefix_lens, | ||
| extend_seq_lens=extend_seq_lens, | ||
| input_ids=input_ids, | ||
| multimodal_model=multimodal_model, | ||
| input_embedding=embed_tokens, | ||
| data_embedding_func_mapping=data_embedding_funcs, | ||
| placeholder_tokens=placeholder_tokens, | ||
| use_deepstack=use_deepstack, | ||
| ) | ||
| # add for qwen3_vl deepstack | ||
| if use_deepstack: | ||
| kwargs["input_deepstack_embeds"] = other_info["input_deepstack_embeds"] | ||
| # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models | ||
| # just being defensive here | ||
| forward_batch.mm_inputs = None | ||
|
|
@@ -701,6 +678,9 @@ def general_mm_embed_routine( | |
| else: | ||
| inputs_embeds = None | ||
|
|
||
| if skip_llm_forward: | ||
| return inputs_embeds | ||
|
|
||
| hidden_states = language_model( | ||
| input_ids=None, | ||
| forward_batch=forward_batch, | ||
|
|
@@ -816,3 +796,144 @@ def hash_feature(f): | |
| reconstruct_t = f.reconstruct_on_target_device(torch.cuda.current_device()) | ||
| return tensor_hash([reconstruct_t]) | ||
| return data_hash(f) | ||
|
|
||
|
|
||
| def resolve_language_model(multimodal_model: nn.Module) -> Optional[nn.Module]: | ||
| # Qwen2-VL / Qwen3-VL Style | ||
| if hasattr(multimodal_model, "model"): | ||
| lm = getattr(multimodal_model, "model") | ||
| if hasattr(lm, "get_input_embeddings"): | ||
| return lm | ||
|
|
||
| # Llava / OneVision Style | ||
| if hasattr(multimodal_model, "language_model"): | ||
| lm = getattr(multimodal_model, "language_model") | ||
| if hasattr(lm, "get_input_embeddings"): | ||
| return lm | ||
|
|
||
| if hasattr(multimodal_model, "get_input_embeddings"): | ||
| return multimodal_model | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def should_use_external_mm_preprocess(multimodal_model: nn.Module) -> bool: | ||
| """Decide whether we should use our generic "multimodal_preprocess_routine". | ||
|
|
||
| We are adapting VLM for piecewise CUDA graph. Since the encoder's forward | ||
| pass cannot be executed within the model's forward pass, we need to | ||
| precompute image embeddings using the encoder within the model runner. | ||
| For models that have already been adjusted, there is a member called | ||
| should_use_external_mm_preprocess, which is set to True. In practice, | ||
| the multimodal_preprocess_routine function will be called in the | ||
| model_runner.forward_extend to handle multimodal inputs. | ||
|
|
||
| For models that have not yet been adapted, the general_mm_embed_routine | ||
| will still be called in the model class's forward function for processing. | ||
|
|
||
| Current strategy: | ||
| - Llava family (models with vision_tower + multi_modal_projector): | ||
| Their forward already calls general_mm_embed_routine and includes | ||
| built-in multimodal processing. If we run it again in ModelRunner, | ||
| it will conflict with the internal logic, so we skip it here. | ||
| - Others (such as Qwen2-VL / Qwen2.5-VL / Qwen3-VL): use the | ||
| multimodal preprocessing. | ||
| """ | ||
|
|
||
| cls_name = multimodal_model.__class__.__name__ | ||
|
|
||
| qwen_vl_classes = { | ||
| "Qwen2VLForConditionalGeneration", | ||
| "Qwen2_5_VLForConditionalGeneration", | ||
| "Qwen3VLForConditionalGeneration", | ||
| "Qwen3VLMoeForConditionalGeneration", | ||
| "Qwen3OmniMoeForConditionalGeneration", | ||
| } | ||
|
|
||
| return cls_name in qwen_vl_classes | ||
|
Comment on lines
+820
to
+853
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using a hardcoded set of class names in |
||
|
|
||
|
|
||
| def multimodal_preprocess_routine( | ||
| forward_batch: ForwardBatch, | ||
| multimodal_model: Optional[nn.Module] = None, | ||
| data_embedding_funcs: Dict[ | ||
| Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] | ||
| ] = None, | ||
| ) -> torch.Tensor: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| """ | ||
| Process multimodal inputs and forward through language model. | ||
| Args: | ||
| input_ids: Input token IDs tensor | ||
| forward_batch: Batch information for model forward pass | ||
| data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function. | ||
| **kwargs: Additional arguments passed to language model | ||
| Returns: | ||
| Hidden states from language model forward pass | ||
| """ | ||
| if not should_use_external_mm_preprocess(multimodal_model): | ||
| return forward_batch | ||
|
|
||
| language_model = resolve_language_model(multimodal_model) | ||
| if language_model is None: | ||
| raise ValueError( | ||
| f"Cannot resolve language model from {type(multimodal_model).__name__}. " | ||
| f"Please ensure the model has 'model' or 'language_model' attribute." | ||
| ) | ||
|
|
||
| assert hasattr(language_model, "get_input_embeddings") | ||
| embed_tokens = language_model.get_input_embeddings() | ||
| if not hasattr(language_model, "pp_group") or language_model.pp_group.is_first_rank: | ||
|
|
||
| input_ids = forward_batch.input_ids | ||
| if ( | ||
| not forward_batch.forward_mode.is_decode() | ||
| and not forward_batch.forward_mode.is_target_verify() | ||
| and forward_batch.contains_mm_inputs() | ||
| ): | ||
| mm_inputs_list = [ | ||
| mm_input for mm_input in forward_batch.mm_inputs if mm_input is not None | ||
| ] | ||
| extend_prefix_lens = [ | ||
| prefix_len | ||
| for i, prefix_len in enumerate(forward_batch.extend_prefix_lens_cpu) | ||
| if forward_batch.mm_inputs[i] is not None | ||
| ] | ||
| extend_seq_lens = [ | ||
| seq_len | ||
| for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu) | ||
| if forward_batch.mm_inputs[i] is not None | ||
| ] | ||
| input_embeds, forward_batch = embed_mm_inputs( | ||
| forward_batch=forward_batch, | ||
| mm_inputs_list=mm_inputs_list, | ||
| extend_prefix_lens=extend_prefix_lens, | ||
| extend_seq_lens=extend_seq_lens, | ||
| input_ids=forward_batch.input_ids, | ||
| multimodal_model=multimodal_model, | ||
| input_embedding=embed_tokens, | ||
| data_embedding_func_mapping=data_embedding_funcs, | ||
| ) | ||
| # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models | ||
| # just being defensive here | ||
| forward_batch.mm_inputs = None | ||
| else: | ||
| # NOTE: This may reduce the performance for only-text inputs. | ||
| # Using a fixed-address buffer might be better, though it could be a bit dirty. | ||
| input_embeds = embed_tokens(input_ids) | ||
| # only for qwen3vl | ||
| if getattr(multimodal_model, "use_deepstack", False): | ||
| forward_batch.input_deepstack_embeds = torch.zeros( | ||
| ( | ||
| len(input_ids), | ||
| multimodal_model.config.hidden_size | ||
| * len(multimodal_model.deepstack_visual_indexes), | ||
| ), | ||
| device=input_embeds.device, | ||
| dtype=input_embeds.dtype, | ||
| ) | ||
|
|
||
| forward_batch.input_embeds = input_embeds | ||
| else: | ||
| forward_batch.input_embeds = None | ||
|
|
||
| return forward_batch | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (bug_risk): Returning only 'inputs_embeds' when 'skip_llm_forward' is True may break downstream code expecting a tuple.
To prevent runtime errors, ensure the return type remains consistent, or update documentation to reflect the change.