diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ab80ddffec50..bc81c24f7347 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -40,6 +40,7 @@ "models": [], "modular_pipelines": [], "pipelines": [], + "quantizers.pipe_quant_config": ["PipelineQuantizationConfig"], "quantizers.quantization_config": [], "schedulers": [], "utils": [ diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 0375fbb0856a..6b8ba55941b7 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1096,6 +1096,8 @@ def load_module(name, value): model.register_to_config(_name_or_path=pretrained_model_name_or_path) if device_map is not None: setattr(model, "hf_device_map", final_device_map) + if quantization_config is not None: + setattr(model, "quantization_config", quantization_config) return model @property diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index efd241875321..3ca867c12908 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -12,183 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from typing import Dict, List, Optional, Union -from ..utils import is_transformers_available, logging from .auto import DiffusersAutoQuantizer from .base import DiffusersQuantizer -from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin - - -try: - from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin -except ImportError: - - class TransformersQuantConfigMixin: - pass - - -logger = logging.get_logger(__name__) - - -class PipelineQuantizationConfig: - """ - Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`]. - - Args: - quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend - is available to both `diffusers` and `transformers`. - quant_kwargs (`dict`): Params to initialize the quantization backend class. - components_to_quantize (`list`): Components of a pipeline to be quantized. - quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline - components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`, - and `components_to_quantize`. - """ - - def __init__( - self, - quant_backend: str = None, - quant_kwargs: Dict[str, Union[str, float, int, dict]] = None, - components_to_quantize: Optional[List[str]] = None, - quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None, - ): - self.quant_backend = quant_backend - # Initialize kwargs to be {} to set to the defaults. - self.quant_kwargs = quant_kwargs or {} - self.components_to_quantize = components_to_quantize - self.quant_mapping = quant_mapping - - self.post_init() - - def post_init(self): - quant_mapping = self.quant_mapping - self.is_granular = True if quant_mapping is not None else False - - self._validate_init_args() - - def _validate_init_args(self): - if self.quant_backend and self.quant_mapping: - raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.") - - if not self.quant_mapping and not self.quant_backend: - raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.") - - if not self.quant_kwargs and not self.quant_mapping: - raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.") - - if self.quant_backend is not None: - self._validate_init_kwargs_in_backends() - - if self.quant_mapping is not None: - self._validate_quant_mapping_args() - - def _validate_init_kwargs_in_backends(self): - quant_backend = self.quant_backend - - self._check_backend_availability(quant_backend) - - quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() - - if quant_config_mapping_transformers is not None: - init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__) - init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"} - else: - init_kwargs_transformers = None - - init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__) - init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"} - - if init_kwargs_transformers != init_kwargs_diffusers: - raise ValueError( - "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. " - f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how " - "this mapping would look like." - ) - - def _validate_quant_mapping_args(self): - quant_mapping = self.quant_mapping - transformers_map, diffusers_map = self._get_quant_config_list() - - available_transformers = list(transformers_map.values()) if transformers_map else None - available_diffusers = list(diffusers_map.values()) - - for module_name, config in quant_mapping.items(): - if any(isinstance(config, cfg) for cfg in available_diffusers): - continue - - if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers): - continue - - if available_transformers: - raise ValueError( - f"Provided config for module_name={module_name} could not be found. " - f"Available diffusers configs: {available_diffusers}; " - f"Available transformers configs: {available_transformers}." - ) - else: - raise ValueError( - f"Provided config for module_name={module_name} could not be found. " - f"Available diffusers configs: {available_diffusers}." - ) - - def _check_backend_availability(self, quant_backend: str): - quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() - - available_backends_transformers = ( - list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None - ) - available_backends_diffusers = list(quant_config_mapping_diffusers.keys()) - - if ( - available_backends_transformers and quant_backend not in available_backends_transformers - ) or quant_backend not in quant_config_mapping_diffusers: - error_message = f"Provided quant_backend={quant_backend} was not found." - if available_backends_transformers: - error_message += f"\nAvailable ones (transformers): {available_backends_transformers}." - error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}." - raise ValueError(error_message) - - def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None): - quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() - - quant_mapping = self.quant_mapping - components_to_quantize = self.components_to_quantize - - # Granular case - if self.is_granular and module_name in quant_mapping: - logger.debug(f"Initializing quantization config class for {module_name}.") - config = quant_mapping[module_name] - return config - - # Global config case - else: - should_quantize = False - # Only quantize the modules requested for. - if components_to_quantize and module_name in components_to_quantize: - should_quantize = True - # No specification for `components_to_quantize` means all modules should be quantized. - elif not self.is_granular and not components_to_quantize: - should_quantize = True - - if should_quantize: - logger.debug(f"Initializing quantization config class for {module_name}.") - mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers - quant_config_cls = mapping_to_use[self.quant_backend] - quant_kwargs = self.quant_kwargs - return quant_config_cls(**quant_kwargs) - - # Fallback: no applicable configuration found. - return None - - def _get_quant_config_list(self): - if is_transformers_available(): - from transformers.quantizers.auto import ( - AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers, - ) - else: - quant_config_mapping_transformers = None - - from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers - - return quant_config_mapping_transformers, quant_config_mapping_diffusers +from .pipe_quant_config import PipelineQuantizationConfig diff --git a/src/diffusers/quantizers/pipe_quant_config.py b/src/diffusers/quantizers/pipe_quant_config.py new file mode 100644 index 000000000000..5d02de16fd1c --- /dev/null +++ b/src/diffusers/quantizers/pipe_quant_config.py @@ -0,0 +1,202 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Dict, List, Optional, Union + +from ..utils import is_transformers_available, logging +from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin + + +try: + from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin +except ImportError: + + class TransformersQuantConfigMixin: + pass + + +logger = logging.get_logger(__name__) + + +class PipelineQuantizationConfig: + """ + Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`]. + + Args: + quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend + is available to both `diffusers` and `transformers`. + quant_kwargs (`dict`): Params to initialize the quantization backend class. + components_to_quantize (`list`): Components of a pipeline to be quantized. + quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline + components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`, + and `components_to_quantize`. + """ + + def __init__( + self, + quant_backend: str = None, + quant_kwargs: Dict[str, Union[str, float, int, dict]] = None, + components_to_quantize: Optional[List[str]] = None, + quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None, + ): + self.quant_backend = quant_backend + # Initialize kwargs to be {} to set to the defaults. + self.quant_kwargs = quant_kwargs or {} + self.components_to_quantize = components_to_quantize + self.quant_mapping = quant_mapping + self.config_mapping = {} # book-keeping Example: `{module_name: quant_config}` + self.post_init() + + def post_init(self): + quant_mapping = self.quant_mapping + self.is_granular = True if quant_mapping is not None else False + + self._validate_init_args() + + def _validate_init_args(self): + if self.quant_backend and self.quant_mapping: + raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.") + + if not self.quant_mapping and not self.quant_backend: + raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.") + + if not self.quant_kwargs and not self.quant_mapping: + raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.") + + if self.quant_backend is not None: + self._validate_init_kwargs_in_backends() + + if self.quant_mapping is not None: + self._validate_quant_mapping_args() + + def _validate_init_kwargs_in_backends(self): + quant_backend = self.quant_backend + + self._check_backend_availability(quant_backend) + + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + if quant_config_mapping_transformers is not None: + init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__) + init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"} + else: + init_kwargs_transformers = None + + init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__) + init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"} + + if init_kwargs_transformers != init_kwargs_diffusers: + raise ValueError( + "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. " + f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how " + "this mapping would look like." + ) + + def _validate_quant_mapping_args(self): + quant_mapping = self.quant_mapping + transformers_map, diffusers_map = self._get_quant_config_list() + + available_transformers = list(transformers_map.values()) if transformers_map else None + available_diffusers = list(diffusers_map.values()) + + for module_name, config in quant_mapping.items(): + if any(isinstance(config, cfg) for cfg in available_diffusers): + continue + + if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers): + continue + + if available_transformers: + raise ValueError( + f"Provided config for module_name={module_name} could not be found. " + f"Available diffusers configs: {available_diffusers}; " + f"Available transformers configs: {available_transformers}." + ) + else: + raise ValueError( + f"Provided config for module_name={module_name} could not be found. " + f"Available diffusers configs: {available_diffusers}." + ) + + def _check_backend_availability(self, quant_backend: str): + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + available_backends_transformers = ( + list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None + ) + available_backends_diffusers = list(quant_config_mapping_diffusers.keys()) + + if ( + available_backends_transformers and quant_backend not in available_backends_transformers + ) or quant_backend not in quant_config_mapping_diffusers: + error_message = f"Provided quant_backend={quant_backend} was not found." + if available_backends_transformers: + error_message += f"\nAvailable ones (transformers): {available_backends_transformers}." + error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}." + raise ValueError(error_message) + + def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None): + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + quant_mapping = self.quant_mapping + components_to_quantize = self.components_to_quantize + + # Granular case + if self.is_granular and module_name in quant_mapping: + logger.debug(f"Initializing quantization config class for {module_name}.") + config = quant_mapping[module_name] + self.config_mapping.update({module_name: config}) + return config + + # Global config case + else: + should_quantize = False + # Only quantize the modules requested for. + if components_to_quantize and module_name in components_to_quantize: + should_quantize = True + # No specification for `components_to_quantize` means all modules should be quantized. + elif not self.is_granular and not components_to_quantize: + should_quantize = True + + if should_quantize: + logger.debug(f"Initializing quantization config class for {module_name}.") + mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers + quant_config_cls = mapping_to_use[self.quant_backend] + quant_kwargs = self.quant_kwargs + quant_obj = quant_config_cls(**quant_kwargs) + self.config_mapping.update({module_name: quant_obj}) + return quant_obj + + # Fallback: no applicable configuration found. + return None + + def _get_quant_config_list(self): + if is_transformers_available(): + from transformers.quantizers.auto import ( + AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers, + ) + else: + quant_config_mapping_transformers = None + + from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers + + return quant_config_mapping_transformers, quant_config_mapping_diffusers + + def __repr__(self): + out = "" + config_mapping = dict(sorted(self.config_mapping.copy().items())) + for module_name, config in config_mapping.items(): + out += f"{module_name} {config}" + return out diff --git a/tests/quantization/test_pipeline_level_quantization.py b/tests/quantization/test_pipeline_level_quantization.py index 5a724df5c3ca..e91fe6d4cbab 100644 --- a/tests/quantization/test_pipeline_level_quantization.py +++ b/tests/quantization/test_pipeline_level_quantization.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import tempfile import unittest import torch from parameterized import parameterized -from diffusers import DiffusionPipeline, QuantoConfig +from diffusers import BitsAndBytesConfig, DiffusionPipeline, QuantoConfig from diffusers.quantizers import PipelineQuantizationConfig from diffusers.utils import logging from diffusers.utils.testing_utils import ( @@ -243,3 +244,57 @@ def test_no_quantization_for_all_invalid_components(self, method): for name, component in pipe.components.items(): if isinstance(component, torch.nn.Module): self.assertTrue(not hasattr(component.config, "quantization_config")) + + @parameterized.expand(["quant_kwargs", "quant_mapping"]) + def test_quant_config_repr(self, method): + component_name = "transformer" + if method == "quant_kwargs": + components_to_quantize = [component_name] + quant_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_8bit", + quant_kwargs={"load_in_8bit": True}, + components_to_quantize=components_to_quantize, + ) + else: + quant_config = PipelineQuantizationConfig( + quant_mapping={component_name: BitsAndBytesConfig(load_in_8bit=True)} + ) + + pipe = DiffusionPipeline.from_pretrained( + self.model_name, + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + ) + self.assertTrue(getattr(pipe, "quantization_config", None) is not None) + retrieved_config = pipe.quantization_config + expected_config = """ +transformer BitsAndBytesConfig { + "_load_in_4bit": false, + "_load_in_8bit": true, + "bnb_4bit_compute_dtype": "float32", + "bnb_4bit_quant_storage": "uint8", + "bnb_4bit_quant_type": "fp4", + "bnb_4bit_use_double_quant": false, + "llm_int8_enable_fp32_cpu_offload": false, + "llm_int8_has_fp16_weight": false, + "llm_int8_skip_modules": null, + "llm_int8_threshold": 6.0, + "load_in_4bit": false, + "load_in_8bit": true, + "quant_method": "bitsandbytes" +} + +""" + expected_data = self._parse_config_string(expected_config) + actual_data = self._parse_config_string(str(retrieved_config)) + self.assertTrue(actual_data == expected_data) + + def _parse_config_string(self, config_string: str) -> tuple[str, dict]: + first_brace = config_string.find("{") + if first_brace == -1: + raise ValueError("Could not find opening brace '{' in the string.") + + json_part = config_string[first_brace:] + data = json.loads(json_part) + + return data