15
15
)
16
16
17
17
from torchao .core .config import AOBaseConfig
18
- from torchao .dtypes import CutlassInt4PackedLayout , Int4CPULayout , SemiSparseLayout
18
+ from torchao .dtypes import (
19
+ CutlassInt4PackedLayout ,
20
+ Int4CPULayout ,
21
+ Int4XPULayout ,
22
+ SemiSparseLayout ,
23
+ )
19
24
from torchao .quantization import (
20
25
Int4WeightOnlyConfig ,
21
26
Int8DynamicActivationInt8WeightConfig ,
31
36
from torchao .testing .utils import skip_if_no_cuda , skip_if_rocm
32
37
from torchao .utils import (
33
38
TORCH_VERSION_AT_LEAST_2_5 ,
34
- TORCH_VERSION_AT_LEAST_2_6 ,
39
+ check_cpu_version ,
40
+ check_xpu_version ,
35
41
is_fbcode ,
36
42
is_ROCM ,
37
43
is_sm_at_least_89 ,
@@ -52,15 +58,19 @@ def get_quantization_functions(
52
58
int8_dynamic_activation_int8_weight (act_mapping_type = MappingType .ASYMMETRIC ),
53
59
]
54
60
if do_int4 :
55
- if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6 :
61
+ if check_cpu_version ( device ) :
56
62
base_functions .append (
57
63
int4_weight_only (group_size = 32 , layout = Int4CPULayout ())
58
64
)
65
+ elif check_xpu_version (device ):
66
+ base_functions .append (
67
+ int4_weight_only (group_size = 32 , layout = Int4XPULayout ())
68
+ )
59
69
if int4_zp_int :
60
70
base_functions .append (
61
71
int4_weight_only (
62
72
group_size = 32 ,
63
- layout = Int4CPULayout (),
73
+ layout = Int4XPULayout (),
64
74
zero_point_domain = ZeroPointDomain .INT ,
65
75
)
66
76
)
@@ -77,7 +87,7 @@ def get_quantization_functions(
77
87
)
78
88
base_functions .append (int4_dynamic_activation_int4_weight ())
79
89
80
- if do_sparse :
90
+ if do_sparse and device != "xpu" :
81
91
base_functions .append (
82
92
int8_dynamic_activation_int8_weight (layout = SemiSparseLayout ())
83
93
)
@@ -89,6 +99,10 @@ def get_quantization_functions(
89
99
90
100
91
101
class TestAffineQuantized (TestCase ):
102
+ GPU_DEVICES = (["cuda" ] if torch .cuda .is_available () else []) + (
103
+ ["xpu" ] if torch .xpu .is_available () else []
104
+ )
105
+
92
106
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
93
107
def test_tensor_core_layout_transpose (self ):
94
108
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
@@ -109,51 +123,53 @@ def test_tensor_core_layout_transpose(self):
109
123
aqt_shape = aqt .shape
110
124
self .assertEqual (aqt_shape , shape )
111
125
112
- @unittest .skipIf (not torch . cuda . is_available () , "Need CUDA available" )
113
- @ common_utils . parametrize (
114
- "apply_quant" ,
115
- get_quantization_functions ( is_cusparselt_available , True , "cuda" , True ),
116
- )
117
- @ skip_if_rocm ( "ROCm enablement in progress" )
118
- def test_weights_only ( self , apply_quant ) :
119
- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
120
- if isinstance (apply_quant , AOBaseConfig ):
121
- quantize_ (linear , apply_quant )
122
- ql = linear
123
- else :
124
- # TODO(#1690): delete this once config migration is done
125
- ql = apply_quant (linear )
126
- with tempfile .NamedTemporaryFile () as f :
127
- torch .save (ql .state_dict (), f )
128
- f .seek (0 )
129
- # `weights_only=True` is enabled for torch 2.5+
130
- if TORCH_VERSION_AT_LEAST_2_5 :
131
- _ = torch .load (f , weights_only = True )
132
- else :
133
- _ = torch .load (f , weights_only = False )
134
-
135
- @unittest .skipIf (not torch . cuda . is_available () , "Need CUDA available" )
126
+ @unittest .skipIf (len ( GPU_DEVICES ) == 0 , "Need GPU available" )
127
+ def test_weights_only ( self ):
128
+ for device in self . GPU_DEVICES :
129
+ apply_quant_list = get_quantization_functions (
130
+ is_cusparselt_available , True , device , True
131
+ )
132
+ for apply_quant in apply_quant_list :
133
+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = device )
134
+ if isinstance (apply_quant , AOBaseConfig ):
135
+ quantize_ (linear , apply_quant )
136
+ ql = linear
137
+ else :
138
+ # TODO(#1690): delete this once config migration is done
139
+ ql = apply_quant (linear )
140
+ with tempfile .NamedTemporaryFile () as f :
141
+ torch .save (ql .state_dict (), f )
142
+ f .seek (0 )
143
+ # `weights_only=True` is enabled for torch 2.5+
144
+ if TORCH_VERSION_AT_LEAST_2_5 :
145
+ _ = torch .load (f , weights_only = True )
146
+ else :
147
+ _ = torch .load (f , weights_only = False )
148
+
149
+ @unittest .skipIf (len ( GPU_DEVICES ) == 0 , "Need GPU available" )
136
150
@common_utils .parametrize ("apply_quant" , get_quantization_functions (False , False ))
137
151
def test_to_device (self , apply_quant ):
138
- def _apply (module , config_or_subclass_inserter ):
139
- if isinstance (config_or_subclass_inserter , AOBaseConfig ):
140
- quantize_ (module , config_or_subclass_inserter )
141
- else :
142
- # TODO(#1690): delete this once config migration is done
143
- module = config_or_subclass_inserter (module )
144
- return module
152
+ for device in self .GPU_DEVICES :
145
153
146
- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
147
- ql = _apply (linear , apply_quant )
148
- ql .to ("cuda" )
154
+ def _apply (module , config_or_subclass_inserter ):
155
+ if isinstance (config_or_subclass_inserter , AOBaseConfig ):
156
+ quantize_ (module , config_or_subclass_inserter )
157
+ else :
158
+ # TODO(#1690): delete this once config migration is done
159
+ module = config_or_subclass_inserter (module )
160
+ return module
149
161
150
- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
151
- ql = _apply (linear , apply_quant )
152
- ql .to (device = "cuda" )
162
+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
163
+ ql = _apply (linear , apply_quant )
164
+ ql .to (device )
153
165
154
- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
155
- ql = _apply (linear , apply_quant )
156
- ql .cuda ()
166
+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
167
+ ql = _apply (linear , apply_quant )
168
+ ql .to (device = device )
169
+
170
+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
171
+ ql = _apply (linear , apply_quant )
172
+ ql .to (device )
157
173
158
174
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
159
175
def test_register_new_dispatch (self ):
@@ -203,20 +219,19 @@ def apply_uint6_weight_only_quant(linear):
203
219
204
220
deregister_aqt_quantized_linear_dispatch (dispatch_condition )
205
221
206
- @common_utils .parametrize (
207
- "apply_quant" , get_quantization_functions (is_cusparselt_available , True )
208
- )
209
- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
210
- @skip_if_rocm ("ROCm enablement in progress" )
211
- def test_print_quantized_module (self , apply_quant ):
212
- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
213
- if isinstance (apply_quant , AOBaseConfig ):
214
- quantize_ (linear , apply_quant )
215
- ql = linear
216
- else :
217
- # TODO(#1690): delete this once config migration is done
218
- ql = apply_quant (linear )
219
- assert "AffineQuantizedTensor" in str (ql )
222
+ @unittest .skipIf (len (GPU_DEVICES ) == 0 , "Need GPU available" )
223
+ def test_print_quantized_module (self ):
224
+ for device in self .GPU_DEVICES :
225
+ apply_quant_list = get_quantization_functions (True , True , device , True )
226
+ for apply_quant in apply_quant_list :
227
+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = device )
228
+ if isinstance (apply_quant , AOBaseConfig ):
229
+ quantize_ (linear , apply_quant )
230
+ ql = linear
231
+ else :
232
+ # TODO(#1690): delete this once config migration is done
233
+ ql = apply_quant (linear )
234
+ assert "AffineQuantizedTensor" in str (ql )
220
235
221
236
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
222
237
@common_utils .parametrize (
@@ -267,7 +282,11 @@ def test_copy__mismatch_metadata(self, apply_quant):
267
282
268
283
269
284
class TestAffineQuantizedBasic (TestCase ):
270
- COMMON_DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
285
+ COMMON_DEVICES = (
286
+ ["cpu" ]
287
+ + (["cuda" ] if torch .cuda .is_available () else [])
288
+ + (["xpu" ] if torch .xpu .is_available () else [])
289
+ )
271
290
COMMON_DTYPES = [torch .bfloat16 ]
272
291
273
292
@common_utils .parametrize ("device" , COMMON_DEVICES )
0 commit comments