Skip to content

Commit 10ed77c

Browse files
authored
[3/n] Flatten import path (#2125)
* [3/n] Flatten import path Summary: We only keep two import paths: ``` torchao.quantization.pt2e torchao.quantization.pt2e.quantizer ``` for all classes and utils. Next: remove the underscore before public util functions Test Plan: pytest test/quantization/pt2e Reviewers: Subscribers: Tasks: Tags: * move refactor * fix circular import * import
1 parent 868afa6 commit 10ed77c

23 files changed

+314
-1720
lines changed

test/quantization/pt2e/test_duplicate_dq.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
Quantizer,
2727
SharedQuantizationSpec,
2828
)
29-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
29+
from torchao.testing.pt2e._xnnpack_quantizer import (
3030
get_symmetric_quantization_config,
3131
)
32-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer_utils import (
32+
from torchao.testing.pt2e._xnnpack_quantizer_utils import (
3333
OP_TO_ANNOTATOR,
3434
QuantizationConfig,
3535
)

test/quantization/pt2e/test_metadata_porting.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
1818
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation, Quantizer
19-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
19+
from torchao.testing.pt2e._xnnpack_quantizer import (
2020
get_symmetric_quantization_config,
2121
)
22-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
22+
from torchao.testing.pt2e._xnnpack_quantizer_utils import OP_TO_ANNOTATOR
2323
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
2424

2525

test/quantization/pt2e/test_numeric_debugger.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
2626
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
27-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
27+
from torchao.testing.pt2e._xnnpack_quantizer import (
2828
XNNPACKQuantizer,
2929
get_symmetric_quantization_config,
3030
)

test/quantization/pt2e/test_quantize_pt2e.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@
5757
from torchao.quantization.pt2e.quantizer.embedding_quantizer import ( # noqa: F811
5858
EmbeddingQuantizer,
5959
)
60-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
60+
from torchao.testing.pt2e._xnnpack_quantizer import (
6161
XNNPACKQuantizer,
6262
get_symmetric_quantization_config,
6363
)
64-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer_utils import (
64+
from torchao.testing.pt2e._xnnpack_quantizer_utils import (
6565
OP_TO_ANNOTATOR,
6666
QuantizationConfig,
6767
)
@@ -1328,6 +1328,40 @@ def validate(self, model: torch.fx.GraphModule) -> None:
13281328
with self.assertRaises(Exception):
13291329
m = prepare_pt2e(m, BackendAQuantizer())
13301330

1331+
def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False):
1332+
# resetting dynamo cache
1333+
torch._dynamo.reset()
1334+
1335+
m = export_for_training(
1336+
m,
1337+
example_inputs,
1338+
).module()
1339+
if is_qat:
1340+
m = prepare_qat_pt2e(m, quantizer)
1341+
else:
1342+
m = prepare_pt2e(m, quantizer)
1343+
m(*example_inputs)
1344+
m = convert_pt2e(m)
1345+
return m
1346+
1347+
def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule:
1348+
class M(torch.nn.Module):
1349+
def __init__(self) -> None:
1350+
super().__init__()
1351+
self.linear = torch.nn.Linear(2, 2)
1352+
1353+
def forward(self, x):
1354+
return self.linear(x)
1355+
1356+
quantizer = XNNPACKQuantizer()
1357+
operator_config = get_symmetric_quantization_config(
1358+
is_per_channel=is_per_channel
1359+
)
1360+
quantizer.set_global(operator_config)
1361+
example_inputs = (torch.randn(2, 2),)
1362+
m = M().eval()
1363+
return self._quantize(m, quantizer, example_inputs)
1364+
13311365
def test_fold_quantize(self):
13321366
"""Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)"""
13331367
m = self._get_pt2e_quantized_linear()

test/quantization/pt2e/test_quantize_pt2e_qat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
QuantizationSpec,
4848
Quantizer,
4949
)
50-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
50+
from torchao.testing.pt2e._xnnpack_quantizer import (
5151
XNNPACKQuantizer,
5252
get_symmetric_quantization_config,
5353
)

test/quantization/pt2e/test_representation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2525
from torchao.quantization.pt2e.quantizer import Quantizer
26-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
26+
from torchao.testing.pt2e._xnnpack_quantizer import (
2727
XNNPACKQuantizer,
2828
get_symmetric_quantization_config,
2929
)

0 commit comments

Comments
 (0)