Skip to content

Commit df46e7a

Browse files
authored
INT4 XPU enabling (#1577)
* enable floating zero point with little numerical issue * unify weight packing * review to view * add justfy contiguous * fix torch.compile * remove typos in tests * overload copy_ for torch.load * copy_ need the 2nd args to be int4 * expose preserve_zero * refactor zero_point_domain dispatch * export zero_point_domain and preserve_zero as the top arguments * format * fix parameter initialization in UT Signed-off-by: Meng, Hengyu <[email protected]> * encapsulate version check as helpers remove zero_point_dtype assigning Signed-off-by: Meng, Hengyu <[email protected]> enable zp dtype: u8/s8/s16/s32/s64 Signed-off-by: Meng, Hengyu <[email protected]> * fix zero_point_dtype Signed-off-by: Meng, Hengyu <[email protected]> --------- Signed-off-by: Meng, Hengyu <[email protected]>
1 parent 31f119e commit df46e7a

15 files changed

+841
-136
lines changed

test/dtypes/test_affine_quantized.py

+79-60
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
)
1616

1717
from torchao.core.config import AOBaseConfig
18-
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
18+
from torchao.dtypes import (
19+
CutlassInt4PackedLayout,
20+
Int4CPULayout,
21+
Int4XPULayout,
22+
SemiSparseLayout,
23+
)
1924
from torchao.quantization import (
2025
Int4WeightOnlyConfig,
2126
Int8DynamicActivationInt8WeightConfig,
@@ -31,7 +36,8 @@
3136
from torchao.testing.utils import skip_if_no_cuda, skip_if_rocm
3237
from torchao.utils import (
3338
TORCH_VERSION_AT_LEAST_2_5,
34-
TORCH_VERSION_AT_LEAST_2_6,
39+
check_cpu_version,
40+
check_xpu_version,
3541
is_fbcode,
3642
is_ROCM,
3743
is_sm_at_least_89,
@@ -52,15 +58,19 @@ def get_quantization_functions(
5258
int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC),
5359
]
5460
if do_int4:
55-
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
61+
if check_cpu_version(device):
5662
base_functions.append(
5763
int4_weight_only(group_size=32, layout=Int4CPULayout())
5864
)
65+
elif check_xpu_version(device):
66+
base_functions.append(
67+
int4_weight_only(group_size=32, layout=Int4XPULayout())
68+
)
5969
if int4_zp_int:
6070
base_functions.append(
6171
int4_weight_only(
6272
group_size=32,
63-
layout=Int4CPULayout(),
73+
layout=Int4XPULayout(),
6474
zero_point_domain=ZeroPointDomain.INT,
6575
)
6676
)
@@ -77,7 +87,7 @@ def get_quantization_functions(
7787
)
7888
base_functions.append(int4_dynamic_activation_int4_weight())
7989

80-
if do_sparse:
90+
if do_sparse and device != "xpu":
8191
base_functions.append(
8292
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
8393
)
@@ -89,6 +99,10 @@ def get_quantization_functions(
8999

90100

91101
class TestAffineQuantized(TestCase):
102+
GPU_DEVICES = (["cuda"] if torch.cuda.is_available() else []) + (
103+
["xpu"] if torch.xpu.is_available() else []
104+
)
105+
92106
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
93107
def test_tensor_core_layout_transpose(self):
94108
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
@@ -109,51 +123,53 @@ def test_tensor_core_layout_transpose(self):
109123
aqt_shape = aqt.shape
110124
self.assertEqual(aqt_shape, shape)
111125

112-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
113-
@common_utils.parametrize(
114-
"apply_quant",
115-
get_quantization_functions(is_cusparselt_available, True, "cuda", True),
116-
)
117-
@skip_if_rocm("ROCm enablement in progress")
118-
def test_weights_only(self, apply_quant):
119-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
120-
if isinstance(apply_quant, AOBaseConfig):
121-
quantize_(linear, apply_quant)
122-
ql = linear
123-
else:
124-
# TODO(#1690): delete this once config migration is done
125-
ql = apply_quant(linear)
126-
with tempfile.NamedTemporaryFile() as f:
127-
torch.save(ql.state_dict(), f)
128-
f.seek(0)
129-
# `weights_only=True` is enabled for torch 2.5+
130-
if TORCH_VERSION_AT_LEAST_2_5:
131-
_ = torch.load(f, weights_only=True)
132-
else:
133-
_ = torch.load(f, weights_only=False)
134-
135-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
126+
@unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available")
127+
def test_weights_only(self):
128+
for device in self.GPU_DEVICES:
129+
apply_quant_list = get_quantization_functions(
130+
is_cusparselt_available, True, device, True
131+
)
132+
for apply_quant in apply_quant_list:
133+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device)
134+
if isinstance(apply_quant, AOBaseConfig):
135+
quantize_(linear, apply_quant)
136+
ql = linear
137+
else:
138+
# TODO(#1690): delete this once config migration is done
139+
ql = apply_quant(linear)
140+
with tempfile.NamedTemporaryFile() as f:
141+
torch.save(ql.state_dict(), f)
142+
f.seek(0)
143+
# `weights_only=True` is enabled for torch 2.5+
144+
if TORCH_VERSION_AT_LEAST_2_5:
145+
_ = torch.load(f, weights_only=True)
146+
else:
147+
_ = torch.load(f, weights_only=False)
148+
149+
@unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available")
136150
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
137151
def test_to_device(self, apply_quant):
138-
def _apply(module, config_or_subclass_inserter):
139-
if isinstance(config_or_subclass_inserter, AOBaseConfig):
140-
quantize_(module, config_or_subclass_inserter)
141-
else:
142-
# TODO(#1690): delete this once config migration is done
143-
module = config_or_subclass_inserter(module)
144-
return module
152+
for device in self.GPU_DEVICES:
145153

146-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
147-
ql = _apply(linear, apply_quant)
148-
ql.to("cuda")
154+
def _apply(module, config_or_subclass_inserter):
155+
if isinstance(config_or_subclass_inserter, AOBaseConfig):
156+
quantize_(module, config_or_subclass_inserter)
157+
else:
158+
# TODO(#1690): delete this once config migration is done
159+
module = config_or_subclass_inserter(module)
160+
return module
149161

150-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
151-
ql = _apply(linear, apply_quant)
152-
ql.to(device="cuda")
162+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
163+
ql = _apply(linear, apply_quant)
164+
ql.to(device)
153165

154-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
155-
ql = _apply(linear, apply_quant)
156-
ql.cuda()
166+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
167+
ql = _apply(linear, apply_quant)
168+
ql.to(device=device)
169+
170+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
171+
ql = _apply(linear, apply_quant)
172+
ql.to(device)
157173

158174
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
159175
def test_register_new_dispatch(self):
@@ -203,20 +219,19 @@ def apply_uint6_weight_only_quant(linear):
203219

204220
deregister_aqt_quantized_linear_dispatch(dispatch_condition)
205221

206-
@common_utils.parametrize(
207-
"apply_quant", get_quantization_functions(is_cusparselt_available, True)
208-
)
209-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
210-
@skip_if_rocm("ROCm enablement in progress")
211-
def test_print_quantized_module(self, apply_quant):
212-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
213-
if isinstance(apply_quant, AOBaseConfig):
214-
quantize_(linear, apply_quant)
215-
ql = linear
216-
else:
217-
# TODO(#1690): delete this once config migration is done
218-
ql = apply_quant(linear)
219-
assert "AffineQuantizedTensor" in str(ql)
222+
@unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available")
223+
def test_print_quantized_module(self):
224+
for device in self.GPU_DEVICES:
225+
apply_quant_list = get_quantization_functions(True, True, device, True)
226+
for apply_quant in apply_quant_list:
227+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device)
228+
if isinstance(apply_quant, AOBaseConfig):
229+
quantize_(linear, apply_quant)
230+
ql = linear
231+
else:
232+
# TODO(#1690): delete this once config migration is done
233+
ql = apply_quant(linear)
234+
assert "AffineQuantizedTensor" in str(ql)
220235

221236
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
222237
@common_utils.parametrize(
@@ -267,7 +282,11 @@ def test_copy__mismatch_metadata(self, apply_quant):
267282

268283

269284
class TestAffineQuantizedBasic(TestCase):
270-
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
285+
COMMON_DEVICES = (
286+
["cpu"]
287+
+ (["cuda"] if torch.cuda.is_available() else [])
288+
+ (["xpu"] if torch.xpu.is_available() else [])
289+
)
271290
COMMON_DTYPES = [torch.bfloat16]
272291

273292
@common_utils.parametrize("device", COMMON_DEVICES)

test/integration/test_integration.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
from torch._inductor.utils import run_and_get_code
2020

2121
import torchao
22-
from torchao.dtypes import Int4CPULayout, TensorCoreTiledLayout
23-
from torchao.dtypes.utils import is_device
22+
from torchao.dtypes import Int4CPULayout, Int4XPULayout, TensorCoreTiledLayout
2423
from torchao.quantization import safe_int_mm
2524
from torchao.quantization.autoquant import (
2625
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
@@ -84,6 +83,8 @@
8483
TORCH_VERSION_AT_LEAST_2_6,
8584
TORCH_VERSION_AT_LEAST_2_7,
8685
benchmark_model,
86+
check_cpu_version,
87+
check_xpu_version,
8788
is_fbcode,
8889
is_sm_at_least_90,
8990
unwrap_tensor_subclass,
@@ -146,17 +147,19 @@ def _int8da_int8w_api(
146147

147148

148149
def _int4wo_api(mod, use_hqq=False):
149-
if (
150-
is_device(next(mod.parameters()).device.type, "cpu")
151-
and TORCH_VERSION_AT_LEAST_2_6
152-
):
150+
if check_cpu_version(next(mod.parameters()).device):
153151
quantize_(
154152
mod,
155153
int4_weight_only(
156154
layout=Int4CPULayout(), use_hqq=use_hqq, set_inductor_config=False
157155
),
158156
)
159157
unwrap_tensor_subclass(mod)
158+
elif check_xpu_version(next(mod.parameters()).device):
159+
quantize_(
160+
mod, int4_weight_only(layout=Int4XPULayout()), set_inductor_config=False
161+
)
162+
unwrap_tensor_subclass(mod)
160163
elif TORCH_VERSION_AT_LEAST_2_4:
161164
quantize_(mod, int4_weight_only(set_inductor_config=False))
162165
if not TORCH_VERSION_AT_LEAST_2_5:
@@ -1129,8 +1132,10 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
11291132
if dtype != torch.bfloat16:
11301133
self.skipTest(f"Fails for {dtype}")
11311134
layout_list = []
1132-
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
1135+
if check_cpu_version(device):
11331136
layout_list.append(Int4CPULayout())
1137+
elif check_xpu_version(device):
1138+
layout_list.append(Int4XPULayout())
11341139
else:
11351140
for inner_k_tiles in [4, 2]:
11361141
layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles))

test/quantization/test_quant_api.py

+58-17
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
from torchao import quantize_
2525
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
2626
from torchao._models.llama.tokenizer import get_tokenizer
27-
from torchao.dtypes import AffineQuantizedTensor
27+
from torchao.dtypes import (
28+
AffineQuantizedTensor,
29+
Int4CPULayout,
30+
Int4XPULayout,
31+
)
2832
from torchao.quantization import LinearActivationQuantizedTensor
2933
from torchao.quantization.quant_api import (
3034
Quantizer,
@@ -54,6 +58,7 @@
5458
TORCH_VERSION_AT_LEAST_2_4,
5559
TORCH_VERSION_AT_LEAST_2_5,
5660
TORCH_VERSION_AT_LEAST_2_6,
61+
TORCH_VERSION_AT_LEAST_2_8,
5762
is_sm_at_least_89,
5863
is_sm_at_least_90,
5964
unwrap_tensor_subclass,
@@ -189,6 +194,10 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
189194

190195

191196
class TestQuantFlow(TestCase):
197+
GPU_DEVICES = (["cuda"] if torch.cuda.is_available() else []) + (
198+
["xpu"] if torch.xpu.is_available() else []
199+
)
200+
192201
def test_dynamic_quant_gpu_singleline(self):
193202
m = ToyLinearModel().eval()
194203
example_inputs = m.example_inputs()
@@ -229,6 +238,34 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
229238
compiled = m(*example_inputs)
230239
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
231240

241+
@unittest.skipIf(not torch.xpu.is_available(), "Need XPU available")
242+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "only works for torch 2.8+")
243+
def test_int4_wo_quant_save_load(self):
244+
m = ToyLinearModel().eval().cpu()
245+
246+
def api(model):
247+
quantize_(model, int4_weight_only(layout=Int4XPULayout()))
248+
unwrap_tensor_subclass(model)
249+
250+
api(m)
251+
252+
example_inputs = m.example_inputs()
253+
ref = m(*example_inputs)
254+
with tempfile.NamedTemporaryFile() as f:
255+
torch.save(m.state_dict(), f)
256+
f.seek(0)
257+
state_dict = torch.load(f)
258+
259+
m2 = ToyLinearModel().eval().cpu()
260+
api(m2)
261+
262+
m2.load_state_dict(state_dict)
263+
m2 = m2.to(device="xpu")
264+
example_inputs = map(lambda x: x.xpu(), example_inputs)
265+
res = m2(*example_inputs)
266+
267+
torch.testing.assert_close(ref, res.cpu())
268+
232269
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
233270
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "only works for torch 2.4+")
234271
def test_int8_wo_quant_save_load(self):
@@ -615,25 +652,31 @@ def test_quantized_tensor_subclass_8da4w(self, mapping_type):
615652

616653
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
617654
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+")
618-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
655+
@unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available")
619656
def test_quantized_tensor_subclass_int4(self):
620-
# use 1024 so that we don't need padding
621-
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
622-
m_copy = copy.deepcopy(m)
623-
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
657+
for device in self.GPU_DEVICES:
658+
# use 1024 so that we don't need padding
659+
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to(device)
660+
m_copy = copy.deepcopy(m)
661+
example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device)
624662

625-
group_size = 32
626-
quantize_(m, int4_weight_only(group_size=group_size))
627-
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
628-
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
663+
group_size = 32
664+
if device == "xpu":
665+
quantize_(
666+
m, int4_weight_only(group_size=group_size, layout=Int4XPULayout())
667+
)
668+
else:
669+
quantize_(m, int4_weight_only(group_size=group_size))
670+
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
671+
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
629672

630-
# reference
631-
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=group_size)
673+
# reference
674+
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=group_size)
632675

633-
res = m(*example_inputs)
634-
ref = m_copy(*example_inputs)
676+
res = m(*example_inputs)
677+
ref = m_copy(*example_inputs)
635678

636-
self.assertTrue(torch.equal(res, ref))
679+
self.assertTrue(torch.equal(res, ref))
637680

638681
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
639682
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -799,8 +842,6 @@ def reset_memory():
799842
@common_utils.parametrize("x_dim", [2, 3])
800843
@common_utils.parametrize("use_hqq", [True, False])
801844
def test_int4wo_cpu(self, dtype, x_dim, use_hqq):
802-
from torchao.dtypes import Int4CPULayout
803-
804845
device = "cpu"
805846
m = ToyLinearModel().eval().to(dtype).to(device)
806847
example_inputs = m.example_inputs(dtype=dtype, device=device)

0 commit comments

Comments
 (0)