From 1f87b7ddd3405d2aafb989da1b1227a73b973037 Mon Sep 17 00:00:00 2001 From: Felix Marty Date: Fri, 21 Feb 2025 14:33:09 +0100 Subject: [PATCH 01/17] add quark quantizer --- src/transformers/__init__.py | 2 + src/transformers/modeling_utils.py | 1 + src/transformers/quantizers/auto.py | 7 +- .../quantizers/quantizer_quark.py | 103 ++++++++++++ src/transformers/testing_utils.py | 6 + src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 5 + src/transformers/utils/quantization_config.py | 30 ++++ tests/quantization/quark/__init__.py | 0 tests/quantization/quark/test_quark.py | 158 ++++++++++++++++++ 10 files changed, 312 insertions(+), 1 deletion(-) mode change 100755 => 100644 src/transformers/__init__.py mode change 100755 => 100644 src/transformers/modeling_utils.py mode change 100755 => 100644 src/transformers/quantizers/auto.py create mode 100644 src/transformers/quantizers/quantizer_quark.py mode change 100755 => 100644 src/transformers/utils/__init__.py mode change 100755 => 100644 src/transformers/utils/import_utils.py mode change 100755 => 100644 src/transformers/utils/quantization_config.py create mode 100644 tests/quantization/quark/__init__.py create mode 100644 tests/quantization/quark/test_quark.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py old mode 100755 new mode 100644 index ed2682901008..a8f5f475f625 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1037,6 +1037,7 @@ "HiggsConfig", "HqqConfig", "QuantoConfig", + "QuarkConfig", "SpQRConfig", "TorchAoConfig", "VptqConfig", @@ -6248,6 +6249,7 @@ HiggsConfig, HqqConfig, QuantoConfig, + QuarkConfig, SpQRConfig, TorchAoConfig, VptqConfig, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py old mode 100755 new mode 100644 index bd2354aec739..093449f822a0 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4931,6 +4931,7 @@ def _find_mismatched_keys( remove_prefix_from_model, ignore_mismatched_sizes, ) + if low_cpu_mem_usage: if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: for key, param in model_to_load.state_dict().items(): diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py old mode 100755 new mode 100644 index 64634f98a44a..1d5f0b853673 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -1,4 +1,5 @@ # Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Modifications Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,7 +32,8 @@ QuantizationConfigMixin, QuantizationMethod, QuantoConfig, - SpQRConfig, + QuarkConfig, + SpQRConfig, TorchAoConfig, VptqConfig, ) @@ -49,6 +51,7 @@ from .quantizer_higgs import HiggsHfQuantizer from .quantizer_hqq import HqqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer +from .quantizer_quark import QuarkHfQuantizer from .quantizer_spqr import SpQRHfQuantizer from .quantizer_torchao import TorchAoHfQuantizer from .quantizer_vptq import VptqHfQuantizer @@ -61,6 +64,7 @@ "gptq": GptqHfQuantizer, "aqlm": AqlmHfQuantizer, "quanto": QuantoHfQuantizer, + "quark": QuarkHfQuantizer, "eetq": EetqHfQuantizer, "higgs": HiggsHfQuantizer, "hqq": HqqHfQuantizer, @@ -81,6 +85,7 @@ "gptq": GPTQConfig, "aqlm": AqlmConfig, "quanto": QuantoConfig, + "quark": QuarkConfig, "hqq": HqqConfig, "compressed-tensors": CompressedTensorsConfig, "fbgemm_fp8": FbgemmFp8Config, diff --git a/src/transformers/quantizers/quantizer_quark.py b/src/transformers/quantizers/quantizer_quark.py new file mode 100644 index 000000000000..131b8d119e89 --- /dev/null +++ b/src/transformers/quantizers/quantizer_quark.py @@ -0,0 +1,103 @@ +# Copyright 2024 Advanced Micro Devices, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, Any + +from .base import HfQuantizer + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +from ..utils import is_quark_available, is_accelerate_available, logging + +if is_accelerate_available(): + from accelerate.utils import set_module_tensor_to_device + +if is_quark_available(): + from quark.torch import ModelImporter + +logger = logging.get_logger(__name__) + + +CHECKPOINT_KEYS = { + "weight_scale": "weight_quantizer.scale", + "bias_scale": "bias_quantizer.scale", + "input_scale": "input_quantizer.scale", + "output_scale": "output_quantizer.scale", + "weight_zero_point": "weight_quantizer.zero_point", + "bias_zero_point": "bias_quantizer.zero_point", + "input_zero_point": "input_quantizer.zero_point", + "output_zero_point": "output_quantizer.zero_point", +} + +class QuarkHfQuantizer(HfQuantizer): + """ + Quark quantizer (https://quark.docs.amd.com/latest/). + """ + + requires_calibration = True # On-the-fly quantization with quark is not supported for now. + required_packages = ["quark"] + + # Checkpoints are expected to be already quantized when loading a quark model. However, as some keys from + # the checkpoint might mismatch the model parameters keys, we use the `create_quantized_param` method + # to load the checkpoints, remapping the keys. + requires_parameters_quantization = True + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + self.json_export_config = quantization_config.json_export_config + + def validate_environment(self, *args, **kwargs): + if not is_quark_available(): + raise ImportError( + "Loading a Quark quantized model requires the `quark` library but it was not found in the environment. Please refer to https://quark.docs.amd.com/latest/install.html." + ) + + def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): + # TODO: support legacy? legacy=self.quantization_config.legacy + ModelImporter.map_to_quark(model, self.quantization_config.quant_config, pack_method=self.json_export_config.pack_method, custom_mode=self.quantization_config.custom_mode) + + return model + + def check_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + return True + + def create_quantized_param(self, model, param, param_name, param_device, state_dict, unexpected_keys) -> "torch.nn.Parameter": + postfix = param_name.split(".")[-1] + + if postfix in CHECKPOINT_KEYS: + param_name = param_name.replace(postfix, CHECKPOINT_KEYS[postfix]) + + set_module_tensor_to_device(model, param_name, param_device, value=param) + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + return model + + def is_serializable(self, safe_serialization=None): + # TODO: check serialization + return False + + @property + def is_trainable(self): + # TODO: check trainable + return False diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 1d575ad4a3a7..b00ede804083 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -112,6 +112,7 @@ is_pytesseract_available, is_pytest_available, is_pytorch_quantization_available, + is_quark_available, is_rjieba_available, is_sacremoses_available, is_safetensors_available, @@ -1254,6 +1255,11 @@ def require_fbgemm_gpu(test_case): """ return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case) +def require_quark(test_case): + """ + Decorator for quark dependency + """ + return unittest.skipUnless(is_quark_available(), "test requires quark")(test_case) def require_flute_hadamard(test_case): """ diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py old mode 100755 new mode 100644 index 079e575a5e39..d26650a4a4cd --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -181,6 +181,7 @@ is_pytesseract_available, is_pytest_available, is_pytorch_quantization_available, + is_quark_available, is_rjieba_available, is_sacremoses_available, is_safetensors_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py old mode 100755 new mode 100644 index 960079a9e388..2e7ffb33518a --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -150,6 +150,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ # `importlib.metadata.version` doesn't work with `awq` _auto_awq_available = importlib.util.find_spec("awq") is not None _quanto_available = _is_package_available("quanto") +_quark_available = _is_package_available("quark") _is_optimum_quanto_available = False try: importlib.metadata.version("optimum_quanto") @@ -1044,6 +1045,10 @@ def is_optimum_quanto_available(): return _is_optimum_quanto_available +def is_quark_available(): + return _quark_available + + def is_compressed_tensors_available(): return _compressed_tensors_available diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py old mode 100755 new mode 100644 index 3fafca29b9c3..ad69cacda4da --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -2,6 +2,7 @@ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Modifications Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,6 +30,7 @@ from ..utils import ( is_auto_awq_available, is_gptqmodel_available, + is_quark_available, is_hqq_available, is_torch_available, is_torchao_available, @@ -59,6 +61,7 @@ class QuantizationMethod(str, Enum): BITNET = "bitnet" SPQR = "spqr" FP8 = "fp8" + QUARK = "quark" class AWQLinearVersion(str, Enum): @@ -1681,3 +1684,30 @@ def post_init(self): raise ValueError("weight_block_size must be a tuple of two integers") if self.weight_block_size[0] <= 0 or self.weight_block_size[1] <= 0: raise ValueError("weight_block_size must be a tuple of two positive integers") + + +class QuarkConfig(QuantizationConfigMixin): + def __init__( + self, + **kwargs, + ): + if is_torch_available() and is_quark_available(): + import quark.torch + + self.custom_mode = kwargs["quant_method"] # This might be e.g. `"fp8"` or `"awq"`. + self.legacy = "export" not in kwargs + + if self.custom_mode in ["awq", "fp8"]: + # Legacy (quark<1.0) or custom export. + self.quant_config = quark.torch.export.main_export.quant_config_parser.QuantConfigParser.from_custom_config(kwargs, is_bias_quantized=False) + self.json_export_config = quark.torch.export.config.config.JsonExporterConfig() + else: + self.quant_config = quark.torch.quantization.config.config.Config.from_dict(kwargs) + + if "export" in kwargs: + self.json_export_config = quark.torch.export.config.config.JsonExporterConfig(**kwargs["export"]) + else: + # Legacy (quark<1.0) or custom export. + self.json_export_config = quark.torch.export.config.config.JsonExporterConfig() + + self.quant_method = QuantizationMethod.QUARK diff --git a/tests/quantization/quark/__init__.py b/tests/quantization/quark/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/quantization/quark/test_quark.py b/tests/quantization/quark/test_quark.py new file mode 100644 index 000000000000..bfa715235803 --- /dev/null +++ b/tests/quantization/quark/test_quark.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, QuarkConfig, GenerationConfig +from transformers.utils.import_utils import is_quark_available +from transformers.testing_utils import ( + is_torch_available, + require_accelerate, + require_quark, + require_torch_gpu, + require_torch_multi_gpu, + slow, +) + + +if is_torch_available(): + import torch + +if is_quark_available(): + from quark.torch.export.nn.modules import QParamsLinear + + +class QuarkConfigTest(unittest.TestCase): + def test_unknown_arg(self): + QuarkConfig(foo=2) + + def test_commmon_args(self): + config = AutoConfig.from_pretrained("amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test") + QuarkConfig(**config.quant_config) + + def test_missing_args(self): + config = AutoConfig.from_pretrained("amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test") + + # TODO: delete one arg from quant_config + QuarkConfig(**config.quant_config) + +@slow +@require_quark +@require_torch_gpu +class QuarkTest(unittest.TestCase): + reference_model_name = "meta-llama/Llama-3.1-8B-Instruct" + quantized_model_name = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test" + + input_text = "Today I am in Paris and" + + EXPECTED_OUTPUTS = set() + EXPECTED_OUTPUTS.add("Today I am in Paris and I am not in Paris, France\nToday I am in Paris, Illinois") + + EXPECTED_RELATIVE_DIFFERENCE = 1.664253062 + device_map = None + + # called only once for all test in this class + @classmethod + def setUpClass(cls): + """ + Setup reference & quantized model + """ + cls.model_fp16 = AutoModelForCausalLM.from_pretrained( + cls.reference_model_name, torch_dtype=torch.float16, device_map=cls.device_map + ) + cls.mem_fp16 = cls.model_fp16.get_memory_footprint() + + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True) + + cls.quantized_model = AutoModelForCausalLM.from_pretrained( + cls.quantized_model_name, + torch_dtype=torch.float16, + device_map=cls.device_map, + ) + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model + """ + mem_quantized = self.quantized_model.get_memory_footprint() + + self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE) + + def test_device_and_dtype_assignment(self): + r""" + Test whether trying to cast (or assigning a device to) a model after quantization will throw an error. + Checks also if other models are casted correctly. + """ + # This should work + if self.device_map is None: + _ = self.quantized_model.to(0) + + with self.assertRaises(ValueError): + # Tries with a `dtype`` + self.quantized_model.to(torch.float16) + + def test_original_dtype(self): + r""" + A simple test to check if the model succesfully stores the original dtype + """ + self.assertTrue(hasattr(self.quantized_model.config, "_pre_quantization_dtype")) + self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype")) + self.assertTrue(self.quantized_model.config._pre_quantization_dtype == torch.float16) + + def test_quantized_layers_class(self): + """ + Simple test to check if the model conversion has been done correctly by checking on + the class type of the linear layers of the converted models + """ + self.assertTrue(isinstance(self.quantized_model.model.layers[0].mlp.gate_proj, QParamsLinear)) + + def check_inference_correctness(self, model): + r""" + Test the generation quality of the quantized model and see that we are matching the expected output. + Given that we are operating on small numbers + the testing model is relatively small, we might not get + the same output across GPUs. So we'll generate few tokens (5-10) and check their output. + """ + # Check that inference pass works on the model + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + + gen_config = GenerationConfig( + max_new_tokens=15, + min_new_tokens=15, + use_cache=True, + num_beams=1, + do_sample=False, + ) + + # Check the exactness of the results + output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), generation_config=gen_config) + + # Get the generation + self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + + def test_generate_quality(self): + """ + Simple test to check the quality of the model by comparing the generated tokens with the expected tokens + """ + if self.device_map is None: + self.check_inference_correctness(self.quantized_model.to(0)) + else: + self.check_inference_correctness(self.quantized_model) + +@require_accelerate +@require_torch_multi_gpu +@require_quark +class QuarkTestDeviceMap(QuarkTest): + device_map = "auto" From c405adb4cd8395ae6fded6ab4213336b7ad15f8d Mon Sep 17 00:00:00 2001 From: Felix Marty Date: Fri, 21 Feb 2025 16:28:41 +0100 Subject: [PATCH 02/17] add quark doc --- docs/source/en/_toctree.yml | 2 + docs/source/en/main_classes/quantization.md | 4 ++ docs/source/en/quantization/overview.md | 7 +++ docs/source/en/quantization/quark.md | 63 +++++++++++++++++++++ 4 files changed, 76 insertions(+) create mode 100644 docs/source/en/quantization/quark.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7d7201da5027..fb039a0786e6 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -180,6 +180,8 @@ title: VPTQ - local: quantization/quanto title: Quanto + - local: quantization/quark + title: Quark - local: quantization/eetq title: EETQ - local: quantization/higgs diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md index 6da5b8ce69b5..fb42e886bace 100755 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -88,3 +88,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide. ## FineGrainedFP8Config [[autodoc]] FineGrainedFP8Config + +## QuarkConfig + +[[autodoc]] QuarkConfig diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 94696e300a57..955cb59b7275 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -63,6 +63,7 @@ Use the table below to help you decide which quantization method to use. | [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ | | [SpQR](./spqr.md) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ | | [FINEGRAINED_FP8](./finegrained_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | | +| [Quark](./quark.md) | 🔴 | 🟢 6 | 🟢 6 | 🟢 6 | 🟢 6 | 🟢 6 | ? | 2/4/6/8/9/16 | 🔴 | 🔴 | 🟢 | https://quark.docs.amd.com/latest/ | **1:** bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. @@ -93,3 +94,9 @@ Use the table below to help you decide which quantization method to use. + + +**6:** Quark is hardware agnostic, and may not supported accelerated inference / kernels for every quantization scheme and every hardware and PyTorch distribution. + + + diff --git a/docs/source/en/quantization/quark.md b/docs/source/en/quantization/quark.md new file mode 100644 index 000000000000..986c3b63fd0b --- /dev/null +++ b/docs/source/en/quantization/quark.md @@ -0,0 +1,63 @@ + + +# Quark + +[Quark](https://quark.docs.amd.com/latest/) is a deep learning quantization toolkit designed to be agnostic to specific data types, algorithms, and hardware. Different pre-processing strategies, algorithms and data-types can be combined in Quark. + +The PyTorch support integrated through 🤗 Transformers primarily targets AMD CPUs and GPUs, and is primarily meant to be used for evaluation purposes. + +Users interested in Quark can refer to its [documentation](https://quark.docs.amd.com/latest/) to get started quantizing models and using them in supported open-source libraries! + +Although Quark has its own checkpoint / [configuration format](https://huggingface.co/amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test/blob/main/config.json#L26), the library also supports producing models with a serialization layout compliant with other quantization/runtime implementations ([AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq), [native fp8 in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8)). + +## Support matrix + +Models quantized through Quark support a large range of features, that can be combined together. All quantized models independently of their configuration can seamlessly be reloaded through `PretrainedModel.from_pretrained`. + +| **Feature** | **Supported subset in Quark** | | +|---------------------------------|-----------------------------------------------------------------------------------------------------------|---| +| Data types | int8, int4, int2, bfloat16, float16, fp8_e5m2, fp8_e4m3, fp6_e3m2, fp6_e2m3, fp4, OCP MX, MX6, MX9, bfp16 | | +| Pre-quantization transformation | SmoothQuant, QuaRot, SpinQuant, AWQ | | +| Quantization algorithm | GPTQ | | +| Supported operators | ``nn.Linear``, ``nn.Conv2d``, ``nn.ConvTranspose2d``, ``nn.Embedding``, ``nn.EmbeddingBag`` | | +| Granularity | per-tensor, per-channel, per-block, per-layer, per-layer type | | +| KV cache | fp8 | | +| Activation calibration | MinMax / Percentile / MSE | | +| Quantization strategy | weight-only, static, dynamic, with or without output quantization | | + +## Models on Hugging Face Hub + +Public models using Quark native serialization can be found at https://huggingface.co/models?other=quark. + +Although Quark also supports [models using `quant_method="fp8"`](https://huggingface.co/models?other=fp8) and [models using `quant_method="awq"`](https://huggingface.co/models?other=awq), Transformers loads these models rather through [AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq) or uses the [native fp8 support in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8). + +## Using Quark models in Transformers + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("EmbeddedLLM/Llama-3.1-8B-Instruct-w_fp8_per_channel_sym") +model = model.to("cuda") + +tokenizer = AutoTokenizer.from_pretrained("EmbeddedLLM/Llama-3.1-8B-Instruct-w_fp8_per_channel_sym") + +inp = tokenizer("Where is a good place to cycle around Tokyo?", return_tensors="pt").to("cuda") + +res = model.generate(**inp) + +print(tokenizer.batch_decode(res)) +``` \ No newline at end of file From eb189de5fef120f9c283725ce9c149ab951dd493 Mon Sep 17 00:00:00 2001 From: Felix Marty Date: Fri, 21 Feb 2025 09:06:54 -0700 Subject: [PATCH 03/17] clean up doc --- docs/source/en/quantization/quark.md | 27 ++++++++++++++----- .../quantizers/quantizer_quark.py | 11 ++++---- src/transformers/utils/quantization_config.py | 2 +- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/docs/source/en/quantization/quark.md b/docs/source/en/quantization/quark.md index 986c3b63fd0b..1abe84deecd0 100644 --- a/docs/source/en/quantization/quark.md +++ b/docs/source/en/quantization/quark.md @@ -18,7 +18,7 @@ rendered properly in your Markdown viewer. [Quark](https://quark.docs.amd.com/latest/) is a deep learning quantization toolkit designed to be agnostic to specific data types, algorithms, and hardware. Different pre-processing strategies, algorithms and data-types can be combined in Quark. -The PyTorch support integrated through 🤗 Transformers primarily targets AMD CPUs and GPUs, and is primarily meant to be used for evaluation purposes. +The PyTorch support integrated through 🤗 Transformers primarily targets AMD CPUs and GPUs, and is primarily meant to be used for evaluation purposes. For example, it is possible to use [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) with 🤗 Transformers backend and evaluate a wide range of models quantized through Quark seamlessly. Users interested in Quark can refer to its [documentation](https://quark.docs.amd.com/latest/) to get started quantizing models and using them in supported open-source libraries! @@ -28,6 +28,8 @@ Although Quark has its own checkpoint / [configuration format](https://huggingfa Models quantized through Quark support a large range of features, that can be combined together. All quantized models independently of their configuration can seamlessly be reloaded through `PretrainedModel.from_pretrained`. +The table below shows a few features supported by Quark: + | **Feature** | **Supported subset in Quark** | | |---------------------------------|-----------------------------------------------------------------------------------------------------------|---| | Data types | int8, int4, int2, bfloat16, float16, fp8_e5m2, fp8_e4m3, fp6_e3m2, fp6_e2m3, fp4, OCP MX, MX6, MX9, bfp16 | | @@ -47,17 +49,30 @@ Although Quark also supports [models using `quant_method="fp8"`](https://hugging ## Using Quark models in Transformers +Here is an example of how one can load a Quark model in Transformers: + ```python from transformers import AutoModelForCausalLM, AutoTokenizer -model = AutoModelForCausalLM.from_pretrained("EmbeddedLLM/Llama-3.1-8B-Instruct-w_fp8_per_channel_sym") +model_id = "EmbeddedLLM/Llama-3.1-8B-Instruct-w_fp8_per_channel_sym" +model = AutoModelForCausalLM.from_pretrained(model_id) model = model.to("cuda") -tokenizer = AutoTokenizer.from_pretrained("EmbeddedLLM/Llama-3.1-8B-Instruct-w_fp8_per_channel_sym") +print(model.model.layers[0].self_attn.q_proj) +# QParamsLinear( +# (weight_quantizer): ScaledRealQuantizer() +# (input_quantizer): ScaledRealQuantizer() +# (output_quantizer): ScaledRealQuantizer() +# ) -inp = tokenizer("Where is a good place to cycle around Tokyo?", return_tensors="pt").to("cuda") +tokenizer = AutoTokenizer.from_pretrained(model_id) +inp = tokenizer("Where is a good place to cycle around Tokyo?", return_tensors="pt") +inp = inp.to("cuda") -res = model.generate(**inp) +res = model.generate(**inp, min_new_tokens=50, max_new_tokens=100) -print(tokenizer.batch_decode(res)) +print(tokenizer.batch_decode(res)[0]) +# <|begin_of_text|>Where is a good place to cycle around Tokyo? There are several places in Tokyo that are suitable for cycling, depending on your skill level and interests. Here are a few suggestions: +# 1. Yoyogi Park: This park is a popular spot for cycling and has a wide, flat path that's perfect for beginners. You can also visit the Meiji Shrine, a famous Shinto shrine located in the park. +# 2. Imperial Palace East Garden: This beautiful garden has a large, flat path that's perfect for cycling. You can also visit the ``` \ No newline at end of file diff --git a/src/transformers/quantizers/quantizer_quark.py b/src/transformers/quantizers/quantizer_quark.py index 131b8d119e89..65326b6e1226 100644 --- a/src/transformers/quantizers/quantizer_quark.py +++ b/src/transformers/quantizers/quantizer_quark.py @@ -25,9 +25,6 @@ if is_accelerate_available(): from accelerate.utils import set_module_tensor_to_device -if is_quark_available(): - from quark.torch import ModelImporter - logger = logging.get_logger(__name__) @@ -67,8 +64,12 @@ def validate_environment(self, *args, **kwargs): ) def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): - # TODO: support legacy? legacy=self.quantization_config.legacy - ModelImporter.map_to_quark(model, self.quantization_config.quant_config, pack_method=self.json_export_config.pack_method, custom_mode=self.quantization_config.custom_mode) + # Having imports not at the top of files is the approach taken for quantizers + # in Transformers - they quantizer_*.py files are not imported lazily. + if is_quark_available(): + from quark.torch.export.api import _map_to_quark + + _map_to_quark(model, self.quantization_config.quant_config, pack_method=self.json_export_config.pack_method, custom_mode=self.quantization_config.custom_mode) return model diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index ad69cacda4da..93856c5a3f9c 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -61,7 +61,7 @@ class QuantizationMethod(str, Enum): BITNET = "bitnet" SPQR = "spqr" FP8 = "fp8" - QUARK = "quark" + QUARK = "quark" class AWQLinearVersion(str, Enum): From 36d18cfe848ddf8d6d65da3712dea4f4aabe3b92 Mon Sep 17 00:00:00 2001 From: Felix Marty Date: Fri, 21 Feb 2025 09:43:05 -0700 Subject: [PATCH 04/17] fix tests --- src/transformers/modeling_utils.py | 4 +++ .../{quark => quark_integration}/__init__.py | 0 .../test_quark.py | 27 ++++++------------- 3 files changed, 12 insertions(+), 19 deletions(-) rename tests/quantization/{quark => quark_integration}/__init__.py (100%) rename tests/quantization/{quark => quark_integration}/test_quark.py (84%) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 093449f822a0..6fbe0f842a87 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3172,6 +3172,10 @@ def to(self, *args, **kwargs): if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: raise ValueError("`.to` is not supported for HQQ-quantized models.") + + if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK: + raise ValueError(f"Casting a Quark quantized model to a new `dtype` is not supported.") + # Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: if dtype_present_in_args: diff --git a/tests/quantization/quark/__init__.py b/tests/quantization/quark_integration/__init__.py similarity index 100% rename from tests/quantization/quark/__init__.py rename to tests/quantization/quark_integration/__init__.py diff --git a/tests/quantization/quark/test_quark.py b/tests/quantization/quark_integration/test_quark.py similarity index 84% rename from tests/quantization/quark/test_quark.py rename to tests/quantization/quark_integration/test_quark.py index bfa715235803..1eb1956ec111 100644 --- a/tests/quantization/quark/test_quark.py +++ b/tests/quantization/quark_integration/test_quark.py @@ -31,22 +31,14 @@ import torch if is_quark_available(): - from quark.torch.export.nn.modules import QParamsLinear + from quark.torch.export.nn.modules.qparamslinear import QParamsLinear -class QuarkConfigTest(unittest.TestCase): - def test_unknown_arg(self): - QuarkConfig(foo=2) - +class QuarkConfigTest(unittest.TestCase): def test_commmon_args(self): config = AutoConfig.from_pretrained("amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test") - QuarkConfig(**config.quant_config) + QuarkConfig(**config.quantization_config) - def test_missing_args(self): - config = AutoConfig.from_pretrained("amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test") - - # TODO: delete one arg from quant_config - QuarkConfig(**config.quant_config) @slow @require_quark @@ -59,11 +51,12 @@ class QuarkTest(unittest.TestCase): EXPECTED_OUTPUTS = set() EXPECTED_OUTPUTS.add("Today I am in Paris and I am not in Paris, France\nToday I am in Paris, Illinois") + EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying the city of light. I am not just any ordinary Paris") + EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying my day off! The sun is shining, the birds are") - EXPECTED_RELATIVE_DIFFERENCE = 1.664253062 + EXPECTED_RELATIVE_DIFFERENCE = 1.66 device_map = None - # called only once for all test in this class @classmethod def setUpClass(cls): """ @@ -74,7 +67,7 @@ def setUpClass(cls): ) cls.mem_fp16 = cls.model_fp16.get_memory_footprint() - cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.reference_model_name, use_fast=True) cls.quantized_model = AutoModelForCausalLM.from_pretrained( cls.quantized_model_name, @@ -83,13 +76,9 @@ def setUpClass(cls): ) def test_memory_footprint(self): - r""" - A simple test to check if the model conversion has been done correctly by checking on the - memory footprint of the converted model - """ mem_quantized = self.quantized_model.get_memory_footprint() - self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE) + self.assertTrue(self.mem_fp16 / mem_quantized > self.EXPECTED_RELATIVE_DIFFERENCE) def test_device_and_dtype_assignment(self): r""" From 8d233b469cc3923b1bf0bf0751501ab3eac7ee72 Mon Sep 17 00:00:00 2001 From: Felix Marty Date: Fri, 21 Feb 2025 09:45:03 -0700 Subject: [PATCH 05/17] make style --- src/transformers/modeling_utils.py | 4 ++-- src/transformers/quantizers/quantizer_quark.py | 5 +++-- src/transformers/utils/quantization_config.py | 2 +- tests/quantization/quark_integration/test_quark.py | 6 +++--- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6fbe0f842a87..59fe3da21a9e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3172,9 +3172,9 @@ def to(self, *args, **kwargs): if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: raise ValueError("`.to` is not supported for HQQ-quantized models.") - + if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK: - raise ValueError(f"Casting a Quark quantized model to a new `dtype` is not supported.") + raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.") # Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: diff --git a/src/transformers/quantizers/quantizer_quark.py b/src/transformers/quantizers/quantizer_quark.py index 65326b6e1226..3f11cd6d0e1b 100644 --- a/src/transformers/quantizers/quantizer_quark.py +++ b/src/transformers/quantizers/quantizer_quark.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, Any +from typing import TYPE_CHECKING, Any, Dict from .base import HfQuantizer @@ -20,7 +20,8 @@ if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -from ..utils import is_quark_available, is_accelerate_available, logging +from ..utils import is_accelerate_available, is_quark_available, logging + if is_accelerate_available(): from accelerate.utils import set_module_tensor_to_device diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 93856c5a3f9c..585aa1cc5950 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -30,8 +30,8 @@ from ..utils import ( is_auto_awq_available, is_gptqmodel_available, - is_quark_available, is_hqq_available, + is_quark_available, is_torch_available, is_torchao_available, logging, diff --git a/tests/quantization/quark_integration/test_quark.py b/tests/quantization/quark_integration/test_quark.py index 1eb1956ec111..586121b945ef 100644 --- a/tests/quantization/quark_integration/test_quark.py +++ b/tests/quantization/quark_integration/test_quark.py @@ -15,8 +15,7 @@ import unittest -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, QuarkConfig, GenerationConfig -from transformers.utils.import_utils import is_quark_available +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, QuarkConfig from transformers.testing_utils import ( is_torch_available, require_accelerate, @@ -25,6 +24,7 @@ require_torch_multi_gpu, slow, ) +from transformers.utils.import_utils import is_quark_available if is_torch_available(): @@ -34,7 +34,7 @@ from quark.torch.export.nn.modules.qparamslinear import QParamsLinear -class QuarkConfigTest(unittest.TestCase): +class QuarkConfigTest(unittest.TestCase): def test_commmon_args(self): config = AutoConfig.from_pretrained("amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test") QuarkConfig(**config.quantization_config) From 5f24cee8910de49f0e848b970ec3569e9c49cf33 Mon Sep 17 00:00:00 2001 From: Felix Marty Date: Fri, 21 Feb 2025 09:47:02 -0700 Subject: [PATCH 06/17] more style fixes --- src/transformers/__init__.py | 2 +- src/transformers/quantizers/auto.py | 2 +- src/transformers/quantizers/quantizer_quark.py | 16 ++++++++++++++-- src/transformers/testing_utils.py | 2 ++ src/transformers/utils/quantization_config.py | 6 +++++- .../quantization/quark_integration/test_quark.py | 1 + 6 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a8f5f475f625..e56707add469 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -6249,7 +6249,7 @@ HiggsConfig, HqqConfig, QuantoConfig, - QuarkConfig, + QuarkConfig, SpQRConfig, TorchAoConfig, VptqConfig, diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 1d5f0b853673..bc0335670514 100644 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -33,7 +33,7 @@ QuantizationMethod, QuantoConfig, QuarkConfig, - SpQRConfig, + SpQRConfig, TorchAoConfig, VptqConfig, ) diff --git a/src/transformers/quantizers/quantizer_quark.py b/src/transformers/quantizers/quantizer_quark.py index 3f11cd6d0e1b..a7059f0f3cab 100644 --- a/src/transformers/quantizers/quantizer_quark.py +++ b/src/transformers/quantizers/quantizer_quark.py @@ -14,12 +14,16 @@ from typing import TYPE_CHECKING, Any, Dict +from ..file_utils import is_torch_available from .base import HfQuantizer if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel + if is_torch_available(): + import torch + from ..utils import is_accelerate_available, is_quark_available, logging @@ -40,6 +44,7 @@ "output_zero_point": "output_quantizer.zero_point", } + class QuarkHfQuantizer(HfQuantizer): """ Quark quantizer (https://quark.docs.amd.com/latest/). @@ -70,7 +75,12 @@ def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwarg if is_quark_available(): from quark.torch.export.api import _map_to_quark - _map_to_quark(model, self.quantization_config.quant_config, pack_method=self.json_export_config.pack_method, custom_mode=self.quantization_config.custom_mode) + _map_to_quark( + model, + self.quantization_config.quant_config, + pack_method=self.json_export_config.pack_method, + custom_mode=self.quantization_config.custom_mode, + ) return model @@ -84,7 +94,9 @@ def check_quantized_param( ) -> bool: return True - def create_quantized_param(self, model, param, param_name, param_device, state_dict, unexpected_keys) -> "torch.nn.Parameter": + def create_quantized_param( + self, model, param, param_name, param_device, state_dict, unexpected_keys + ) -> "torch.nn.Parameter": postfix = param_name.split(".")[-1] if postfix in CHECKPOINT_KEYS: diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index b00ede804083..f8f8600963e6 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1255,12 +1255,14 @@ def require_fbgemm_gpu(test_case): """ return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case) + def require_quark(test_case): """ Decorator for quark dependency """ return unittest.skipUnless(is_quark_available(), "test requires quark")(test_case) + def require_flute_hadamard(test_case): """ Decorator marking a test that requires higgs and hadamard diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 585aa1cc5950..78d237e5db4b 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1699,7 +1699,11 @@ def __init__( if self.custom_mode in ["awq", "fp8"]: # Legacy (quark<1.0) or custom export. - self.quant_config = quark.torch.export.main_export.quant_config_parser.QuantConfigParser.from_custom_config(kwargs, is_bias_quantized=False) + self.quant_config = ( + quark.torch.export.main_export.quant_config_parser.QuantConfigParser.from_custom_config( + kwargs, is_bias_quantized=False + ) + ) self.json_export_config = quark.torch.export.config.config.JsonExporterConfig() else: self.quant_config = quark.torch.quantization.config.config.Config.from_dict(kwargs) diff --git a/tests/quantization/quark_integration/test_quark.py b/tests/quantization/quark_integration/test_quark.py index 586121b945ef..ccbc8a926381 100644 --- a/tests/quantization/quark_integration/test_quark.py +++ b/tests/quantization/quark_integration/test_quark.py @@ -140,6 +140,7 @@ def test_generate_quality(self): else: self.check_inference_correctness(self.quantized_model) + @require_accelerate @require_torch_multi_gpu @require_quark From d275c873790157240cddabd0f920074f2ef2d730 Mon Sep 17 00:00:00 2001 From: Felix Marty Date: Fri, 21 Feb 2025 09:52:28 -0700 Subject: [PATCH 07/17] cleanup imports --- src/transformers/utils/quantization_config.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 78d237e5db4b..11b6195b656c 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1692,26 +1692,25 @@ def __init__( **kwargs, ): if is_torch_available() and is_quark_available(): - import quark.torch + from quark.torch.export.config.config import JsonExporterConfig + from quark.torch.export.main_export.quant_config_parser import QuantConfigParser + from quark.torch.quantization.config.config import Config - self.custom_mode = kwargs["quant_method"] # This might be e.g. `"fp8"` or `"awq"`. + # This might be e.g. `"fp8"` or `"awq"`. + self.custom_mode = kwargs["quant_method"] self.legacy = "export" not in kwargs if self.custom_mode in ["awq", "fp8"]: # Legacy (quark<1.0) or custom export. - self.quant_config = ( - quark.torch.export.main_export.quant_config_parser.QuantConfigParser.from_custom_config( - kwargs, is_bias_quantized=False - ) - ) - self.json_export_config = quark.torch.export.config.config.JsonExporterConfig() + self.quant_config = QuantConfigParser.from_custom_config(kwargs, is_bias_quantized=False) + self.json_export_config = JsonExporterConfig() else: - self.quant_config = quark.torch.quantization.config.config.Config.from_dict(kwargs) + self.quant_config = Config.from_dict(kwargs) if "export" in kwargs: - self.json_export_config = quark.torch.export.config.config.JsonExporterConfig(**kwargs["export"]) + self.json_export_config = JsonExporterConfig(**kwargs["export"]) else: # Legacy (quark<1.0) or custom export. - self.json_export_config = quark.torch.export.config.config.JsonExporterConfig() + self.json_export_config = JsonExporterConfig() self.quant_method = QuantizationMethod.QUARK From f5e1817290e37d8d7343e7ecae794313f42b9e43 Mon Sep 17 00:00:00 2001 From: Felix Marty Date: Mon, 24 Feb 2025 07:36:11 -0700 Subject: [PATCH 08/17] cleaning --- docs/source/en/quantization/overview.md | 2 +- src/transformers/quantizers/quantizer_quark.py | 2 -- src/transformers/utils/import_utils.py | 11 +++++++++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 955cb59b7275..68799d3766ac 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -96,7 +96,7 @@ Use the table below to help you decide which quantization method to use. -**6:** Quark is hardware agnostic, and may not supported accelerated inference / kernels for every quantization scheme and every hardware and PyTorch distribution. +**6:** Quark is hardware agnostic, and may not support accelerated inference / kernels for every quantization scheme, hardware and PyTorch distribution. diff --git a/src/transformers/quantizers/quantizer_quark.py b/src/transformers/quantizers/quantizer_quark.py index a7059f0f3cab..c40759fd53c5 100644 --- a/src/transformers/quantizers/quantizer_quark.py +++ b/src/transformers/quantizers/quantizer_quark.py @@ -108,10 +108,8 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs return model def is_serializable(self, safe_serialization=None): - # TODO: check serialization return False @property def is_trainable(self): - # TODO: check trainable return False diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 2e7ffb33518a..9daedbf1346f 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -45,6 +45,11 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ package_version = "N/A" if package_exists: try: + # TODO: Once python 3.9 support is dropped, `importlib.metadata.packages_distributions()` + # should be used here to map from package name to distribution names + # e.g. PIL -> Pillow, Pillow-SIMD; quark -> amd-quark; onnxruntime -> onnxruntime-gpu. + # `importlib.metadata.packages_distributions()` is not available in Python 3.9. + # Primary method to get the package version package_version = importlib.metadata.version(pkg_name) except importlib.metadata.PackageNotFoundError: @@ -62,6 +67,12 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ except ImportError: # If the package can't be imported, it's not available package_exists = False + elif pkg_name == "quark": + # TODO: remove once `importlib.metadata.packages_distributions()` is supported. + try: + package_version = importlib.metadata.version("amd-quark") + except Exception: + package_exists = False else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False From 70e30faec04504b80d657d51fe1571c2f1f4122f Mon Sep 17 00:00:00 2001 From: Felix Marty Date: Mon, 24 Feb 2025 07:41:52 -0700 Subject: [PATCH 09/17] precise install --- docs/source/en/quantization/quark.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/en/quantization/quark.md b/docs/source/en/quantization/quark.md index 1abe84deecd0..66b6572e6c1e 100644 --- a/docs/source/en/quantization/quark.md +++ b/docs/source/en/quantization/quark.md @@ -24,6 +24,12 @@ Users interested in Quark can refer to its [documentation](https://quark.docs.am Although Quark has its own checkpoint / [configuration format](https://huggingface.co/amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test/blob/main/config.json#L26), the library also supports producing models with a serialization layout compliant with other quantization/runtime implementations ([AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq), [native fp8 in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8)). +To be able to load Quark quantized models in Transformers, the library first needs to be installed: + +```bash +pip install amd-quark +``` + ## Support matrix Models quantized through Quark support a large range of features, that can be combined together. All quantized models independently of their configuration can seamlessly be reloaded through `PretrainedModel.from_pretrained`. From c2e5ba0c1a75aa75f1de789db789fba29cb73ab0 Mon Sep 17 00:00:00 2001 From: fxmarty-amd Date: Fri, 7 Mar 2025 11:49:28 +0100 Subject: [PATCH 10/17] Update docs/source/en/quantization/quark.md Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- docs/source/en/quantization/quark.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/quantization/quark.md b/docs/source/en/quantization/quark.md index 66b6572e6c1e..8d60affbc280 100644 --- a/docs/source/en/quantization/quark.md +++ b/docs/source/en/quantization/quark.md @@ -1,4 +1,4 @@ -