Skip to content

Commit 90136b3

Browse files
szyszyzysfacebook-github-bot
authored andcommitted
Move codebook (LUT) generation methods into common utils. Update functions be more compatible with coreml. (#2772)
Summary: Pull Request resolved: #2772 Reviewed By: metascroy Differential Revision: D79595460
1 parent e6b38bb commit 90136b3

File tree

10 files changed

+1043
-665
lines changed

10 files changed

+1043
-665
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
import torch.nn.functional as F
11+
from parameterized import param, parameterized
12+
13+
from torchao.prototype.quantization.codebook_coreml.codebook_ops import (
14+
choose_qparams_and_quantize_codebook_coreml as choose_qparams_and_quantize_codebook_coreml_original,
15+
)
16+
from torchao.prototype.quantization.codebook_coreml.codebook_ops import (
17+
choose_qparams_and_quantize_codebook_coreml_refactored,
18+
dequantize_codebook,
19+
)
20+
from torchao.quantization.quant_primitives import (
21+
_DTYPE_TO_BIT_WIDTH,
22+
)
23+
24+
25+
class TestCoreMLQuantCompatibility(unittest.TestCase):
26+
TEST_CASES = [
27+
param(grouping_type="column", group_size=128, tensor_shape=(16, 1024)),
28+
]
29+
30+
@parameterized.expand(TEST_CASES)
31+
def test_functional_equivalence(self, grouping_type, group_size, tensor_shape):
32+
input_tensor = torch.randn(tensor_shape, dtype=torch.float32)
33+
code_dtype = torch.uint4
34+
nbits = _DTYPE_TO_BIT_WIDTH[code_dtype]
35+
torch.manual_seed(42)
36+
37+
# --- Get results from reference implementations ---
38+
block_size = [-1, group_size]
39+
expected_luts, expected_codes = (
40+
choose_qparams_and_quantize_codebook_coreml_original(
41+
input_tensor, code_dtype, block_size.copy()
42+
)
43+
)
44+
45+
actual_luts, actual_codes = (
46+
choose_qparams_and_quantize_codebook_coreml_refactored(
47+
input_tensor, code_dtype, block_size.copy()
48+
)
49+
)
50+
51+
# Ensure codes are long for dequantize op compatibility
52+
expected_codes = expected_codes.to(torch.long)
53+
actual_codes = actual_codes.to(torch.long)
54+
55+
self.assertEqual(
56+
actual_luts.shape,
57+
expected_luts.shape,
58+
"LUT shapes do not match after processing",
59+
)
60+
self.assertEqual(
61+
actual_codes.shape, expected_codes.shape, "Code shapes do not match"
62+
)
63+
64+
dequant_expected = dequantize_codebook(
65+
expected_codes, expected_luts, nbits, block_size
66+
)
67+
dequant_actual = dequantize_codebook(
68+
actual_codes, actual_luts, nbits, block_size
69+
)
70+
71+
expected_error = torch.mean((input_tensor - dequant_expected) ** 2).item()
72+
actual_error = torch.mean((input_tensor - dequant_actual) ** 2).item()
73+
74+
self.assertAlmostEqual(
75+
actual_error,
76+
expected_error,
77+
delta=1e-5,
78+
msg="Dequantization error differs significantly between implementations",
79+
)
80+
81+
82+
if __name__ == "__main__":
83+
unittest.main()
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import tempfile
9+
import unittest
10+
11+
import torch
12+
import torch.nn as nn
13+
from parameterized import param, parameterized
14+
from torch import uint1, uint2, uint3, uint4
15+
16+
from torchao.prototype.quantization.codebook_groupwise.api import (
17+
GroupwiseLutWeightConfig,
18+
)
19+
from torchao.prototype.quantization.codebook_utils.codebook_utils import (
20+
group_size_to_block_shapes,
21+
)
22+
from torchao.quantization.quant_api import quantize_
23+
24+
25+
class TestGroupwiseLowbitWeightLut(unittest.TestCase):
26+
"""
27+
Test suite for the GroupwiseLutWeight quantization scheme, updated for the
28+
new simplified API.
29+
"""
30+
31+
TEST_CASES = [
32+
param(
33+
code_dtype=code_dtype,
34+
lut_group_size=lut_group_size,
35+
weight_dtype=weight_dtype,
36+
has_bias=has_bias,
37+
)
38+
for code_dtype in [uint1, uint2, uint3, uint4]
39+
for lut_group_size in [256, 512]
40+
for weight_dtype in [torch.float32]
41+
for has_bias in [True, False]
42+
]
43+
44+
# --------------------------------------------------------------------------
45+
# Test 1: End-to-End Model Accuracy
46+
# --------------------------------------------------------------------------
47+
@parameterized.expand(TEST_CASES)
48+
def test_e2e_accuracy_vs_reference(
49+
self,
50+
code_dtype,
51+
lut_group_size,
52+
weight_dtype,
53+
has_bias,
54+
):
55+
"""
56+
Tests the numerical accuracy of the full quantized model against a reference.
57+
This now uses the `use_qdq_reference` flag instead of layout objects.
58+
"""
59+
m, k, n = 3, 64, 32
60+
activations = torch.randn(m, k, dtype=weight_dtype)
61+
model = nn.Sequential(nn.Linear(k, n, bias=has_bias)).to(dtype=weight_dtype)
62+
63+
# --- 2. Update tensor_shape to reflect the new (k, n) layout ---
64+
lut_block_shape = group_size_to_block_shapes(
65+
lut_group_size=lut_group_size, tensor_shape=(n, k)
66+
)
67+
68+
# --- Quantize using C++ ops ---
69+
quantized_model = copy.deepcopy(model)
70+
perf_config = GroupwiseLutWeightConfig(
71+
code_dtype=code_dtype,
72+
weight_dtype=weight_dtype,
73+
lut_block_shape=lut_block_shape,
74+
use_qdq_reference=False,
75+
)
76+
quantize_(quantized_model, perf_config)
77+
with torch.no_grad():
78+
actual_result = quantized_model(activations)
79+
80+
# --- Quantize for Reference (using Python ops) ---
81+
reference_model = copy.deepcopy(model)
82+
ref_config = GroupwiseLutWeightConfig(
83+
code_dtype=code_dtype,
84+
weight_dtype=weight_dtype,
85+
lut_block_shape=lut_block_shape,
86+
use_qdq_reference=True,
87+
)
88+
quantize_(reference_model, ref_config)
89+
with torch.no_grad():
90+
expected_result = reference_model(activations)
91+
# Compare results
92+
self.assertTrue(
93+
torch.allclose(actual_result, expected_result, atol=1e-2, rtol=1e-2)
94+
)
95+
96+
def tearDown(self):
97+
"""
98+
Clear the TorchDynamo cache after each test case to prevent
99+
recompilation errors in parameterized tests.
100+
"""
101+
super().tearDown()
102+
torch._dynamo.reset()
103+
104+
# --------------------------------------------------------------------------
105+
# Test 2: Deployment Readiness (Updated for new API)
106+
# --------------------------------------------------------------------------
107+
@parameterized.expand(TEST_CASES)
108+
def test_export_compile_aoti(
109+
self,
110+
code_dtype,
111+
lut_group_size,
112+
weight_dtype,
113+
has_bias,
114+
):
115+
"""
116+
Tests that the quantized model can be exported and compiled.
117+
"""
118+
k, n = 64, 32
119+
activations = torch.randn(2, k, dtype=weight_dtype)
120+
model = (
121+
nn.Sequential(nn.Linear(k, n, bias=has_bias)).to(dtype=weight_dtype).eval()
122+
)
123+
lut_block_shape = group_size_to_block_shapes(
124+
lut_group_size=lut_group_size,
125+
tensor_shape=(n, k),
126+
)
127+
128+
# Configure the quantization using the new API
129+
config = GroupwiseLutWeightConfig(
130+
code_dtype=code_dtype,
131+
weight_dtype=weight_dtype,
132+
lut_block_shape=lut_block_shape,
133+
use_qdq_reference=False,
134+
)
135+
quantize_(model, config)
136+
137+
with torch.no_grad():
138+
eager_results = model(activations)
139+
140+
# Export and Compile
141+
exported_model = torch.export.export(model, (activations,))
142+
compiled_model = torch.compile(model, fullgraph=True)
143+
144+
with tempfile.TemporaryDirectory() as tmpdir, torch.no_grad():
145+
# Check exported model
146+
exported_results = exported_model.module()(activations)
147+
self.assertTrue(
148+
torch.allclose(eager_results, exported_results, atol=1e-3, rtol=1e-3)
149+
)
150+
151+
# Check compiled model
152+
compiled_results = compiled_model(activations)
153+
self.assertTrue(
154+
torch.allclose(eager_results, compiled_results, atol=1e-3, rtol=1e-3)
155+
)
156+
157+
# Check AOTI compiled model using the packaging API
158+
package_path = f"{tmpdir}/model.pt2"
159+
torch._inductor.aoti_compile_and_package(
160+
exported_model, package_path=package_path
161+
)
162+
aoti_model = torch._inductor.aoti_load_package(package_path)
163+
aoti_results = aoti_model(activations)
164+
self.assertTrue(
165+
torch.allclose(eager_results, aoti_results, atol=1e-3, rtol=1e-3)
166+
)
167+
168+
169+
if __name__ == "__main__":
170+
unittest.main()

torchao/prototype/quantization/codebook_coreml/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .api import CodebookWeightOnlyConfig
22
from .codebook_ops import (
33
choose_qparams_and_quantize_codebook_coreml,
4+
choose_qparams_and_quantize_codebook_coreml_refactored,
45
dequantize_codebook,
56
)
67
from .codebook_quantized_tensor import CodebookQuantizedTensor
@@ -9,5 +10,6 @@
910
"CodebookQuantizedTensor",
1011
"CodebookWeightOnlyConfig",
1112
"choose_qparams_and_quantize_codebook_coreml",
13+
"choose_qparams_and_quantize_codebook_coreml_refactored",
1214
"dequantize_codebook",
1315
]

torchao/prototype/quantization/codebook_coreml/codebook_ops.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,112 @@ def choose_qparams_and_quantize_codebook_coreml(
117117

118118
return res_lut, res_w
119119

120+
def choose_qparams_and_quantize_codebook_coreml_refactored(
121+
input_tensor: torch.Tensor,
122+
code_dtype: torch.dtype,
123+
block_size: List[int],
124+
force_kmeans1d: bool = False,
125+
cluster_dim: int = 1,
126+
vector_axis: Optional[int] = None,
127+
) -> Tuple[torch.Tensor, torch.Tensor]:
128+
"""
129+
Initialize the codebook using k-means clustering on blocks of the input tensor.
130+
131+
Args:
132+
input_tensor (torch.Tensor): The input tensor to be quantized.
133+
code_dtype (torch.dtype): The dtype for the codes. [torch.uint1, ..., torch.uint8]
134+
block_size (List[int]): block sizes for how many elements in each dimension share
135+
the same lookup table (len(block_size) == input_tensor.dim())
136+
Each dimension of input_tensor must be divisible by the corresponding element of block_size
137+
Look up tables are indexed by {(di // bi) for i in input_tensor.dim()}
138+
For example, if the input tensor has shape (N, K), and block_size is (N, group_size), this means
139+
there is a lookup table for group_size columns, i.e., (K // group_size) total look up tables
140+
force_kmeans1d (bool): Use kmeans1d regardless of number of weights
141+
cluster_dim (int): this means the size of the vector for vector lookup table quantization
142+
e.g. when cluster_dim is 4, instead of quantizing each scalar value one by one, we quantize
143+
the tensor in a unit of 4 element vectors, a vector of original tensor will be mapped to
144+
a vector in the codebook (lookup table) based on the indices.
145+
vector_axis (Optional[int]): used in vector quantization, see more docs in https://github.com/apple/coremltools/blob/1c0e5cb1c1e3ab759af107b54f2be18b7c03f8aa/coremltools/optimize/_utils.py#L371
146+
147+
Returns:
148+
Tuple[torch.Tensor, torch.Tensor] The codebook (lookup table) Tensor and the quantized Tensor (codes, torch.uint8)
149+
The LUT table has dimension (g0, .., g(N-1), 2**nbits, vec_dim), where:
150+
* The first N dimensions index over the different tables (gi = input_tensor.shape[i] // block_size[i] in each dimension)
151+
* The N + 1 dimension indexes over the nbit indices (2 ** nbits)
152+
* The N + 2 dimension indexes over the look up values (shape = 1 for scalar)
153+
"""
154+
assert code_dtype in list(_SUB_BYTE_UINT_BOUNDS.keys()) + [torch.uint8]
155+
nbits = _DTYPE_TO_BIT_WIDTH[code_dtype]
156+
assert nbits >= 1 and nbits <= 8, f"nbits must be in [1, 8], got {nbits}"
157+
assert input_tensor.dim() == 2, "Currently only rank 2 tensors are supported"
158+
assert cluster_dim == 1, f"only cluster_dim == 1 is supported right now, got {cluster_dim}"
159+
160+
original_shape = input_tensor.shape
161+
N, K = original_shape
162+
input_tensor = input_tensor.detach()
163+
164+
# --- Process block_size ---
165+
assert len(block_size) == 2
166+
processed_block_size = block_size.copy()
167+
if processed_block_size[0] == -1:
168+
processed_block_size[0] = N
169+
if processed_block_size[1] == -1:
170+
processed_block_size[1] = K
171+
172+
row_block_size, col_block_size = processed_block_size
173+
assert N % row_block_size == 0, f"Tensor rows ({N}) not divisible by row block size ({row_block_size})"
174+
assert K % col_block_size == 0, f"Tensor cols ({K}) not divisible by col block size ({col_block_size})"
175+
176+
# --- Determine and execute grouping strategy ---
177+
is_col_grouping = (col_block_size < K and row_block_size == N)
178+
is_row_grouping = (row_block_size < N and col_block_size == K)
179+
assert is_col_grouping or is_row_grouping, "Invalid block_size. Must group by either rows or columns, not both or neither."
180+
181+
res_lut_list = []
182+
from coremltools.models.neural_network.quantization_utils import (
183+
_get_kmeans_lookup_table_and_weight,
184+
)
185+
if is_col_grouping:
186+
# STRATEGY 1: Group by COLUMNS (original behavior)
187+
num_luts = K // col_block_size
188+
reshaped_tensor = input_tensor.reshape(N, num_luts, col_block_size)
189+
res_codes = torch.zeros_like(reshaped_tensor, dtype=torch.uint8)
190+
191+
for i in range(num_luts):
192+
block_to_quantize = reshaped_tensor[:, i, :]
193+
lut, w = _get_kmeans_lookup_table_and_weight(
194+
nbits, block_to_quantize, force_kmeans1d, cluster_dim, vector_axis
195+
)
196+
res_lut_list.append(torch.from_numpy(lut))
197+
res_codes[:, i, :] = torch.from_numpy(w.reshape(N, col_block_size))
198+
199+
# Shape to match CoreML spec: (1, num_luts, 2**nbits, 1)
200+
final_luts = torch.stack(res_lut_list, dim=0).reshape(1, num_luts, 2**nbits, 1)
201+
202+
else: # is_row_grouping
203+
# STRATEGY 2: Group by ROWS (your wrapper's behavior)
204+
num_luts = N // row_block_size
205+
reshaped_tensor = input_tensor.reshape(num_luts, row_block_size, K)
206+
res_codes = torch.zeros_like(reshaped_tensor, dtype=torch.uint8)
207+
208+
for i in range(num_luts):
209+
block_to_quantize = reshaped_tensor[i, :, :]
210+
lut, w = _get_kmeans_lookup_table_and_weight(
211+
nbits, block_to_quantize, force_kmeans1d, cluster_dim, vector_axis
212+
)
213+
res_lut_list.append(torch.from_numpy(lut))
214+
res_codes[i, :, :] = torch.from_numpy(w.reshape(row_block_size, K))
215+
216+
final_luts_stacked = torch.stack(res_lut_list, dim=0) # Shape: (num_luts, 2**nbits, 1)
217+
218+
# Reshape to the consistent 4D format
219+
# The shape is (num_row_groups, 1, 2**nbits, 1)
220+
final_luts = final_luts_stacked.reshape(num_luts, 1, 2**nbits, 1)
221+
222+
# Reshape codes back to the original tensor shape
223+
final_codes = res_codes.reshape(*original_shape)
224+
225+
return final_luts, final_codes
120226

121227
@register_custom_op
122228
def dequantize_codebook(

0 commit comments

Comments
 (0)