Skip to content

Commit f10302e

Browse files
[mlir] Require folders to produce Values of same type (#75887)
This commit adds extra assertions to `OperationFolder` and `OpBuilder` to ensure that the types of the folded SSA values match with the result types of the op. There used to be checks that discard the folded results if the types do not match. This commit makes these checks stricter and turns them into assertions. Discarding folded results with the wrong type (without failing explicitly) can hide bugs in op folders. Two such bugs became apparent in MLIR (and some more in downstream projects) and are fixed with this change. Note: The existing type checks were introduced in https://reviews.llvm.org/D95991. Migration guide: If you see failing assertions (`folder produced value of incorrect type`; make sure to run with assertions enabled!), run with `-debug` or dump the operation right before the failing assertion. This will point you to the op that has the broken folder. A common mistake is a mismatch between static/dynamic dimensions (e.g., input has a static dimension but folded result has a dynamic dimension).
1 parent 560564f commit f10302e

File tree

11 files changed

+36
-48
lines changed

11 files changed

+36
-48
lines changed

flang/lib/Optimizer/Dialect/FIROps.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -625,11 +625,13 @@ void fir::BoxAddrOp::build(mlir::OpBuilder &builder,
625625
mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) {
626626
if (auto *v = getVal().getDefiningOp()) {
627627
if (auto box = mlir::dyn_cast<fir::EmboxOp>(v)) {
628-
if (!box.getSlice()) // Fold only if not sliced
628+
// Fold only if not sliced
629+
if (!box.getSlice() && box.getMemref().getType() == getType())
629630
return box.getMemref();
630631
}
631632
if (auto box = mlir::dyn_cast<fir::EmboxCharOp>(v))
632-
return box.getMemref();
633+
if (box.getMemref().getType() == getType())
634+
return box.getMemref();
633635
}
634636
return {};
635637
}

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -1352,9 +1352,11 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
13521352
setOperand(src);
13531353
return getResult();
13541354
}
1355+
13551356
// trunci(zexti(a)) -> a
13561357
// trunci(sexti(a)) -> a
1357-
return src;
1358+
if (srcType == dstType)
1359+
return src;
13581360
}
13591361

13601362
// trunci(trunci(a)) -> trunci(a))

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,8 @@ OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
771771
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
772772
if (!inputTy.hasRank()) \
773773
return {}; \
774+
if (inputTy != getType()) \
775+
return {}; \
774776
if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
775777
return getInput(); \
776778
return {}; \

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -1602,9 +1602,10 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
16021602
return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
16031603
: 0;
16041604
};
1605+
16051606
// If splat or broadcast from a scalar, just return the source scalar.
16061607
unsigned broadcastSrcRank = getRank(source.getType());
1607-
if (broadcastSrcRank == 0)
1608+
if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
16081609
return source;
16091610

16101611
unsigned extractResultRank = getRank(extractOp.getType());

mlir/lib/IR/Builders.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -486,14 +486,11 @@ LogicalResult OpBuilder::tryFold(Operation *op,
486486

487487
// Populate the results with the folded results.
488488
Dialect *dialect = op->getDialect();
489-
for (auto it : llvm::zip(foldResults, opResults.getTypes())) {
489+
for (auto it : llvm::zip_equal(foldResults, opResults.getTypes())) {
490490
Type expectedType = std::get<1>(it);
491491

492492
// Normal values get pushed back directly.
493493
if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
494-
if (value.getType() != expectedType)
495-
return cleanupFailure();
496-
497494
results.push_back(value);
498495
continue;
499496
}

mlir/lib/IR/Operation.cpp

+24-2
Original file line numberDiff line numberDiff line change
@@ -606,13 +606,30 @@ void Operation::setSuccessor(Block *block, unsigned index) {
606606
getBlockOperands()[index].set(block);
607607
}
608608

609+
#ifndef NDEBUG
610+
/// Assert that the folded results (in case of values) have the same type as
611+
/// the results of the given op.
612+
static void checkFoldResultTypes(Operation *op,
613+
SmallVectorImpl<OpFoldResult> &results) {
614+
if (!results.empty())
615+
for (auto [ofr, opResult] : llvm::zip_equal(results, op->getResults()))
616+
if (auto value = ofr.dyn_cast<Value>())
617+
assert(value.getType() == opResult.getType() &&
618+
"folder produced value of incorrect type");
619+
}
620+
#endif // NDEBUG
621+
609622
/// Attempt to fold this operation using the Op's registered foldHook.
610623
LogicalResult Operation::fold(ArrayRef<Attribute> operands,
611624
SmallVectorImpl<OpFoldResult> &results) {
612625
// If we have a registered operation definition matching this one, use it to
613626
// try to constant fold the operation.
614-
if (succeeded(name.foldHook(this, operands, results)))
627+
if (succeeded(name.foldHook(this, operands, results))) {
628+
#ifndef NDEBUG
629+
checkFoldResultTypes(this, results);
630+
#endif // NDEBUG
615631
return success();
632+
}
616633

617634
// Otherwise, fall back on the dialect hook to handle it.
618635
Dialect *dialect = getDialect();
@@ -623,7 +640,12 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
623640
if (!interface)
624641
return failure();
625642

626-
return interface->fold(this, operands, results);
643+
LogicalResult status = interface->fold(this, operands, results);
644+
#ifndef NDEBUG
645+
if (succeeded(status))
646+
checkFoldResultTypes(this, results);
647+
#endif // NDEBUG
648+
return status;
627649
}
628650

629651
LogicalResult Operation::fold(SmallVectorImpl<OpFoldResult> &results) {

mlir/lib/Transforms/Utils/FoldUtils.cpp

-4
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,6 @@ OperationFolder::processFoldResults(Operation *op,
247247

248248
// Check if the result was an SSA value.
249249
if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
250-
if (repl.getType() != op->getResult(i).getType()) {
251-
results.clear();
252-
return failure();
253-
}
254250
results.emplace_back(repl);
255251
continue;
256252
}

mlir/test/Transforms/test-canonicalize.mlir

-13
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,6 @@ func.func @test_commutative_multi_cst(%arg0: i32, %arg1: i32) -> (i32, i32) {
7070
return %y, %z: i32, i32
7171
}
7272

73-
// CHECK-LABEL: func @typemismatch
74-
75-
func.func @typemismatch() -> i32 {
76-
%c42 = arith.constant 42.0 : f32
77-
78-
// The "passthrough_fold" folder will naively return its operand, but we don't
79-
// want to fold here because of the type mismatch.
80-
81-
// CHECK: "test.passthrough_fold"
82-
%0 = "test.passthrough_fold"(%c42) : (f32) -> (i32)
83-
return %0 : i32
84-
}
85-
8673
// CHECK-LABEL: test_dialect_canonicalizer
8774
func.func @test_dialect_canonicalizer() -> (i32) {
8875
%0 = "test.dialect_canonicalizable"() : () -> (i32)

mlir/test/Transforms/test-legalizer.mlir

-10
Original file line numberDiff line numberDiff line change
@@ -310,16 +310,6 @@ builtin.module {
310310

311311
// -----
312312

313-
// The "passthrough_fold" folder will naively return its operand, but we don't
314-
// want to fold here because of the type mismatch.
315-
func.func @typemismatch(%arg: f32) -> i32 {
316-
// expected-remark@+1 {{op 'test.passthrough_fold' is not legalizable}}
317-
%0 = "test.passthrough_fold"(%arg) : (f32) -> (i32)
318-
"test.return"(%0) : (i32) -> ()
319-
}
320-
321-
// -----
322-
323313
// expected-remark @below {{applyPartialConversion failed}}
324314
module {
325315
func.func private @callee(%0 : f32) -> f32

mlir/test/lib/Dialect/Test/TestDialect.cpp

-4
Original file line numberDiff line numberDiff line change
@@ -542,10 +542,6 @@ OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
542542
return {};
543543
}
544544

545-
OpFoldResult TestPassthroughFold::fold(FoldAdaptor adaptor) {
546-
return getOperand();
547-
}
548-
549545
OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
550546
int64_t sum = 0;
551547
if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))

mlir/test/lib/Dialect/Test/TestOps.td

-7
Original file line numberDiff line numberDiff line change
@@ -1363,13 +1363,6 @@ def TestOpFoldWithFoldAdaptor
13631363
let hasFolder = 1;
13641364
}
13651365

1366-
// An op that always fold itself.
1367-
def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
1368-
let arguments = (ins AnyType:$op);
1369-
let results = (outs AnyType);
1370-
let hasFolder = 1;
1371-
}
1372-
13731366
def TestDialectCanonicalizerOp : TEST_Op<"dialect_canonicalizable"> {
13741367
let arguments = (ins);
13751368
let results = (outs I32);

0 commit comments

Comments
 (0)