Skip to content

Conversation

@pulkital
Copy link
Collaborator

@pulkital pulkital commented Mar 8, 2024

Adds a CoreMLQuantizer which extends PyTorch 2.0 Export Quantizer to enable CoreML quantization rules.

Usage:

import torch.nn as nn
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_qat_pt2e

from coremltools.optimize.torch.quantization._coreml_quantizer import CoreMLQuantizer

model = nn.Sequential(
    OrderedDict(
        {
            "conv": nn.Conv2d(1, 20, (3, 3)),
            "relu1": nn.ReLU(),
            "conv2": nn.Conv2d(20, 20, (3, 3)),
            "relu2": nn.ReLU(),
        }
    )
)

loss_fn = define_loss()

# initialize the annotator with quantization config
config = LinearQuantizerConfig.from_dict(
    {
        "global_config": {
            "quantization_scheme": "symmetric",
        }
    }
)
quantizer = CoreMLQuantizer(config)

example_inputs = torch.randn(1, 1, 28, 28)

# create export graph
exported_model = capture_pre_autograd_graph(model, (example_inputs,))

# prepare the model to insert FakeQuantize layers for QAT
prepared_model = prepare_qat_pt2e(exported_model, quantizer)

# use prepared model in your PyTorch training loop
for inputs, labels in data:
    output = prepared_model(inputs)
    loss = loss_fn(output, labels)
    loss.backward()
    optimizer.step()
    # turn observers/quantizers on/off depending on iteration number

# convert operations to their quanitzed counterparts using parameters learnt via QAT
model = convert_pt2e(prepared_model)

@pulkital pulkital requested a review from YifanShenSZ March 8, 2024 20:42
@pulkital pulkital changed the title Add CoreMLQuantizer in coremltools.optimize.torch to support PyTorch Export Quantizer Add CoreMLQuantizer in coremltools.optimize.torch to support PyTorch Export based quantization Mar 8, 2024
@pulkital pulkital force-pushed the add-support-for-pytorch-export-quantizer branch from 53b2b7c to 6af4a4c Compare March 8, 2024 20:47
@YifanShenSZ YifanShenSZ self-assigned this Mar 8, 2024
@YifanShenSZ YifanShenSZ added PyTorch (not traced) ct.optimize Question/issue related to the coremltool's optimization package labels Mar 8, 2024
@YifanShenSZ
Copy link
Collaborator

Added 2 changes:

  1. Fix linter issue
  2. Quantization conversion test

Testing: ✅

  1. GitLab CI
  2. Locally verified backward compatibility with pytorch 1.13.1
  3. Locally verified executorch tests

Copy link
Collaborator

@YifanShenSZ YifanShenSZ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! We are now one step closer to ExecuTorch CoreMLQuantizer! Thanks for the great work!

@YifanShenSZ YifanShenSZ merged commit 01e9845 into main Mar 9, 2024
@YifanShenSZ YifanShenSZ deleted the add-support-for-pytorch-export-quantizer branch March 9, 2024 21:06
@jerryzh168
Copy link

jerryzh168 commented Dec 24, 2025

@pulkital @YifanShenSZ we are deprecating pt2e quantization from pytorch/pytorch (pytorch/pytorch#169151),

wondering if you can help updating the callsites to import from torchao instead?

from torch.ao.quantization.quantizer.quantizer import Quantizer as _TorchQuantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer import _get_module_name_filter

before

from torch.ao.quantization.quantizer.quantizer import Quantizer as _TorchQuantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer import _get_module_name_filter

after:

from torchao.quantization.pt2e.quantizer import Quantizer as _TorchQuantizer
from torchao.quantization.pt2e.quantizer import get_module_name_filter

@jerryzh168
Copy link

please take a look #2634

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ct.optimize Question/issue related to the coremltool's optimization package PyTorch (not traced)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants