Skip to content

Commit 1a37479

Browse files
fxmarty-amdSunMarcBowenBaoMekkCyber
authored
Support loading Quark quantized models in Transformers (#36372)
* add quark quantizer * add quark doc * clean up doc * fix tests * make style * more style fixes * cleanup imports * cleaning * precise install * Update docs/source/en/quantization/quark.md Co-authored-by: Marc Sun <[email protected]> * Update tests/quantization/quark_integration/test_quark.py Co-authored-by: Marc Sun <[email protected]> * Update src/transformers/utils/quantization_config.py Co-authored-by: Marc Sun <[email protected]> * remove import guard as suggested * update copyright headers * add quark to transformers-quantization-latest-gpu Dockerfile * make tests pass on transformers main + quark==0.7 * add missing F8_E4M3 and F8_E5M2 keys from str_to_torch_dtype --------- Co-authored-by: Marc Sun <[email protected]> Co-authored-by: Bowen Bao <[email protected]> Co-authored-by: Mohamed Mekkouri <[email protected]>
1 parent ce091b1 commit 1a37479

File tree

15 files changed

+432
-1
lines changed

15 files changed

+432
-1
lines changed

docker/transformers-quantization-latest-gpu/Dockerfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ RUN git clone https://github.com/NetEase-FuXi/EETQ.git && cd EETQ/ && git submod
7979
# Add compressed-tensors for quantization testing
8080
RUN python3 -m pip install --no-cache-dir compressed-tensors
8181

82+
# Add AMD Quark for quantization testing
83+
RUN python3 -m pip install --no-cache-dir amd-quark
84+
8285
# Add transformers in editable mode
8386
RUN python3 -m pip install --no-cache-dir -e ./transformers[dev-torch]
8487

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@
187187
title: Optimum
188188
- local: quantization/quanto
189189
title: Quanto
190+
- local: quantization/quark
191+
title: Quark
190192
- local: quantization/torchao
191193
title: torchao
192194
- local: quantization/spqr

docs/source/en/main_classes/quantization.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
8888
## FineGrainedFP8Config
8989

9090
[[autodoc]] FineGrainedFP8Config
91+
92+
## QuarkConfig
93+
94+
[[autodoc]] QuarkConfig

docs/source/en/quantization/overview.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Use the Space below to help you pick a quantization method depending on your har
4040
| [VPTQ](./vptq) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
4141
| [FINEGRAINED_FP8](./finegrained_fp8) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
4242
| [SpQR](./spqr) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
43+
| [Quark](./quark.md) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | ? | 2/4/6/8/9/16 | 🔴 | 🔴 | 🟢 | https://quark.docs.amd.com/latest/ |
4344

4445
## Resources
4546

@@ -55,4 +56,4 @@ If you are looking for a user-friendly quantization experience, you can use the
5556
* [Bitsandbytes Space](https://huggingface.co/spaces/bnb-community/bnb-my-repo)
5657
* [GGUF Space](https://huggingface.co/spaces/ggml-org/gguf-my-repo)
5758
* [MLX Space](https://huggingface.co/spaces/mlx-community/mlx-my-repo)
58-
* [AuoQuant Notebook](https://colab.research.google.com/drive/1b6nqC7UZVt8bx4MksX7s656GXPM-eWw4?usp=sharing#scrollTo=ZC9Nsr9u5WhN)
59+
* [AuoQuant Notebook](https://colab.research.google.com/drive/1b6nqC7UZVt8bx4MksX7s656GXPM-eWw4?usp=sharing#scrollTo=ZC9Nsr9u5WhN)

docs/source/en/quantization/quark.md

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
<!--Copyright 2025 Advanced Micro Devices, Inc. and The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Quark
18+
19+
[Quark](https://quark.docs.amd.com/latest/) is a deep learning quantization toolkit designed to be agnostic to specific data types, algorithms, and hardware. Different pre-processing strategies, algorithms and data-types can be combined in Quark.
20+
21+
The PyTorch support integrated through 🤗 Transformers primarily targets AMD CPUs and GPUs, and is primarily meant to be used for evaluation purposes. For example, it is possible to use [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) with 🤗 Transformers backend and evaluate a wide range of models quantized through Quark seamlessly.
22+
23+
Users interested in Quark can refer to its [documentation](https://quark.docs.amd.com/latest/) to get started quantizing models and using them in supported open-source libraries!
24+
25+
Although Quark has its own checkpoint / [configuration format](https://huggingface.co/amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test/blob/main/config.json#L26), the library also supports producing models with a serialization layout compliant with other quantization/runtime implementations ([AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq), [native fp8 in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8)).
26+
27+
To be able to load Quark quantized models in Transformers, the library first needs to be installed:
28+
29+
```bash
30+
pip install amd-quark
31+
```
32+
33+
## Support matrix
34+
35+
Models quantized through Quark support a large range of features, that can be combined together. All quantized models independently of their configuration can seamlessly be reloaded through `PretrainedModel.from_pretrained`.
36+
37+
The table below shows a few features supported by Quark:
38+
39+
| **Feature** | **Supported subset in Quark** | |
40+
|---------------------------------|-----------------------------------------------------------------------------------------------------------|---|
41+
| Data types | int8, int4, int2, bfloat16, float16, fp8_e5m2, fp8_e4m3, fp6_e3m2, fp6_e2m3, fp4, OCP MX, MX6, MX9, bfp16 | |
42+
| Pre-quantization transformation | SmoothQuant, QuaRot, SpinQuant, AWQ | |
43+
| Quantization algorithm | GPTQ | |
44+
| Supported operators | ``nn.Linear``, ``nn.Conv2d``, ``nn.ConvTranspose2d``, ``nn.Embedding``, ``nn.EmbeddingBag`` | |
45+
| Granularity | per-tensor, per-channel, per-block, per-layer, per-layer type | |
46+
| KV cache | fp8 | |
47+
| Activation calibration | MinMax / Percentile / MSE | |
48+
| Quantization strategy | weight-only, static, dynamic, with or without output quantization | |
49+
50+
## Models on Hugging Face Hub
51+
52+
Public models using Quark native serialization can be found at https://huggingface.co/models?other=quark.
53+
54+
Although Quark also supports [models using `quant_method="fp8"`](https://huggingface.co/models?other=fp8) and [models using `quant_method="awq"`](https://huggingface.co/models?other=awq), Transformers loads these models rather through [AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq) or uses the [native fp8 support in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8).
55+
56+
## Using Quark models in Transformers
57+
58+
Here is an example of how one can load a Quark model in Transformers:
59+
60+
```python
61+
from transformers import AutoModelForCausalLM, AutoTokenizer
62+
63+
model_id = "EmbeddedLLM/Llama-3.1-8B-Instruct-w_fp8_per_channel_sym"
64+
model = AutoModelForCausalLM.from_pretrained(model_id)
65+
model = model.to("cuda")
66+
67+
print(model.model.layers[0].self_attn.q_proj)
68+
# QParamsLinear(
69+
# (weight_quantizer): ScaledRealQuantizer()
70+
# (input_quantizer): ScaledRealQuantizer()
71+
# (output_quantizer): ScaledRealQuantizer()
72+
# )
73+
74+
tokenizer = AutoTokenizer.from_pretrained(model_id)
75+
inp = tokenizer("Where is a good place to cycle around Tokyo?", return_tensors="pt")
76+
inp = inp.to("cuda")
77+
78+
res = model.generate(**inp, min_new_tokens=50, max_new_tokens=100)
79+
80+
print(tokenizer.batch_decode(res)[0])
81+
# <|begin_of_text|>Where is a good place to cycle around Tokyo? There are several places in Tokyo that are suitable for cycling, depending on your skill level and interests. Here are a few suggestions:
82+
# 1. Yoyogi Park: This park is a popular spot for cycling and has a wide, flat path that's perfect for beginners. You can also visit the Meiji Shrine, a famous Shinto shrine located in the park.
83+
# 2. Imperial Palace East Garden: This beautiful garden has a large, flat path that's perfect for cycling. You can also visit the
84+
```

src/transformers/__init__.py

100755100644
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,7 @@
10461046
"HiggsConfig",
10471047
"HqqConfig",
10481048
"QuantoConfig",
1049+
"QuarkConfig",
10491050
"SpQRConfig",
10501051
"TorchAoConfig",
10511052
"VptqConfig",
@@ -6287,6 +6288,7 @@
62876288
HiggsConfig,
62886289
HqqConfig,
62896290
QuantoConfig,
6291+
QuarkConfig,
62906292
SpQRConfig,
62916293
TorchAoConfig,
62926294
VptqConfig,

src/transformers/modeling_utils.py

100755100644
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,10 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
536536
str_to_torch_dtype["U32"] = torch.uint32
537537
str_to_torch_dtype["U64"] = torch.uint64
538538

539+
if is_torch_greater_or_equal("2.1.0"):
540+
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
541+
str_to_torch_dtype["F8_E5M2"] = torch.float8_e5m2
542+
539543

540544
def load_state_dict(
541545
checkpoint_file: Union[str, os.PathLike],
@@ -3675,6 +3679,10 @@ def to(self, *args, **kwargs):
36753679

36763680
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
36773681
raise ValueError("`.to` is not supported for HQQ-quantized models.")
3682+
3683+
if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
3684+
raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")
3685+
36783686
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
36793687
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
36803688
if dtype_present_in_args:

src/transformers/quantizers/auto.py

100755100644
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2+
# Modifications Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
45
# you may not use this file except in compliance with the License.
@@ -31,6 +32,7 @@
3132
QuantizationConfigMixin,
3233
QuantizationMethod,
3334
QuantoConfig,
35+
QuarkConfig,
3436
SpQRConfig,
3537
TorchAoConfig,
3638
VptqConfig,
@@ -49,6 +51,7 @@
4951
from .quantizer_higgs import HiggsHfQuantizer
5052
from .quantizer_hqq import HqqHfQuantizer
5153
from .quantizer_quanto import QuantoHfQuantizer
54+
from .quantizer_quark import QuarkHfQuantizer
5255
from .quantizer_spqr import SpQRHfQuantizer
5356
from .quantizer_torchao import TorchAoHfQuantizer
5457
from .quantizer_vptq import VptqHfQuantizer
@@ -61,6 +64,7 @@
6164
"gptq": GptqHfQuantizer,
6265
"aqlm": AqlmHfQuantizer,
6366
"quanto": QuantoHfQuantizer,
67+
"quark": QuarkHfQuantizer,
6468
"eetq": EetqHfQuantizer,
6569
"higgs": HiggsHfQuantizer,
6670
"hqq": HqqHfQuantizer,
@@ -81,6 +85,7 @@
8185
"gptq": GPTQConfig,
8286
"aqlm": AqlmConfig,
8387
"quanto": QuantoConfig,
88+
"quark": QuarkConfig,
8489
"hqq": HqqConfig,
8590
"compressed-tensors": CompressedTensorsConfig,
8691
"fbgemm_fp8": FbgemmFp8Config,
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# coding=utf-8
2+
# Copyright 2025 Advanced Micro Devices, Inc. and The HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import TYPE_CHECKING, Any, Dict
17+
18+
from ..file_utils import is_torch_available
19+
from .base import HfQuantizer
20+
21+
22+
if TYPE_CHECKING:
23+
from ..modeling_utils import PreTrainedModel
24+
25+
if is_torch_available():
26+
import torch
27+
28+
from ..utils import is_accelerate_available, is_quark_available, logging
29+
30+
31+
if is_accelerate_available():
32+
from accelerate.utils import set_module_tensor_to_device
33+
34+
logger = logging.get_logger(__name__)
35+
36+
37+
CHECKPOINT_KEYS = {
38+
"weight_scale": "weight_quantizer.scale",
39+
"bias_scale": "bias_quantizer.scale",
40+
"input_scale": "input_quantizer.scale",
41+
"output_scale": "output_quantizer.scale",
42+
"weight_zero_point": "weight_quantizer.zero_point",
43+
"bias_zero_point": "bias_quantizer.zero_point",
44+
"input_zero_point": "input_quantizer.zero_point",
45+
"output_zero_point": "output_quantizer.zero_point",
46+
}
47+
48+
49+
class QuarkHfQuantizer(HfQuantizer):
50+
"""
51+
Quark quantizer (https://quark.docs.amd.com/latest/).
52+
"""
53+
54+
requires_calibration = True # On-the-fly quantization with quark is not supported for now.
55+
required_packages = ["quark"]
56+
57+
# Checkpoints are expected to be already quantized when loading a quark model. However, as some keys from
58+
# the checkpoint might mismatch the model parameters keys, we use the `create_quantized_param` method
59+
# to load the checkpoints, remapping the keys.
60+
requires_parameters_quantization = True
61+
62+
def __init__(self, quantization_config, **kwargs):
63+
super().__init__(quantization_config, **kwargs)
64+
65+
self.json_export_config = quantization_config.json_export_config
66+
67+
def validate_environment(self, *args, **kwargs):
68+
if not is_quark_available():
69+
raise ImportError(
70+
"Loading a Quark quantized model requires the `quark` library but it was not found in the environment. Please refer to https://quark.docs.amd.com/latest/install.html."
71+
)
72+
73+
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
74+
from quark.torch.export.api import _map_to_quark
75+
76+
_map_to_quark(
77+
model,
78+
self.quantization_config.quant_config,
79+
pack_method=self.json_export_config.pack_method,
80+
custom_mode=self.quantization_config.custom_mode,
81+
)
82+
83+
return model
84+
85+
def check_quantized_param(
86+
self,
87+
model: "PreTrainedModel",
88+
param_value: "torch.Tensor",
89+
param_name: str,
90+
state_dict: Dict[str, Any],
91+
**kwargs,
92+
) -> bool:
93+
return True
94+
95+
def create_quantized_param(
96+
self, model, param, param_name, param_device, state_dict, unexpected_keys
97+
) -> "torch.nn.Parameter":
98+
postfix = param_name.split(".")[-1]
99+
100+
if postfix in CHECKPOINT_KEYS:
101+
param_name = param_name.replace(postfix, CHECKPOINT_KEYS[postfix])
102+
103+
set_module_tensor_to_device(model, param_name, param_device, value=param)
104+
105+
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
106+
return model
107+
108+
def is_serializable(self, safe_serialization=None):
109+
return False
110+
111+
@property
112+
def is_trainable(self):
113+
return False

src/transformers/testing_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
is_pytesseract_available,
117117
is_pytest_available,
118118
is_pytorch_quantization_available,
119+
is_quark_available,
119120
is_rjieba_available,
120121
is_sacremoses_available,
121122
is_safetensors_available,
@@ -1299,6 +1300,13 @@ def require_fbgemm_gpu(test_case):
12991300
return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case)
13001301

13011302

1303+
def require_quark(test_case):
1304+
"""
1305+
Decorator for quark dependency
1306+
"""
1307+
return unittest.skipUnless(is_quark_available(), "test requires quark")(test_case)
1308+
1309+
13021310
def require_flute_hadamard(test_case):
13031311
"""
13041312
Decorator marking a test that requires higgs and hadamard

0 commit comments

Comments
 (0)