Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,84 @@
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import itertools
import pytest
from typing import Tuple

from coremltools._deps import _HAS_EXECUTORCH

if not _HAS_EXECUTORCH:
pytest.skip(allow_module_level=True, reason="executorch is required")

import torch
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e, prepare_qat_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)

import coremltools as ct
from coremltools._deps import _HAS_EXECUTORCH
_TORCH_VERSION = torch.__version__
_EXPECTED_TORCH_VERSION = "2.2.0"
if _TORCH_VERSION < _EXPECTED_TORCH_VERSION:
pytest.skip(allow_module_level=True, reason=f"PyTorch {_EXPECTED_TORCH_VERSION} or higher is required")

if not _HAS_EXECUTORCH:
pytest.skip(allow_module_level=True, reason="executorch is required")
import coremltools as ct
from coremltools.optimize.torch.quantization.quantization_config import (
LinearQuantizerConfig,
QuantizationScheme,
)
from coremltools.optimize.torch.quantization._coreml_quantizer import CoreMLQuantizer

from .testing_utils import TorchBaseTest, TorchFrontend


class TestExecutorchQuantization(TorchBaseTest):
def test_conv_relu(self):
@staticmethod
def make_torch_quantized_graph(
model,
example_inputs: Tuple[torch.Tensor],
quantizer_name: str,
quantization_type: str,
) -> torch.fx.GraphModule:
assert quantizer_name in {"XNNPack", "CoreML"}
assert quantization_type in {"PTQ", "QAT"}

pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_inputs)

if quantizer_name == "XNNPack":
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
elif quantizer_name == "CoreML":
quantization_config = LinearQuantizerConfig.from_dict(
{
"global_config": {
"quantization_scheme": QuantizationScheme.symmetric,
"milestones": [0, 0, 10, 10],
"activation_dtype": torch.quint8,
"weight_dtype": torch.qint8,
"weight_per_channel": True,
}
}
)
quantizer = CoreMLQuantizer(quantization_config)

if quantization_type == "PTQ":
prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)
elif quantization_type == "QAT":
prepared_graph = prepare_qat_pt2e(pre_autograd_aten_dialect, quantizer)

prepared_graph(*example_inputs)
converted_graph = convert_pt2e(prepared_graph)
return converted_graph

@pytest.mark.parametrize(
"quantizer_name, quantization_type",
itertools.product(
("XNNPack", "CoreML"),
("PTQ", "QAT")
)
)
def test_conv_relu(self, quantizer_name, quantization_type):
SHAPE = (1, 3, 256, 256)

class Model(torch.nn.Module):
Expand All @@ -40,12 +97,49 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

model = Model()

example_args = (torch.randn(SHAPE),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_args)
example_inputs = (torch.randn(SHAPE),)
converted_graph = self.make_torch_quantized_graph(
model,
example_inputs,
quantizer_name,
quantization_type,
)

quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)
converted_graph = convert_pt2e(prepared_graph)
self.run_compare_torch(
SHAPE,
converted_graph,
frontend=TorchFrontend.EXIR,
backend=("mlprogram", "fp16"),
minimum_deployment_target=ct.target.iOS17,
)

@pytest.mark.parametrize(
"quantizer_name, quantization_type",
itertools.product(
("XNNPack", "CoreML"),
("PTQ", "QAT")
)
)
def test_linear(self, quantizer_name, quantization_type):
SHAPE = (1, 5)

class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)

model = Model()

example_inputs = (torch.randn(SHAPE),)
converted_graph = self.make_torch_quantized_graph(
model,
example_inputs,
quantizer_name,
quantization_type,
)

self.run_compare_torch(
SHAPE,
Expand Down
47 changes: 46 additions & 1 deletion coremltools/optimize/torch/_utils/python_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,57 @@
# Copyright (c) 2023, Apple Inc. All rights reserved.
# Copyright (c) 2024, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import logging as _logging
from collections import OrderedDict as _OrderedDict
from typing import Any as _Any

_logger = _logging.getLogger(__name__)


def get_str(val: _Any):
if isinstance(val, float):
return f"{val:.5f}"
return str(val)


class RegistryMixin:
REGISTRY = None

@classmethod
def register(cls, name: str):
if cls.REGISTRY is None:
cls.REGISTRY = _OrderedDict()

def inner_wrapper(wrapped_obj):
if name in cls.REGISTRY:
_logger.warning(
f"Name: {name} is already registered with object: {cls.REGISTRY[name].__name__} "
f"in registry: {cls.__name__}"
f"Over-writing the name with new class: {wrapped_obj.__name__}"
)
cls.REGISTRY[name] = wrapped_obj
return wrapped_obj

return inner_wrapper

@classmethod
def _get_object(cls, name: str):
if name in cls.REGISTRY:
return cls.REGISTRY[name]
raise NotImplementedError(
f"No object is registered with name: {name} in registry {cls.__name__}."
)


class ClassRegistryMixin(RegistryMixin):
@classmethod
def get_class(cls, name: str):
return cls._get_object(name)


class FunctionRegistryMixin(RegistryMixin):
@classmethod
def get_function(cls, name: str):
return cls._get_object(name)
109 changes: 109 additions & 0 deletions coremltools/optimize/torch/quantization/_annotation_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2024, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from typing import Optional as _Optional

import torch as _torch
import torch.ao.quantization as _aoquant
from attr import define as _define
from torch.ao.quantization.quantizer.quantizer import (
QuantizationSpec as _TorchQuantizationSpec,
)

from coremltools.optimize.torch.quantization.quantization_config import (
ModuleLinearQuantizerConfig as _ModuleLinearQuantizerConfig,
)
from coremltools.optimize.torch.quantization.quantization_config import ObserverType as _ObserverType
from coremltools.optimize.torch.quantization.quantization_config import (
QuantizationScheme as _QuantizationScheme,
)


@_define
class AnnotationConfig:
"""
Module/Operator level configuration class for :py:class:`CoreMLQuantizer`.

For each module/operator, defines the dtype, quantization scheme and observer type
for input(s), output and weights (if any).
"""

input_activation: _Optional[_TorchQuantizationSpec] = None
output_activation: _Optional[_TorchQuantizationSpec] = None
weight: _Optional[_TorchQuantizationSpec] = None

@staticmethod
def _normalize_dtype(dtype: _torch.dtype) -> _torch.dtype:
"""
PyTorch export quantizer only supports uint8 and int8 data types,
so we map the quantized dtypes to the corresponding supported dtype.
"""
dtype_map = {
_torch.quint8: _torch.uint8,
_torch.qint8: _torch.int8,
}
return dtype_map.get(dtype, dtype)

@classmethod
def from_quantization_config(
cls,
quantization_config: _Optional[_ModuleLinearQuantizerConfig],
) -> _Optional["AnnotationConfig"]:
"""
Creates a :py:class:`AnnotationConfig` from ``ModuleLinearQuantizerConfig``
"""
if (
quantization_config is None
or quantization_config.weight_dtype == _torch.float32
):
return None

# Activation QSpec
if quantization_config.activation_dtype == _torch.float32:
output_activation_qspec = None
else:
activation_qscheme = _QuantizationScheme.get_qscheme(
quantization_config.quantization_scheme,
is_per_channel=False,
)
activation_dtype = cls._normalize_dtype(
quantization_config.activation_dtype
)
output_activation_qspec = _TorchQuantizationSpec(
observer_or_fake_quant_ctr=_aoquant.FakeQuantize.with_args(
observer=_ObserverType.get_observer(
quantization_config.activation_observer,
is_per_channel=False,
),
dtype=activation_dtype,
qscheme=activation_qscheme,
),
dtype=activation_dtype,
qscheme=activation_qscheme,
)

# Weight QSpec
weight_qscheme = _QuantizationScheme.get_qscheme(
quantization_config.quantization_scheme,
is_per_channel=quantization_config.weight_per_channel,
)
weight_dtype = cls._normalize_dtype(quantization_config.weight_dtype)
weight_qspec = _TorchQuantizationSpec(
observer_or_fake_quant_ctr=_aoquant.FakeQuantize.with_args(
observer=_ObserverType.get_observer(
quantization_config.weight_observer,
is_per_channel=quantization_config.weight_per_channel,
),
dtype=weight_dtype,
qscheme=weight_qscheme,
),
dtype=weight_dtype,
qscheme=weight_qscheme,
)
return AnnotationConfig(
input_activation=output_activation_qspec,
output_activation=output_activation_qspec,
weight=weight_qspec,
)
Loading