-
Notifications
You must be signed in to change notification settings - Fork 364
[kv_cache] integrated vlm code for benchmark (Stacked on #3527) #3652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: kv_cache
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tools/llm/utils.py 2025-07-03 01:46:24.189295+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/utils.py 2025-07-03 01:46:57.661981+00:00
@@ -318,14 +318,16 @@
generated = 0
while generated < osl:
cur_embeds = seq_embeds # full seq first step or cache off
position_ids = (
- torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device)
- )
+ torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device)
+ )
with torch.no_grad():
- logits = model.language_model(inputs_embeds=cur_embeds, position_ids=position_ids)
+ logits = model.language_model(
+ inputs_embeds=cur_embeds, position_ids=position_ids
+ )
if hasattr(logits, "logits"):
logits = logits.logits
next_tok = torch.argmax(logits[:, -1, :], dim=-1) # (B,)
# append token & embed
@@ -381,13 +383,11 @@
mask = seq_tokens.view(B * N) == model.image_token_index
flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()]
seq_embeds = flat.view(B, N, C)
# ───────────────────── KV-cache initialization ─────────────────────
- kv_cache = get_zeroed_static_cache_inputs(
- model.language_model
- )
+ kv_cache = get_zeroed_static_cache_inputs(model.language_model)
start_idx = 0 # First token index
end_idx = seq_embeds.size(1) # Prompt length
generated = 0
max_total_len = max_output_seq_length
output_tokens = seq_tokens.clone()
@@ -607,13 +607,11 @@
mask = seq_tokens.view(B * N) == model.image_token_index
flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()]
seq_embeds = flat.view(B, N, C)
# ───────────────────── KV-cache initialization ─────────────────────
- kv_cache = get_zeroed_static_cache_inputs(
- model.language_model
- )
+ kv_cache = get_zeroed_static_cache_inputs(model.language_model)
start_idx = 0 # First token index
end_idx = seq_embeds.size(1) # Prompt length
generated = 0
max_total_len = end_idx + max_new_tokens
output_tokens = seq_tokens.clone()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tools/llm/utils.py 2025-07-03 01:46:23.684507+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/utils.py 2025-07-03 01:46:57.646835+00:00
@@ -318,14 +318,16 @@
generated = 0
while generated < osl:
cur_embeds = seq_embeds # full seq first step or cache off
position_ids = (
- torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device)
- )
+ torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device)
+ )
with torch.no_grad():
- logits = model.language_model(inputs_embeds=cur_embeds, position_ids=position_ids)
+ logits = model.language_model(
+ inputs_embeds=cur_embeds, position_ids=position_ids
+ )
if hasattr(logits, "logits"):
logits = logits.logits
next_tok = torch.argmax(logits[:, -1, :], dim=-1) # (B,)
# append token & embed
@@ -381,13 +383,11 @@
mask = seq_tokens.view(B * N) == model.image_token_index
flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()]
seq_embeds = flat.view(B, N, C)
# ───────────────────── KV-cache initialization ─────────────────────
- kv_cache = get_zeroed_static_cache_inputs(
- model.language_model
- )
+ kv_cache = get_zeroed_static_cache_inputs(model.language_model)
start_idx = 0 # First token index
end_idx = seq_embeds.size(1) # Prompt length
generated = 0
max_total_len = max_output_seq_length
output_tokens = seq_tokens.clone()
@@ -607,13 +607,11 @@
mask = seq_tokens.view(B * N) == model.image_token_index
flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()]
seq_embeds = flat.view(B, N, C)
# ───────────────────── KV-cache initialization ─────────────────────
- kv_cache = get_zeroed_static_cache_inputs(
- model.language_model
- )
+ kv_cache = get_zeroed_static_cache_inputs(model.language_model)
start_idx = 0 # First token index
end_idx = seq_embeds.size(1) # Prompt length
generated = 0
max_total_len = end_idx + max_new_tokens
output_tokens = seq_tokens.clone()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added initial set of review comments. I'll try to run this example and add more comments later. Where is the vision model being compiled here ?
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) | ||
import transformers.models.qwen2.modeling_qwen2 as mq # noqa: E402 | ||
|
||
mq.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = mq.ALL_ATTENTION_FUNCTIONS["sdpa"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try to modify the model config after it is initialized instead of modifying the entries ?
model.config._attn_implementation = "sdpa"
"""Dispatch helper for supported VLMs.""" | ||
if model_name.lower() == "eagle2": | ||
return _load_eagle2(device, torch_dtype) | ||
msg = f"Unsupported model: {model_name}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider modifying the message to Encountered model: {model_name} is not supported. Supported models include nvidia/Eagle2-2B, PaliGemma2 (and others if they are supported)
lm_wrap = _LMNoCache(language_model).to(DEVICE).eval() | ||
max_seq_len = input_embeds.shape[1] + args.num_tokens | ||
|
||
S = torch.export.Dim("seq", min=1, max=max_seq_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider renaming S
to seq_len
enabled_precisions = {torch.float32} | ||
|
||
with torch.inference_mode(): | ||
exported = torch.export.export( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor comment: consider renaming exported
to exported_program
here
|
||
example_embeds = torch.randn( | ||
1, | ||
2560, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider declaring this in a map eg: image_tokens_map = {model_name: 2560}EAGLE_IMG_TOKENS = 2560.
Front-end dispatcher mirroring *run_llm.py*’s `compile_torchtrt`. | ||
|
||
Depending on the target VLM, delegates to the appropriate compile routine. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider making this docstring more informative. It looks like this function only compiles the LLM part of the VLM. Please mention that.
parser = argparse.ArgumentParser( | ||
description="Run VLM inference (PyTorch & TensorRT back-ends)" | ||
) | ||
parser.add_argument("--model", default="eagle2", help="VLM model name") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let the model names be exactly same as HF model names instead of using short forms. Eg: the default should be nvidia/Eagle2-2B
url = "https://cdn.pixabay.com/photo/2019/08/08/23/33/car-4393990_1280.jpg" | ||
image = Image.open(requests.get(url, stream=True).raw) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this to a load image function
image = Image.open(requests.get(url, stream=True).raw) | ||
|
||
if args.benchmark: | ||
prompt_len = args.isl - 1792 - 26 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is 1792 and 26 ? Please specify them as variables and add comments indicating what they are
) | ||
|
||
# Register static cache lowering passes if requested | ||
if args.cache == "static_v1": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is static_v2 not working ?
Description
Base branch:
kv_cache
(PR #3527 )Type of change
Please delete options that are not relevant and/or add your own.
Checklist: