Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Documentation
quantization/w4a16.md
quantization/w8a8.md
quantization/kv_quant.md
quantization/blocked_fp8.md

.. _benchmark:
.. toctree::
Expand Down
59 changes: 59 additions & 0 deletions docs/en/quantization/blocked_fp8.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Blocked FP8 Quantization

LMDeploy supports a weight-only blocked FP8 quantization method. This approach quantizes the weights of a model to 8-bit floating-point numbers in a blocked format, which can reduce the model's memory footprint while maintaining good performance on supported hardware.

Before proceeding, please ensure that lmdeploy is installed by following the [installation guide](../get_started/installation.md). A typical installation command is:

```shell
pip install lmdeploy[all]
```

## Quantization

A single command is all that is needed to perform blocked FP8 quantization. The script will load the model, quantize the linear layers to blocked FP8, and save the resulting model and configuration to the specified working directory.

The command for this is `lmdeploy lite blocked_fp8`.

Here is an example of how to quantize `OpenGVLab/InternVL3_5-8B`:

```shell
export HF_MODEL=OpenGVLab/InternVL3_5-8B
export WORK_DIR=OpenGVLab/InternVL3_5-8B-FP8

lmdeploy lite blocked_fp8 $HF_MODEL \
--work-dir $WORK_DIR \
--quant-dtype fp8 \
--block-size 128
```

Key arguments for the command:

- `--work-dir`: The directory where the quantized model weights and configuration will be saved.
- `--quant-dtype`: The target FP8 format. Can be `float8_e4m3fn` (same as passing 'fp8', recommended) or `float8_e5m2`.
- `--block-size`: The block size for quantization. The default of `128` is generally a good choice.

## Inference

You can perform batched offline inference with the quantized model using both the `turbomind` and `pytorch` backend.

Here is a simple code example:

```python
from lmdeploy import pipeline

pipe = pipeline("OpenGVLab/InternVL3_5-8B-FP8")
response = pipe(["Hi, pls intro yourself", "Shanghai is"])
print(response)
```

## Service

LMDeploy's `api_server` can be used to serve the blocked FP8 model.

```shell
lmdeploy serve api_server OpenGVLab/InternVL3_5-8B-FP8
```

The default port for the `api_server` is `23333`.

You can view the available API endpoints through the Swagger UI at `http://0.0.0.0:23333`. For more details on the API, please refer to the [API Server documentation](../llm/api_server.md).
1 change: 1 addition & 0 deletions docs/zh_cn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ LMDeploy 工具箱提供以下核心功能:
quantization/w4a16.md
quantization/w8a8.md
quantization/kv_quant.md
quantization/blocked_fp8.md

.. _测试基准:
.. toctree::
Expand Down
59 changes: 59 additions & 0 deletions docs/zh_cn/quantization/blocked_fp8.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Blocked FP8 模型量化

LMDeploy 支持一种仅权重的 (weight-only) Blocked FP8 量化方法。该方法将模型的权重以分块(blocked)的形式量化为 8-bit 浮点数,可以在支持的硬件上保持良好性能的同时,有效降低模型的显存占用。

在进行量化和推理之前,请确保按照[安装指南](../get_started/installation.md)安装了 lmdeploy。

```shell
pip install lmdeploy[all]
```

## 模型量化

仅需执行一条命令,就可以完成模型量化工作。该脚本会加载模型,将线性层量化为 Blocked FP8 格式,并将最终的模型和配置文件保存在指定的工作目录中。

使用的命令是 `lmdeploy lite blocked_fp8`。

以下是如何量化 `OpenGVLab/InternVL3_5-8B` 的示例:

```shell
export HF_MODEL=OpenGVLab/InternVL3_5-8B
export WORK_DIR=OpenGVLab/InternVL3_5-8B-FP8

lmdeploy lite blocked_fp8 $HF_MODEL \
--work-dir $WORK_DIR \
--quant-dtype fp8 \
--block-size 128
```

命令行的主要参数说明:

- `--work-dir`: 用于保存量化后的模型权重和配置的工作目录。
- `--quant-dtype`: 目标 FP8 格式。可以是 `float8_e4m3fn` (与传入 'fp8' 效果相同,推荐) 或 `float8_e5m2`。
- `--block-size`: 量化的块大小。默认值 `128` 通常是一个不错的选择。

## 模型推理

您可以使用 `turbomind` 和 `pytorch` 后端对量化后的模型进行批量离线推理。

这是一个简单的代码示例:

```python
from lmdeploy import pipeline

pipe = pipeline("OpenGVLab/InternVL3_5-8B-FP8")
response = pipe(["Hi, pls intro yourself", "Shanghai is"])
print(response)
```

## 推理服务

LMDeploy 的 `api_server` 可用于服务化部署 Blocked FP8 模型。

```shell
lmdeploy serve api_server OpenGVLab/InternVL3_5-8B-FP8
```

服务的默认端口是 `23333`。

您可以通过 Swagger UI `http://0.0.0.0:23333` 在线阅读和试用 `api_server` 的各接口,也可直接查阅[文档](../llm/api_server.md),了解各接口的定义和使用方法。
30 changes: 30 additions & 0 deletions lmdeploy/cli/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,28 @@ def add_parser_smooth_quant():
ArgumentHelper.revision(parser)
ArgumentHelper.download_dir(parser)

@staticmethod
def add_parser_blocked_fp8():
"""Add parser for blocked_fp8 command."""
parser = SubCliLite.subparsers.add_parser('blocked_fp8',
formatter_class=DefaultsAndTypesHelpFormatter,
description=SubCliLite.blocked_fp8.__doc__,
help=SubCliLite.blocked_fp8.__doc__)
parser.set_defaults(run=SubCliLite.blocked_fp8)
parser.add_argument('model', type=str, help='The name or path of the model to be loaded')
parser.add_argument('--work-dir',
type=str,
default='./work_dir',
help='The working directory for outputs. defaults to "./work_dir"')
parser.add_argument('--quant-dtype',
type=str,
default='float8_e4m3fn',
choices=['fp8', 'float8_e4m3fn', 'float8_e5m2'],
help='The quantization data type for weight')
parser.add_argument('--block-size', type=int, default=128, help='Block size for blocked-fp8 quantization')
ArgumentHelper.revision(parser)
ArgumentHelper.download_dir(parser)

@staticmethod
def auto_awq(args):
"""Perform weight quantization using AWQ algorithm."""
Expand All @@ -117,6 +139,13 @@ def auto_gptq(args):
kwargs = convert_args(args)
auto_gptq(**kwargs)

@staticmethod
def blocked_fp8(args):
"""Perform weight quantization to blocked fp8 format."""
from lmdeploy.lite.apis.blocked_fp8 import blocked_fp8
kwargs = convert_args(args)
blocked_fp8(**kwargs)

@staticmethod
def calibrate(args):
"""Perform calibration on a given dataset."""
Expand All @@ -138,3 +167,4 @@ def add_parsers():
SubCliLite.add_parser_auto_gptq()
SubCliLite.add_parser_calibrate()
SubCliLite.add_parser_smooth_quant()
SubCliLite.add_parser_blocked_fp8()
105 changes: 105 additions & 0 deletions lmdeploy/lite/apis/blocked_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) OpenMMLab. All rights reserved.

import os
import os.path as osp
from typing import Literal

import fire
import torch
from torch import nn
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from lmdeploy.lite.quantization.weight.quant_utils import quant_blocked_fp8
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.pytorch.models import QLinear


def blocked_fp8(model: str,
work_dir: str = './work_dir',
quant_dtype: Literal['fp8', 'float8_e4m3fn', 'float8_e5m2'] = 'float8_e4m3fn',
block_size: int = 128,
revision: str = None,
download_dir: str = None):
if quant_dtype == 'fp8':
quant_dtype = 'float8_e4m3fn'

q_dtype = getattr(torch, quant_dtype, None)
assert q_dtype is not None

if not osp.exists(model):
print(f'can\'t find model from local_path {model}, '
'try to download from remote')
from lmdeploy.utils import get_model
model_path = get_model(model, revision=revision, download_dir=download_dir)
else:
model_path = model

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, dtype=torch.bfloat16)
model = model.eval().cuda()

# collect all linear layers
fcs = collect_target_modules(model, nn.Linear)
skip_patterns = [
'lm_head',
'embed_tokens',
'mlp.gate', # sparse MOE router gate
'vision_model', # non-HF InternVL, vision part
'mlp1', # non-HF InternVL, projector
'mlp2', # non-HF InternVL-Flash, projector
'vision_tower', # HF InternVL, vision part
'multi_modal_projector', # HF InternVL, projector
]
modules_to_not_convert = []
Comment on lines +44 to +54
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These configurations are model-specific. We should adopt a more maintainable approach.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the vLLM FP8 compressor example, and noticed that the ignored patterns are indeed model-specific. Currently, these patterns are passed as an input argument named ignore in the quantization recipe.

https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_w8a8_fp8

https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_w8a8_fp8/qwen2vl_example.py#L20

How about we also expose this as a configurable input argument, allowing users to define their own ignore patterns as needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RunningLeon As discussed with @CUHKSZzxy, we propose a new --skip-pattern config.py option for custom skip patterns, alongside lmdeploy's internal defaults.
what's your opinion?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, if only passing skip patterns, a config file is not necessary.


# quantize and replace linear layers
for name, linear in tqdm(fcs.items(), desc='Quantizing'):
# skip not to convert modules
if any([x in name for x in skip_patterns]):
modules_to_not_convert.append(name)
continue

linear.to('cuda')
# quantize weight
q_weight, scales = quant_blocked_fp8(weight=linear.weight, fp8_dtype=q_dtype, block_size=block_size)

# create and replace with QLinear
q_linear = QLinear.from_float(linear, quant_dtype=q_dtype, initialization=False)
q_linear.weight.data = q_weight
q_linear.weight_scale_inv.data = scales
if linear.bias is not None:
q_linear.bias.data = linear.bias.detach()
parent_name, _, child_name = name.rpartition('.')
parent = model.get_submodule(parent_name)
setattr(parent, child_name, q_linear)

# move original layer to CPU to free GPU memory
linear.to('cpu')
torch.cuda.empty_cache()

model.to('cpu')

# update model config
if quant_dtype == 'float8_e4m3fn':
fmt = 'e4m3'
elif quant_dtype == 'float8_e5m2':
fmt = 'e5m2'
quant_config = dict(activation_scheme='dynamic',
modules_to_not_convert=modules_to_not_convert,
fmt=fmt,
quant_method='fp8',
weight_block_size=[block_size, block_size])
model.config.update(dict(quantization_config=quant_config))

# save model and tokenizer
if not osp.exists(work_dir):
os.makedirs(work_dir)
print('Saving the quantized model ...')
model.save_pretrained(work_dir, safe_serialization=True)
tokenizer.save_pretrained(work_dir)
print(f'Blocked FP8 model successfully saved to {work_dir}')


if __name__ == '__main__':
fire.Fire(blocked_fp8)
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/configurations/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ def condition(cls, hf_config):

@classmethod
def build(cls, hf_config, model_path: str = None, **kwargs):
"""Build llava hf."""
"""Build internvl hf."""
# hack quantization_config
if hasattr(hf_config, 'quantization_config') and not hasattr(hf_config.llm_config, 'quantization_config'):
setattr(hf_config.llm_config, 'quantization_config', hf_config.quantization_config)
cfg = DefaultModelConfigBuilder.build(hf_config.llm_config, model_path, **kwargs)
cfg.hf_config = hf_config
return cfg
14 changes: 7 additions & 7 deletions lmdeploy/pytorch/models/q_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class QTensor:
This class wraps around a regular Pytorch tensor and adds quantization- specific parameters.
"""
tensor: torch.Tensor
scale: torch.Tensor
weight_scale_inv: torch.Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing scale to weight_scale_inv might affect w8a8 quantized model inference.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RunningLeon @grimoire any good ideas?

zero_point: torch.Tensor = None

def __post_init__(self):
Expand Down Expand Up @@ -58,7 +58,7 @@ def forward(self, hidden_states):
"""Defines the computation performed at every call.

Performs RMS normalization followed by dynamic quantization on hidden_states. Returns a QTensor which wraps the
quantized tensor along with its scale factor.
quantized tensor along with its weight_scale_inv factor.
"""
hidden_states_quant, rms_scale = rms_norm_dynamic_quant(hidden_states,
self.weight,
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(self,
self.out_features = out_features
self.quant_dtype = quant_dtype
self.register_buffer('weight', torch.empty((out_features, in_features), device=device, dtype=quant_dtype))
self.register_buffer('scale', torch.empty((out_features, 1), device=device, dtype=torch.float32))
self.register_buffer('weight_scale_inv', torch.empty((out_features, 1), device=device, dtype=torch.float32))
if bias:
self.register_buffer('bias', torch.empty(out_features, **factory_kwargs))
else:
Expand All @@ -112,9 +112,9 @@ def from_float(cls, mod: nn.Module, initialization: bool = True, quant_dtype=tor
quant_dtype=quant_dtype)

if initialization:
weight_quant, scale = per_channel_quant(mod.weight.detach(), quant_dtype)
weight_quant, weight_scale_inv = per_channel_quant(mod.weight.detach(), quant_dtype)
q_mod.weight.data = weight_quant
q_mod.scale = scale
q_mod.weight_scale_inv = weight_scale_inv

if mod.bias is not None:
q_mod.bias.data = mod.bias.detach()
Expand All @@ -132,12 +132,12 @@ def forward(self, input):
input_quant, input_scale = per_token_quant_int8(input, 1e-7, quant_dtype=self.quant_dtype)
else:
assert isinstance(input, QTensor)
input_quant, input_scale = input.tensor, input.scale
input_quant, input_scale = input.tensor, input.weight_scale_inv

out = matmul_kernel_dynamic_quant(input_quant,
self.weight,
input_scale,
self.scale,
self.weight_scale_inv,
output_dtype=torch.float16,
bias=self.bias)
return out
Expand Down