Skip to content

Added Support for Custom Quantization #35915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 18, 2025
78 changes: 78 additions & 0 deletions examples/quantization/custom_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import json
from typing import Any, Dict

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.quantizers import HfQuantizer, register_quantization_config, register_quantizer
from transformers.utils.quantization_config import QuantizationConfigMixin


@register_quantization_config("custom")
class CustomConfig(QuantizationConfigMixin):
def __init__(self):
self.quant_method = "custom"
self.bits = 8

def to_dict(self) -> Dict[str, Any]:
output = {
"num_bits": self.bits,
}
return output

def __repr__(self):
config_dict = self.to_dict()
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"

def to_diff_dict(self) -> Dict[str, Any]:
config_dict = self.to_dict()

default_config_dict = CustomConfig().to_dict()

serializable_config_dict = {}

for key, value in config_dict.items():
if value != default_config_dict[key]:
serializable_config_dict[key] = value

return serializable_config_dict


@register_quantizer("custom")
class CustomQuantizer(HfQuantizer):
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config
self.scale_map = {}
self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
self.torch_dtype = kwargs.get("torch_dtype", torch.float32)

def _process_model_before_weight_loading(self, model, **kwargs):
return True

def _process_model_after_weight_loading(self, model, **kwargs):
return True

def is_serializable(self) -> bool:
return True

def is_trainable(self) -> bool:
return False


model_8bit = AutoModelForCausalLM.from_pretrained(
"facebook/opt-350m", quantization_config=CustomConfig(), torch_dtype="auto"
)

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
input_text = "once there is"
inputs = tokenizer(input_text, return_tensors="pt")
output = model_8bit.generate(
**inputs,
max_length=100,
num_return_sequences=1,
no_repeat_ngram_size=2,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

print(generated_text)
6 changes: 4 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3706,8 +3706,10 @@ def from_pretrained(
device_map = hf_quantizer.update_device_map(device_map)

# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value

if hasattr(hf_quantizer.quantization_config.quant_method, "value"):
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
else:
user_agent["quant"] = hf_quantizer.quantization_config.quant_method
# Force-set to `True` for more mem efficiency
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .auto import AutoHfQuantizer, AutoQuantizationConfig
from .auto import AutoHfQuantizer, AutoQuantizationConfig, register_quantization_config, register_quantizer
from .base import HfQuantizer
33 changes: 33 additions & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TorchAoConfig,
VptqConfig,
)
from .base import HfQuantizer
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
from .quantizer_bitnet import BitNetHfQuantizer
Expand Down Expand Up @@ -226,3 +227,35 @@ def supports_quant_method(quantization_config_dict):
)
return False
return True


def register_quantization_config(method: str):
"""Register a custom quantization configuration."""

def register_config_fn(cls):
if method in AUTO_QUANTIZATION_CONFIG_MAPPING:
raise ValueError(f"Config '{method}' already registered")

if not issubclass(cls, QuantizationConfigMixin):
raise ValueError("Config must extend QuantizationConfigMixin")

AUTO_QUANTIZATION_CONFIG_MAPPING[method] = cls
return cls

return register_config_fn


def register_quantizer(name: str):
"""Register a custom quantizer."""

def register_quantizer_fn(cls):
if name in AUTO_QUANTIZER_MAPPING:
raise ValueError(f"Quantizer '{name}' already registered")

if not issubclass(cls, HfQuantizer):
raise ValueError("Quantizer must extend HfQuantizer")

AUTO_QUANTIZER_MAPPING[name] = cls
return cls

return register_quantizer_fn