From 644f9e27733d14621850cefb403b1ddf82fca107 Mon Sep 17 00:00:00 2001 From: keetrap Date: Mon, 27 Jan 2025 23:08:35 +0530 Subject: [PATCH 1/5] Added Support for Custom Quantization --- custom_quant_example.py | 81 +++++++++++++++++++++++++ src/transformers/modeling_utils.py | 6 +- src/transformers/quantizers/__init__.py | 2 +- src/transformers/quantizers/auto.py | 28 +++++++++ 4 files changed, 114 insertions(+), 3 deletions(-) create mode 100644 custom_quant_example.py diff --git a/custom_quant_example.py b/custom_quant_example.py new file mode 100644 index 000000000000..0ddde1af3a96 --- /dev/null +++ b/custom_quant_example.py @@ -0,0 +1,81 @@ +import json +import torch +from typing import Dict, Any +from transformers.quantizers import HfQuantizer +from transformers.utils.quantization_config import QuantizationConfigMixin +from transformers.quantizers import register_quantization_config, register_quantizer +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer + + +@register_quantization_config("custom") +class CustomConfig(QuantizationConfigMixin): + def __init__(self): + self.quant_method = "custom" + self.bits = 8 + + def to_dict(self) -> Dict[str, Any]: + output = { + "num_bits": self.bits, + } + return output + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + + def to_diff_dict(self) -> Dict[str, Any]: + + config_dict = self.to_dict() + + default_config_dict = CustomConfig().to_dict() + + serializable_config_dict = {} + + for key, value in config_dict.items(): + if value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict + + +@register_quantizer("custom") +class CustomQuantizer(HfQuantizer): + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + super().__init__(quantization_config, **kwargs) + self.quantization_config = quantization_config + self.scale_map = {} + self.device = kwargs.get( + "device", "cuda" if torch.cuda.is_available() else "cpu" + ) + self.torch_dtype = kwargs.get("torch_dtype", torch.float32) + + def _process_model_before_weight_loading(self, model, **kwargs): + return True + + def _process_model_after_weight_loading(self, model, **kwargs): + return True + + def is_serializable(self) -> bool: + return True + + def is_trainable(self) -> bool: + return False + + +model_8bit = AutoModelForCausalLM.from_pretrained( + "facebook/opt-350m", quantization_config=CustomConfig(), torch_dtype="auto" +) + +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") +input_text = "once there is" +inputs = tokenizer(input_text, return_tensors="pt") +output = model_8bit.generate( + **inputs, + max_length=100, + num_return_sequences=1, + no_repeat_ngram_size=2, +) +generated_text = tokenizer.decode(output[0], skip_special_tokens=True) + +print(generated_text) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d869229af9c8..6abab2ae0905 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3641,8 +3641,10 @@ def from_pretrained( device_map = hf_quantizer.update_device_map(device_map) # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` - user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value - + try: + user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value + except: + user_agent["quant"] = hf_quantizer.quantization_config.quant_method # Force-set to `True` for more mem efficiency if low_cpu_mem_usage is None: low_cpu_mem_usage = True diff --git a/src/transformers/quantizers/__init__.py b/src/transformers/quantizers/__init__.py index 3409af4cd78c..96c8d4fa5043 100755 --- a/src/transformers/quantizers/__init__.py +++ b/src/transformers/quantizers/__init__.py @@ -11,5 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .auto import AutoHfQuantizer, AutoQuantizationConfig +from .auto import AutoHfQuantizer, AutoQuantizationConfig, register_quantization_config, register_quantizer from .base import HfQuantizer diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index d5b51d038ab8..ce8a4d977f50 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -14,6 +14,7 @@ import warnings from typing import Dict, Optional, Union +from .base import HfQuantizer from ..models.auto.configuration_auto import AutoConfig from ..utils.quantization_config import ( AqlmConfig, @@ -195,3 +196,30 @@ def merge_quantization_configs( warnings.warn(warning_msg) return quantization_config + +def register_quantization_config(method: str): + """Register a custom quantization configuration.""" + + def register_config_fn(cls): + if method in AUTO_QUANTIZATION_CONFIG_MAPPING: + raise ValueError(f"Config '{method}' already registered") + + if not issubclass(cls, QuantizationConfigMixin): + raise ValueError(f"Config must extend QuantizationConfigMixin") + + AUTO_QUANTIZATION_CONFIG_MAPPING[method] = cls + return cls + return register_config_fn + +def register_quantizer(name: str): + """Register a custom quantizer.""" + def register_quantizer_fn(cls): + if name in AUTO_QUANTIZER_MAPPING: + raise ValueError(f"Quantizer '{name}' already registered") + + if not issubclass(cls, HfQuantizer): + raise ValueError(f"Quantizer must extend HfQuantizer") + + AUTO_QUANTIZER_MAPPING[name] = cls + return cls + return register_quantizer_fn \ No newline at end of file From 0b7cd98410e73e0d3cfd83100848e9f46d9cc20d Mon Sep 17 00:00:00 2001 From: keetrap Date: Tue, 28 Jan 2025 13:14:10 +0530 Subject: [PATCH 2/5] Update code --- custom_quant_example.py | 12 ++++++------ src/transformers/modeling_utils.py | 2 +- src/transformers/quantizers/auto.py | 16 ++++++++-------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/custom_quant_example.py b/custom_quant_example.py index 0ddde1af3a96..b237e69fc536 100644 --- a/custom_quant_example.py +++ b/custom_quant_example.py @@ -1,11 +1,11 @@ import json +from typing import Any, Dict + import torch -from typing import Dict, Any -from transformers.quantizers import HfQuantizer + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.quantizers import HfQuantizer, register_quantization_config, register_quantizer from transformers.utils.quantization_config import QuantizationConfigMixin -from transformers.quantizers import register_quantization_config, register_quantizer -from transformers import AutoModelForCausalLM -from transformers import AutoTokenizer @register_quantization_config("custom") @@ -72,7 +72,7 @@ def is_trainable(self) -> bool: inputs = tokenizer(input_text, return_tensors="pt") output = model_8bit.generate( **inputs, - max_length=100, + max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, ) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6abab2ae0905..f9fa9e9d3aac 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3643,7 +3643,7 @@ def from_pretrained( # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` try: user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value - except: + except Exception: user_agent["quant"] = hf_quantizer.quantization_config.quant_method # Force-set to `True` for more mem efficiency if low_cpu_mem_usage is None: diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index ce8a4d977f50..a6e561a69ed4 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -14,7 +14,6 @@ import warnings from typing import Dict, Optional, Union -from .base import HfQuantizer from ..models.auto.configuration_auto import AutoConfig from ..utils.quantization_config import ( AqlmConfig, @@ -33,6 +32,7 @@ TorchAoConfig, VptqConfig, ) +from .base import HfQuantizer from .quantizer_aqlm import AqlmHfQuantizer from .quantizer_awq import AwqQuantizer from .quantizer_bitnet import BitNetHfQuantizer @@ -203,11 +203,11 @@ def register_quantization_config(method: str): def register_config_fn(cls): if method in AUTO_QUANTIZATION_CONFIG_MAPPING: raise ValueError(f"Config '{method}' already registered") - + if not issubclass(cls, QuantizationConfigMixin): - raise ValueError(f"Config must extend QuantizationConfigMixin") - - AUTO_QUANTIZATION_CONFIG_MAPPING[method] = cls + raise ValueError("Config must extend QuantizationConfigMixin") + + AUTO_QUANTIZATION_CONFIG_MAPPING[method] = cls return cls return register_config_fn @@ -218,8 +218,8 @@ def register_quantizer_fn(cls): raise ValueError(f"Quantizer '{name}' already registered") if not issubclass(cls, HfQuantizer): - raise ValueError(f"Quantizer must extend HfQuantizer") - + raise ValueError("Quantizer must extend HfQuantizer") + AUTO_QUANTIZER_MAPPING[name] = cls return cls - return register_quantizer_fn \ No newline at end of file + return register_quantizer_fn From 7a9ce1fbb040473e1142ff2a06d38b0290c9e41a Mon Sep 17 00:00:00 2001 From: keetrap Date: Tue, 28 Jan 2025 13:25:20 +0530 Subject: [PATCH 3/5] code reformatted --- src/transformers/quantizers/auto.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index a6e561a69ed4..990c4056b3e7 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -197,6 +197,7 @@ def merge_quantization_configs( return quantization_config + def register_quantization_config(method: str): """Register a custom quantization configuration.""" @@ -209,10 +210,13 @@ def register_config_fn(cls): AUTO_QUANTIZATION_CONFIG_MAPPING[method] = cls return cls + return register_config_fn + def register_quantizer(name: str): """Register a custom quantizer.""" + def register_quantizer_fn(cls): if name in AUTO_QUANTIZER_MAPPING: raise ValueError(f"Quantizer '{name}' already registered") @@ -222,4 +226,5 @@ def register_quantizer_fn(cls): AUTO_QUANTIZER_MAPPING[name] = cls return cls + return register_quantizer_fn From 89d102af141fc1cf5dcd85cd95862707e9705370 Mon Sep 17 00:00:00 2001 From: keetrap Date: Tue, 28 Jan 2025 17:18:13 +0530 Subject: [PATCH 4/5] Updated Changes --- examples/quantization/custom_quantization.py | 78 ++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 examples/quantization/custom_quantization.py diff --git a/examples/quantization/custom_quantization.py b/examples/quantization/custom_quantization.py new file mode 100644 index 000000000000..16b31cd8ebe4 --- /dev/null +++ b/examples/quantization/custom_quantization.py @@ -0,0 +1,78 @@ +import json +from typing import Any, Dict + +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.quantizers import HfQuantizer, register_quantization_config, register_quantizer +from transformers.utils.quantization_config import QuantizationConfigMixin + + +@register_quantization_config("custom") +class CustomConfig(QuantizationConfigMixin): + def __init__(self): + self.quant_method = "custom" + self.bits = 8 + + def to_dict(self) -> Dict[str, Any]: + output = { + "num_bits": self.bits, + } + return output + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + + def to_diff_dict(self) -> Dict[str, Any]: + config_dict = self.to_dict() + + default_config_dict = CustomConfig().to_dict() + + serializable_config_dict = {} + + for key, value in config_dict.items(): + if value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict + + +@register_quantizer("custom") +class CustomQuantizer(HfQuantizer): + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + super().__init__(quantization_config, **kwargs) + self.quantization_config = quantization_config + self.scale_map = {} + self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") + self.torch_dtype = kwargs.get("torch_dtype", torch.float32) + + def _process_model_before_weight_loading(self, model, **kwargs): + return True + + def _process_model_after_weight_loading(self, model, **kwargs): + return True + + def is_serializable(self) -> bool: + return True + + def is_trainable(self) -> bool: + return False + + +model_8bit = AutoModelForCausalLM.from_pretrained( + "facebook/opt-350m", quantization_config=CustomConfig(), torch_dtype="auto" +) + +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") +input_text = "once there is" +inputs = tokenizer(input_text, return_tensors="pt") +output = model_8bit.generate( + **inputs, + max_length=100, + num_return_sequences=1, + no_repeat_ngram_size=2, +) +generated_text = tokenizer.decode(output[0], skip_special_tokens=True) + +print(generated_text) From de6c7e7fd20f1800a3e875dcc7f83ef95a3c4c29 Mon Sep 17 00:00:00 2001 From: keetrap Date: Tue, 28 Jan 2025 17:18:44 +0530 Subject: [PATCH 5/5] Updated Changes --- custom_quant_example.py | 81 ------------------------------ src/transformers/modeling_utils.py | 4 +- 2 files changed, 2 insertions(+), 83 deletions(-) delete mode 100644 custom_quant_example.py diff --git a/custom_quant_example.py b/custom_quant_example.py deleted file mode 100644 index b237e69fc536..000000000000 --- a/custom_quant_example.py +++ /dev/null @@ -1,81 +0,0 @@ -import json -from typing import Any, Dict - -import torch - -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.quantizers import HfQuantizer, register_quantization_config, register_quantizer -from transformers.utils.quantization_config import QuantizationConfigMixin - - -@register_quantization_config("custom") -class CustomConfig(QuantizationConfigMixin): - def __init__(self): - self.quant_method = "custom" - self.bits = 8 - - def to_dict(self) -> Dict[str, Any]: - output = { - "num_bits": self.bits, - } - return output - - def __repr__(self): - config_dict = self.to_dict() - return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" - - def to_diff_dict(self) -> Dict[str, Any]: - - config_dict = self.to_dict() - - default_config_dict = CustomConfig().to_dict() - - serializable_config_dict = {} - - for key, value in config_dict.items(): - if value != default_config_dict[key]: - serializable_config_dict[key] = value - - return serializable_config_dict - - -@register_quantizer("custom") -class CustomQuantizer(HfQuantizer): - def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): - super().__init__(quantization_config, **kwargs) - self.quantization_config = quantization_config - self.scale_map = {} - self.device = kwargs.get( - "device", "cuda" if torch.cuda.is_available() else "cpu" - ) - self.torch_dtype = kwargs.get("torch_dtype", torch.float32) - - def _process_model_before_weight_loading(self, model, **kwargs): - return True - - def _process_model_after_weight_loading(self, model, **kwargs): - return True - - def is_serializable(self) -> bool: - return True - - def is_trainable(self) -> bool: - return False - - -model_8bit = AutoModelForCausalLM.from_pretrained( - "facebook/opt-350m", quantization_config=CustomConfig(), torch_dtype="auto" -) - -tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") -input_text = "once there is" -inputs = tokenizer(input_text, return_tensors="pt") -output = model_8bit.generate( - **inputs, - max_length=100, - num_return_sequences=1, - no_repeat_ngram_size=2, -) -generated_text = tokenizer.decode(output[0], skip_special_tokens=True) - -print(generated_text) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f9fa9e9d3aac..a5c232f09e3b 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3641,9 +3641,9 @@ def from_pretrained( device_map = hf_quantizer.update_device_map(device_map) # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` - try: + if hasattr(hf_quantizer.quantization_config.quant_method, "value"): user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value - except Exception: + else: user_agent["quant"] = hf_quantizer.quantization_config.quant_method # Force-set to `True` for more mem efficiency if low_cpu_mem_usage is None: