diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile index ff0655c7e3a7..33d8b10b02ee 100755 --- a/docker/transformers-quantization-latest-gpu/Dockerfile +++ b/docker/transformers-quantization-latest-gpu/Dockerfile @@ -79,6 +79,9 @@ RUN git clone https://github.com/NetEase-FuXi/EETQ.git && cd EETQ/ && git submod # Add compressed-tensors for quantization testing RUN python3 -m pip install --no-cache-dir compressed-tensors +# Add AMD Quark for quantization testing +RUN python3 -m pip install --no-cache-dir amd-quark + # Add transformers in editable mode RUN python3 -m pip install --no-cache-dir -e ./transformers[dev-torch] diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 79f8eb3d490d..5a5f34445432 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -187,6 +187,8 @@ title: Optimum - local: quantization/quanto title: Quanto + - local: quantization/quark + title: Quark - local: quantization/torchao title: torchao - local: quantization/spqr diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md index 6da5b8ce69b5..fb42e886bace 100755 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -88,3 +88,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide. ## FineGrainedFP8Config [[autodoc]] FineGrainedFP8Config + +## QuarkConfig + +[[autodoc]] QuarkConfig diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index d69675e909aa..ac6fdb4a3dab 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -40,6 +40,7 @@ Use the Space below to help you pick a quantization method depending on your har | [VPTQ](./vptq) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ | | [FINEGRAINED_FP8](./finegrained_fp8) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | | | [SpQR](./spqr) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ | +| [Quark](./quark.md) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | ? | 2/4/6/8/9/16 | 🔴 | 🔴 | 🟢 | https://quark.docs.amd.com/latest/ | ## Resources @@ -55,4 +56,4 @@ If you are looking for a user-friendly quantization experience, you can use the * [Bitsandbytes Space](https://huggingface.co/spaces/bnb-community/bnb-my-repo) * [GGUF Space](https://huggingface.co/spaces/ggml-org/gguf-my-repo) * [MLX Space](https://huggingface.co/spaces/mlx-community/mlx-my-repo) -* [AuoQuant Notebook](https://colab.research.google.com/drive/1b6nqC7UZVt8bx4MksX7s656GXPM-eWw4?usp=sharing#scrollTo=ZC9Nsr9u5WhN) \ No newline at end of file +* [AuoQuant Notebook](https://colab.research.google.com/drive/1b6nqC7UZVt8bx4MksX7s656GXPM-eWw4?usp=sharing#scrollTo=ZC9Nsr9u5WhN) diff --git a/docs/source/en/quantization/quark.md b/docs/source/en/quantization/quark.md new file mode 100644 index 000000000000..8d60affbc280 --- /dev/null +++ b/docs/source/en/quantization/quark.md @@ -0,0 +1,84 @@ + + +# Quark + +[Quark](https://quark.docs.amd.com/latest/) is a deep learning quantization toolkit designed to be agnostic to specific data types, algorithms, and hardware. Different pre-processing strategies, algorithms and data-types can be combined in Quark. + +The PyTorch support integrated through 🤗 Transformers primarily targets AMD CPUs and GPUs, and is primarily meant to be used for evaluation purposes. For example, it is possible to use [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) with 🤗 Transformers backend and evaluate a wide range of models quantized through Quark seamlessly. + +Users interested in Quark can refer to its [documentation](https://quark.docs.amd.com/latest/) to get started quantizing models and using them in supported open-source libraries! + +Although Quark has its own checkpoint / [configuration format](https://huggingface.co/amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test/blob/main/config.json#L26), the library also supports producing models with a serialization layout compliant with other quantization/runtime implementations ([AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq), [native fp8 in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8)). + +To be able to load Quark quantized models in Transformers, the library first needs to be installed: + +```bash +pip install amd-quark +``` + +## Support matrix + +Models quantized through Quark support a large range of features, that can be combined together. All quantized models independently of their configuration can seamlessly be reloaded through `PretrainedModel.from_pretrained`. + +The table below shows a few features supported by Quark: + +| **Feature** | **Supported subset in Quark** | | +|---------------------------------|-----------------------------------------------------------------------------------------------------------|---| +| Data types | int8, int4, int2, bfloat16, float16, fp8_e5m2, fp8_e4m3, fp6_e3m2, fp6_e2m3, fp4, OCP MX, MX6, MX9, bfp16 | | +| Pre-quantization transformation | SmoothQuant, QuaRot, SpinQuant, AWQ | | +| Quantization algorithm | GPTQ | | +| Supported operators | ``nn.Linear``, ``nn.Conv2d``, ``nn.ConvTranspose2d``, ``nn.Embedding``, ``nn.EmbeddingBag`` | | +| Granularity | per-tensor, per-channel, per-block, per-layer, per-layer type | | +| KV cache | fp8 | | +| Activation calibration | MinMax / Percentile / MSE | | +| Quantization strategy | weight-only, static, dynamic, with or without output quantization | | + +## Models on Hugging Face Hub + +Public models using Quark native serialization can be found at https://huggingface.co/models?other=quark. + +Although Quark also supports [models using `quant_method="fp8"`](https://huggingface.co/models?other=fp8) and [models using `quant_method="awq"`](https://huggingface.co/models?other=awq), Transformers loads these models rather through [AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq) or uses the [native fp8 support in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8). + +## Using Quark models in Transformers + +Here is an example of how one can load a Quark model in Transformers: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_id = "EmbeddedLLM/Llama-3.1-8B-Instruct-w_fp8_per_channel_sym" +model = AutoModelForCausalLM.from_pretrained(model_id) +model = model.to("cuda") + +print(model.model.layers[0].self_attn.q_proj) +# QParamsLinear( +# (weight_quantizer): ScaledRealQuantizer() +# (input_quantizer): ScaledRealQuantizer() +# (output_quantizer): ScaledRealQuantizer() +# ) + +tokenizer = AutoTokenizer.from_pretrained(model_id) +inp = tokenizer("Where is a good place to cycle around Tokyo?", return_tensors="pt") +inp = inp.to("cuda") + +res = model.generate(**inp, min_new_tokens=50, max_new_tokens=100) + +print(tokenizer.batch_decode(res)[0]) +# <|begin_of_text|>Where is a good place to cycle around Tokyo? There are several places in Tokyo that are suitable for cycling, depending on your skill level and interests. Here are a few suggestions: +# 1. Yoyogi Park: This park is a popular spot for cycling and has a wide, flat path that's perfect for beginners. You can also visit the Meiji Shrine, a famous Shinto shrine located in the park. +# 2. Imperial Palace East Garden: This beautiful garden has a large, flat path that's perfect for cycling. You can also visit the +``` \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py old mode 100755 new mode 100644 index 1e33af79ab59..8eff9c9fc551 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1042,6 +1042,7 @@ "HiggsConfig", "HqqConfig", "QuantoConfig", + "QuarkConfig", "SpQRConfig", "TorchAoConfig", "VptqConfig", @@ -6278,6 +6279,7 @@ HiggsConfig, HqqConfig, QuantoConfig, + QuarkConfig, SpQRConfig, TorchAoConfig, VptqConfig, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py old mode 100755 new mode 100644 index 4158c82b4094..8fd7530789e2 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -536,6 +536,10 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): str_to_torch_dtype["U32"] = torch.uint32 str_to_torch_dtype["U64"] = torch.uint64 +if is_torch_greater_or_equal("2.1.0"): + str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn + str_to_torch_dtype["F8_E5M2"] = torch.float8_e5m2 + def load_state_dict( checkpoint_file: Union[str, os.PathLike], @@ -3672,6 +3676,10 @@ def to(self, *args, **kwargs): if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: raise ValueError("`.to` is not supported for HQQ-quantized models.") + + if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK: + raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.") + # Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: if dtype_present_in_args: diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py old mode 100755 new mode 100644 index 64634f98a44a..9d24b3539530 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -1,4 +1,5 @@ # Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Modifications Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,6 +32,7 @@ QuantizationConfigMixin, QuantizationMethod, QuantoConfig, + QuarkConfig, SpQRConfig, TorchAoConfig, VptqConfig, @@ -49,6 +51,7 @@ from .quantizer_higgs import HiggsHfQuantizer from .quantizer_hqq import HqqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer +from .quantizer_quark import QuarkHfQuantizer from .quantizer_spqr import SpQRHfQuantizer from .quantizer_torchao import TorchAoHfQuantizer from .quantizer_vptq import VptqHfQuantizer @@ -61,6 +64,7 @@ "gptq": GptqHfQuantizer, "aqlm": AqlmHfQuantizer, "quanto": QuantoHfQuantizer, + "quark": QuarkHfQuantizer, "eetq": EetqHfQuantizer, "higgs": HiggsHfQuantizer, "hqq": HqqHfQuantizer, @@ -81,6 +85,7 @@ "gptq": GPTQConfig, "aqlm": AqlmConfig, "quanto": QuantoConfig, + "quark": QuarkConfig, "hqq": HqqConfig, "compressed-tensors": CompressedTensorsConfig, "fbgemm_fp8": FbgemmFp8Config, diff --git a/src/transformers/quantizers/quantizer_quark.py b/src/transformers/quantizers/quantizer_quark.py new file mode 100644 index 000000000000..374360b1cb8f --- /dev/null +++ b/src/transformers/quantizers/quantizer_quark.py @@ -0,0 +1,113 @@ +# coding=utf-8 +# Copyright 2025 Advanced Micro Devices, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict + +from ..file_utils import is_torch_available +from .base import HfQuantizer + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + + if is_torch_available(): + import torch + +from ..utils import is_accelerate_available, is_quark_available, logging + + +if is_accelerate_available(): + from accelerate.utils import set_module_tensor_to_device + +logger = logging.get_logger(__name__) + + +CHECKPOINT_KEYS = { + "weight_scale": "weight_quantizer.scale", + "bias_scale": "bias_quantizer.scale", + "input_scale": "input_quantizer.scale", + "output_scale": "output_quantizer.scale", + "weight_zero_point": "weight_quantizer.zero_point", + "bias_zero_point": "bias_quantizer.zero_point", + "input_zero_point": "input_quantizer.zero_point", + "output_zero_point": "output_quantizer.zero_point", +} + + +class QuarkHfQuantizer(HfQuantizer): + """ + Quark quantizer (https://quark.docs.amd.com/latest/). + """ + + requires_calibration = True # On-the-fly quantization with quark is not supported for now. + required_packages = ["quark"] + + # Checkpoints are expected to be already quantized when loading a quark model. However, as some keys from + # the checkpoint might mismatch the model parameters keys, we use the `create_quantized_param` method + # to load the checkpoints, remapping the keys. + requires_parameters_quantization = True + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + self.json_export_config = quantization_config.json_export_config + + def validate_environment(self, *args, **kwargs): + if not is_quark_available(): + raise ImportError( + "Loading a Quark quantized model requires the `quark` library but it was not found in the environment. Please refer to https://quark.docs.amd.com/latest/install.html." + ) + + def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): + from quark.torch.export.api import _map_to_quark + + _map_to_quark( + model, + self.quantization_config.quant_config, + pack_method=self.json_export_config.pack_method, + custom_mode=self.quantization_config.custom_mode, + ) + + return model + + def check_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + return True + + def create_quantized_param( + self, model, param, param_name, param_device, state_dict, unexpected_keys + ) -> "torch.nn.Parameter": + postfix = param_name.split(".")[-1] + + if postfix in CHECKPOINT_KEYS: + param_name = param_name.replace(postfix, CHECKPOINT_KEYS[postfix]) + + set_module_tensor_to_device(model, param_name, param_device, value=param) + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + return model + + def is_serializable(self, safe_serialization=None): + return False + + @property + def is_trainable(self): + return False diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index e2811ae9f108..40c9bf413efa 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -116,6 +116,7 @@ is_pytesseract_available, is_pytest_available, is_pytorch_quantization_available, + is_quark_available, is_rjieba_available, is_sacremoses_available, is_safetensors_available, @@ -1299,6 +1300,13 @@ def require_fbgemm_gpu(test_case): return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case) +def require_quark(test_case): + """ + Decorator for quark dependency + """ + return unittest.skipUnless(is_quark_available(), "test requires quark")(test_case) + + def require_flute_hadamard(test_case): """ Decorator marking a test that requires higgs and hadamard diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py old mode 100755 new mode 100644 index bdcb273c7e92..a549af2928da --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -181,6 +181,7 @@ is_pytesseract_available, is_pytest_available, is_pytorch_quantization_available, + is_quark_available, is_rich_available, is_rjieba_available, is_sacremoses_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py old mode 100755 new mode 100644 index ad4e685b24f4..b6eb2be5db4e --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -45,6 +45,11 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ package_version = "N/A" if package_exists: try: + # TODO: Once python 3.9 support is dropped, `importlib.metadata.packages_distributions()` + # should be used here to map from package name to distribution names + # e.g. PIL -> Pillow, Pillow-SIMD; quark -> amd-quark; onnxruntime -> onnxruntime-gpu. + # `importlib.metadata.packages_distributions()` is not available in Python 3.9. + # Primary method to get the package version package_version = importlib.metadata.version(pkg_name) except importlib.metadata.PackageNotFoundError: @@ -62,6 +67,12 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ except ImportError: # If the package can't be imported, it's not available package_exists = False + elif pkg_name == "quark": + # TODO: remove once `importlib.metadata.packages_distributions()` is supported. + try: + package_version = importlib.metadata.version("amd-quark") + except Exception: + package_exists = False else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False @@ -150,6 +161,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _gptqmodel_available = _is_package_available("gptqmodel") # `importlib.metadata.version` doesn't work with `awq` _auto_awq_available = importlib.util.find_spec("awq") is not None +_quark_available = _is_package_available("quark") _is_optimum_quanto_available = False try: importlib.metadata.version("optimum_quanto") @@ -1118,6 +1130,10 @@ def is_optimum_quanto_available(): return _is_optimum_quanto_available +def is_quark_available(): + return _quark_available + + def is_compressed_tensors_available(): return _compressed_tensors_available diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py old mode 100755 new mode 100644 index 152572223fdf..6d859264005b --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -2,6 +2,7 @@ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Modifications Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,6 +32,7 @@ is_compressed_tensors_available, is_gptqmodel_available, is_hqq_available, + is_quark_available, is_torch_available, is_torchao_available, logging, @@ -60,6 +62,7 @@ class QuantizationMethod(str, Enum): BITNET = "bitnet" SPQR = "spqr" FP8 = "fp8" + QUARK = "quark" class AWQLinearVersion(str, Enum): @@ -1772,3 +1775,41 @@ def post_init(self): raise ValueError("weight_block_size must be a tuple of two integers") if self.weight_block_size[0] <= 0 or self.weight_block_size[1] <= 0: raise ValueError("weight_block_size must be a tuple of two positive integers") + + +class QuarkConfig(QuantizationConfigMixin): + def __init__( + self, + **kwargs, + ): + if is_torch_available() and is_quark_available(): + from quark import __version__ as quark_version + from quark.torch.export.config.config import JsonExporterConfig + from quark.torch.export.main_export.quant_config_parser import QuantConfigParser + from quark.torch.quantization.config.config import Config + + # This might be e.g. `"fp8"` or `"awq"`. + self.custom_mode = kwargs["quant_method"] + self.legacy = "export" not in kwargs + + if self.custom_mode in ["awq", "fp8"]: + # Legacy (quark<1.0) or custom export. + self.quant_config = QuantConfigParser.from_custom_config(kwargs, is_bias_quantized=False) + self.json_export_config = JsonExporterConfig() + else: + self.quant_config = Config.from_dict(kwargs) + + if "export" in kwargs: + # TODO: Remove this check once configuration version is handled natively by Quark. + if "min_kv_scale" in kwargs["export"] and version.parse(quark_version) < version.parse("0.8"): + min_kv_scale = kwargs["export"].pop("min_kv_scale") + logger.warning( + f"The parameter `min_kv_scale={min_kv_scale}` was found in the model config.json's `quantization_config.export` configuration, but this parameter is supported only for quark>=0.8. Ignoring this configuration parameter. Please update the `amd-quark` package." + ) + + self.json_export_config = JsonExporterConfig(**kwargs["export"]) + else: + # Legacy (quark<1.0) or custom export. + self.json_export_config = JsonExporterConfig() + + self.quant_method = QuantizationMethod.QUARK diff --git a/tests/quantization/quark_integration/__init__.py b/tests/quantization/quark_integration/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/quantization/quark_integration/test_quark.py b/tests/quantization/quark_integration/test_quark.py new file mode 100644 index 000000000000..32a9f6a6d8fd --- /dev/null +++ b/tests/quantization/quark_integration/test_quark.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2025 Advanced Micro Devices, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, QuarkConfig +from transformers.testing_utils import ( + is_torch_available, + require_accelerate, + require_quark, + require_torch_gpu, + require_torch_multi_gpu, + slow, +) +from transformers.utils.import_utils import is_quark_available + + +if is_torch_available(): + import torch + +if is_quark_available(): + from quark.torch.export.nn.modules.qparamslinear import QParamsLinear + + +class QuarkConfigTest(unittest.TestCase): + def test_commmon_args(self): + config = AutoConfig.from_pretrained("amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test") + QuarkConfig(**config.quantization_config) + + +@slow +@require_quark +@require_torch_gpu +class QuarkTest(unittest.TestCase): + reference_model_name = "meta-llama/Llama-3.1-8B-Instruct" + quantized_model_name = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test" + + input_text = "Today I am in Paris and" + + EXPECTED_OUTPUTS = set() + EXPECTED_OUTPUTS.add("Today I am in Paris and I am not in Paris, France\nToday I am in Paris, Illinois") + EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying the city of light. I am not just any ordinary Paris") + EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying my day off! The sun is shining, the birds are") + + EXPECTED_RELATIVE_DIFFERENCE = 1.66 + device_map = None + + @classmethod + def setUpClass(cls): + """ + Setup reference & quantized model + """ + cls.model_fp16 = AutoModelForCausalLM.from_pretrained( + cls.reference_model_name, torch_dtype=torch.float16, device_map=cls.device_map + ) + cls.mem_fp16 = cls.model_fp16.get_memory_footprint() + + cls.tokenizer = AutoTokenizer.from_pretrained(cls.reference_model_name, use_fast=True) + + cls.quantized_model = AutoModelForCausalLM.from_pretrained( + cls.quantized_model_name, + torch_dtype=torch.float16, + device_map=cls.device_map, + ) + + def test_memory_footprint(self): + mem_quantized = self.quantized_model.get_memory_footprint() + + self.assertTrue(self.mem_fp16 / mem_quantized > self.EXPECTED_RELATIVE_DIFFERENCE) + + def test_device_and_dtype_assignment(self): + r""" + Test whether trying to cast (or assigning a device to) a model after quantization will throw an error. + Checks also if other models are casted correctly. + """ + # This should work + if self.device_map is None: + _ = self.quantized_model.to(0) + + with self.assertRaises(ValueError): + # Tries with a `dtype`` + self.quantized_model.to(torch.float16) + + def test_original_dtype(self): + r""" + A simple test to check if the model succesfully stores the original dtype + """ + self.assertTrue(hasattr(self.quantized_model.config, "_pre_quantization_dtype")) + self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype")) + self.assertTrue(self.quantized_model.config._pre_quantization_dtype == torch.float16) + + self.assertTrue(isinstance(self.quantized_model.model.layers[0].mlp.gate_proj, QParamsLinear)) + + def check_inference_correctness(self, model): + r""" + Test the generation quality of the quantized model and see that we are matching the expected output. + Given that we are operating on small numbers + the testing model is relatively small, we might not get + the same output across GPUs. So we'll generate few tokens (5-10) and check their output. + """ + # Check that inference pass works on the model + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + + gen_config = GenerationConfig( + max_new_tokens=15, + min_new_tokens=15, + use_cache=True, + num_beams=1, + do_sample=False, + ) + + # Check the exactness of the results + output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), generation_config=gen_config) + + # Get the generation + self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + + def test_generate_quality(self): + """ + Simple test to check the quality of the model by comparing the generated tokens with the expected tokens + """ + if self.device_map is None: + self.check_inference_correctness(self.quantized_model.to(0)) + else: + self.check_inference_correctness(self.quantized_model) + + +@require_accelerate +@require_torch_multi_gpu +@require_quark +class QuarkTestDeviceMap(QuarkTest): + device_map = "auto"