From 6af4a4c0dd6fb8a8cb3fec0babc65d92a943b357 Mon Sep 17 00:00:00 2001 From: Pulkit Agrawal Date: Fri, 8 Mar 2024 02:56:26 -0800 Subject: [PATCH 1/6] Add support for PyTorch Export Quantizer --- .../optimize/torch/_utils/python_utils.py | 47 +- .../torch/quantization/_annotation_config.py | 109 +++ .../torch/quantization/_coreml_quantizer.py | 624 ++++++++++++++ .../quantization/_coreml_quantizer_utils.py | 767 ++++++++++++++++++ .../quantization/test_coreml_quantizer.py | 219 +++++ 5 files changed, 1765 insertions(+), 1 deletion(-) create mode 100644 coremltools/optimize/torch/quantization/_annotation_config.py create mode 100644 coremltools/optimize/torch/quantization/_coreml_quantizer.py create mode 100644 coremltools/optimize/torch/quantization/_coreml_quantizer_utils.py create mode 100644 coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py diff --git a/coremltools/optimize/torch/_utils/python_utils.py b/coremltools/optimize/torch/_utils/python_utils.py index 33f9cf2fa..ba39c081f 100644 --- a/coremltools/optimize/torch/_utils/python_utils.py +++ b/coremltools/optimize/torch/_utils/python_utils.py @@ -1,12 +1,57 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +import logging as _logging +from collections import OrderedDict as _OrderedDict from typing import Any as _Any +_logger = _logging.getLogger(__name__) + def get_str(val: _Any): if isinstance(val, float): return f"{val:.5f}" return str(val) + + +class RegistryMixin: + REGISTRY = None + + @classmethod + def register(cls, name: str): + if cls.REGISTRY is None: + cls.REGISTRY = _OrderedDict() + + def inner_wrapper(wrapped_obj): + if name in cls.REGISTRY: + _logger.warning( + f"Name: {name} is already registered with object: {cls.REGISTRY[name].__name__} " + f"in registry: {cls.__name__}" + f"Over-writing the name with new class: {wrapped_obj.__name__}" + ) + cls.REGISTRY[name] = wrapped_obj + return wrapped_obj + + return inner_wrapper + + @classmethod + def _get_object(cls, name: str): + if name in cls.REGISTRY: + return cls.REGISTRY[name] + raise NotImplementedError( + f"No object is registered with name: {name} in registry {cls.__name__}." + ) + + +class ClassRegistryMixin(RegistryMixin): + @classmethod + def get_class(cls, name: str): + return cls._get_object(name) + + +class FunctionRegistryMixin(RegistryMixin): + @classmethod + def get_function(cls, name: str): + return cls._get_object(name) diff --git a/coremltools/optimize/torch/quantization/_annotation_config.py b/coremltools/optimize/torch/quantization/_annotation_config.py new file mode 100644 index 000000000..2a498b7f3 --- /dev/null +++ b/coremltools/optimize/torch/quantization/_annotation_config.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from typing import Optional as _Optional + +import torch as _torch +import torch.ao.quantization as _aoquant +from attr import define as _define +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationSpec as _TorchQuantizationSpec, +) + +from coremltools.optimize.torch.quantization.quantization_config import ( + ModuleLinearQuantizerConfig as _ModuleLinearQuantizerConfig, +) +from coremltools.optimize.torch.quantization.quantization_config import ObserverType as _ObserverType +from coremltools.optimize.torch.quantization.quantization_config import ( + QuantizationScheme as _QuantizationScheme, +) + + +@_define +class AnnotationConfig: + """ + Module/Operator level configuration class for :py:class:`CoreMLQuantizer`. + + For each module/operator, defines the dtype, quantization scheme and observer type + for input(s), output and weights (if any). + """ + + input_activation: _Optional[_TorchQuantizationSpec] = None + output_activation: _Optional[_TorchQuantizationSpec] = None + weight: _Optional[_TorchQuantizationSpec] = None + + @staticmethod + def _normalize_dtype(dtype: _torch.dtype) -> _torch.dtype: + """ + PyTorch export quantizer only supports uint8 and int8 data types, + so we map the quantized dtypes to the corresponding supported dtype. + """ + dtype_map = { + _torch.quint8: _torch.uint8, + _torch.qint8: _torch.int8, + } + return dtype_map.get(dtype, dtype) + + @classmethod + def from_quantization_config( + cls, + quantization_config: _Optional[_ModuleLinearQuantizerConfig], + ) -> _Optional["AnnotationConfig"]: + """ + Creates a :py:class:`AnnotationConfig` from ``ModuleLinearQuantizerConfig`` + """ + if ( + quantization_config is None + or quantization_config.weight_dtype == _torch.float32 + ): + return None + + # Activation QSpec + if quantization_config.activation_dtype == _torch.float32: + output_activation_qspec = None + else: + activation_qscheme = _QuantizationScheme.get_qscheme( + quantization_config.quantization_scheme, + is_per_channel=False, + ) + activation_dtype = cls._normalize_dtype( + quantization_config.activation_dtype + ) + output_activation_qspec = _TorchQuantizationSpec( + observer_or_fake_quant_ctr=_aoquant.FakeQuantize.with_args( + observer=_ObserverType.get_observer( + quantization_config.activation_observer, + is_per_channel=False, + ), + dtype=activation_dtype, + qscheme=activation_qscheme, + ), + dtype=activation_dtype, + qscheme=activation_qscheme, + ) + + # Weight QSpec + weight_qscheme = _QuantizationScheme.get_qscheme( + quantization_config.quantization_scheme, + is_per_channel=quantization_config.weight_per_channel, + ) + weight_dtype = cls._normalize_dtype(quantization_config.weight_dtype) + weight_qspec = _TorchQuantizationSpec( + observer_or_fake_quant_ctr=_aoquant.FakeQuantize.with_args( + observer=_ObserverType.get_observer( + quantization_config.weight_observer, + is_per_channel=quantization_config.weight_per_channel, + ), + dtype=weight_dtype, + qscheme=weight_qscheme, + ), + dtype=weight_dtype, + qscheme=weight_qscheme, + ) + return AnnotationConfig( + input_activation=output_activation_qspec, + output_activation=output_activation_qspec, + weight=weight_qspec, + ) diff --git a/coremltools/optimize/torch/quantization/_coreml_quantizer.py b/coremltools/optimize/torch/quantization/_coreml_quantizer.py new file mode 100644 index 000000000..c46bacf46 --- /dev/null +++ b/coremltools/optimize/torch/quantization/_coreml_quantizer.py @@ -0,0 +1,624 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import operator as _operator +from typing import Callable as _Callable +from typing import List as _List +from typing import Optional as _Optional + +import torch as _torch +from torch.ao.quantization.quantizer.quantizer import Quantizer as _TorchQuantizer +from torch.ao.quantization.quantizer.xnnpack_quantizer import _get_module_name_filter +from torch.fx import Node as _Node + +import coremltools.optimize.torch.quantization._coreml_quantizer_utils as _annotation_utils +from coremltools.optimize.torch._utils.python_utils import FunctionRegistryMixin as _FunctionRegistryMixin +from coremltools.optimize.torch.quantization._annotation_config import ( + AnnotationConfig as _AnnotationConfig, +) +from coremltools.optimize.torch.quantization.quantization_config import ( + LinearQuantizerConfig as _LinearQuantizerConfig, +) +from coremltools.optimize.torch.quantization.quantization_config import ( + ModuleLinearQuantizerConfig as _ModuleLinearQuantizerConfig, +) + + +class _AnnotationPatternRegistry(_FunctionRegistryMixin): + """ + A registry of quantization annotation rules. + """ + @classmethod + def get_annotators(cls): + return cls.REGISTRY + + +@_AnnotationPatternRegistry.register("conv_act") +def _annotate_conv_act( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> conv -> activation -> output + """ + return _annotation_utils.annotate_conv_bn_act_helper( + model, quantization_config, filter_fn, use_bn=False + ) + + +@_AnnotationPatternRegistry.register("conv_bn_act") +def _annotate_conv_bn_act( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> conv -> batch_norm -> activation -> output + """ + return _annotation_utils.annotate_conv_bn_act_helper( + model, quantization_config, filter_fn, use_bn=True + ) + + +@_AnnotationPatternRegistry.register("conv_bn") +def _annotate_conv_bn( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> conv -> batch_norm -> output + """ + annotated_partitions = [] + + conv_dims = [1, 2, 3] + for conv_dim in conv_dims: + pattern_gm = _annotation_utils.get_conv_bn_pattern( + conv_dim, act_fn=None, act_in_place=False + ) + annotated_partitions.extend( + _annotation_utils.annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + return annotated_partitions + + +@_AnnotationPatternRegistry.register("conv") +def _annotate_conv( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> conv -> output + """ + annotated_partitions = [] + for conv_dim in [1, 2, 3]: + pattern_gm = _annotation_utils.get_conv_pattern(conv_dim=conv_dim, act_fn=None) + annotated_partitions.extend( + _annotation_utils.annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + + return annotated_partitions + + +@_AnnotationPatternRegistry.register("linear_act") +def _annotate_linear_act( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> linear -> activation -> output + """ + return _annotation_utils.annotate_linear_bn_act_helper( + model, quantization_config, filter_fn, use_bn=False + ) + + +@_AnnotationPatternRegistry.register("linear_bn_act") +def _annotate_linear_bn_act( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> linear -> batch_norm -> activation -> output + """ + return _annotation_utils.annotate_linear_bn_act_helper( + model, quantization_config, filter_fn, use_bn=True + ) + + +@_AnnotationPatternRegistry.register("linear_bn") +def _annotate_linear_bn( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> linear -> batch_norm -> output + """ + pattern_gm = _annotation_utils.get_linear_bn_pattern( + act_fn=None, act_in_place=False + ) + return _annotation_utils.annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + + +@_AnnotationPatternRegistry.register("linear") +def _annotate_linear( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> linear -> output + """ + pattern_gm = _annotation_utils.get_linear_pattern(act_fn=None) + return _annotation_utils.annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + + +@_AnnotationPatternRegistry.register("add_act") +def _annotate_add_act( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input_1 --- + \ + --> add -> activation -> output + / + input_2 --- + """ + ops = [_operator.add, _torch.add, _operator.iadd] + return _annotation_utils.annotate_binary_op_helper( + model, ops, quantization_config, filter_fn + ) + + +@_AnnotationPatternRegistry.register("add") +def _annotate_add( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input_1 --- + \ + --> add -> output + / + input_2 --- + """ + annotated_partitions = [] + ops = [_operator.add, _torch.add, _operator.iadd] + for binary_op in ops: + pattern_gm = _annotation_utils.get_binary_op_act_pattern(binary_op, None) + annotated_partitions.extend( + _annotation_utils.annotate_binary_op_act_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + return annotated_partitions + + +@_AnnotationPatternRegistry.register("mul_act") +def _annotate_mul_act( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input_1 --- + \ + --> mul -> activation -> output + / + input_2 --- + """ + ops = [_operator.mul, _torch.mul, _operator.imul] + return _annotation_utils.annotate_binary_op_helper( + model, ops, quantization_config, filter_fn + ) + + +@_AnnotationPatternRegistry.register("mul") +def _annotate_mul( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input_1 --- + \ + --> mul -> output + / + input_2 --- + """ + annotated_partitions = [] + ops = [_operator.mul, _torch.mul, _operator.imul] + for binary_op in ops: + pattern_gm = _annotation_utils.get_binary_op_act_pattern(binary_op, None) + annotated_partitions.extend( + _annotation_utils.annotate_binary_op_act_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + return annotated_partitions + + +@_AnnotationPatternRegistry.register("matmul_act") +def _annotate_matmul_act( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input_1 --- + \ + --> matmul -> activation -> output + / + input_2 --- + """ + return _annotation_utils.annotate_binary_op_helper( + model, [_torch.matmul], quantization_config, filter_fn + ) + + +@_AnnotationPatternRegistry.register("matmul") +def _annotate_matmul( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input_1 --- + \ + --> matmul -> output + / + input_2 --- + """ + pattern_gm = _annotation_utils.get_binary_op_act_pattern(_torch.matmul, None) + return _annotation_utils.annotate_binary_op_act_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + + +@_AnnotationPatternRegistry.register("max_pool1d") +def _annotate_max_pool1d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> max_pool1d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [_torch.nn.MaxPool1d, _torch.nn.functional.max_pool1d, _torch.max_pool1d], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("max_pool2d") +def _annotate_max_pool2d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> max_pool2d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [_torch.nn.MaxPool2d, _torch.nn.functional.max_pool2d, _torch.max_pool2d], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("max_pool3d") +def _annotate_max_pool3d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> max_pool3d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [_torch.nn.MaxPool3d, _torch.nn.functional.max_pool3d, _torch.max_pool3d], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("adaptive_avg_pool1d") +def _annotate_adaptive_avg_pool1d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> adaptive_avg_pool1d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [ + _torch.nn.AdaptiveAvgPool1d, + _torch.nn.functional.adaptive_avg_pool1d, + _torch.adaptive_avg_pool1d, + ], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("adaptive_avg_pool2d") +def _annotate_adaptive_avg_pool2d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> adaptive_avg_pool2d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [_torch.nn.AdaptiveAvgPool2d, _torch.nn.functional.adaptive_avg_pool2d], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("adaptive_avg_pool3d") +def _annotate_adaptive_avg_pool3d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> adaptive_avg_pool3d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [_torch.nn.AdaptiveAvgPool3d, _torch.nn.functional.adaptive_avg_pool3d], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("avg_pool1d") +def _annotate_avg_pool1d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> avg_pool1d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [ + _torch.nn.AvgPool1d, + _torch.nn.functional.avg_pool1d, + _torch.avg_pool1d, + _torch.mean, + ], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("avg_pool2d") +def _annotate_avg_pool2d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> avg_pool2d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [ + _torch.nn.AvgPool2d, + _torch.nn.functional.avg_pool2d, + _torch._C._nn.avg_pool2d, + ], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("avg_pool3d") +def _annotate_avg_pool3d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> avg_pool3d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [ + _torch.nn.AvgPool3d, + _torch.nn.functional.avg_pool3d, + _torch._C._nn.avg_pool3d, + ], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("flatten") +def _annotate_flatten( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> flatten -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [ + _torch.nn.Flatten, + _torch.flatten, + ], + quantization_config, + filter_fn, + ) + + +class CoreMLQuantizer(_TorchQuantizer): + """ + Annotates all recognized patterns using ``config``. + + Extends py:class:`torch.ao.quantization.quantizer.quantizer.Quantizer` to + add support for quantization patterns supported by Core ML. + + Use it in conjunction with PyTorch 2.0 ``prepare_pt2e`` and ``prepare_qat_pt2e`` APIs + for post training weight and activation quantization using calibration data and + for quantization aware training (QAT). + + Example: + + .. code-block:: python + + import torch.nn as nn + from torch._export import capture_pre_autograd_graph + from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_qat_pt2e + + from coremltools.optimize.torch.quantization._coreml_quantizer import CoreMLQuantizer + + model = nn.Sequential( + OrderedDict( + { + "conv": nn.Conv2d(1, 20, (3, 3)), + "relu1": nn.ReLU(), + "conv2": nn.Conv2d(20, 20, (3, 3)), + "relu2": nn.ReLU(), + } + ) + ) + + loss_fn = define_loss() + + # initialize the annotator with quantization config + config = LinearQuantizerConfig.from_dict( + { + "global_config": { + "quantization_scheme": "symmetric", + } + } + ) + quantizer = CoreMLQuantizer(config) + + example_inputs = torch.randn(1, 1, 28, 28) + + # create export graph + exported_model = capture_pre_autograd_graph(model, (example_inputs,)) + + # prepare the model to insert FakeQuantize layers for QAT + prepared_model = prepare_qat_pt2e(exported_model, quantizer) + + # use prepared model in your PyTorch training loop + for inputs, labels in data: + output = prepared_model(inputs) + loss = loss_fn(output, labels) + loss.backward() + optimizer.step() + # turn observers/quantizers on/off depending on iteration number + + # convert operations to their quanitzed counterparts using parameters learnt via QAT + model = convert_pt2e(prepared_model) + """ + + def __init__(self, config: _Optional[_LinearQuantizerConfig]): + self._config = config + + def _annotate_all_patterns( + self, + model: _torch.fx.GraphModule, + quantization_config: _Optional[_ModuleLinearQuantizerConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, + ): + annotators = _AnnotationPatternRegistry.get_annotators() + for _, annotator in annotators.items(): + annotation_config = _AnnotationConfig.from_quantization_config( + quantization_config + ) + annotator(model, annotation_config, filter_fn) + + def annotate(self, model: _torch.fx.GraphModule) -> _torch.fx.GraphModule: + # First annotate all modules/operations which have name based configs + module_name_list = list(self._config.module_name_configs.keys()) + for module_name, config in self._config.module_name_configs.items(): + self._annotate_all_patterns( + model, config, _get_module_name_filter(module_name) + ) + + # Next annotate all modules/operations which have type based configs + tp_list = list(self._config.module_type_configs.keys()) + for module_type, config in self._config.module_type_configs.items(): + self._annotate_all_patterns( + model, config, _annotation_utils.get_object_type_filter(module_type) + ) + + # Annotate all other modules/operations + self._annotate_all_patterns( + model, + self._config.global_config, + _annotation_utils.get_not_object_type_or_name_filter( + tp_list, module_name_list + ), + ) + return model + + def validate(self, model: _torch.fx.GraphModule) -> None: + pass diff --git a/coremltools/optimize/torch/quantization/_coreml_quantizer_utils.py b/coremltools/optimize/torch/quantization/_coreml_quantizer_utils.py new file mode 100644 index 000000000..f03b9a6c2 --- /dev/null +++ b/coremltools/optimize/torch/quantization/_coreml_quantizer_utils.py @@ -0,0 +1,767 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import itertools as _itertools +from typing import Callable as _Callable +from typing import List as _List +from typing import Optional as _Optional +from typing import Tuple as _Tuple + +import torch as _torch +import torch.nn.functional as _F +from torch.ao.quantization.pt2e.utils import ( + get_aten_graph_module as _get_aten_graph_module, +) +from torch.ao.quantization.quantizer.quantizer import ( + FixedQParamsQuantizationSpec as _FixedQParamsQuantizationSpec, +) +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationAnnotation as _QuantizationAnnotation, +) +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationSpec as _TorchQuantizationSpec, +) +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationSpecBase as _TorchQuantizationSpecBase, +) +from torch.ao.quantization.quantizer.quantizer import ( + SharedQuantizationSpec as _SharedQuantizationSpec, +) +from torch.ao.quantization.quantizer.xnnpack_quantizer import _get_module_name_filter +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + _is_annotated, + _mark_nodes_as_annotated, +) +from torch.fx import Node as _Node +from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( + SubgraphMatcherWithNameNodeMap as _SubgraphMatcherWithNameNodeMap, +) +from torch.fx.passes.utils.source_matcher_utils import ( + get_source_partitions as _get_source_partitions, +) + +from coremltools.optimize.torch.quantization._annotation_config import ( + AnnotationConfig as _AnnotationConfig, +) + +# All activations recognized for conv-act/conv-bn-act patterns +_supported_activations = ( + _F.relu, + _F.relu6, + _F.leaky_relu, + _F.silu, + _F.elu, + _F.celu, + _F.selu, + _F.mish, + _F.hardtanh, + _F.hardswish, + _F.hardsigmoid, +) + + +# These activation functions don't have an inplace argument +_supported_activations_no_inplace = (_F.gelu, _F.sigmoid, _F.logsigmoid, _F.tanh) + + +# Map of dimension to convolution function +_conv_fn_map = {1: _F.conv1d, 2: _F.conv2d, 3: _F.conv3d} + + +def _adjust_activation_qspec( + node: _torch.fx.Node, qspec: _Optional[_TorchQuantizationSpecBase] +) -> _Optional[_TorchQuantizationSpecBase]: + """ + Adjust quantization spec for ops which can use fixed qparams + or ops for which we can use affine quantization mode during + symmetric quantization because their output is always positive. + """ + if qspec is None: + return qspec + + tanh_qspec = _FixedQParamsQuantizationSpec( + dtype=_torch.uint8, + scale=2.0 / 256.0, + zero_point=128, + quant_min=0, + quant_max=255, + qscheme=_torch.per_tensor_symmetric, + ) + + sigmoid_qspec = _FixedQParamsQuantizationSpec( + dtype=_torch.uint8, + scale=1.0 / 256.0, + zero_point=0, + quant_min=0, + quant_max=255, + qscheme=_torch.per_tensor_affine, + ) + + fixed_q_params_ops = { + _torch.ops.aten.tanh.default: tanh_qspec, + _torch.ops.aten.tanh_.default: tanh_qspec, + _torch.ops.aten.sigmoid.default: sigmoid_qspec, + _torch.ops.aten.sigmoid_.default: sigmoid_qspec, + _torch.ops.aten.hardsigmoid.default: sigmoid_qspec, + _torch.ops.aten.hardsigmoid_.default: sigmoid_qspec, + } + + always_affine_ops = ( + _torch.ops.aten.relu.default, + _torch.ops.aten.relu_.default, + _torch.ops.aten.relu6.default, + _torch.ops.aten.relu6_.default, + ) + + # ReLU6 activation maps to _torch.ops.aten.hardtanh.default with + # min_val = 0 and max_val = 6 + is_always_affine_op = node.target in always_affine_ops or ( + node.target + in [_torch.ops.aten.hardtanh.default, _torch.ops.aten.hardtanh_.default] + and node.args[1] == 0 # min_val, corresponding to ReLU6 + and node.args[2] == 6 # max_val, corresponding to ReLU6 + ) + + if node.target in fixed_q_params_ops: + return _TorchQuantizationSpec( + observer_or_fake_quant_ctr=qspec.observer_or_fake_quant_ctr, + dtype=qspec.dtype, + qscheme=fixed_q_params_ops[node.target].qscheme, + ) + # FIXME: Because of a bug in PyTorch in function _create_obs_or_fq_from_qspec + # in module torch/ao/quantization/fx/prepare.py which creates a + # FixedQParamsFakeQuantize partial, instead of an instance, we cannot + # actually create FixedQParamsQuantizationSpec + if is_always_affine_op: + return _TorchQuantizationSpec( + observer_or_fake_quant_ctr=qspec.observer_or_fake_quant_ctr, + dtype=qspec.dtype, + qscheme=_torch.per_tensor_affine, + ) + return qspec + + +def get_object_type_filter(tp: _Callable): + """ + Returns a filter which returns True if a node in the final exported graph + was created because of an object of type ``tp`` + """ + + def object_type_filter(n: _Node) -> bool: + # example: { + # 'add_10': ('add', ) + # } + nn_module_stack = n.meta.get("nn_module_stack", {}) + types = [t for _, t in nn_module_stack.values()] + source_fn_stack = n.meta.get("source_fn_stack", {}) + types.extend([t for _, t in source_fn_stack]) + return tp in types + + return object_type_filter + + +def get_not_object_type_or_name_filter( + tp_list: _List[_Callable], module_name_list: _List[str] +) -> _Callable[[_Node], bool]: + """ + Returns a filter which returns True if a node in the final exported graph + was not created using any modules with names in ``module_name_list`` or objects with + type in ``tp_list``. + """ + object_type_filters = [get_object_type_filter(tp) for tp in tp_list] + module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + + def not_object_type_or_name_filter(n: _Node) -> bool: + return not any(f(n) for f in object_type_filters + module_name_list_filters) + + return not_object_type_or_name_filter + + +def _get_weighted_mod_pattern( + mod_fn: _Callable, + example_inputs: _Tuple[_torch.Tensor, ...], + act_fn: _Optional[_Callable] = None, + act_in_place: bool = False, +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> weighted_mod -> activation -> output + + A weighted mod is a module which has a weight and bias, such as a + convolution module or a linear module. + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + + def pattern(input, weight, bias): + mod_out = mod_fn(input, weight, bias) + output = mod_out + node_dict = { + "input": input, + "mod": mod_out, + "weight": weight, + "bias": bias, + } + if act_fn is not None: + # Only add output if activation function is applied to model output + output = act_fn(output, inplace=True) if act_in_place else act_fn(output) + node_dict["output"] = output + return output, node_dict + + return _get_aten_graph_module(pattern, example_inputs, is_cuda=False) + + +def _get_weighted_mod_bn_pattern( + mod_fn: _Callable, + example_inputs: _Tuple[_torch.Tensor, ...], + act_fn: _Optional[_Callable] = None, + act_in_place: bool = False, +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> weighted_mod -> batch_norm -> activation -> output + + A weighted mod is a module which has a weight and bias, such as a + convolution module or a linear module. + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + + def pattern(input, weight, bias, bn_weight, bn_bias, bn_run_mean, bn_run_var): + mod_out = mod_fn(input, weight, bias) + output = _F.batch_norm( + mod_out, bn_run_mean, bn_run_var, bn_weight, bn_bias, training=True + ) + if act_fn is not None: + output = act_fn(output, inplace=True) if act_in_place else act_fn(output) + return output, { + "input": input, + "mod": mod_out, + "weight": weight, + "bias": bias, + "output": output, + } + + return _get_aten_graph_module(pattern, example_inputs, is_cuda=False) + + +def get_binary_op_act_pattern( + binary_op: _Callable, + act_fn: _Optional[_Callable] = None, + act_in_place: bool = False, +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input_1 --- + \ + --> binary_op -> activation -> output + / + input_2 --- + + A binary op is any operation which consumes two inputs to create one output, + such as addition or multiplication. + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + + def pattern(input_1, input_2): + binary_op_out = binary_op(input_1, input_2) + node_dict = { + "binary_op": binary_op_out, + } + output = binary_op_out + if act_fn is not None: + output = act_fn(output, inplace=True) if act_in_place else act_fn(output) + node_dict["output"] = output + return output, node_dict + + example_inputs = (_torch.randn(1), _torch.randn(1)) + return _get_aten_graph_module(pattern, example_inputs, is_cuda=False) + + +def get_conv_pattern( + conv_dim: int, act_fn: _Optional[_Callable] = None, act_in_place: bool = False +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> conv -> activation -> output + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + assert ( + conv_dim in _conv_fn_map + ), f"Dimension {conv_dim} is not supported for Convolution layers." + + example_inputs = ( + _torch.randn(1, 1, *[3] * conv_dim), # input + _torch.randn(1, 1, *[1] * conv_dim), # conv weight + _torch.randn(1), # conv bias + ) + return _get_weighted_mod_pattern( + _conv_fn_map[conv_dim], example_inputs, act_fn, act_in_place + ) + + +def get_conv_bn_pattern( + conv_dim: int, act_fn: _Optional[_Callable] = None, act_in_place: bool = False +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> conv -> batch_norm -> activation -> output + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + assert ( + conv_dim in _conv_fn_map + ), f"Dimension {conv_dim} is not supported for Convolution layers." + + example_inputs = ( + _torch.randn(1, 1, *[3] * conv_dim), # input + _torch.randn(1, 1, *[1] * conv_dim), # conv weight + _torch.randn(1), # conv bias + _torch.randn(1), # bn_weight + _torch.randn(1), # bn_bias + _torch.randn(1), # bn_run_mean + _torch.randn(1), # bn_run_var + ) + return _get_weighted_mod_bn_pattern( + _conv_fn_map[conv_dim], example_inputs, act_fn, act_in_place + ) + + +def get_linear_pattern( + act_fn: _Optional[_Callable] = None, act_in_place: bool = False +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> linear -> activation -> output + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + example_inputs = ( + _torch.randn(1, 1), # input + _torch.randn(3, 1), # linear weight + _torch.randn(3), # linear bias + ) + return _get_weighted_mod_pattern(_F.linear, example_inputs, act_fn, act_in_place) + + +def get_linear_bn_pattern( + act_fn: _Optional[_Callable] = None, act_in_place: bool = False +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> linear -> batch_norm -> activation -> output + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + example_inputs = ( + _torch.randn(2, 1), # input + _torch.randn(3, 1), # linear weight + _torch.randn(3), # linear bias + _torch.randn(3), # bn_weight + _torch.randn(3), # bn_bias + _torch.randn(3), # bn_run_mean + _torch.randn(3), # bn_run_var + ) + return _get_weighted_mod_bn_pattern(_F.linear, example_inputs, act_fn, act_in_place) + + +def annotate_weighted_mod_pattern( + model: _torch.fx.GraphModule, + pattern_gm: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]], +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates all nodes in ``model``, which match the pattern specified by ``pattern_gm`` + using ``quantization_config``. + + ``pattern_gm`` captures patterns of the following type: + + input -> weighted_mod -> batch_norm -> activation -> output + + batch_norm and activation may or may not be applied in the pattern. + + Only annotates those patterns in which all nodes return True when ``filter_fn`` is applied + to them. + """ + model.graph.eliminate_dead_code() + model.recompile() + + matcher = _SubgraphMatcherWithNameNodeMap(pattern_gm, ignore_literals=True) + matches = matcher.match(model.graph) + + annotated_partitions = [] + for match in matches: + name_node_map = match.name_node_map + input_node = name_node_map["input"] + mod_node = name_node_map["mod"] + weight_node = name_node_map["weight"] + bias_node = name_node_map["bias"] + if "output" in name_node_map: + # In this case, an activation is applied to the weighted module output + output_node = name_node_map["output"] + # If the output is same as mod_node, it means we have an inplace activation, + # so we need to correct the mod_node + if mod_node == output_node: + mod_node = mod_node.args[0] + else: + output_node = None + + # Validate mod args + if mod_node.args[0] is not input_node: + raise ValueError( + f"Weighted module arg did not contain input node ", input_node + ) + if mod_node.args[1] is not weight_node: + raise ValueError( + f"Weighted module arg did not contain weight node ", weight_node + ) + if len(mod_node.args) > 2 and mod_node.args[2] is not bias_node: + raise ValueError( + f"Weighted module arg did not contain bias node ", bias_node + ) + + # Skip if the partition is already annotated or is filtered out by the user + partition = [mod_node, weight_node] + if bias_node is not None: + partition.append(bias_node) + if _is_annotated(partition): + continue + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + # Annotate conv inputs and pattern output + input_qspec_map = dict() + if not _is_annotated([input_node]): + input_qspec_map[input_node] = ( + quantization_config.input_activation if quantization_config else None + ) + else: + input_qspec_map[input_node] = input_node.meta[ + "quantization_annotation" + ].output_qspec + + input_qspec_map[weight_node] = ( + quantization_config.weight if quantization_config else None + ) + output_qspec = ( + quantization_config.output_activation if quantization_config else None + ) + if bias_node is not None: + input_qspec_map[bias_node] = None + + if output_node is None: + mod_node.meta["quantization_annotation"] = _QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_qspec, + _annotated=True, + ) + else: + mod_node.meta["quantization_annotation"] = _QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + if not _is_annotated([output_node]): + output_qspec = _adjust_activation_qspec( + node=output_node, qspec=output_qspec + ) + output_node.meta["quantization_annotation"] = _QuantizationAnnotation( + output_qspec=output_qspec, + _annotated=True, + ) + + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +def annotate_binary_op_act_pattern( + model: _torch.fx.GraphModule, + pattern_gm: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates all nodes in ``model``, which match the pattern specified by ``pattern_gm`` + using ``quantization_config``. + + ``pattern_gm`` captures patterns of the following type: + + input_1 --- + \ + --> binary_op -> activation -> output + / + input_2 --- + + activation may or may not be applied in the pattern. + + Only annotates those patterns in which all nodes return True when ``filter_fn`` is applied + to them. + """ + model.graph.eliminate_dead_code() + model.recompile() + + matcher = _SubgraphMatcherWithNameNodeMap(pattern_gm, ignore_literals=True) + matches = matcher.match(model.graph) + + annotated_partitions = [] + for match in matches: + name_node_map = match.name_node_map + binary_op_node: _Node = name_node_map["binary_op"] + if "output" in name_node_map: + output_node = name_node_map["output"] + # In this case, an activation is applied to the weighted module output + output_node = name_node_map["output"] + # If the output is same as binary_op_node, it means we have an inplace activation, + # so we need to correct the binary_op_node + if binary_op_node == output_node: + binary_op_node = binary_op_node.args[0] + partition = [output_node, binary_op_node] + else: + output_node = None + partition = [binary_op_node] + + if output_node is not None and len(binary_op_node.users) > 1: + raise ValueError("Binary op with activation has more than one users.") + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = ( + quantization_config.input_activation if quantization_config else None + ) + output_act_qspec = ( + quantization_config.output_activation if quantization_config else None + ) + + input_qspec_map = {} + input_act0 = binary_op_node.args[0] + if isinstance(input_act0, _Node): + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = binary_op_node.args[1] + if isinstance(input_act1, _Node): + input_qspec_map[input_act1] = input_act_qspec + + if output_node is None: + binary_op_node.meta["quantization_annotation"] = _QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + else: + binary_op_node.meta["quantization_annotation"] = _QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + output_act_qspec = _adjust_activation_qspec( + node=output_node, qspec=output_act_qspec + ) + output_node.meta["quantization_annotation"] = _QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +def annotate_unary_shared_observer_ops( + model: _torch.fx.GraphModule, + ops: _List[_Callable], + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates all nodes in ``model``, which correspond to unary ops specified in ``ops``. + + input --> op --> output + + input and output nodes share the same quantization parameters. + """ + partitions = _get_source_partitions(model.graph, ops, filter_fn) + annotated_partitions = [] + for _, op_partitions in partitions.items(): + for partition in op_partitions: + output_node = partition.output_nodes[0] + op_node = partition.nodes[0] + if _is_annotated([output_node, op_node]): + continue + + input_node = op_node.args[0] + + input_act_qspec = ( + quantization_config.input_activation if quantization_config else None + ) + output_act_qspec = ( + quantization_config.output_activation if quantization_config else None + ) + + if ( + "quantization_annotation" not in input_node.meta + or not input_node.meta["quantization_annotation"]._annotated + or input_node.meta["quantization_annotation"].output_qspec is None + or input_act_qspec is None + or output_act_qspec is None + ): + continue + + # input and output of op will share quantization parameter with input of op + act_qspec = _SharedQuantizationSpec(input_node) + op_node.meta["quantization_annotation"] = _QuantizationAnnotation( + input_qspec_map={ + input_node: act_qspec, + }, + _annotated=True, + ) + output_node.meta["quantization_annotation"] = _QuantizationAnnotation( + output_qspec=act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition.nodes) + return annotated_partitions + + +def annotate_conv_bn_act_helper( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, + use_bn: bool = False, +) -> _Optional[_List[_List[_Node]]]: + """ + A helper function for annotating all patterns involving convolution operations, i.e., + + input -> conv -> batch_norm -> activation -> output + + conv can be either 1D, 2D or 3D + batch_norm and activation may or may not be applied. + """ + annotated_partitions = [] + + pattern_map = { + True: get_conv_bn_pattern, + False: get_conv_pattern, + } + + conv_dims = [1, 2, 3] + combinations = _itertools.product(conv_dims, _supported_activations, [True, False]) + for conv_dim, act_fn, act_in_place in combinations: + pattern_gm = pattern_map[use_bn](conv_dim, act_fn, act_in_place) + annotated_partitions.extend( + annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + + combinations = _itertools.product(conv_dims, _supported_activations_no_inplace) + for conv_dim, act_fn in combinations: + pattern_gm = pattern_map[use_bn](conv_dim, act_fn, act_in_place=False) + annotated_partitions.extend( + annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + + return annotated_partitions + + +def annotate_linear_bn_act_helper( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, + use_bn: bool = False, +) -> _Optional[_List[_List[_Node]]]: + """ + A helper function for annotating all patterns involving linear operations, i.e., + + input -> linear -> batch_norm -> activation -> output + + batch_norm and activation may or may not be applied. + """ + annotated_partitions = [] + + pattern_map = { + True: get_linear_bn_pattern, + False: get_linear_pattern, + } + + combinations = _itertools.product(_supported_activations, [True, False]) + for act_fn, act_in_place in combinations: + pattern_gm = pattern_map[use_bn](act_fn, act_in_place) + annotated_partitions.extend( + annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + + for act_fn in _supported_activations_no_inplace: + pattern_gm = pattern_map[use_bn](act_fn, act_in_place=False) + annotated_partitions.extend( + annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + + return annotated_partitions + + +def annotate_binary_op_helper( + model: _torch.fx.GraphModule, + binary_ops: _List[_Callable], + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + A helper function for annotating all patterns involving binary operations, i.e., + using ``quantization_config``. + + input_1 --- + \ + --> binary_op -> activation -> output + / + input_2 --- + + activation may or may not be applied in the pattern. + """ + annotated_partitions = [] + + combinations = _itertools.product(binary_ops, _supported_activations, [True, False]) + for binary_op, act_fn, act_in_place in combinations: + pattern_gm = get_binary_op_act_pattern(binary_op, act_fn, act_in_place) + annotated_partitions.extend( + annotate_binary_op_act_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + + combinations = _itertools.product(binary_ops, _supported_activations_no_inplace) + for binary_op, act_fn in combinations: + pattern_gm = get_binary_op_act_pattern(binary_op, act_fn, act_in_place=False) + annotated_partitions.extend( + annotate_binary_op_act_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + + return annotated_partitions diff --git a/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py b/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py new file mode 100644 index 000000000..85b42e54a --- /dev/null +++ b/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py @@ -0,0 +1,219 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from collections import OrderedDict +from typing import Dict, Optional + +import pytest +import torch +import torch.nn as nn +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torch.fx import Node + +from coremltools.optimize.torch.quantization._coreml_quantizer import CoreMLQuantizer +from coremltools.optimize.torch.quantization.quantization_config import ( + LinearQuantizerConfig, + QuantizationScheme, +) + +activations = { + nn.ReLU: { + True: torch.ops.aten.relu_.default, + False: torch.ops.aten.relu.default, + }, + nn.ReLU6: { + True: torch.ops.aten.hardtanh_.default, + False: torch.ops.aten.hardtanh.default, + }, + nn.LeakyReLU: { + True: torch.ops.aten.leaky_relu_.default, + False: torch.ops.aten.leaky_relu.default, + }, + nn.SiLU: { + True: torch.ops.aten.silu_.default, + False: torch.ops.aten.silu.default, + }, + nn.ELU: { + True: torch.ops.aten.elu_.default, + False: torch.ops.aten.elu.default, + }, + nn.CELU: { + True: torch.ops.aten.celu_.default, + False: torch.ops.aten.celu.default, + }, + nn.SELU: { + True: torch.ops.aten.selu_.default, + False: torch.ops.aten.selu.default, + }, + nn.Mish: { + True: torch.ops.aten.mish_.default, + False: torch.ops.aten.mish.default, + }, + nn.Hardtanh: { + True: torch.ops.aten.hardtanh_.default, + False: torch.ops.aten.hardtanh.default, + }, + nn.Hardswish: { + True: torch.ops.aten.hardswish_.default, + False: torch.ops.aten.hardswish.default, + }, + nn.Hardsigmoid: { + True: torch.ops.aten.hardsigmoid_.default, + False: torch.ops.aten.hardsigmoid.default, + }, + nn.GELU: { + False: torch.ops.aten.gelu.default, + }, + nn.Sigmoid: { + False: torch.ops.aten.sigmoid.default, + }, + nn.LogSigmoid: { + False: torch.ops.aten.log_sigmoid.default, + }, + nn.Tanh: { + False: torch.ops.aten.tanh.default, + }, +} + + +@pytest.fixture(scope="module") +def model_for_quant() -> torch.nn.Module: + model_dict = OrderedDict() + activation_dict = {} + idx = 0 + start_idx = idx + for act_fn in activations: + for inplace in activations[act_fn].keys(): + inp_channels = 1 if idx == start_idx else 20 + model_dict[f"conv_{idx}"] = torch.nn.Conv2d( + inp_channels, 20, (3, 3), padding=(1, 1) + ) + model_dict[f"act_{idx}"] = act_fn(inplace=inplace) if inplace else act_fn() + activation_dict[idx] = activations[act_fn][inplace] + idx += 1 + model_dict[f"conv_{idx}"] = torch.nn.Conv2d(20, 20, (3, 3), padding=(1, 1)) + model_dict[f"bn_{idx}"] = nn.BatchNorm2d(20) + model_dict[f"act_{idx}"] = act_fn(inplace=inplace) if inplace else act_fn() + activation_dict[idx] = activations[act_fn][inplace] + idx += 1 + + model_dict["flatten"] = torch.nn.Flatten(start_dim=2) + start_idx = idx + for act_fn in activations: + for inplace in activations[act_fn].keys(): + inp_channels = 784 if idx == start_idx else 20 + model_dict[f"lin_{idx}"] = nn.Linear(inp_channels, 20) + model_dict[f"act_{idx}"] = act_fn(inplace=inplace) if inplace else act_fn() + activation_dict[idx] = activations[act_fn][inplace] + idx += 1 + model_dict[f"lin_{idx}"] = nn.Linear(20, 20) + model_dict[f"bn_{idx}"] = nn.BatchNorm1d(20) + model_dict[f"act_{idx}"] = act_fn(inplace=inplace) if inplace else act_fn() + activation_dict[idx] = activations[act_fn][inplace] + idx += 1 + return nn.Sequential(model_dict) + + +def get_node_map(model: torch.fx.GraphModule) -> Dict[str, Node]: + """ + Return a dictionary of node name to node + """ + node_map = {} + for node in model.graph.nodes: + node_map[node.name] = node + return node_map + + +@pytest.fixture(scope="module") +def config(request) -> LinearQuantizerConfig: + quantization_scheme, weight_per_channel, activation_dtype = request.param + return LinearQuantizerConfig.from_dict( + { + "global_config": { + "quantization_scheme": quantization_scheme, + "milestones": [0, 0, 10, 10], + "activation_dtype": activation_dtype, + "weight_dtype": torch.qint8, + "weight_per_channel": weight_per_channel, + } + } + ) + + +def quantize_model( + model: nn.Module, + data: torch.Tensor, + quantization_config: Optional[LinearQuantizerConfig] = None, + is_qat: bool = True, +): + quantizer = CoreMLQuantizer(quantization_config) + exported_model = capture_pre_autograd_graph(model, (data,)) + if is_qat: + prepared_model = prepare_qat_pt2e(exported_model, quantizer) + else: + prepared_model = prepare_pt2e(exported_model, quantizer) + prepared_model(data) + converted_model = convert_pt2e(prepared_model, use_reference_representation=False) + return converted_model + + +@pytest.mark.parametrize( + "config", + [ + (QuantizationScheme.symmetric, True, torch.quint8), + (QuantizationScheme.symmetric, True, torch.float32), + ], + indirect=True, +) +@pytest.mark.parametrize("is_qat", [True, False]) +def test_weight_module_act_fusion(model_for_quant, is_qat, config): + model = model_for_quant + data = torch.randn(2, 1, 28, 28) + converted_model = quantize_model(model, data, config, is_qat=is_qat) + + node_map = get_node_map(converted_model) + mod_nodes = [torch.ops.aten.conv2d.default, torch.ops.aten.linear.default] + activation_dtype = config.global_config.activation_dtype + + for node_name, node in node_map.items(): + if node.target in mod_nodes: + if activation_dtype == torch.float32: + assert ( + node.args[0].target + != torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + else: + assert ( + node.args[0].target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + + assert ( + node.args[1].target + == torch.ops.quantized_decomposed.dequantize_per_channel.default + ) + assert len(node.users) == 1 + act_node = list(node.users.keys())[0] + if act_node.target == torch.ops.aten._native_batch_norm_legit.default: + act_node = act_node.next.next + assert len(act_node.users) == 1 + post_act_node = list(act_node.users.keys())[0] + if activation_dtype == torch.float32: + assert ( + post_act_node.target + != torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + else: + assert ( + post_act_node.target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + # necessary to clear cache, otherwise tests fail with cache_size_limit reached + torch._dynamo.reset() From 951f00bd7918573598c895f9d9ba784f3f34447a Mon Sep 17 00:00:00 2001 From: Pulkit Agrawal Date: Fri, 8 Mar 2024 15:02:46 -0800 Subject: [PATCH 2/6] skip tests if export APIs don't exist --- .../quantization/test_coreml_quantizer.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py b/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py index 85b42e54a..df696eeab 100644 --- a/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py +++ b/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py @@ -9,19 +9,24 @@ import pytest import torch import torch.nn as nn -from torch._export import capture_pre_autograd_graph -from torch.ao.quantization.quantize_pt2e import ( - convert_pt2e, - prepare_pt2e, - prepare_qat_pt2e, -) + from torch.fx import Node + from coremltools.optimize.torch.quantization._coreml_quantizer import CoreMLQuantizer from coremltools.optimize.torch.quantization.quantization_config import ( LinearQuantizerConfig, QuantizationScheme, ) +from coremltools._deps import _HAS_TORCH_EXPORT_API +if _HAS_TORCH_EXPORT_API: + from torch._export import capture_pre_autograd_graph + from torch.ao.quantization.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, + ) + activations = { nn.ReLU: { @@ -173,6 +178,7 @@ def quantize_model( indirect=True, ) @pytest.mark.parametrize("is_qat", [True, False]) +@pytest.mark.skipif(not _HAS_TORCH_EXPORT_API, reason="This test requires PyTorch Export APIs.") def test_weight_module_act_fusion(model_for_quant, is_qat, config): model = model_for_quant data = torch.randn(2, 1, 28, 28) From 736ea8e5530fcf6a3ffb803613224ee3ce66c710 Mon Sep 17 00:00:00 2001 From: Pulkit Agrawal Date: Fri, 8 Mar 2024 15:04:16 -0800 Subject: [PATCH 3/6] skip tests if export APIs don't exist --- .../test/optimize/torch/quantization/test_coreml_quantizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py b/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py index df696eeab..5af4d3388 100644 --- a/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py +++ b/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py @@ -12,7 +12,6 @@ from torch.fx import Node - from coremltools.optimize.torch.quantization._coreml_quantizer import CoreMLQuantizer from coremltools.optimize.torch.quantization.quantization_config import ( LinearQuantizerConfig, From 18d42216b96066703874a622c14b83565e13e12b Mon Sep 17 00:00:00 2001 From: Pulkit Agrawal Date: Fri, 8 Mar 2024 16:59:51 -0800 Subject: [PATCH 4/6] skip test if version < 2.2.0 --- .../optimize/torch/quantization/test_coreml_quantizer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py b/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py index 5af4d3388..3f3ebd969 100644 --- a/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py +++ b/coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py @@ -12,7 +12,6 @@ from torch.fx import Node -from coremltools.optimize.torch.quantization._coreml_quantizer import CoreMLQuantizer from coremltools.optimize.torch.quantization.quantization_config import ( LinearQuantizerConfig, QuantizationScheme, @@ -26,6 +25,11 @@ prepare_qat_pt2e, ) +_TORCH_VERSION = torch.__version__ +_EXPECTED_TORCH_VERSION = '2.2.0' +if _TORCH_VERSION >= _EXPECTED_TORCH_VERSION: + from coremltools.optimize.torch.quantization._coreml_quantizer import CoreMLQuantizer + activations = { nn.ReLU: { @@ -177,7 +181,8 @@ def quantize_model( indirect=True, ) @pytest.mark.parametrize("is_qat", [True, False]) -@pytest.mark.skipif(not _HAS_TORCH_EXPORT_API, reason="This test requires PyTorch Export APIs.") +@pytest.mark.skipif(not _HAS_TORCH_EXPORT_API or _TORCH_VERSION < _EXPECTED_TORCH_VERSION, + reason="This test requires PyTorch Export APIs and PyTorch >= 2.2.0.") def test_weight_module_act_fusion(model_for_quant, is_qat, config): model = model_for_quant data = torch.randn(2, 1, 28, 28) From 3528f2532915c1eee056c2c8f8584174497cfed1 Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Sat, 9 Mar 2024 00:44:08 -0800 Subject: [PATCH 5/6] fix linter issue --- .../optimize/torch/quantization/_coreml_quantizer_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/coremltools/optimize/torch/quantization/_coreml_quantizer_utils.py b/coremltools/optimize/torch/quantization/_coreml_quantizer_utils.py index f03b9a6c2..e1501c266 100644 --- a/coremltools/optimize/torch/quantization/_coreml_quantizer_utils.py +++ b/coremltools/optimize/torch/quantization/_coreml_quantizer_utils.py @@ -428,15 +428,15 @@ def annotate_weighted_mod_pattern( # Validate mod args if mod_node.args[0] is not input_node: raise ValueError( - f"Weighted module arg did not contain input node ", input_node + f"Weighted module arg did not contain input node {input_node}" ) if mod_node.args[1] is not weight_node: raise ValueError( - f"Weighted module arg did not contain weight node ", weight_node + f"Weighted module arg did not contain weight node {weight_node}" ) if len(mod_node.args) > 2 and mod_node.args[2] is not bias_node: raise ValueError( - f"Weighted module arg did not contain bias node ", bias_node + f"Weighted module arg did not contain bias node {bias_node}" ) # Skip if the partition is already annotated or is filtered out by the user From 431d272eaaf46d800ca82fbf944c79b4b6772018 Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Sat, 9 Mar 2024 12:49:45 -0800 Subject: [PATCH 6/6] add quantization conversion test --- .../test/test_executorch_quantization.py | 116 ++++++++++++++++-- 1 file changed, 105 insertions(+), 11 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_executorch_quantization.py b/coremltools/converters/mil/frontend/torch/test/test_executorch_quantization.py index 7fe257c3f..f171ee0b7 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_executorch_quantization.py +++ b/coremltools/converters/mil/frontend/torch/test/test_executorch_quantization.py @@ -3,27 +3,84 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +import itertools import pytest +from typing import Tuple + +from coremltools._deps import _HAS_EXECUTORCH + +if not _HAS_EXECUTORCH: + pytest.skip(allow_module_level=True, reason="executorch is required") import torch from torch._export import capture_pre_autograd_graph -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e, prepare_qat_pt2e from torch.ao.quantization.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) -import coremltools as ct -from coremltools._deps import _HAS_EXECUTORCH +_TORCH_VERSION = torch.__version__ +_EXPECTED_TORCH_VERSION = "2.2.0" +if _TORCH_VERSION < _EXPECTED_TORCH_VERSION: + pytest.skip(allow_module_level=True, reason=f"PyTorch {_EXPECTED_TORCH_VERSION} or higher is required") -if not _HAS_EXECUTORCH: - pytest.skip(allow_module_level=True, reason="executorch is required") +import coremltools as ct +from coremltools.optimize.torch.quantization.quantization_config import ( + LinearQuantizerConfig, + QuantizationScheme, +) +from coremltools.optimize.torch.quantization._coreml_quantizer import CoreMLQuantizer from .testing_utils import TorchBaseTest, TorchFrontend class TestExecutorchQuantization(TorchBaseTest): - def test_conv_relu(self): + @staticmethod + def make_torch_quantized_graph( + model, + example_inputs: Tuple[torch.Tensor], + quantizer_name: str, + quantization_type: str, + ) -> torch.fx.GraphModule: + assert quantizer_name in {"XNNPack", "CoreML"} + assert quantization_type in {"PTQ", "QAT"} + + pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_inputs) + + if quantizer_name == "XNNPack": + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + elif quantizer_name == "CoreML": + quantization_config = LinearQuantizerConfig.from_dict( + { + "global_config": { + "quantization_scheme": QuantizationScheme.symmetric, + "milestones": [0, 0, 10, 10], + "activation_dtype": torch.quint8, + "weight_dtype": torch.qint8, + "weight_per_channel": True, + } + } + ) + quantizer = CoreMLQuantizer(quantization_config) + + if quantization_type == "PTQ": + prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer) + elif quantization_type == "QAT": + prepared_graph = prepare_qat_pt2e(pre_autograd_aten_dialect, quantizer) + + prepared_graph(*example_inputs) + converted_graph = convert_pt2e(prepared_graph) + return converted_graph + + @pytest.mark.parametrize( + "quantizer_name, quantization_type", + itertools.product( + ("XNNPack", "CoreML"), + ("PTQ", "QAT") + ) + ) + def test_conv_relu(self, quantizer_name, quantization_type): SHAPE = (1, 3, 256, 256) class Model(torch.nn.Module): @@ -40,12 +97,49 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: model = Model() - example_args = (torch.randn(SHAPE),) - pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_args) + example_inputs = (torch.randn(SHAPE),) + converted_graph = self.make_torch_quantized_graph( + model, + example_inputs, + quantizer_name, + quantization_type, + ) - quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) - prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer) - converted_graph = convert_pt2e(prepared_graph) + self.run_compare_torch( + SHAPE, + converted_graph, + frontend=TorchFrontend.EXIR, + backend=("mlprogram", "fp16"), + minimum_deployment_target=ct.target.iOS17, + ) + + @pytest.mark.parametrize( + "quantizer_name, quantization_type", + itertools.product( + ("XNNPack", "CoreML"), + ("PTQ", "QAT") + ) + ) + def test_linear(self, quantizer_name, quantization_type): + SHAPE = (1, 5) + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + model = Model() + + example_inputs = (torch.randn(SHAPE),) + converted_graph = self.make_torch_quantized_graph( + model, + example_inputs, + quantizer_name, + quantization_type, + ) self.run_compare_torch( SHAPE,