diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py new file mode 100644 index 0000000000..f6bd62624f --- /dev/null +++ b/tools/llm/run_vlm.py @@ -0,0 +1,387 @@ +""" +.. _run_vlm: + +Running VLM inference with Torch-TensorRT +========================================================== + +This script mirrors the style and structure of *run_llm.py*, illustrating a +Torch-TensorRT (dynamo backend) workflow for Visual-Language Models (VLMs). +""" + +import argparse +import copy +import os +import sys +from contextlib import nullcontext +from typing import Tuple + +import requests +import torch +import torch_tensorrt +from PIL import Image +from torchtrt_ext import register_sdpa +from transformers import AutoModel, AutoProcessor +from utils import ( + generate_mm, + generate_mm_with_static_cache, + record_stats, + time_generate_mm, +) + +# -----------------------------------------------------------------------------# +# Global configuration +# -----------------------------------------------------------------------------# +DEVICE = torch.device("cuda:0") + +# Register SDPA as a standalone operator. Converter & lowering pass are defined +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +import transformers.models.qwen2.modeling_qwen2 as mq # noqa: E402 + +mq.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = mq.ALL_ATTENTION_FUNCTIONS["sdpa"] + +# -----------------------------------------------------------------------------# +# Model loading helpers +# -----------------------------------------------------------------------------# + + +def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): + """ + Load Eagle2 model and processor. + + Returns + ------- + tuple[torch.nn.Module, transformers.AutoProcessor, torch.nn.Embedding] + The model, its processor and the language-model input embedding layer. + """ + model_id = "nvidia/Eagle2-2B" + with torch.no_grad(): + model = ( + AutoModel.from_pretrained( + model_id, trust_remote_code=True, torch_dtype=torch_dtype + ) + .eval() + .to(device) + ) + + processor = AutoProcessor.from_pretrained( + model_id, trust_remote_code=True, use_fast=True + ) + if hasattr(processor, "tokenizer"): + processor.tokenizer.padding_side = "left" + + emb_layer = model.language_model.get_input_embeddings().to(torch_dtype).to(device) + return model, processor, emb_layer + + +def _load_model( + model_name: str, device: torch.device, torch_dtype: torch.dtype +) -> Tuple[torch.nn.Module, AutoProcessor, torch.nn.Embedding]: + """Dispatch helper for supported VLMs.""" + if model_name.lower() == "eagle2": + return _load_eagle2(device, torch_dtype) + msg = f"Unsupported model: {model_name}" + raise ValueError(msg) + + +# -----------------------------------------------------------------------------# +# Torch-TensorRT compilation helpers +# -----------------------------------------------------------------------------# + + +class _LMNoCache(torch.nn.Module): + """ + Thin wrapper that exposes a language model via ``inputs_embeds`` without KV-cache. + """ + + def __init__(self, lm): + super().__init__() + self.lm = lm + + def forward(self, inputs_embeds, position_ids): + out = self.lm(inputs_embeds=inputs_embeds, position_ids=position_ids) + return out.logits if hasattr(out, "logits") else out + + +def _compile_eagle2_lm( + language_model: torch.nn.Module, + input_embeds: torch.Tensor, + args: argparse.Namespace, +) -> torch.nn.Module: + """ + Compile Eagle2 language model with Torch-TensorRT. + + The function follows the same precision-specific flag logic used in + *run_llm.py* for consistency. + """ + lm_wrap = _LMNoCache(language_model).to(DEVICE).eval() + max_seq_len = input_embeds.shape[1] + args.num_tokens + + S = torch.export.Dim("seq", min=1, max=max_seq_len) + position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) + dyn_shapes = {"inputs_embeds": {1: S}, "position_ids": {1: S}} + + # Precision-specific flags --------------------------------------------------# + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + else: # FP32 + enabled_precisions = {torch.float32} + + with torch.inference_mode(): + exported = torch.export.export( + lm_wrap, + (input_embeds, position_ids), + dynamic_shapes=dyn_shapes, + strict=False, + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_mod = torch_tensorrt.dynamo.compile( + exported, + inputs=[input_embeds, position_ids], + enabled_precisions=enabled_precisions, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, + device=DEVICE, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + offload_module_to_cpu=True, + min_block_size=args.min_block_size, + ) + return trt_mod + + +def compile_torchtrt( + model: torch.nn.Module, args: argparse.Namespace +) -> torch.nn.Module: + """ + Front-end dispatcher mirroring *run_llm.py*’s `compile_torchtrt`. + + Depending on the target VLM, delegates to the appropriate compile routine. + """ + torch_dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + example_embeds = torch.randn( + 1, + 2560, + model.language_model.config.hidden_size, + dtype=torch_dtype, + device=DEVICE, + ) + + if args.model.lower() == "eagle2": + return _compile_eagle2_lm(model.language_model, example_embeds, args) + + msg = f"Unsupported model for compilation: {args.model}" + raise ValueError(msg) + + +# -----------------------------------------------------------------------------# +# Utility helpers +# -----------------------------------------------------------------------------# + + +def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): + """Pretty-print generated text for comparison.""" + print(f"========= {backend_name} =========") + print( + f"{backend_name} model generated text: ", + tokenizer.decode(gen_tokens[0], skip_special_tokens=True), + ) + print("===================================") + + +# -----------------------------------------------------------------------------# +# Main driver +# -----------------------------------------------------------------------------# +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run VLM inference (PyTorch & TensorRT back-ends)" + ) + parser.add_argument("--model", default="eagle2", help="VLM model name") + parser.add_argument("--prompt", default="Describe this image.", help="Prompt text") + parser.add_argument( + "--precision", + default="FP16", + choices=["FP16", "BF16", "FP32"], + help="Computation precision", + ) + parser.add_argument("--iterations", type=int, default=5, help="# iterations") + parser.add_argument("--min_block_size", type=int, default=1, help="Min block size") + parser.add_argument("--num_tokens", type=int, default=128, help="# new tokens") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--isl", type=int, default=2048, help="Input seq length") + parser.add_argument( + "--enable_pytorch_run", + action="store_true", + help="Run the PyTorch baseline as well", + ) + parser.add_argument( + "--cache", + default="", + choices=["", "static_v1"], + help="KV-cache variant to use", + ) + parser.add_argument( + "--debug", action="store_true", help="Enable Torch-TensorRT debug logs" + ) + parser.add_argument( + "--benchmark", action="store_true", help="Enable benchmarking mode" + ) + + args = parser.parse_args() + + # -------------------------------------------------------------------------# + # 1. Model / processor / embeddings + # -------------------------------------------------------------------------# + dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + model, processor, emb_layer = _load_model(args.model, DEVICE, dtype) + + # -------------------------------------------------------------------------# + # 2. Input construction (image + text prompt) + # -------------------------------------------------------------------------# + url = "https://cdn.pixabay.com/photo/2019/08/08/23/33/car-4393990_1280.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + if args.benchmark: + prompt_len = args.isl - 1792 - 26 + prompt_txt = " ".join(["token"] * max(prompt_len, 0)) + else: + prompt_txt = args.prompt + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": prompt_txt}, + ], + } + ] + + txt = [ + processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + ] + img_in, vid_in = processor.process_vision_info(messages) + inputs = processor( + text=txt, images=img_in, videos=vid_in, return_tensors="pt", padding=True + ).to(DEVICE) + + max_output_len = inputs["input_ids"].shape[1] + args.num_tokens + + # -------------------------------------------------------------------------# + # 3. Optional: PyTorch baseline + # -------------------------------------------------------------------------# + pyt_gen_tokens = pyt_timings = pyt_stats = None + if args.enable_pytorch_run: + pyt_gen_tokens = generate_mm( + model, + inputs["pixel_values"], + inputs["input_ids"], + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + ) + if args.benchmark: + pyt_timings = time_generate_mm( + generate_mm, + model, + inputs["pixel_values"].clone(), + inputs["input_ids"].clone(), + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + iterations=args.iterations, + ) + pyt_stats = record_stats( + "PyTorch", + pyt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + # Register static cache lowering passes if requested + if args.cache == "static_v1": + import static_cache_v1 # noqa: F401 + + # -------------------------------------------------------------------------# + # 4. Torch-TensorRT compile & run + # -------------------------------------------------------------------------# + trt_lm = compile_torchtrt(model, args) + trt_model = copy.deepcopy(model) + trt_model.language_model = trt_lm + + emb_layer = emb_layer.to(DEVICE) + + if args.cache == "static_v1": + trt_generate = generate_mm_with_static_cache + else: + trt_generate = generate_mm + + trt_gen_tokens = trt_generate( + trt_model, + inputs["pixel_values"], + inputs["input_ids"], + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + DEVICE if args.cache == "static_v1" else None, # device arg only for static_v1 + ) + + if args.benchmark: + trt_timings = time_generate_mm( + trt_generate, + trt_model, + inputs["pixel_values"].clone(), + inputs["input_ids"].clone(), + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + iterations=args.iterations, + device=DEVICE if args.cache == "static_v1" else None, + ) + trt_stats = record_stats( + "TensorRT", + trt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + # -------------------------------------------------------------------------# + # 5. Reporting + # -------------------------------------------------------------------------# + if not args.benchmark: + if args.enable_pytorch_run: + print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer) + print_outputs("TensorRT", trt_gen_tokens, processor.tokenizer) + + if args.enable_pytorch_run: + print( + f"PyTorch and TensorRT outputs match: " + f"{torch.equal(pyt_gen_tokens, trt_gen_tokens)}" + ) + + if args.benchmark: + if args.enable_pytorch_run: + print("========= PyTorch PERFORMANCE =========\n") + print(pyt_stats) + print("=====================\n") + print("========= TensorRT PERFORMANCE =========\n") + print(trt_stats) diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 2c3434b0ed..5e188f0e8b 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -242,3 +242,464 @@ def record_stats(backend, timings, precision, batch_size=1, compile_time_s=None) "Compile Time(s)": compile_time_s, } return stats + + +def generate_mm( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, +): + """Greedy decode for Eagle2-style VLM. + + Parameters + ---------- + model : nn.Module + Must expose vision_model, mlp1, language_model, pixel_shuffle, downsample_ratio, image_token_index. + pixel_values : Tensor | None + Input image batch (B,C,H,W) or None. + input_ids : LongTensor (B, N_prompt) + Text prompt token ids including [IMG] placeholder(s). + max_output_seq_length : int + Maximum tokens to generate **in addition to** the prompt. + eos_token_id : int + Stop generation when all sequences emit EOS. + emb_layer : nn.Embedding + Embedding layer for input_ids. + """ + + vit_embeds = None + + if pixel_values is not None: + # --- Vision encoder timing --- + vis_s = torch.cuda.Event(enable_timing=True) + vis_e = torch.cuda.Event(enable_timing=True) + vis_s.record() + vit_out = model.vision_model(pixel_values) + vis_e.record() + torch.cuda.synchronize() + + vit_embeds = ( + vit_out.last_hidden_state + if hasattr(vit_out, "last_hidden_state") + else vit_out + ) + + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = model.pixel_shuffle( + vit_embeds, scale_factor=model.downsample_ratio + ) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + vit_embeds = model.mlp1(vit_embeds) + + # 2) Text token embeddings + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat_emb = seq_embeds.view(B * N, C) + + mask = seq_tokens.view(B * N) == model.image_token_index + try: + flat_emb[mask] = vit_embeds.reshape(-1, C).to(flat_emb.dtype)[: mask.sum()] + except Exception: + # Fallback in unlikely size-mismatch cases + flat_emb[mask] = vit_embeds.reshape(-1, C)[: mask.sum()].to(flat_emb.dtype) + seq_embeds = flat_emb.view(B, N, C) + + # ───────────────────────────────── Greedy loop ─────────────────────────────────────────────────── + isl = seq_tokens.shape[1] + osl = max_output_seq_length - isl + + generated = 0 + + while generated < osl: + cur_embeds = seq_embeds # full seq first step or cache off + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + with torch.no_grad(): + logits = model.language_model(inputs_embeds=cur_embeds, position_ids=position_ids) + if hasattr(logits, "logits"): + logits = logits.logits + + next_tok = torch.argmax(logits[:, -1, :], dim=-1) # (B,) + # append token & embed + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=-1) + seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) + + generated += 1 + if (next_tok == eos_token_id).all(): + break + + return seq_tokens[:, input_ids.shape[1] :] + + +@torch.inference_mode() +def generate_mm_with_static_cache( + model, # Complete VLM module + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, # (B, N_prompt) + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + device: str = "cuda:0", +) -> torch.LongTensor: # (B, N_prompt + new) + """ + Greedy Decoder for multimodal VLM (using static KV-cache v1). + Basic structure is identical to LM version (generate_with_static_cache) but + * Input is `inputs_embeds` + * Vision tokens are sent together only in the first step + """ + + # ───────────────────── Vision encoding ───────────────────── + vit_embeds = None + if pixel_values is not None: + vit_latent = model.vision_model(pixel_values) + vit_embeds = ( + vit_latent.last_hidden_state + if hasattr(vit_latent, "last_hidden_state") + else vit_latent + ) + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.view(vit_embeds.size(0), h, w, -1) + vit_embeds = model.pixel_shuffle(vit_embeds, model.downsample_ratio) + vit_embeds = vit_embeds.view(vit_embeds.size(0), -1, vit_embeds.size(-1)) + vit_embeds = model.mlp1(vit_embeds) # (B, N_img, C) + + # ───────────────────── Text embedding & [IMG] replacement ───────────── + seq_tokens = input_ids.clone() # (B, N_txt) + seq_embeds = emb_layer(seq_tokens) # (B, N_txt, C) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] + seq_embeds = flat.view(B, N, C) + + # ───────────────────── KV-cache initialization ───────────────────── + kv_cache = get_zeroed_static_cache_inputs( + model.language_model + ) + start_idx = 0 # First token index + end_idx = seq_embeds.size(1) # Prompt length + generated = 0 + max_total_len = max_output_seq_length + output_tokens = seq_tokens.clone() + + # ───────────────────── Greedy loop ─────────────────────── + while output_tokens.size(1) < max_total_len: + + # When using static cache: + # - First step: Use full prompt embedding + # - Subsequent steps: Use only new token embedding (KV cache remembers previous tokens) + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + + # position_ids: Same pattern as generate_with_static_cache + # - First step: Position of entire sequence + # - Subsequent steps: Position of current token only + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + else: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + + # is_causal = True if cur_embeds.shape[1] > 1 else False + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + # is_causal, + ) + + logits_and_kv = model.language_model(*input_signature) + logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] + + next_tok = logits[:, -1, :].argmax(dim=-1) # (B,) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + # Prepare for next step - Static cache only needs new token + next_embed = emb_layer(next_tok)[:, None, :] # (B, 1, C) + seq_embeds = next_embed # Next step uses only new token + + generated += 1 + start_idx = end_idx + end_idx += 1 + # is_causal = True # Causal mask active from now on + + if (next_tok == eos_token_id).all(): + break + + return output_tokens + + +def generate_mm_with_timing( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + use_cache: bool = False, +): + # Create timing events + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + mlp_start = torch.cuda.Event(enable_timing=True) + mlp_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + + overall_start.record() + + vit_embeds = None + if pixel_values is not None: + vision_start.record() + vit_out = model.vision_model(pixel_values) + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + + vit_embeds = ( + vit_out.last_hidden_state + if hasattr(vit_out, "last_hidden_state") + else vit_out + ) + + mlp_start.record() + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = model.pixel_shuffle( + vit_embeds, scale_factor=model.downsample_ratio + ) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + vit_embeds = model.mlp1(vit_embeds) + mlp_end.record() + torch.cuda.synchronize() + mlp_time = mlp_start.elapsed_time(mlp_end) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat_emb = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat_emb[mask] = vit_embeds.reshape(-1, C).to(flat_emb.dtype)[: mask.sum()] + seq_embeds = flat_emb.view(B, N, C) + + step_times = [] + generated = 0 + past_key_values = None + + while generated < max_output_seq_length: + lm_start.record() + cur_embeds = seq_embeds + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + with torch.no_grad(): + logits = model.language_model( + inputs_embeds=cur_embeds, position_ids=position_ids + ) + if hasattr(logits, "logits"): + logits = logits.logits + + next_tok = torch.argmax(logits[:, -1, :], dim=-1) + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=-1) + seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) + + generated += 1 + if (next_tok == eos_token_id).all(): + break + + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + + return seq_tokens, step_times, overall_time, vision_time, mlp_time + + +@torch.inference_mode() +def generate_mm_with_static_cache_timing( + model, # Complete VLM module + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, # (B, N_prompt) + eos_token_id: int, + emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, + device: str = "cuda:0", +) -> tuple: # (seq_tokens, step_times, overall_time, vision_time, mlp_time) + """ + Greedy Decoder for multimodal VLM (using static KV-cache v1) + detailed timing measurement. + + Returns: + seq_tokens: Generated token sequence + step_times: Language model inference time for each step (ms) + overall_time: Total execution time (ms) + vision_time: Vision encoding time (ms) + mlp_time: MLP processing time (ms) + """ + + # ───────────────────── Create timing events ───────────────────── + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + mlp_start = torch.cuda.Event(enable_timing=True) + mlp_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + + overall_start.record() + + # ───────────────────── Vision encoding ───────────────────── + vit_embeds = None + vision_time = 0.0 + mlp_time = 0.0 + + if pixel_values is not None: + vision_start.record() + vit_latent = model.vision_model(pixel_values) + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + + vit_embeds = ( + vit_latent.last_hidden_state + if hasattr(vit_latent, "last_hidden_state") + else vit_latent + ) + + mlp_start.record() + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.view(vit_embeds.size(0), h, w, -1) + vit_embeds = model.pixel_shuffle(vit_embeds, model.downsample_ratio) + vit_embeds = vit_embeds.view(vit_embeds.size(0), -1, vit_embeds.size(-1)) + vit_embeds = model.mlp1(vit_embeds) # (B, N_img, C) + mlp_end.record() + torch.cuda.synchronize() + mlp_time = mlp_start.elapsed_time(mlp_end) + + # ───────────────────── Text embedding & [IMG] replacement ───────────── + seq_tokens = input_ids.clone() # (B, N_txt) + seq_embeds = emb_layer(seq_tokens) # (B, N_txt, C) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] + seq_embeds = flat.view(B, N, C) + + # ───────────────────── KV-cache initialization ───────────────────── + kv_cache = get_zeroed_static_cache_inputs( + model.language_model + ) + start_idx = 0 # First token index + end_idx = seq_embeds.size(1) # Prompt length + generated = 0 + max_total_len = end_idx + max_new_tokens + output_tokens = seq_tokens.clone() + step_times = [] # Timing for each step + + # ───────────────────── Greedy loop ─────────────────────── + while output_tokens.size(1) < max_total_len: + lm_start.record() + + # When using static cache: + # - First step: Use full prompt embedding + # - Subsequent steps: Use only new token embedding (KV cache remembers previous tokens) + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + + # position_ids: Same pattern as generate_with_static_cache + # - First step: Position of entire sequence + # - Subsequent steps: Position of current token only + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + else: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + + # is_causal = True if cur_embeds.shape[1] > 1 else False + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + # is_causal, + ) + + logits_and_kv = model.language_model(*input_signature) + logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] + + next_tok = logits[:, -1, :].argmax(dim=-1) # (B,) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + # Prepare for next step - Static cache only needs new token + next_embed = emb_layer(next_tok)[:, None, :] # (B, 1, C) + seq_embeds = next_embed # Next step uses only new token + + generated += 1 + start_idx = end_idx + end_idx += 1 + + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + if (next_tok == eos_token_id).all(): + break + + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + + return output_tokens, step_times, overall_time, vision_time, mlp_time + + +def time_generate_mm( + generate_fn, + model, + pixel_values, + input_ids, + output_seq_length, + eos_token_id, + emb_layer, + iterations=10, + device="cuda:0", +): + """ + Measure the time for generating a sentence over certain number of iterations + """ + timings = [] + for _ in range(iterations): + start_time = timeit.default_timer() + _ = generate_fn( + model, pixel_values, input_ids, output_seq_length, eos_token_id, emb_layer + ) + torch.cuda.synchronize() + end_time = timeit.default_timer() + timings.append(end_time - start_time) + + return timings