-
Notifications
You must be signed in to change notification settings - Fork 750
Add CoreMLQuantizer in coremltools.optimize.torch to support PyTorch Export based quantization
#2162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add CoreMLQuantizer in coremltools.optimize.torch to support PyTorch Export based quantization
#2162
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
6af4a4c
Add support for PyTorch Export Quantizer
951f00b
skip tests if export APIs don't exist
pulkital 736ea8e
skip tests if export APIs don't exist
pulkital 18d4221
skip test if version < 2.2.0
pulkital 3528f25
fix linter issue
431d272
add quantization conversion test
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,57 @@ | ||
| # Copyright (c) 2023, Apple Inc. All rights reserved. | ||
| # Copyright (c) 2024, Apple Inc. All rights reserved. | ||
| # | ||
| # Use of this source code is governed by a BSD-3-clause license that can be | ||
| # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause | ||
|
|
||
| import logging as _logging | ||
| from collections import OrderedDict as _OrderedDict | ||
| from typing import Any as _Any | ||
|
|
||
| _logger = _logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def get_str(val: _Any): | ||
| if isinstance(val, float): | ||
| return f"{val:.5f}" | ||
| return str(val) | ||
|
|
||
|
|
||
| class RegistryMixin: | ||
| REGISTRY = None | ||
|
|
||
| @classmethod | ||
| def register(cls, name: str): | ||
| if cls.REGISTRY is None: | ||
| cls.REGISTRY = _OrderedDict() | ||
|
|
||
| def inner_wrapper(wrapped_obj): | ||
| if name in cls.REGISTRY: | ||
| _logger.warning( | ||
| f"Name: {name} is already registered with object: {cls.REGISTRY[name].__name__} " | ||
| f"in registry: {cls.__name__}" | ||
| f"Over-writing the name with new class: {wrapped_obj.__name__}" | ||
| ) | ||
| cls.REGISTRY[name] = wrapped_obj | ||
| return wrapped_obj | ||
|
|
||
| return inner_wrapper | ||
|
|
||
| @classmethod | ||
| def _get_object(cls, name: str): | ||
| if name in cls.REGISTRY: | ||
| return cls.REGISTRY[name] | ||
| raise NotImplementedError( | ||
| f"No object is registered with name: {name} in registry {cls.__name__}." | ||
| ) | ||
|
|
||
|
|
||
| class ClassRegistryMixin(RegistryMixin): | ||
| @classmethod | ||
| def get_class(cls, name: str): | ||
| return cls._get_object(name) | ||
|
|
||
|
|
||
| class FunctionRegistryMixin(RegistryMixin): | ||
| @classmethod | ||
| def get_function(cls, name: str): | ||
| return cls._get_object(name) |
109 changes: 109 additions & 0 deletions
109
coremltools/optimize/torch/quantization/_annotation_config.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| # Copyright (c) 2024, Apple Inc. All rights reserved. | ||
| # | ||
| # Use of this source code is governed by a BSD-3-clause license that can be | ||
| # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause | ||
|
|
||
| from typing import Optional as _Optional | ||
|
|
||
| import torch as _torch | ||
| import torch.ao.quantization as _aoquant | ||
| from attr import define as _define | ||
YifanShenSZ marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| from torch.ao.quantization.quantizer.quantizer import ( | ||
| QuantizationSpec as _TorchQuantizationSpec, | ||
| ) | ||
|
|
||
| from coremltools.optimize.torch.quantization.quantization_config import ( | ||
| ModuleLinearQuantizerConfig as _ModuleLinearQuantizerConfig, | ||
| ) | ||
| from coremltools.optimize.torch.quantization.quantization_config import ObserverType as _ObserverType | ||
| from coremltools.optimize.torch.quantization.quantization_config import ( | ||
| QuantizationScheme as _QuantizationScheme, | ||
| ) | ||
|
|
||
|
|
||
| @_define | ||
| class AnnotationConfig: | ||
| """ | ||
| Module/Operator level configuration class for :py:class:`CoreMLQuantizer`. | ||
|
|
||
| For each module/operator, defines the dtype, quantization scheme and observer type | ||
| for input(s), output and weights (if any). | ||
| """ | ||
|
|
||
| input_activation: _Optional[_TorchQuantizationSpec] = None | ||
| output_activation: _Optional[_TorchQuantizationSpec] = None | ||
| weight: _Optional[_TorchQuantizationSpec] = None | ||
|
|
||
| @staticmethod | ||
| def _normalize_dtype(dtype: _torch.dtype) -> _torch.dtype: | ||
| """ | ||
| PyTorch export quantizer only supports uint8 and int8 data types, | ||
| so we map the quantized dtypes to the corresponding supported dtype. | ||
| """ | ||
| dtype_map = { | ||
| _torch.quint8: _torch.uint8, | ||
| _torch.qint8: _torch.int8, | ||
| } | ||
| return dtype_map.get(dtype, dtype) | ||
|
|
||
| @classmethod | ||
| def from_quantization_config( | ||
| cls, | ||
| quantization_config: _Optional[_ModuleLinearQuantizerConfig], | ||
| ) -> _Optional["AnnotationConfig"]: | ||
| """ | ||
| Creates a :py:class:`AnnotationConfig` from ``ModuleLinearQuantizerConfig`` | ||
| """ | ||
| if ( | ||
| quantization_config is None | ||
| or quantization_config.weight_dtype == _torch.float32 | ||
| ): | ||
| return None | ||
|
|
||
| # Activation QSpec | ||
| if quantization_config.activation_dtype == _torch.float32: | ||
| output_activation_qspec = None | ||
| else: | ||
| activation_qscheme = _QuantizationScheme.get_qscheme( | ||
| quantization_config.quantization_scheme, | ||
| is_per_channel=False, | ||
| ) | ||
| activation_dtype = cls._normalize_dtype( | ||
| quantization_config.activation_dtype | ||
| ) | ||
| output_activation_qspec = _TorchQuantizationSpec( | ||
| observer_or_fake_quant_ctr=_aoquant.FakeQuantize.with_args( | ||
| observer=_ObserverType.get_observer( | ||
| quantization_config.activation_observer, | ||
| is_per_channel=False, | ||
| ), | ||
| dtype=activation_dtype, | ||
| qscheme=activation_qscheme, | ||
| ), | ||
| dtype=activation_dtype, | ||
| qscheme=activation_qscheme, | ||
| ) | ||
|
|
||
| # Weight QSpec | ||
| weight_qscheme = _QuantizationScheme.get_qscheme( | ||
| quantization_config.quantization_scheme, | ||
| is_per_channel=quantization_config.weight_per_channel, | ||
| ) | ||
| weight_dtype = cls._normalize_dtype(quantization_config.weight_dtype) | ||
| weight_qspec = _TorchQuantizationSpec( | ||
| observer_or_fake_quant_ctr=_aoquant.FakeQuantize.with_args( | ||
| observer=_ObserverType.get_observer( | ||
| quantization_config.weight_observer, | ||
| is_per_channel=quantization_config.weight_per_channel, | ||
| ), | ||
| dtype=weight_dtype, | ||
| qscheme=weight_qscheme, | ||
| ), | ||
| dtype=weight_dtype, | ||
| qscheme=weight_qscheme, | ||
| ) | ||
| return AnnotationConfig( | ||
| input_activation=output_activation_qspec, | ||
| output_activation=output_activation_qspec, | ||
| weight=weight_qspec, | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.