1
1
# RUN: %PYTHON %s | FileCheck %s
2
2
3
+ import numpy as np
3
4
from mlir .ir import *
4
5
from mlir .dialects import quant
5
6
@@ -18,21 +19,28 @@ def test_type_hierarchy():
18
19
any = Type .parse ("!quant.any<i8<-8:7>:f32>" )
19
20
uniform = Type .parse ("!quant.uniform<i8<-8:7>:f32, 0.99872:127>" )
20
21
per_axis = Type .parse ("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>" )
22
+ sub_channel = Type .parse (
23
+ "!quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>"
24
+ )
21
25
calibrated = Type .parse ("!quant.calibrated<f32<-0.998:1.2321>>" )
22
26
23
27
assert not quant .QuantizedType .isinstance (i8 )
24
28
assert quant .QuantizedType .isinstance (any )
25
29
assert quant .QuantizedType .isinstance (uniform )
26
30
assert quant .QuantizedType .isinstance (per_axis )
31
+ assert quant .QuantizedType .isinstance (sub_channel )
27
32
assert quant .QuantizedType .isinstance (calibrated )
28
33
29
34
assert quant .AnyQuantizedType .isinstance (any )
30
35
assert quant .UniformQuantizedType .isinstance (uniform )
31
36
assert quant .UniformQuantizedPerAxisType .isinstance (per_axis )
37
+ assert quant .UniformQuantizedSubChannelType .isinstance (sub_channel )
32
38
assert quant .CalibratedQuantizedType .isinstance (calibrated )
33
39
34
40
assert not quant .AnyQuantizedType .isinstance (uniform )
35
41
assert not quant .UniformQuantizedType .isinstance (per_axis )
42
+ assert not quant .UniformQuantizedType .isinstance (sub_channel )
43
+ assert not quant .UniformQuantizedPerAxisType .isinstance (sub_channel )
36
44
37
45
38
46
# CHECK-LABEL: TEST: test_any_quantized_type
@@ -121,6 +129,47 @@ def test_uniform_per_axis_type():
121
129
assert per_axis == Type .parse ("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>" )
122
130
123
131
132
+ # CHECK-LABEL: TEST: test_uniform_sub_channel_type
133
+ @run
134
+ def test_uniform_sub_channel_type ():
135
+ with Context ():
136
+ i8 = IntegerType .get_signless (8 )
137
+ f32 = F32Type .get ()
138
+ sub_channel = quant .UniformQuantizedSubChannelType .get (
139
+ quant .QuantizedType .FLAG_SIGNED ,
140
+ i8 ,
141
+ f32 ,
142
+ DenseElementsAttr .get (
143
+ np .asarray ([2.0 , 3.0 , 4.0 , 5.0 ], np .float32 ).reshape (2 , 2 )
144
+ ),
145
+ DenseElementsAttr .get (np .asarray ([10 , 20 , 30 , 40 ], np .int8 ).reshape (2 , 2 )),
146
+ [0 , 1 ],
147
+ [1 , 2 ],
148
+ storage_type_min = quant .QuantizedType .default_minimum_for_integer (
149
+ is_signed = True , integral_width = 8
150
+ ),
151
+ storage_type_max = quant .QuantizedType .default_maximum_for_integer (
152
+ is_signed = True , integral_width = 8
153
+ ),
154
+ )
155
+
156
+ # CHECK: quantized dimensions: [0, 1]
157
+ print (f"quantized dimensions: { sub_channel .quantized_dimensions } " )
158
+ # CHECK: block sizes: [1, 2]
159
+ print (f"block sizes: { sub_channel .block_sizes } " )
160
+ # CHECK: scales: {{\[}}[2. 3.]
161
+ # CHECK: [4. 5.]]
162
+ print (f"scales: { np .asarray (sub_channel .scales )} " )
163
+ # CHECK: zero-points: {{\[}}[10 20]
164
+ # CHECK: [30 40]]
165
+ print (f"zero-points: { np .asarray (sub_channel .zero_points )} " )
166
+ # CHECK: !quant.uniform<i8:f32:{0:1,1:2}, {{\{}}{2.000000e+00:10, 3.000000e+00:20}, {4.000000e+00:30, 5.000000e+00:40}}>
167
+ print (sub_channel )
168
+ assert sub_channel == Type .parse (
169
+ "!quant.uniform<i8:f32:{0:1,1:2},{{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>"
170
+ )
171
+
172
+
124
173
# CHECK-LABEL: TEST: test_calibrated_type
125
174
@run
126
175
def test_calibrated_type ():
0 commit comments