Skip to content

[kv_cache] integrated vlm code for benchmark (Stacked on #3527) #3652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: kv_cache
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
387 changes: 387 additions & 0 deletions tools/llm/run_vlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,387 @@
"""
.. _run_vlm:

Running VLM inference with Torch-TensorRT
==========================================================

This script mirrors the style and structure of *run_llm.py*, illustrating a
Torch-TensorRT (dynamo backend) workflow for Visual-Language Models (VLMs).
"""

import argparse
import copy
import os
import sys
from contextlib import nullcontext
from typing import Tuple

import requests
import torch
import torch_tensorrt
from PIL import Image
from torchtrt_ext import register_sdpa
from transformers import AutoModel, AutoProcessor
from utils import (
generate_mm,
generate_mm_with_static_cache,
record_stats,
time_generate_mm,
)

# -----------------------------------------------------------------------------#
# Global configuration
# -----------------------------------------------------------------------------#
DEVICE = torch.device("cuda:0")

# Register SDPA as a standalone operator. Converter & lowering pass are defined
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
import transformers.models.qwen2.modeling_qwen2 as mq # noqa: E402

mq.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = mq.ALL_ATTENTION_FUNCTIONS["sdpa"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try to modify the model config after it is initialized instead of modifying the entries ?

model.config._attn_implementation = "sdpa"


# -----------------------------------------------------------------------------#
# Model loading helpers
# -----------------------------------------------------------------------------#


def _load_eagle2(device: torch.device, torch_dtype: torch.dtype):
"""
Load Eagle2 model and processor.

Returns
-------
tuple[torch.nn.Module, transformers.AutoProcessor, torch.nn.Embedding]
The model, its processor and the language-model input embedding layer.
"""
model_id = "nvidia/Eagle2-2B"
with torch.no_grad():
model = (
AutoModel.from_pretrained(
model_id, trust_remote_code=True, torch_dtype=torch_dtype
)
.eval()
.to(device)
)

processor = AutoProcessor.from_pretrained(
model_id, trust_remote_code=True, use_fast=True
)
if hasattr(processor, "tokenizer"):
processor.tokenizer.padding_side = "left"

emb_layer = model.language_model.get_input_embeddings().to(torch_dtype).to(device)
return model, processor, emb_layer


def _load_model(
model_name: str, device: torch.device, torch_dtype: torch.dtype
) -> Tuple[torch.nn.Module, AutoProcessor, torch.nn.Embedding]:
"""Dispatch helper for supported VLMs."""
if model_name.lower() == "eagle2":
return _load_eagle2(device, torch_dtype)
msg = f"Unsupported model: {model_name}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider modifying the message to Encountered model: {model_name} is not supported. Supported models include nvidia/Eagle2-2B, PaliGemma2 (and others if they are supported)

raise ValueError(msg)


# -----------------------------------------------------------------------------#
# Torch-TensorRT compilation helpers
# -----------------------------------------------------------------------------#


class _LMNoCache(torch.nn.Module):
"""
Thin wrapper that exposes a language model via ``inputs_embeds`` without KV-cache.
"""

def __init__(self, lm):
super().__init__()
self.lm = lm

def forward(self, inputs_embeds, position_ids):
out = self.lm(inputs_embeds=inputs_embeds, position_ids=position_ids)
return out.logits if hasattr(out, "logits") else out


def _compile_eagle2_lm(
language_model: torch.nn.Module,
input_embeds: torch.Tensor,
args: argparse.Namespace,
) -> torch.nn.Module:
"""
Compile Eagle2 language model with Torch-TensorRT.

The function follows the same precision-specific flag logic used in
*run_llm.py* for consistency.
"""
lm_wrap = _LMNoCache(language_model).to(DEVICE).eval()
max_seq_len = input_embeds.shape[1] + args.num_tokens

S = torch.export.Dim("seq", min=1, max=max_seq_len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider renaming S to seq_len

position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE)
dyn_shapes = {"inputs_embeds": {1: S}, "position_ids": {1: S}}

# Precision-specific flags --------------------------------------------------#
use_fp32_acc = False
use_explicit_typing = False
if args.precision == "FP16":
enabled_precisions = {torch.float32}
use_fp32_acc = True
use_explicit_typing = True
elif args.precision == "BF16":
enabled_precisions = {torch.bfloat16}
else: # FP32
enabled_precisions = {torch.float32}

with torch.inference_mode():
exported = torch.export.export(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor comment: consider renaming exported to exported_program here

lm_wrap,
(input_embeds, position_ids),
dynamic_shapes=dyn_shapes,
strict=False,
)

with torch_tensorrt.logging.debug() if args.debug else nullcontext():
trt_mod = torch_tensorrt.dynamo.compile(
exported,
inputs=[input_embeds, position_ids],
enabled_precisions=enabled_precisions,
use_explicit_typing=use_explicit_typing,
use_fp32_acc=use_fp32_acc,
device=DEVICE,
disable_tf32=True,
use_python_runtime=True,
debug=args.debug,
offload_module_to_cpu=True,
min_block_size=args.min_block_size,
)
return trt_mod


def compile_torchtrt(
model: torch.nn.Module, args: argparse.Namespace
) -> torch.nn.Module:
"""
Front-end dispatcher mirroring *run_llm.py*’s `compile_torchtrt`.

Depending on the target VLM, delegates to the appropriate compile routine.
Comment on lines +164 to +166
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making this docstring more informative. It looks like this function only compiles the LLM part of the VLM. Please mention that.

"""
torch_dtype = {
"FP16": torch.float16,
"BF16": torch.bfloat16,
}.get(args.precision, torch.float32)

example_embeds = torch.randn(
1,
2560,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider declaring this in a map eg: image_tokens_map = {model_name: 2560}EAGLE_IMG_TOKENS = 2560.

model.language_model.config.hidden_size,
dtype=torch_dtype,
device=DEVICE,
)

if args.model.lower() == "eagle2":
return _compile_eagle2_lm(model.language_model, example_embeds, args)

msg = f"Unsupported model for compilation: {args.model}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider changing this message similar to the comment above (by including list of supported models)

raise ValueError(msg)


# -----------------------------------------------------------------------------#
# Utility helpers
# -----------------------------------------------------------------------------#


def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer):
"""Pretty-print generated text for comparison."""
print(f"========= {backend_name} =========")
print(
f"{backend_name} model generated text: ",
tokenizer.decode(gen_tokens[0], skip_special_tokens=True),
)
print("===================================")


# -----------------------------------------------------------------------------#
# Main driver
# -----------------------------------------------------------------------------#
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run VLM inference (PyTorch & TensorRT back-ends)"
)
parser.add_argument("--model", default="eagle2", help="VLM model name")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let the model names be exactly same as HF model names instead of using short forms. Eg: the default should be nvidia/Eagle2-2B

parser.add_argument("--prompt", default="Describe this image.", help="Prompt text")
parser.add_argument(
"--precision",
default="FP16",
choices=["FP16", "BF16", "FP32"],
help="Computation precision",
)
parser.add_argument("--iterations", type=int, default=5, help="# iterations")
parser.add_argument("--min_block_size", type=int, default=1, help="Min block size")
parser.add_argument("--num_tokens", type=int, default=128, help="# new tokens")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--isl", type=int, default=2048, help="Input seq length")
parser.add_argument(
"--enable_pytorch_run",
action="store_true",
help="Run the PyTorch baseline as well",
)
parser.add_argument(
"--cache",
default="",
choices=["", "static_v1"],
help="KV-cache variant to use",
)
parser.add_argument(
"--debug", action="store_true", help="Enable Torch-TensorRT debug logs"
)
parser.add_argument(
"--benchmark", action="store_true", help="Enable benchmarking mode"
)

args = parser.parse_args()

# -------------------------------------------------------------------------#
# 1. Model / processor / embeddings
# -------------------------------------------------------------------------#
dtype = {
"FP16": torch.float16,
"BF16": torch.bfloat16,
}.get(args.precision, torch.float32)

model, processor, emb_layer = _load_model(args.model, DEVICE, dtype)

# -------------------------------------------------------------------------#
# 2. Input construction (image + text prompt)
# -------------------------------------------------------------------------#
url = "https://cdn.pixabay.com/photo/2019/08/08/23/33/car-4393990_1280.jpg"
image = Image.open(requests.get(url, stream=True).raw)
Comment on lines +256 to +257
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to a load image function


if args.benchmark:
prompt_len = args.isl - 1792 - 26
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is 1792 and 26 ? Please specify them as variables and add comments indicating what they are

prompt_txt = " ".join(["token"] * max(prompt_len, 0))
else:
prompt_txt = args.prompt

messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt_txt},
],
}
]

txt = [
processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
]
img_in, vid_in = processor.process_vision_info(messages)
inputs = processor(
text=txt, images=img_in, videos=vid_in, return_tensors="pt", padding=True
).to(DEVICE)

max_output_len = inputs["input_ids"].shape[1] + args.num_tokens

# -------------------------------------------------------------------------#
# 3. Optional: PyTorch baseline
# -------------------------------------------------------------------------#
pyt_gen_tokens = pyt_timings = pyt_stats = None
if args.enable_pytorch_run:
pyt_gen_tokens = generate_mm(
model,
inputs["pixel_values"],
inputs["input_ids"],
max_output_len,
processor.tokenizer.eos_token_id,
emb_layer,
)
if args.benchmark:
pyt_timings = time_generate_mm(
generate_mm,
model,
inputs["pixel_values"].clone(),
inputs["input_ids"].clone(),
max_output_len,
processor.tokenizer.eos_token_id,
emb_layer,
iterations=args.iterations,
)
pyt_stats = record_stats(
"PyTorch",
pyt_timings,
args.precision,
batch_size=args.batch_size,
compile_time_s=None,
)

# Register static cache lowering passes if requested
if args.cache == "static_v1":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is static_v2 not working ?

import static_cache_v1 # noqa: F401

# -------------------------------------------------------------------------#
# 4. Torch-TensorRT compile & run
# -------------------------------------------------------------------------#
trt_lm = compile_torchtrt(model, args)
trt_model = copy.deepcopy(model)
trt_model.language_model = trt_lm

emb_layer = emb_layer.to(DEVICE)

if args.cache == "static_v1":
trt_generate = generate_mm_with_static_cache
else:
trt_generate = generate_mm

trt_gen_tokens = trt_generate(
trt_model,
inputs["pixel_values"],
inputs["input_ids"],
max_output_len,
processor.tokenizer.eos_token_id,
emb_layer,
DEVICE if args.cache == "static_v1" else None, # device arg only for static_v1
)

if args.benchmark:
trt_timings = time_generate_mm(
trt_generate,
trt_model,
inputs["pixel_values"].clone(),
inputs["input_ids"].clone(),
max_output_len,
processor.tokenizer.eos_token_id,
emb_layer,
iterations=args.iterations,
device=DEVICE if args.cache == "static_v1" else None,
)
trt_stats = record_stats(
"TensorRT",
trt_timings,
args.precision,
batch_size=args.batch_size,
compile_time_s=None,
)

# -------------------------------------------------------------------------#
# 5. Reporting
# -------------------------------------------------------------------------#
if not args.benchmark:
if args.enable_pytorch_run:
print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer)
print_outputs("TensorRT", trt_gen_tokens, processor.tokenizer)

if args.enable_pytorch_run:
print(
f"PyTorch and TensorRT outputs match: "
f"{torch.equal(pyt_gen_tokens, trt_gen_tokens)}"
)

if args.benchmark:
if args.enable_pytorch_run:
print("========= PyTorch PERFORMANCE =========\n")
print(pyt_stats)
print("=====================\n")
print("========= TensorRT PERFORMANCE =========\n")
print(trt_stats)
Loading
Loading