Skip to content

Commit ce179a5

Browse files
committed
Add pass to normalize generic quantized types to specific quantized types
1 parent a6426bb commit ce179a5

File tree

4 files changed

+264
-0
lines changed

4 files changed

+264
-0
lines changed

mlir/include/mlir/Dialect/Quant/Transforms/Passes.td

+33
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,39 @@ def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> {
3131
];
3232
}
3333

34+
def NormalizeQuantTypes : Pass<"normalize-quant-types"> {
35+
let summary = "Normalize generic quantized types to specific quantized types";
36+
let description = [{
37+
This pass converts generic quantized types in the `quant` dialect to more
38+
specific types when possible.
39+
40+
The following conversions are performed:
41+
42+
1. Sub-channel to per-axis: If the shape of the scales tensor of sub-channel
43+
quantized type has all but one non-one value, it is converted to a
44+
per-axis quantized type.
45+
46+
For example:
47+
48+
* `!quant.uniform<i8:f32:{0:1}, {{2.0}, {3.0}}>`
49+
-> `!quant.uniform<i8:f32:0, {2.0, 3.0}>`
50+
* `tensor<?x?x!quant.uniform<i8:f32:{0:1,1:4}, {{2.0}, {3.0}}>>`
51+
-> `tensor<?x?x!quant.uniform<i8:f32:0, {2.0, 3.0}>>`
52+
53+
2. Sub-channel to per-tensor: If a sub-channel quantized type has only
54+
one scale or zero-point, it is converted to a per-tensor
55+
quantized type.
56+
57+
For example:
58+
59+
* `!quant.uniform<i8:f32:{}, {{2.0}}>`
60+
-> `!quant.uniform<i8:f32, 2.0>`
61+
* `tensor<?x?x!quant.uniform<i8:f32:{0:1, 0:4}, {{2.0}}>>`
62+
-> `tensor<?x?x!quant.uniform<i8:f32, 2.0>>`
63+
}];
64+
let dependentDialects = ["func::FuncDialect", "quant::QuantDialect"];
65+
}
66+
3467
def StripFuncQuantTypes : Pass<"strip-func-quant-types"> {
3568
let summary = "Strip quantized types from function headers";
3669
let description = [{

mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_dialect_library(MLIRQuantTransforms
22
LowerQuantOps.cpp
3+
NormalizeQuantTypes.cpp
34
StripFuncQuantTypes.cpp
45

56
ADDITIONAL_HEADER_DIRS
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
//===- NormalizeQuantTypes.cpp - Normalize quantized types
2+
//----------------------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// Normalize generic quantized types to specific quantized types
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Func/IR/FuncOps.h"
15+
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
16+
#include "mlir/Dialect/Quant/IR/Quant.h"
17+
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
18+
#include "mlir/Dialect/Quant/Transforms/Passes.h"
19+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
20+
#include "mlir/Transforms/DialectConversion.h"
21+
22+
namespace mlir {
23+
namespace quant {
24+
25+
#define GEN_PASS_DEF_NORMALIZEQUANTTYPES
26+
#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
27+
28+
namespace {
29+
30+
/// Returns true if the given sub-channel quantized type is convertible to a
31+
/// per-tensor quantized type. This is true if the sub-channel type has only
32+
/// one scale and one zero point.
33+
///
34+
/// Assumes that `tensorType` is a tensor with element type
35+
/// `quant::UniformQuantizedSubChannelType`.
36+
static bool isConvertibleToPerTensor(TensorType tensorType) {
37+
return cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
38+
.getScales()
39+
.getType()
40+
.getNumElements() == 1;
41+
}
42+
43+
/// Returns true if the given sub-channel quantized type is convertible to a
44+
/// per-axis quantized type. This is true if the shape of the scales tensor has
45+
/// all but one non-one value.
46+
///
47+
/// Assumes that `tensorType` is a tensor with element type
48+
/// `quant::UniformQuantizedSubChannelType`.
49+
static bool isConvertibleToPerAxis(TensorType tensorType) {
50+
auto shape = cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
51+
.getScales()
52+
.getType()
53+
.getShape();
54+
return llvm::count_if(shape, [](int64_t dim) { return dim != 1; }) == 1;
55+
}
56+
57+
/// This class defines a type converter that converts sub-channel quantized
58+
/// types to per-tensor or per-axis quantized types whenever possible.
59+
class NormalizedQuantTypesConverter : public TypeConverter {
60+
61+
static Type convertType(Type type) {
62+
auto tensorType = dyn_cast<TensorType>(type);
63+
if (!tensorType) {
64+
return type;
65+
}
66+
67+
auto subChannelType =
68+
dyn_cast<UniformQuantizedSubChannelType>(tensorType.getElementType());
69+
if (!subChannelType) {
70+
return type;
71+
}
72+
73+
if (isConvertibleToPerTensor(tensorType)) {
74+
double scale =
75+
subChannelType.getScales().getValues<APFloat>()[0].convertToDouble();
76+
int64_t zeroPoint =
77+
subChannelType.getZeroPoints().getValues<APInt>()[0].getSExtValue();
78+
auto perTensorType = UniformQuantizedType::get(
79+
subChannelType.getFlags(), subChannelType.getStorageType(),
80+
subChannelType.getExpressedType(), scale, zeroPoint,
81+
subChannelType.getStorageTypeMin(),
82+
subChannelType.getStorageTypeMax());
83+
return tensorType.clone(perTensorType);
84+
}
85+
86+
if (isConvertibleToPerAxis(tensorType)) {
87+
auto shape = subChannelType.getScales().getType().getShape();
88+
auto quantizedDimItr =
89+
llvm::find_if(shape, [](int64_t dim) { return dim != 1; });
90+
auto scales = llvm::to_vector(llvm::map_range(
91+
subChannelType.getScales().getValues<APFloat>(),
92+
[](APFloat scale) { return scale.convertToDouble(); }));
93+
auto zeroPoints = llvm::to_vector(llvm::map_range(
94+
subChannelType.getZeroPoints().getValues<APInt>(),
95+
[](APInt zeroPoint) { return zeroPoint.getSExtValue(); }));
96+
auto perAxisType = UniformQuantizedPerAxisType::get(
97+
subChannelType.getFlags(), subChannelType.getStorageType(),
98+
subChannelType.getExpressedType(), scales, zeroPoints,
99+
quantizedDimItr - shape.begin(), subChannelType.getStorageTypeMin(),
100+
subChannelType.getStorageTypeMax());
101+
return tensorType.clone(perAxisType);
102+
}
103+
return type;
104+
}
105+
106+
public:
107+
explicit NormalizedQuantTypesConverter() { addConversion(convertType); }
108+
};
109+
110+
/// This class implements a conversion pattern that converts any generic
111+
/// operation with sub-channel quantized types to an equivalent operation with
112+
/// per-tensor or per-axis quantized types.
113+
class ConvertGenericOpwithSubChannelType : public ConversionPattern {
114+
public:
115+
ConvertGenericOpwithSubChannelType(TypeConverter &typeConverter,
116+
MLIRContext *context)
117+
: ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
118+
119+
LogicalResult
120+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
121+
ConversionPatternRewriter &rewriter) const final {
122+
SmallVector<Type> resultTypes;
123+
if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
124+
return failure();
125+
126+
auto *newOp = Operation::create(
127+
op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
128+
op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
129+
for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
130+
Region &before = std::get<0>(regions);
131+
Region &parent = std::get<1>(regions);
132+
rewriter.inlineRegionBefore(before, parent, parent.end());
133+
if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
134+
return failure();
135+
}
136+
rewriter.insert(newOp);
137+
rewriter.replaceOp(op, newOp->getResults());
138+
return success();
139+
}
140+
};
141+
142+
// Conversion pass
143+
class NormalizeQuantTypes
144+
: public impl::NormalizeQuantTypesBase<NormalizeQuantTypes> {
145+
public:
146+
void runOnOperation() override {
147+
148+
auto moduleOp = cast<ModuleOp>(getOperation());
149+
auto *context = &getContext();
150+
151+
NormalizedQuantTypesConverter typeConverter;
152+
ConversionTarget target(*context);
153+
154+
// Determine legal operations.
155+
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
156+
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
157+
typeConverter.isLegal(&op.getBody());
158+
});
159+
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
160+
return typeConverter.isLegal(op->getOperandTypes()) &&
161+
typeConverter.isLegal(op->getResultTypes());
162+
});
163+
164+
// Register conversion patterns
165+
RewritePatternSet patterns(context);
166+
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
167+
patterns, typeConverter);
168+
patterns.add<ConvertGenericOpwithSubChannelType>(typeConverter, context);
169+
170+
// Apply conversion
171+
if (failed(applyFullConversion(moduleOp, target, std::move(patterns))))
172+
signalPassFailure();
173+
}
174+
};
175+
176+
} // namespace
177+
178+
} // namespace quant
179+
} // namespace mlir
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: mlir-opt %s --normalize-quant-types --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: @callee(
4+
// CHECK-SAME: [[PER_TENSOR:tensor<\?x\?x!quant.uniform<i8:f32, 2.000000e\+00:127>>]],
5+
// CHECK-SAME: [[PER_TENSOR]]
6+
// CHECK-SAME: ([[PER_TENSOR]], [[PER_TENSOR]])
7+
// CHECK-LABEL: @normalize_quant_types_to_per_tensor
8+
// CHECK-SAME: %[[ARG_0:.*]]: [[PER_TENSOR:tensor<\?x\?x!quant.uniform<i8:f32, 2.000000e\+00:127>>]],
9+
// CHECK-SAME: %[[ARG_1:.*]]: [[PER_TENSOR]]
10+
// CHECK-SAME: ([[PER_TENSOR]], [[PER_TENSOR]])
11+
// CHECK: %[[TEMP_0:.*]] = "test.custom_op"(%[[ARG_0]]) : ([[PER_TENSOR]]) -> [[PER_TENSOR]]
12+
// CHECK: %[[TEMP_1:.*]] = "test.custom_op"(%[[ARG_1]]) : ([[PER_TENSOR]]) -> [[PER_TENSOR]]
13+
// CHECK: %[[TEMP_3:.*]]:2 = call @callee(%[[TEMP_0]], %[[TEMP_1]])
14+
// CHECK: return %[[TEMP_3]]#0, %[[TEMP_3]]#1 : [[PER_TENSOR]], [[PER_TENSOR]]
15+
16+
!qalias1 = !quant.uniform<i8:f32:{}, {{2.0:127}}>
17+
!qalias2 = !quant.uniform<i8:f32:{0:1,1:4}, {{2.0:127}}>
18+
19+
func.func private @callee(tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>)
20+
21+
func.func @normalize_quant_types_to_per_tensor(%arg0: tensor<?x?x!qalias1>,
22+
%arg1: tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) {
23+
%0 = "test.custom_op"(%arg0) : (tensor<?x?x!qalias1>) -> tensor<?x?x!qalias1>
24+
%1 = "test.custom_op"(%arg1) : (tensor<?x?x!qalias2>) -> tensor<?x?x!qalias2>
25+
%3:2 = func.call @callee(%0, %1) : (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>)
26+
return %3#0, %3#1 : tensor<?x?x!qalias1>, tensor<?x?x!qalias2>
27+
}
28+
29+
// -----
30+
31+
// CHECK-LABEL: @normalize_quant_types_to_per_axis
32+
// CHECK-SAME: %[[ARG_0:.*]]: [[PER_AXIS:tensor<\?x\?x!quant.uniform<i8:f32:0, \{2.000000e\+00:127,3.000000e\+00:127\}>>]],
33+
// CHECK-SAME: %[[ARG_1:.*]]: [[PER_AXIS]]
34+
// CHECK-SAME: ([[PER_AXIS]], [[PER_AXIS]])
35+
// CHECK: %[[TEMP_0:.*]] = "test.custom_op"(%[[ARG_0]]) : ([[PER_AXIS]]) -> [[PER_AXIS]]
36+
// CHECK: %[[TEMP_1:.*]] = "test.custom_op"(%[[ARG_1]]) : ([[PER_AXIS]]) -> [[PER_AXIS]]
37+
// CHECK: %[[TEMP_3:.*]]:2 = call @callee(%[[TEMP_0]], %[[TEMP_1]])
38+
// CHECK: return %[[TEMP_3]]#0, %[[TEMP_3]]#1 : [[PER_AXIS]], [[PER_AXIS]]
39+
40+
!qalias1 = !quant.uniform<i8:f32:{0:1}, {{2.0:127}, {3.0:127}}>
41+
!qalias2 = !quant.uniform<i8:f32:{0:1,1:4}, {{2.0:127}, {3.0:127}}>
42+
43+
func.func private @callee(tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>)
44+
45+
func.func @normalize_quant_types_to_per_axis(%arg0: tensor<?x?x!qalias1>,
46+
%arg1: tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) {
47+
%0 = "test.custom_op"(%arg0) : (tensor<?x?x!qalias1>) -> tensor<?x?x!qalias1>
48+
%1 = "test.custom_op"(%arg1) : (tensor<?x?x!qalias2>) -> tensor<?x?x!qalias2>
49+
%3:2 = func.call @callee(%0, %1) : (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>)
50+
return %3#0, %3#1 : tensor<?x?x!qalias1>, tensor<?x?x!qalias2>
51+
}

0 commit comments

Comments
 (0)