Skip to content

Commit b09432f

Browse files
committed
address review comments I
1 parent 6d27e45 commit b09432f

File tree

8 files changed

+34
-24
lines changed

8 files changed

+34
-24
lines changed

mlir/include/mlir/Dialect/Quant/Transforms/Passes.td

+6-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> {
3131
];
3232
}
3333

34-
def NormalizeQuantTypes : Pass<"normalize-quant-types"> {
34+
def NormalizeQuantTypes : Pass<"normalize-quant-types", "func::FuncOp"> {
3535
let summary = "Normalize generic quantized types to specific quantized types";
3636
let description = [{
3737
This pass converts generic quantized types in the `quant` dialect to more
@@ -60,6 +60,11 @@ def NormalizeQuantTypes : Pass<"normalize-quant-types"> {
6060
-> `!quant.uniform<i8:f32, 2.0>`
6161
* `tensor<?x?x!quant.uniform<i8:f32:{0:1, 0:4}, {{2.0}}>>`
6262
-> `tensor<?x?x!quant.uniform<i8:f32, 2.0>>`
63+
64+
The rationale for these conversions is that the decompositions / handling of
65+
more precise quantized types tends to be more efficient than treating
66+
everything as subchannel type.
67+
6368
}];
6469
let dependentDialects = ["func::FuncDialect", "quant::QuantDialect"];
6570
}

mlir/lib/Dialect/Quant/IR/QuantOps.cpp

+13-6
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,19 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
167167
return op->emitError(
168168
"expressed type in quantized type expected to match float type");
169169

170-
// Veriy integrity of per-axis quantization information, if present.
170+
// Verify integrity of per-axis quantization information, if present.
171171
if (auto quantizedPerAxisType =
172172
dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
173173
return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType);
174-
} else if (auto quantizedSubChannelType =
175-
dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
174+
}
175+
176+
if (auto quantizedSubChannelType =
177+
dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
176178
return verifySubChannelQuantization(op, quantizedSubChannelType,
177179
containerType);
178180
}
179181

182+
// At this point the type is UniformQuantizedType
180183
return success();
181184
}
182185

@@ -268,14 +271,18 @@ LogicalResult StorageCastOp::verify() {
268271
// the quantization type may appear in the input or the result, their tensor
269272
// shapes are guaranteed to be identical at this point.
270273
if (auto quantizedPerAxisType =
271-
dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
274+
dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
272275
return verifyPerAxisQuantization(*this, quantizedPerAxisType,
273276
getInput().getType());
274-
else if (auto quantizedSunChannelType =
275-
dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
277+
}
278+
279+
if (auto quantizedSunChannelType =
280+
dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
276281
return verifySubChannelQuantization(*this, quantizedSunChannelType,
277282
getInput().getType());
283+
}
278284

285+
// At this point the type is UniformQuantizedType
279286
return success();
280287
}
281288

mlir/lib/Dialect/Quant/IR/TypeParser.cpp

+6-8
Original file line numberDiff line numberDiff line change
@@ -518,12 +518,10 @@ static void
518518
printBlockSizeInfo(ArrayRef<std::pair<int32_t, int64_t>> blockSizeInfo,
519519
DialectAsmPrinter &out) {
520520
out << "{";
521-
llvm::interleave(
522-
llvm::seq<size_t>(0, blockSizeInfo.size()), out,
523-
[&](size_t index) {
521+
llvm::interleaveComma(
522+
llvm::seq<size_t>(0, blockSizeInfo.size()), out, [&](size_t index) {
524523
out << blockSizeInfo[index].first << ":" << blockSizeInfo[index].second;
525-
},
526-
",");
524+
});
527525
out << "}";
528526
}
529527

@@ -593,7 +591,7 @@ void printDenseQuantizationParameters(ArrayRef<APFloat> scales,
593591
SmallVector<unsigned, 4> counter(rank, 0);
594592
unsigned openBrackets = 0;
595593

596-
auto bumpCounter = [&]() {
594+
auto incrementCounterAndDelimit = [&]() {
597595
++counter[rank - 1];
598596
for (unsigned i = rank - 1; i > 0; --i) {
599597
if (counter[i] >= shape[i]) {
@@ -605,7 +603,7 @@ void printDenseQuantizationParameters(ArrayRef<APFloat> scales,
605603
}
606604
};
607605

608-
for (unsigned idx = 0, e = scales.size(); idx != e; ++idx) {
606+
for (unsigned idx = 0, e = scales.size(); idx < e; ++idx) {
609607
if (idx != 0)
610608
out << ", ";
611609
while (openBrackets++ < rank)
@@ -615,7 +613,7 @@ void printDenseQuantizationParameters(ArrayRef<APFloat> scales,
615613
if (zeroPoints[idx] != 0) {
616614
out << ":" << zeroPoints[idx];
617615
}
618-
bumpCounter();
616+
incrementCounterAndDelimit();
619617
}
620618
while (openBrackets-- > 0)
621619
out << '}';

mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ class NormalizeQuantTypes
145145
public:
146146
void runOnOperation() override {
147147

148-
auto moduleOp = cast<ModuleOp>(getOperation());
149148
auto *context = &getContext();
150149

151150
NormalizedQuantTypesConverter typeConverter;
@@ -168,7 +167,8 @@ class NormalizeQuantTypes
168167
patterns.add<ConvertGenericOpwithSubChannelType>(typeConverter, context);
169168

170169
// Apply conversion
171-
if (failed(applyFullConversion(moduleOp, target, std::move(patterns))))
170+
if (failed(
171+
applyFullConversion(getOperation(), target, std::move(patterns))))
172172
signalPassFailure();
173173
}
174174
};

mlir/test/CAPI/quant.c

+2-2
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ void testUniformSubChannelType(MlirContext ctx) {
210210

211211
MlirType subChannelParsed =
212212
mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(
213-
"!quant.uniform<i8:f32:{0:1,1:2}, "
213+
"!quant.uniform<i8:f32:{0:1, 1:2}, "
214214
"{{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>"));
215215

216216
MlirType i8 = mlirIntegerTypeGet(ctx, 8);
@@ -321,7 +321,7 @@ void testUniformSubChannelType(MlirContext ctx) {
321321
// CHECK: equal: 1
322322
fprintf(stderr, "equal: %d\n", mlirTypeEqual(subChannel, subChannelParsed));
323323

324-
// CHECK: !quant.uniform<i8:f32:{0:1,1:2},
324+
// CHECK: !quant.uniform<i8:f32:{0:1, 1:2},
325325
// {{.*}}2.000000e+00:10, 3.000000e+00:20},
326326
// {4.000000e+00:30, 5.000000e+00:40{{.*}}}}>
327327
mlirTypeDump(subChannel);

mlir/test/Dialect/Quant/Bytecode/types.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,6 @@ module @parseUniformPerAxisMixed attributes {
7070

7171
// CHECK-LABEL: parseUniformSubChannel
7272
module @parseUniformSubChannel attributes {
73-
// CHECK: !quant.uniform<i8:f32:{0:1,1:2}, {{\{}}{2.000000e+00:10, 3.000000e+00:20}, {4.000000e+00:30, 5.000000e+00:40}}>
73+
// CHECK: !quant.uniform<i8:f32:{0:1, 1:2}, {{\{}}{2.000000e+00:10, 3.000000e+00:20}, {4.000000e+00:30, 5.000000e+00:40}}>
7474
bytecode.test = !quant.uniform<i8:f32:{0:1, 1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
7575
} {}

mlir/test/Dialect/Quant/lower-quant-ops.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias>
535535
// CHECK: linalg.yield %[[STORED_INT]] : i8
536536
// CHECK: } -> tensor<2x?x?x4xi8>
537537

538-
// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<2x?x?x4xi8> to tensor<2x?x?x4x!quant.uniform<i8:f32:{0:1,3:2}, {{.*}}2.000000e+00:10, 3.000000e+00:20{{.*}}, {{.*}}4.000000e+00:30, 5.000000e+00:40{{.*}}>>
538+
// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<2x?x?x4xi8> to tensor<2x?x?x4x!quant.uniform<i8:f32:{0:1, 3:2}, {{.*}}2.000000e+00:10, 3.000000e+00:20{{.*}}, {{.*}}4.000000e+00:30, 5.000000e+00:40{{.*}}>>
539539
// CHECK: return %[[STORED_QUANT]]
540540

541541
!qalias = !quant.uniform<i8:f32:{0:1, 3:2}, {{{{2.0:10, 3.0:20}}}, {{{4.0:30, 5.0:40}}}}>
@@ -565,7 +565,7 @@ func.func @qcast_sub_channel_ranked(%arg0: tensor<2x?x?x4xf32>) -> tensor<2x?x?x
565565
// CHECK: linalg.yield %[[STORED_INT]] : i8
566566
// CHECK: } -> tensor<2x3x5x4xi8>
567567

568-
// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<2x3x5x4xi8> to tensor<2x3x5x4x!quant.uniform<i8:f32:{0:1,3:2}, {{.*}}2.000000e+00:10, 3.000000e+00:20{{.*}}, {{.*}}4.000000e+00:30, 5.000000e+00:40{{.*}}>>
568+
// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<2x3x5x4xi8> to tensor<2x3x5x4x!quant.uniform<i8:f32:{0:1, 3:2}, {{.*}}2.000000e+00:10, 3.000000e+00:20{{.*}}, {{.*}}4.000000e+00:30, 5.000000e+00:40{{.*}}>>
569569
// CHECK: return %[[STORED_QUANT]]
570570

571571
!qalias = !quant.uniform<i8:f32:{0:1, 3:2}, {{{{2.0:10, 3.0:20}}}, {{{4.0:30, 5.0:40}}}}>

mlir/test/Dialect/Quant/parse-uniform.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ func.func @parse() -> !qalias {
157157

158158
// -----
159159
// Sub-channel scales and zero points (mixed affine and fixedpoint)
160-
// CHECK: !quant.uniform<u8:f32:{0:1,1:2}, {{\{}}{2.000000e+00:120, 3.000000e+00:127}, {4.000000e+00, 5.000000e+00}}>
161-
!qalias = !quant.uniform<u8:f32:{0:1,1:2}, {{2.0:120,3.0:127}, {4.0,5.0}}>
160+
// CHECK: !quant.uniform<u8:f32:{0:1, 1:2}, {{\{}}{2.000000e+00:120, 3.000000e+00:127}, {4.000000e+00, 5.000000e+00}}>
161+
!qalias = !quant.uniform<u8:f32:{0:1, 1:2}, {{2.0:120,3.0:127}, {4.0,5.0}}>
162162
func.func @parse() -> !qalias {
163163
%0 = "foo"() : () -> !qalias
164164
return %0 : !qalias

0 commit comments

Comments
 (0)