Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize internvit #3316

Draft
wants to merge 5 commits into
base: dev
Choose a base branch
from
Draft

Conversation

caikun-pjlab
Copy link

@caikun-pjlab caikun-pjlab commented Mar 24, 2025

Introduction

Optimize internvit.

Benchmark serving result on A100 using 4 gpus with python benchmark/profile_restful_api.py --backend lmdeploy --base-url http://0.0.0.0:23333 --dataset-name sharegpt --dataset-path /workspace/caikun/benchmark/ShareGPT_V3_unfiltered_cleaned_split.json --sharegpt-output-len 4 --num-prompts 500 --model OpenGVLab/InternVL2_5-78B

Benchmark input image use https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg.

About 13%+ improvement of the prefill phase.

Optimization methods

  1. Using torch.compile which will fuse small kernels and use triton to generate efficient mm and conv operators (+save runtime memory).
  2. Split batch to overlap communicate and calculate.

performance before optimization

============ Serving Benchmark Result ============
Backend:                                 lmdeploy
Traffic request rate:                    inf
Successful requests:                     500
Benchmark duration (s):                  394.35
Total input tokens:                      120414
Total generated tokens:                  2000
Total generated tokens (retokenized):    2000
Request throughput (req/s):              1.27
Input token throughput (tok/s):          305.35
Output token throughput (tok/s):         5.07
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   248414.83
Median E2E Latency (ms):                 221409.16
---------------Time to First Token----------------
Mean TTFT (ms):                          199967.16
Median TTFT (ms):                        200245.25
P99 TTFT (ms):                           392125.15
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          16149.22
Median TPOT (ms):                        15288.87
P99 TPOT (ms):                           35295.47
---------------Inter-token Latency----------------
Mean ITL (ms):                           26046.99
Median ITL (ms):                         3548.60
P99 ITL (ms):                            103363.55
==================================================

performance after optimization

============ Serving Benchmark Result ============
Backend:                                 lmdeploy
Traffic request rate:                    inf
Successful requests:                     500
Benchmark duration (s):                  347.99
Total input tokens:                      120414
Total generated tokens:                  2000
Total generated tokens (retokenized):    2000
Request throughput (req/s):              1.44
Input token throughput (tok/s):          346.02
Output token throughput (tok/s):         5.75
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   214875.82
Median E2E Latency (ms):                 189645.27
---------------Time to First Token----------------
Mean TTFT (ms):                          172406.40
Median TTFT (ms):                        170648.01
P99 TTFT (ms):                           344925.42
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          14156.47
Median TPOT (ms):                        13720.74
P99 TPOT (ms):                           30652.67
---------------Inter-token Latency----------------
Mean ITL (ms):                           22857.41
Median ITL (ms):                         3782.86
P99 ITL (ms):                            90382.94
==================================================

@caikun-pjlab caikun-pjlab marked this pull request as draft March 24, 2025 09:27
if tensor is None:
continue
torch._dynamo.mark_dynamic(tensor, dynamic_dims)
self.compiled = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is a little bit confusing(TorchCompile mode == not enable_graph and enter the not enable_graph branch)
I think the following code is more readable:

if self.compiled:
  self.model(**kwargs) # or something like self.compiled_model(**kwargs)

):
"""forward."""
hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
def enable_micro_batch(func):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this function to a common file? like until.py

@caikun-pjlab caikun-pjlab changed the base branch from main to dev March 26, 2025 11:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants