|
57 | 57 | from torchao.quantization.pt2e.quantizer.embedding_quantizer import ( # noqa: F811
|
58 | 58 | EmbeddingQuantizer,
|
59 | 59 | )
|
60 |
| -from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import ( |
| 60 | +from torchao.testing.pt2e._xnnpack_quantizer import ( |
61 | 61 | XNNPACKQuantizer,
|
62 | 62 | get_symmetric_quantization_config,
|
63 | 63 | )
|
64 |
| -from torchao.quantization.pt2e.quantizer.xnnpack_quantizer_utils import ( |
| 64 | +from torchao.testing.pt2e._xnnpack_quantizer_utils import ( |
65 | 65 | OP_TO_ANNOTATOR,
|
66 | 66 | QuantizationConfig,
|
67 | 67 | )
|
@@ -1328,6 +1328,40 @@ def validate(self, model: torch.fx.GraphModule) -> None:
|
1328 | 1328 | with self.assertRaises(Exception):
|
1329 | 1329 | m = prepare_pt2e(m, BackendAQuantizer())
|
1330 | 1330 |
|
| 1331 | + def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): |
| 1332 | + # resetting dynamo cache |
| 1333 | + torch._dynamo.reset() |
| 1334 | + |
| 1335 | + m = export_for_training( |
| 1336 | + m, |
| 1337 | + example_inputs, |
| 1338 | + ).module() |
| 1339 | + if is_qat: |
| 1340 | + m = prepare_qat_pt2e(m, quantizer) |
| 1341 | + else: |
| 1342 | + m = prepare_pt2e(m, quantizer) |
| 1343 | + m(*example_inputs) |
| 1344 | + m = convert_pt2e(m) |
| 1345 | + return m |
| 1346 | + |
| 1347 | + def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule: |
| 1348 | + class M(torch.nn.Module): |
| 1349 | + def __init__(self) -> None: |
| 1350 | + super().__init__() |
| 1351 | + self.linear = torch.nn.Linear(2, 2) |
| 1352 | + |
| 1353 | + def forward(self, x): |
| 1354 | + return self.linear(x) |
| 1355 | + |
| 1356 | + quantizer = XNNPACKQuantizer() |
| 1357 | + operator_config = get_symmetric_quantization_config( |
| 1358 | + is_per_channel=is_per_channel |
| 1359 | + ) |
| 1360 | + quantizer.set_global(operator_config) |
| 1361 | + example_inputs = (torch.randn(2, 2),) |
| 1362 | + m = M().eval() |
| 1363 | + return self._quantize(m, quantizer, example_inputs) |
| 1364 | + |
1331 | 1365 | def test_fold_quantize(self):
|
1332 | 1366 | """Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)"""
|
1333 | 1367 | m = self._get_pt2e_quantized_linear()
|
|
0 commit comments