-
Notifications
You must be signed in to change notification settings - Fork 30.1k
Support loading Quark quantized models in Transformers #36372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1f87b7d
c405adb
eb189de
36d18cf
8d233b4
5f24cee
d275c87
f5e1817
70e30fa
05efcb0
ea2b62e
c2e5ba0
9ee20b1
9b0c135
a1b2c8b
93d8480
2be83a1
3f76848
fda836f
d8ca5e5
7da2a57
f6dbb79
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
<!--Copyright 2025 Advanced Micro Devices, Inc. and The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
|
||
--> | ||
|
||
# Quark | ||
|
||
[Quark](https://quark.docs.amd.com/latest/) is a deep learning quantization toolkit designed to be agnostic to specific data types, algorithms, and hardware. Different pre-processing strategies, algorithms and data-types can be combined in Quark. | ||
|
||
The PyTorch support integrated through 🤗 Transformers primarily targets AMD CPUs and GPUs, and is primarily meant to be used for evaluation purposes. For example, it is possible to use [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) with 🤗 Transformers backend and evaluate a wide range of models quantized through Quark seamlessly. | ||
|
||
Users interested in Quark can refer to its [documentation](https://quark.docs.amd.com/latest/) to get started quantizing models and using them in supported open-source libraries! | ||
|
||
Although Quark has its own checkpoint / [configuration format](https://huggingface.co/amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test/blob/main/config.json#L26), the library also supports producing models with a serialization layout compliant with other quantization/runtime implementations ([AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq), [native fp8 in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8)). | ||
|
||
To be able to load Quark quantized models in Transformers, the library first needs to be installed: | ||
|
||
```bash | ||
pip install amd-quark | ||
``` | ||
|
||
## Support matrix | ||
|
||
Models quantized through Quark support a large range of features, that can be combined together. All quantized models independently of their configuration can seamlessly be reloaded through `PretrainedModel.from_pretrained`. | ||
|
||
The table below shows a few features supported by Quark: | ||
|
||
| **Feature** | **Supported subset in Quark** | | | ||
|---------------------------------|-----------------------------------------------------------------------------------------------------------|---| | ||
| Data types | int8, int4, int2, bfloat16, float16, fp8_e5m2, fp8_e4m3, fp6_e3m2, fp6_e2m3, fp4, OCP MX, MX6, MX9, bfp16 | | | ||
| Pre-quantization transformation | SmoothQuant, QuaRot, SpinQuant, AWQ | | | ||
| Quantization algorithm | GPTQ | | | ||
| Supported operators | ``nn.Linear``, ``nn.Conv2d``, ``nn.ConvTranspose2d``, ``nn.Embedding``, ``nn.EmbeddingBag`` | | | ||
| Granularity | per-tensor, per-channel, per-block, per-layer, per-layer type | | | ||
| KV cache | fp8 | | | ||
| Activation calibration | MinMax / Percentile / MSE | | | ||
| Quantization strategy | weight-only, static, dynamic, with or without output quantization | | | ||
|
||
## Models on Hugging Face Hub | ||
|
||
Public models using Quark native serialization can be found at https://huggingface.co/models?other=quark. | ||
|
||
Although Quark also supports [models using `quant_method="fp8"`](https://huggingface.co/models?other=fp8) and [models using `quant_method="awq"`](https://huggingface.co/models?other=awq), Transformers loads these models rather through [AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq) or uses the [native fp8 support in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8). | ||
|
||
## Using Quark models in Transformers | ||
|
||
Here is an example of how one can load a Quark model in Transformers: | ||
|
||
```python | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
model_id = "EmbeddedLLM/Llama-3.1-8B-Instruct-w_fp8_per_channel_sym" | ||
model = AutoModelForCausalLM.from_pretrained(model_id) | ||
model = model.to("cuda") | ||
|
||
print(model.model.layers[0].self_attn.q_proj) | ||
# QParamsLinear( | ||
# (weight_quantizer): ScaledRealQuantizer() | ||
# (input_quantizer): ScaledRealQuantizer() | ||
# (output_quantizer): ScaledRealQuantizer() | ||
# ) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
inp = tokenizer("Where is a good place to cycle around Tokyo?", return_tensors="pt") | ||
inp = inp.to("cuda") | ||
|
||
res = model.generate(**inp, min_new_tokens=50, max_new_tokens=100) | ||
|
||
print(tokenizer.batch_decode(res)[0]) | ||
# <|begin_of_text|>Where is a good place to cycle around Tokyo? There are several places in Tokyo that are suitable for cycling, depending on your skill level and interests. Here are a few suggestions: | ||
# 1. Yoyogi Park: This park is a popular spot for cycling and has a wide, flat path that's perfect for beginners. You can also visit the Meiji Shrine, a famous Shinto shrine located in the park. | ||
# 2. Imperial Palace East Garden: This beautiful garden has a large, flat path that's perfect for cycling. You can also visit the | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -536,6 +536,10 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): | |
str_to_torch_dtype["U32"] = torch.uint32 | ||
str_to_torch_dtype["U64"] = torch.uint64 | ||
|
||
if is_torch_greater_or_equal("2.1.0"): | ||
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn | ||
str_to_torch_dtype["F8_E5M2"] = torch.float8_e5m2 | ||
Comment on lines
+539
to
+541
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SunMarc @MekkCyber I also had to add this in fda836f following recent changes to modeling_utils.py, in order for the example in the documentation to work. This corresponds to https://github.com/huggingface/safetensors/blob/53fe06c3efd40ff62520f74818819590b2bc25de/bindings/python/py_src/safetensors/torch.py#L385-L386 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't rocm only support There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, only However, we are able to load models quantized in |
||
|
||
|
||
def load_state_dict( | ||
checkpoint_file: Union[str, os.PathLike], | ||
|
@@ -3672,6 +3676,10 @@ def to(self, *args, **kwargs): | |
|
||
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: | ||
raise ValueError("`.to` is not supported for HQQ-quantized models.") | ||
|
||
if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK: | ||
raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.") | ||
|
||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB | ||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: | ||
if dtype_present_in_args: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# coding=utf-8 | ||
# Copyright 2025 Advanced Micro Devices, Inc. and The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import TYPE_CHECKING, Any, Dict | ||
|
||
from ..file_utils import is_torch_available | ||
from .base import HfQuantizer | ||
|
||
|
||
if TYPE_CHECKING: | ||
from ..modeling_utils import PreTrainedModel | ||
|
||
if is_torch_available(): | ||
import torch | ||
|
||
from ..utils import is_accelerate_available, is_quark_available, logging | ||
|
||
|
||
if is_accelerate_available(): | ||
from accelerate.utils import set_module_tensor_to_device | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
CHECKPOINT_KEYS = { | ||
"weight_scale": "weight_quantizer.scale", | ||
"bias_scale": "bias_quantizer.scale", | ||
"input_scale": "input_quantizer.scale", | ||
"output_scale": "output_quantizer.scale", | ||
"weight_zero_point": "weight_quantizer.zero_point", | ||
"bias_zero_point": "bias_quantizer.zero_point", | ||
"input_zero_point": "input_quantizer.zero_point", | ||
"output_zero_point": "output_quantizer.zero_point", | ||
} | ||
|
||
|
||
class QuarkHfQuantizer(HfQuantizer): | ||
""" | ||
Quark quantizer (https://quark.docs.amd.com/latest/). | ||
""" | ||
|
||
requires_calibration = True # On-the-fly quantization with quark is not supported for now. | ||
required_packages = ["quark"] | ||
|
||
# Checkpoints are expected to be already quantized when loading a quark model. However, as some keys from | ||
# the checkpoint might mismatch the model parameters keys, we use the `create_quantized_param` method | ||
# to load the checkpoints, remapping the keys. | ||
requires_parameters_quantization = True | ||
|
||
def __init__(self, quantization_config, **kwargs): | ||
super().__init__(quantization_config, **kwargs) | ||
|
||
self.json_export_config = quantization_config.json_export_config | ||
|
||
def validate_environment(self, *args, **kwargs): | ||
if not is_quark_available(): | ||
raise ImportError( | ||
"Loading a Quark quantized model requires the `quark` library but it was not found in the environment. Please refer to https://quark.docs.amd.com/latest/install.html." | ||
) | ||
MekkCyber marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): | ||
from quark.torch.export.api import _map_to_quark | ||
|
||
_map_to_quark( | ||
model, | ||
self.quantization_config.quant_config, | ||
pack_method=self.json_export_config.pack_method, | ||
custom_mode=self.quantization_config.custom_mode, | ||
) | ||
|
||
return model | ||
|
||
def check_quantized_param( | ||
self, | ||
model: "PreTrainedModel", | ||
param_value: "torch.Tensor", | ||
param_name: str, | ||
state_dict: Dict[str, Any], | ||
**kwargs, | ||
) -> bool: | ||
return True | ||
|
||
def create_quantized_param( | ||
self, model, param, param_name, param_device, state_dict, unexpected_keys | ||
) -> "torch.nn.Parameter": | ||
postfix = param_name.split(".")[-1] | ||
|
||
if postfix in CHECKPOINT_KEYS: | ||
param_name = param_name.replace(postfix, CHECKPOINT_KEYS[postfix]) | ||
|
||
set_module_tensor_to_device(model, param_name, param_device, value=param) | ||
|
||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): | ||
return model | ||
|
||
def is_serializable(self, safe_serialization=None): | ||
return False | ||
|
||
MekkCyber marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@property | ||
def is_trainable(self): | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can do something there so that we are able to run these checkpoints in quark. Will it work OTB if we modify the
config.quantization_config
and pass the new config to the model in from_pretrained ?Or we could add a function / context manager that modify
AUTO_QUANTIZATION_CONFIG_MAPPING
andAUTO_QUANTIZER_MAPPING