-
Notifications
You must be signed in to change notification settings - Fork 609
quant blocked fp8 #4018
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
CUHKSZzxy
wants to merge
6
commits into
InternLM:main
Choose a base branch
from
CUHKSZzxy:quant-fp8
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
quant blocked fp8 #4018
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
b72cc9d
quant blocked fp8
CUHKSZzxy d4e151b
align config formats, optimize
CUHKSZzxy 444cf9a
add docs
CUHKSZzxy 28c3f56
fix moe quant, fix for hf models
CUHKSZzxy c327c01
fix for internvl flash
CUHKSZzxy 291a450
update docs
CUHKSZzxy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Blocked FP8 Quantization | ||
|
||
LMDeploy supports a weight-only blocked FP8 quantization method. This approach quantizes the weights of a model to 8-bit floating-point numbers in a blocked format, which can reduce the model's memory footprint while maintaining good performance on supported hardware. | ||
|
||
Before proceeding, please ensure that lmdeploy is installed by following the [installation guide](../get_started/installation.md). A typical installation command is: | ||
|
||
```shell | ||
pip install lmdeploy[all] | ||
``` | ||
|
||
## Quantization | ||
|
||
A single command is all that is needed to perform blocked FP8 quantization. The script will load the model, quantize the linear layers to blocked FP8, and save the resulting model and configuration to the specified working directory. | ||
|
||
The command for this is `lmdeploy lite blocked_fp8`. | ||
|
||
Here is an example of how to quantize `OpenGVLab/InternVL3_5-8B`: | ||
|
||
```shell | ||
export HF_MODEL=OpenGVLab/InternVL3_5-8B | ||
export WORK_DIR=OpenGVLab/InternVL3_5-8B-FP8 | ||
|
||
lmdeploy lite blocked_fp8 $HF_MODEL \ | ||
--work-dir $WORK_DIR \ | ||
--quant-dtype fp8 \ | ||
--block-size 128 | ||
``` | ||
|
||
Key arguments for the command: | ||
|
||
- `--work-dir`: The directory where the quantized model weights and configuration will be saved. | ||
- `--quant-dtype`: The target FP8 format. Can be `float8_e4m3fn` (same as passing 'fp8', recommended) or `float8_e5m2`. | ||
- `--block-size`: The block size for quantization. The default of `128` is generally a good choice. | ||
|
||
## Inference | ||
|
||
You can perform batched offline inference with the quantized model using both the `turbomind` and `pytorch` backend. | ||
|
||
Here is a simple code example: | ||
|
||
```python | ||
from lmdeploy import pipeline | ||
|
||
pipe = pipeline("OpenGVLab/InternVL3_5-8B-FP8") | ||
response = pipe(["Hi, pls intro yourself", "Shanghai is"]) | ||
print(response) | ||
``` | ||
|
||
## Service | ||
|
||
LMDeploy's `api_server` can be used to serve the blocked FP8 model. | ||
|
||
```shell | ||
lmdeploy serve api_server OpenGVLab/InternVL3_5-8B-FP8 | ||
``` | ||
|
||
The default port for the `api_server` is `23333`. | ||
|
||
You can view the available API endpoints through the Swagger UI at `http://0.0.0.0:23333`. For more details on the API, please refer to the [API Server documentation](../llm/api_server.md). |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Blocked FP8 模型量化 | ||
|
||
LMDeploy 支持一种仅权重的 (weight-only) Blocked FP8 量化方法。该方法将模型的权重以分块(blocked)的形式量化为 8-bit 浮点数,可以在支持的硬件上保持良好性能的同时,有效降低模型的显存占用。 | ||
|
||
在进行量化和推理之前,请确保按照[安装指南](../get_started/installation.md)安装了 lmdeploy。 | ||
|
||
```shell | ||
pip install lmdeploy[all] | ||
``` | ||
|
||
## 模型量化 | ||
|
||
仅需执行一条命令,就可以完成模型量化工作。该脚本会加载模型,将线性层量化为 Blocked FP8 格式,并将最终的模型和配置文件保存在指定的工作目录中。 | ||
|
||
使用的命令是 `lmdeploy lite blocked_fp8`。 | ||
|
||
以下是如何量化 `OpenGVLab/InternVL3_5-8B` 的示例: | ||
|
||
```shell | ||
export HF_MODEL=OpenGVLab/InternVL3_5-8B | ||
export WORK_DIR=OpenGVLab/InternVL3_5-8B-FP8 | ||
|
||
lmdeploy lite blocked_fp8 $HF_MODEL \ | ||
--work-dir $WORK_DIR \ | ||
--quant-dtype fp8 \ | ||
--block-size 128 | ||
``` | ||
|
||
命令行的主要参数说明: | ||
|
||
- `--work-dir`: 用于保存量化后的模型权重和配置的工作目录。 | ||
- `--quant-dtype`: 目标 FP8 格式。可以是 `float8_e4m3fn` (与传入 'fp8' 效果相同,推荐) 或 `float8_e5m2`。 | ||
- `--block-size`: 量化的块大小。默认值 `128` 通常是一个不错的选择。 | ||
|
||
## 模型推理 | ||
|
||
您可以使用 `turbomind` 和 `pytorch` 后端对量化后的模型进行批量离线推理。 | ||
|
||
这是一个简单的代码示例: | ||
|
||
```python | ||
from lmdeploy import pipeline | ||
|
||
pipe = pipeline("OpenGVLab/InternVL3_5-8B-FP8") | ||
response = pipe(["Hi, pls intro yourself", "Shanghai is"]) | ||
print(response) | ||
``` | ||
|
||
## 推理服务 | ||
|
||
LMDeploy 的 `api_server` 可用于服务化部署 Blocked FP8 模型。 | ||
|
||
```shell | ||
lmdeploy serve api_server OpenGVLab/InternVL3_5-8B-FP8 | ||
``` | ||
|
||
服务的默认端口是 `23333`。 | ||
|
||
您可以通过 Swagger UI `http://0.0.0.0:23333` 在线阅读和试用 `api_server` 的各接口,也可直接查阅[文档](../llm/api_server.md),了解各接口的定义和使用方法。 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import os | ||
import os.path as osp | ||
from typing import Literal | ||
|
||
import fire | ||
import torch | ||
from torch import nn | ||
from tqdm import tqdm | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from lmdeploy.lite.quantization.weight.quant_utils import quant_blocked_fp8 | ||
from lmdeploy.lite.utils import collect_target_modules | ||
from lmdeploy.pytorch.models import QLinear | ||
|
||
|
||
def blocked_fp8(model: str, | ||
work_dir: str = './work_dir', | ||
quant_dtype: Literal['fp8', 'float8_e4m3fn', 'float8_e5m2'] = 'float8_e4m3fn', | ||
block_size: int = 128, | ||
revision: str = None, | ||
download_dir: str = None): | ||
if quant_dtype == 'fp8': | ||
quant_dtype = 'float8_e4m3fn' | ||
|
||
q_dtype = getattr(torch, quant_dtype, None) | ||
assert q_dtype is not None | ||
|
||
if not osp.exists(model): | ||
print(f'can\'t find model from local_path {model}, ' | ||
'try to download from remote') | ||
from lmdeploy.utils import get_model | ||
model_path = get_model(model, revision=revision, download_dir=download_dir) | ||
else: | ||
model_path = model | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | ||
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, dtype=torch.bfloat16) | ||
model = model.eval().cuda() | ||
|
||
# collect all linear layers | ||
fcs = collect_target_modules(model, nn.Linear) | ||
skip_patterns = [ | ||
'lm_head', | ||
'embed_tokens', | ||
'mlp.gate', # sparse MOE router gate | ||
'vision_model', # non-HF InternVL, vision part | ||
'mlp1', # non-HF InternVL, projector | ||
'mlp2', # non-HF InternVL-Flash, projector | ||
'vision_tower', # HF InternVL, vision part | ||
'multi_modal_projector', # HF InternVL, projector | ||
] | ||
modules_to_not_convert = [] | ||
|
||
# quantize and replace linear layers | ||
for name, linear in tqdm(fcs.items(), desc='Quantizing'): | ||
# skip not to convert modules | ||
if any([x in name for x in skip_patterns]): | ||
modules_to_not_convert.append(name) | ||
continue | ||
|
||
linear.to('cuda') | ||
# quantize weight | ||
q_weight, scales = quant_blocked_fp8(weight=linear.weight, fp8_dtype=q_dtype, block_size=block_size) | ||
|
||
# create and replace with QLinear | ||
q_linear = QLinear.from_float(linear, quant_dtype=q_dtype, initialization=False) | ||
q_linear.weight.data = q_weight | ||
q_linear.weight_scale_inv.data = scales | ||
if linear.bias is not None: | ||
q_linear.bias.data = linear.bias.detach() | ||
parent_name, _, child_name = name.rpartition('.') | ||
parent = model.get_submodule(parent_name) | ||
setattr(parent, child_name, q_linear) | ||
|
||
# move original layer to CPU to free GPU memory | ||
linear.to('cpu') | ||
torch.cuda.empty_cache() | ||
|
||
model.to('cpu') | ||
|
||
# update model config | ||
if quant_dtype == 'float8_e4m3fn': | ||
fmt = 'e4m3' | ||
elif quant_dtype == 'float8_e5m2': | ||
fmt = 'e5m2' | ||
quant_config = dict(activation_scheme='dynamic', | ||
modules_to_not_convert=modules_to_not_convert, | ||
fmt=fmt, | ||
quant_method='fp8', | ||
weight_block_size=[block_size, block_size]) | ||
model.config.update(dict(quantization_config=quant_config)) | ||
|
||
# save model and tokenizer | ||
if not osp.exists(work_dir): | ||
os.makedirs(work_dir) | ||
print('Saving the quantized model ...') | ||
model.save_pretrained(work_dir, safe_serialization=True) | ||
tokenizer.save_pretrained(work_dir) | ||
print(f'Blocked FP8 model successfully saved to {work_dir}') | ||
|
||
|
||
if __name__ == '__main__': | ||
fire.Fire(blocked_fp8) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ class QTensor: | |
This class wraps around a regular Pytorch tensor and adds quantization- specific parameters. | ||
""" | ||
tensor: torch.Tensor | ||
scale: torch.Tensor | ||
weight_scale_inv: torch.Tensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @RunningLeon @grimoire any good ideas? |
||
zero_point: torch.Tensor = None | ||
|
||
def __post_init__(self): | ||
|
@@ -58,7 +58,7 @@ def forward(self, hidden_states): | |
"""Defines the computation performed at every call. | ||
|
||
Performs RMS normalization followed by dynamic quantization on hidden_states. Returns a QTensor which wraps the | ||
quantized tensor along with its scale factor. | ||
quantized tensor along with its weight_scale_inv factor. | ||
""" | ||
hidden_states_quant, rms_scale = rms_norm_dynamic_quant(hidden_states, | ||
self.weight, | ||
|
@@ -91,7 +91,7 @@ def __init__(self, | |
self.out_features = out_features | ||
self.quant_dtype = quant_dtype | ||
self.register_buffer('weight', torch.empty((out_features, in_features), device=device, dtype=quant_dtype)) | ||
self.register_buffer('scale', torch.empty((out_features, 1), device=device, dtype=torch.float32)) | ||
self.register_buffer('weight_scale_inv', torch.empty((out_features, 1), device=device, dtype=torch.float32)) | ||
if bias: | ||
self.register_buffer('bias', torch.empty(out_features, **factory_kwargs)) | ||
else: | ||
|
@@ -112,9 +112,9 @@ def from_float(cls, mod: nn.Module, initialization: bool = True, quant_dtype=tor | |
quant_dtype=quant_dtype) | ||
|
||
if initialization: | ||
weight_quant, scale = per_channel_quant(mod.weight.detach(), quant_dtype) | ||
weight_quant, weight_scale_inv = per_channel_quant(mod.weight.detach(), quant_dtype) | ||
q_mod.weight.data = weight_quant | ||
q_mod.scale = scale | ||
q_mod.weight_scale_inv = weight_scale_inv | ||
|
||
if mod.bias is not None: | ||
q_mod.bias.data = mod.bias.detach() | ||
|
@@ -132,12 +132,12 @@ def forward(self, input): | |
input_quant, input_scale = per_token_quant_int8(input, 1e-7, quant_dtype=self.quant_dtype) | ||
else: | ||
assert isinstance(input, QTensor) | ||
input_quant, input_scale = input.tensor, input.scale | ||
input_quant, input_scale = input.tensor, input.weight_scale_inv | ||
|
||
out = matmul_kernel_dynamic_quant(input_quant, | ||
self.weight, | ||
input_scale, | ||
self.scale, | ||
self.weight_scale_inv, | ||
output_dtype=torch.float16, | ||
bias=self.bias) | ||
return out | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These configurations are model-specific. We should adopt a more maintainable approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the vLLM FP8 compressor example, and noticed that the ignored patterns are indeed model-specific. Currently, these patterns are passed as an input argument named
ignore
in the quantization recipe.https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_w8a8_fp8
https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_w8a8_fp8/qwen2vl_example.py#L20
How about we also expose this as a configurable input argument, allowing users to define their own ignore patterns as needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RunningLeon As discussed with @CUHKSZzxy, we propose a new
--skip-pattern config.py
option for custom skip patterns, alongside lmdeploy's internal defaults.what's your opinion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally, if only passing skip patterns, a config file is not necessary.