Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize internvit #3316

Open
wants to merge 12 commits into
base: dev
Choose a base branch
from
1 change: 1 addition & 0 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ jobs:
-DUSE_NVTX=ON \
-DSM=80 \
-DCMAKE_CUDA_ARCHITECTURES=80 \
-DCMAKE_POLICY_VERSION_MINIMUM=3.5 \
-DBUILD_TEST=OFF
make -j$(nproc) && make install
- name: Install lmdeploy
Expand Down
1 change: 1 addition & 0 deletions generate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ cmake ${builder} .. \
-DBUILD_PY_FFI=ON \
-DBUILD_MULTI_GPU=ON \
-DCMAKE_CUDA_FLAGS="-lineinfo" \
-DCMAKE_POLICY_VERSION_MINIMUM=3.5 \
-DUSE_NVTX=ON
15 changes: 15 additions & 0 deletions lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from lmdeploy.pytorch.backends.selector import get_backend
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.model_inputs import StepContext
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf

self.graph_pool_handle = torch.cuda.graph_pool_handle()
self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict()
self.has_try_compile_model: bool = False

def check_enable_graph(self):
"""check enable graph."""
Expand All @@ -124,6 +126,16 @@ def check_enable_graph(self):

return getattr(self.model, 'support_cuda_graph', _false)

def _try_compile_model_once(self):
if self.has_try_compile_model:
return

if hasattr(self.model, 'compile_model'):
method = getattr(self.model, 'compile_model')
method()

self.has_try_compile_model = True

def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List,
attn_metadata: Any, inputs_embeds: torch.Tensor, **kwargs):
"""get graph key."""
Expand All @@ -135,6 +147,9 @@ def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, pas

def __call__(self, **kwargs):
"""call."""
if not self.backend_config.eager_mode and get_backend().get_name() == 'cuda':
self._try_compile_model_once()

enable_graph = self.enable_graph(**kwargs)

if not enable_graph:
Expand Down
50 changes: 44 additions & 6 deletions lmdeploy/pytorch/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

import torch
import torch.nn.functional as F
from packaging import version
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch, split_batch
from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
from lmdeploy.pytorch.nn import LayerNorm, RMSNorm
from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear
Expand Down Expand Up @@ -205,15 +208,22 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
self.ls1 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))
self.ls2 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))

def forward(
self,
hidden_states: torch.Tensor,
):
"""forward."""
hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
@enable_micro_batch(param_name='hidden_states', index=0)
def _attn(self, hidden_states):
hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states[0].dtype)) * self.ls1
return hidden_states

@enable_micro_batch(param_name='hidden_states', index=0)
def _mlp(self, hidden_states):
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2
return hidden_states

def forward(
self,
hidden_states,
):
hidden_states = self._attn(hidden_states)
hidden_states = self._mlp(hidden_states)
return hidden_states


Expand Down Expand Up @@ -306,6 +316,33 @@ def __init__(self,

self.input_processor = InternVLInputProcessor(self.config, dtype)

self.compile_vit = False

def compile_model(self):
torch_version = version.parse(torch.__version__)
if torch_version < version.parse('2.5.0'):
return

world_size, _ = get_world_rank()
if torch_version >= version.parse('2.6.0') and world_size > 1:
torch._inductor.config.reorder_for_compute_comm_overlap = True
if isinstance(self.vision_model, InternVisionModel):
self.vision_model.encoder.forward = split_batch(self.vision_model.encoder.forward,
'inputs_embeds',
index=0)

self.extract_feature = torch.compile(self.extract_feature, mode='max-autotune')
self.compile_vit = True
self.has_compiled_vit = False

def _mark_dynamic_once(self, pixel_values, dims):
"""call torch._dynamo.mark_dynamic to avoid recompile."""
if not self.compile_vit or self.has_compiled_vit or pixel_values is None:
return

torch._dynamo.mark_dynamic(pixel_values, dims)
self.has_compiled_vit = True

def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
Expand Down Expand Up @@ -350,6 +387,7 @@ def forward(
):
if inputs_embeds is None and pixel_values is not None:
# extract feature
self._mark_dynamic_once(pixel_values, [0])
vit_embeds = self.extract_feature(pixel_values)
lang_embeds = self.language_model.get_input_embeddings()(input_ids)
lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds)
Expand Down
61 changes: 61 additions & 0 deletions lmdeploy/pytorch/models/utils/micro_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools

import torch


def enable_micro_batch(param_name, index=-1):
"""Decorator factory to enable micro-batch computation."""

def decorator(func):

@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if index != -1 and len(args) > index:
inputs = args[index]
else:
inputs = kwargs.get(param_name, None)

if isinstance(inputs, list):
# Apply forward computation to each micro-batch
results = []
for input in inputs:
if index != -1 and len(args) > index:
args = args[0:index] + (input, ) + args[index + 1:]
else:
kwargs[param_name] = input
result = func(self, *args, **kwargs)
results.append(result)
return results
else:
# If not a list, directly apply the forward computation
return func(self, *args, **kwargs)

return wrapper

return decorator


def split_batch(func, param_name, index=-1, num_splits=2):
"""Decorator to split along the 0th dimension into a specified number of
chunks."""

def wrapper(*args, **kwargs):
if index != -1 and len(args) > index:
inputs = args[index]
else:
inputs = kwargs.get(param_name, None)

if inputs is not None:
split_inputs = list(torch.chunk(inputs, num_splits, dim=0))
if index != -1 and len(args) > index:
args = args[0:index] + (split_inputs, ) + args[index + 1:]
else:
kwargs[param_name] = split_inputs

results = func(*args, **kwargs)
return torch.cat(results, dim=0)
else:
return func(*args, **kwargs)

return wrapper
6 changes: 3 additions & 3 deletions requirements/runtime_cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ safetensors
sentencepiece
shortuuid
tiktoken
torch<=2.5.1,>=2.0.0
torchvision<=0.20.1,>=0.15.0
torch<=2.6.0,>=2.0.0
torchvision<=0.21.0,>=0.15.0
transformers
triton<=3.1.0,>=3.0.0; sys_platform == "linux"
triton<=3.2.0,>=3.0.0; sys_platform == "linux"
uvicorn