Skip to content

Commit 95e7df2

Browse files
committed
Addings c-api and py-apis
1 parent 50ff818 commit 95e7df2

File tree

5 files changed

+75
-2
lines changed

5 files changed

+75
-2
lines changed

Diff for: mlir/lib/Bindings/Python/DialectQuant.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
#include <cstdint>
1010
#include <vector>
1111

12+
#include "mlir-c/BuiltinAttributes.h"
1213
#include "mlir-c/Dialect/Quant.h"
1314
#include "mlir-c/IR.h"
14-
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1515
#include "mlir/Bindings/Python/Nanobind.h"
16+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1617

1718
namespace nb = nanobind;
1819
using namespace llvm;

Diff for: mlir/lib/CAPI/Dialect/Quant.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir-c/Dialect/Quant.h"
10+
#include "mlir-c/BuiltinAttributes.h"
1011
#include "mlir/CAPI/Registration.h"
1112
#include "mlir/Dialect/Quant/IR/Quant.h"
1213
#include "mlir/Dialect/Quant/IR/QuantTypes.h"

Diff for: mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi

+21-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55

6-
from mlir.ir import Type
6+
from mlir.ir import DenseElementsAttr, Type
77

88
__all__ = [
99
"QuantizedType",
@@ -109,6 +109,26 @@ class UniformQuantizedPerAxisType(QuantizedType):
109109
@property
110110
def is_fixed_point(self) -> bool: ...
111111

112+
class UniformQuantizedSubChannelType(QuantizedType):
113+
114+
@classmethod
115+
def get(cls, flags: int, storage_type: Type, expressed_type: Type,
116+
scales: DenseElementsAttr, zero_points: DenseElementsAttr,
117+
quantized_dimensions: list[int], block_sizes: list[int],
118+
storage_type_min: int, storage_type_max: int):
119+
...
120+
121+
@property
122+
def quantized_dimensions(self) -> list[int]: ...
123+
124+
@property
125+
def block_sizes(self) -> list[int]: ...
126+
127+
@property
128+
def scales(self) -> DenseElementsAttr: ...
129+
130+
@property
131+
def zero_points(self) -> DenseElementsAttr: ...
112132

113133
def CalibratedQuantizedType(QuantizedType):
114134

Diff for: mlir/test/CAPI/quant.c

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// RUN: mlir-capi-quant-test 2>&1 | FileCheck %s
1111

1212
#include "mlir-c/Dialect/Quant.h"
13+
#include "mlir-c/BuiltinAttributes.h"
1314
#include "mlir-c/BuiltinTypes.h"
1415
#include "mlir-c/IR.h"
1516

@@ -357,6 +358,7 @@ int main(void) {
357358
testAnyQuantizedType(ctx);
358359
testUniformType(ctx);
359360
testUniformPerAxisType(ctx);
361+
testUniformSubChannelType(ctx);
360362
testCalibratedType(ctx);
361363
mlirContextDestroy(ctx);
362364
return EXIT_SUCCESS;

Diff for: mlir/test/python/dialects/quant.py

+49
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

3+
import numpy as np
34
from mlir.ir import *
45
from mlir.dialects import quant
56

@@ -18,21 +19,28 @@ def test_type_hierarchy():
1819
any = Type.parse("!quant.any<i8<-8:7>:f32>")
1920
uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
2021
per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
22+
sub_channel = Type.parse(
23+
"!quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>"
24+
)
2125
calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
2226

2327
assert not quant.QuantizedType.isinstance(i8)
2428
assert quant.QuantizedType.isinstance(any)
2529
assert quant.QuantizedType.isinstance(uniform)
2630
assert quant.QuantizedType.isinstance(per_axis)
31+
assert quant.QuantizedType.isinstance(sub_channel)
2732
assert quant.QuantizedType.isinstance(calibrated)
2833

2934
assert quant.AnyQuantizedType.isinstance(any)
3035
assert quant.UniformQuantizedType.isinstance(uniform)
3136
assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
37+
assert quant.UniformQuantizedSubChannelType.isinstance(sub_channel)
3238
assert quant.CalibratedQuantizedType.isinstance(calibrated)
3339

3440
assert not quant.AnyQuantizedType.isinstance(uniform)
3541
assert not quant.UniformQuantizedType.isinstance(per_axis)
42+
assert not quant.UniformQuantizedType.isinstance(sub_channel)
43+
assert not quant.UniformQuantizedPerAxisType.isinstance(sub_channel)
3644

3745

3846
# CHECK-LABEL: TEST: test_any_quantized_type
@@ -121,6 +129,47 @@ def test_uniform_per_axis_type():
121129
assert per_axis == Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
122130

123131

132+
# CHECK-LABEL: TEST: test_uniform_sub_channel_type
133+
@run
134+
def test_uniform_sub_channel_type():
135+
with Context():
136+
i8 = IntegerType.get_signless(8)
137+
f32 = F32Type.get()
138+
sub_channel = quant.UniformQuantizedSubChannelType.get(
139+
quant.QuantizedType.FLAG_SIGNED,
140+
i8,
141+
f32,
142+
DenseElementsAttr.get(
143+
np.asarray([2.0, 3.0, 4.0, 5.0], np.float32).reshape(2, 2)
144+
),
145+
DenseElementsAttr.get(np.asarray([10, 20, 30, 40], np.int8).reshape(2, 2)),
146+
[0, 1],
147+
[1, 2],
148+
storage_type_min=quant.QuantizedType.default_minimum_for_integer(
149+
is_signed=True, integral_width=8
150+
),
151+
storage_type_max=quant.QuantizedType.default_maximum_for_integer(
152+
is_signed=True, integral_width=8
153+
),
154+
)
155+
156+
# CHECK: quantized dimensions: [0, 1]
157+
print(f"quantized dimensions: {sub_channel.quantized_dimensions}")
158+
# CHECK: block sizes: [1, 2]
159+
print(f"block sizes: {sub_channel.block_sizes}")
160+
# CHECK: scales: {{\[}}[2. 3.]
161+
# CHECK: [4. 5.]]
162+
print(f"scales: {np.asarray(sub_channel.scales)}")
163+
# CHECK: zero-points: {{\[}}[10 20]
164+
# CHECK: [30 40]]
165+
print(f"zero-points: {np.asarray(sub_channel.zero_points)}")
166+
# CHECK: !quant.uniform<i8:f32:{0:1,1:2}, {{\{}}{2.000000e+00:10, 3.000000e+00:20}, {4.000000e+00:30, 5.000000e+00:40}}>
167+
print(sub_channel)
168+
assert sub_channel == Type.parse(
169+
"!quant.uniform<i8:f32:{0:1,1:2},{{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>"
170+
)
171+
172+
124173
# CHECK-LABEL: TEST: test_calibrated_type
125174
@run
126175
def test_calibrated_type():

0 commit comments

Comments
 (0)