Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,24 @@ mlir::Value SCFLoop::findIVInitValue() {
auto remapAddr = rewriter->getRemappedValue(ivAddr);
if (!remapAddr)
return nullptr;
if (!remapAddr.hasOneUse())
return nullptr;
auto memrefStore = dyn_cast<mlir::memref::StoreOp>(*remapAddr.user_begin());
if (!memrefStore)
return nullptr;
return memrefStore->getOperand(0);
if (auto castOp =
mlir::dyn_cast<mlir::memref::CastOp>(remapAddr.getDefiningOp())) {
remapAddr = castOp->getOperand(0);
if (!remapAddr)
return nullptr;
// Alloca has two uses, one is the CastOp, and second is the StoreOp (which
// bypasses the CastOp)
if (remapAddr.getNumUses() > 2)
return nullptr;
} else {
if (!remapAddr.hasOneUse())
return nullptr;
}
for (auto user : remapAddr.getUsers()) {
if (auto memrefStore = dyn_cast<mlir::memref::StoreOp>(user))
return memrefStore->getOperand(0);
}
return nullptr;
}

void SCFLoop::analysis() {
Expand Down Expand Up @@ -340,10 +352,13 @@ void SCFLoop::transferToSCFForOp() {
// The operations before the loop have been transferred to MLIR.
// So we need to go through getRemappedValue to find the operations.
auto remapAddr = rewriter->getRemappedValue(ivAddr);

if (auto castOp =
mlir::dyn_cast<mlir::memref::CastOp>(remapAddr.getDefiningOp()))
remapAddr = castOp->getOperand(0);
// Since this is a canonical loop we can remove the alloca + initial store op
rewriter->eraseOp(remapAddr.getDefiningOp());
rewriter->eraseOp(*remapAddr.user_begin());
for (auto user : remapAddr.getUsers())
rewriter->eraseOp(user);
}

void SCFLoop::transformToSCFWhileOp() {
Expand Down
183 changes: 125 additions & 58 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,22 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
if (mlir::isa<cir::ArrayType>(adaptor.getAllocaType())) {
if (!memreftype)
return mlir::LogicalResult::failure();
rewriter.replaceOpWithNewOp<mlir::memref::AllocaOp>(
op, memreftype, op.getAlignmentAttr());
} else {
memreftype = mlir::MemRefType::get({}, mlirType);
memreftype = mlir::MemRefType::get({1}, mlirType);
auto allocaOp = mlir::memref::AllocaOp::create(
rewriter, op.getLoc(), memreftype, op.getAlignmentAttr());
// Cast from memref<1xMlirType> to memref<?xMlirType>
// This is needed since Typeconverter produces memref<?xMlirType> for
// non-array cir.ptrs, The cast will be eliminated later in
// load/store-lowering.
auto targetType =
mlir::MemRefType::get({mlir::ShapedType::kDynamic}, mlirType);
auto castOp = mlir::memref::CastOp::create(rewriter, op.getLoc(),
targetType, allocaOp);
rewriter.replaceOp(op, castOp);
}
rewriter.replaceOpWithNewOp<mlir::memref::AllocaOp>(op, memreftype,
op.getAlignmentAttr());
return mlir::LogicalResult::success();
}
};
Expand All @@ -327,9 +338,29 @@ static bool findBaseAndIndices(mlir::Value addr, mlir::Value &base,
addr = addrOp->getOperand(0);
eraseList.push_back(addrOp);
}
if (auto castOp = addr.getDefiningOp<mlir::memref::CastOp>()) {
auto castInput = castOp->getOperand(0);
if (castInput.getDefiningOp<mlir::memref::AllocaOp>() ||
castInput.getDefiningOp<mlir::memref::GetGlobalOp>()) {
// AllocaOp and GetGlobalOp-lowerings produce 1-element memrefs
indices.push_back(
mlir::arith::ConstantIndexOp::create(rewriter, castOp.getLoc(), 0));
addr = castInput;
eraseList.push_back(castOp);
}
}
base = addr;
if (indices.size() == 0)
if (indices.size() == 0) {
auto memrefType = mlir::cast<mlir::MemRefType>(base.getType());
auto rank = memrefType.getRank();
indices.reserve(rank);
for (unsigned d = 0; d < rank; ++d) {
mlir::Value zero = mlir::arith::ConstantIndexOp::create(
rewriter, base.getLoc(), /*value=*/0);
indices.push_back(zero);
}
return false;
}
std::reverse(indices.begin(), indices.end());
return true;
}
Expand All @@ -350,25 +381,33 @@ static void eraseIfSafe(mlir::Value oldAddr, mlir::Value newAddr,
for (auto *user : newAddr.getUsers()) {
if (auto loadOpUser = mlir::dyn_cast_or_null<mlir::memref::LoadOp>(*user)) {
if (!loadOpUser.getIndices().empty()) {
auto strideVal = loadOpUser.getIndices()[0];
if (strideVal ==
mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back())
.getOffsets()[0])
if (auto reinterpretOp =
mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(
eraseList.back())) {
auto strideVal = loadOpUser.getIndices()[0];
if (strideVal == reinterpretOp.getOffsets()[0])
++newUsedNum;
} else if (auto castOp =
mlir::dyn_cast<mlir::memref::CastOp>(eraseList.back()))
++newUsedNum;
}
} else if (auto storeOpUser =
mlir::dyn_cast_or_null<mlir::memref::StoreOp>(*user)) {
if (!storeOpUser.getIndices().empty()) {
auto strideVal = storeOpUser.getIndices()[0];
if (strideVal ==
mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back())
.getOffsets()[0])
if (auto reinterpretOp =
mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(
eraseList.back())) {
auto strideVal = storeOpUser.getIndices()[0];
if (strideVal == reinterpretOp.getOffsets()[0])
++newUsedNum;
} else if (auto castOp =
mlir::dyn_cast<mlir::memref::CastOp>(eraseList.back()))
++newUsedNum;
}
}
}
// If all load/store ops using forwarded offsets from the current
// memref.reinterpret_cast ops erase the memref.reinterpret_cast ops
// If all load/store ops are using forwarded offsets from the current
// memref.(reinterpret_)cast ops, erase them
if (oldUsedNum == newUsedNum) {
for (auto op : eraseList)
rewriter.eraseOp(op);
Expand All @@ -385,10 +424,6 @@ prepareReinterpretMetadata(mlir::MemRefType type,
strides.clear();

for (int64_t dim : type.getShape()) {
if (mlir::ShapedType::isDynamic(dim)) {
anchorOp->emitError("dynamic memref sizes are not supported yet");
return mlir::failure();
}
sizes.push_back(rewriter.getIndexAttr(dim));
}

Expand Down Expand Up @@ -421,15 +456,12 @@ class CIRLoadOpLowering : public mlir::OpConversionPattern<cir::LoadOp> {
SmallVector<mlir::Value> indices;
SmallVector<mlir::Operation *> eraseList;
mlir::memref::LoadOp newLoad;
if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList,
rewriter)) {
newLoad = mlir::memref::LoadOp::create(rewriter, op.getLoc(), base,
indices, op.getIsNontemporal());
bool eraseIntermediateOp = findBaseAndIndices(adaptor.getAddr(), base,
indices, eraseList, rewriter);
newLoad = mlir::memref::LoadOp::create(rewriter, op.getLoc(), base, indices,
op.getIsNontemporal());
if (eraseIntermediateOp)
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
} else
newLoad = mlir::memref::LoadOp::create(
rewriter, op.getLoc(), adaptor.getAddr(), mlir::ValueRange{},
op.getIsNontemporal());

// Convert adapted result to its original type if needed.
mlir::Value result = emitFromMemory(rewriter, op, newLoad.getResult());
Expand All @@ -451,15 +483,13 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {

// Convert adapted value to its memory type if needed.
mlir::Value value = emitToMemory(rewriter, op, adaptor.getValue());
if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList,
rewriter)) {
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(
op, value, base, indices, op.getIsNontemporal());
bool eraseIntermediateOp = findBaseAndIndices(adaptor.getAddr(), base,
indices, eraseList, rewriter);
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, value, base, indices,
op.getIsNontemporal());
if (eraseIntermediateOp)
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
} else
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(
op, value, adaptor.getAddr(), mlir::ValueRange{},
op.getIsNontemporal());

return mlir::LogicalResult::success();
}
};
Expand Down Expand Up @@ -1157,7 +1187,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
return mlir::failure();
auto memrefType = mlir::dyn_cast<mlir::MemRefType>(convertedType);
if (!memrefType)
memrefType = mlir::MemRefType::get({}, convertedType);
memrefType = mlir::MemRefType::get({1}, convertedType);
// Add an optional alignment to the global memref.
mlir::IntegerAttr memrefAlignment =
op.getAlignment()
Expand Down Expand Up @@ -1196,7 +1226,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
} else
initialValue = mlir::Attribute();
} else {
auto rtt = mlir::RankedTensorType::get({}, convertedType);
auto rtt = mlir::RankedTensorType::get({1}, convertedType);
if (mlir::isa<mlir::IntegerType>(convertedType))
initialValue = mlir::DenseIntElementsAttr::get(rtt, 0);
else if (mlir::isa<mlir::FloatType>(convertedType)) {
Expand All @@ -1207,13 +1237,13 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
initialValue = mlir::Attribute();
}
} else if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(init.value())) {
auto rtt = mlir::RankedTensorType::get({}, convertedType);
auto rtt = mlir::RankedTensorType::get({1}, convertedType);
initialValue = mlir::DenseIntElementsAttr::get(rtt, intAttr.getValue());
} else if (auto fltAttr = mlir::dyn_cast<cir::FPAttr>(init.value())) {
auto rtt = mlir::RankedTensorType::get({}, convertedType);
auto rtt = mlir::RankedTensorType::get({1}, convertedType);
initialValue = mlir::DenseFPElementsAttr::get(rtt, fltAttr.getValue());
} else if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(init.value())) {
auto rtt = mlir::RankedTensorType::get({}, convertedType);
auto rtt = mlir::RankedTensorType::get({1}, convertedType);
initialValue =
mlir::DenseIntElementsAttr::get(rtt, (char)boolAttr.getValue());
} else
Expand Down Expand Up @@ -1249,10 +1279,29 @@ class CIRGetGlobalOpLowering
rewriter.eraseOp(op);
return mlir::success();
}
auto globalOpType =
convertTypeForMemory(*getTypeConverter(), op.getType().getPointee());
if (!globalOpType)
return mlir::failure();
auto memrefType = mlir::dyn_cast<mlir::MemRefType>(globalOpType);
if (!memrefType)
memrefType = mlir::MemRefType::get({1}, globalOpType);

auto type = getTypeConverter()->convertType(op.getType());
auto symbol = op.getName();
rewriter.replaceOpWithNewOp<mlir::memref::GetGlobalOp>(op, type, symbol);
auto getGlobalOp = mlir::memref::GetGlobalOp::create(rewriter, op.getLoc(),
memrefType, symbol);

if (isa<cir::ArrayType>(op.getType().getPointee())) {
rewriter.replaceOp(op, getGlobalOp);
} else {
// Cast from memref<1xmlirType> to memref<?xmlirType>. This is needed
// since Typeconverter produces memref<?xmlirType> for non-array cir.ptrs.
// The cast will be eliminated later in load/store-lowering.
auto targetType = getTypeConverter()->convertType(op.getType());
auto castOp = mlir::memref::CastOp::create(rewriter, op.getLoc(),
targetType, getGlobalOp);
rewriter.replaceOp(op, castOp);
}
return mlir::success();
}
};
Expand Down Expand Up @@ -1630,6 +1679,30 @@ class CIRPtrStrideOpLowering
return getTypeConverter()->convertType(ty);
}

// Rewrite
// cir.ptr_stride(%base, %stride)
// to
// memref.reinterpret_cast (%base, %stride)
//
mlir::LogicalResult rewritePtrStrideToReinterpret(
cir::PtrStrideOp op, mlir::Value base, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto ptrType = op.getType();
auto memrefType = llvm::cast<mlir::MemRefType>(convertTy(ptrType));
auto stride = adaptor.getStride();
auto indexType = rewriter.getIndexType();
// Generate casting if the stride is not index type.
if (stride.getType() != indexType)
stride = mlir::arith::IndexCastOp::create(rewriter, op.getLoc(),
indexType, stride);

rewriter.replaceOpWithNewOp<mlir::memref::ReinterpretCastOp>(
op, memrefType, base, stride, mlir::ValueRange{}, mlir::ValueRange{},
llvm::ArrayRef<mlir::NamedAttribute>{});

return mlir::success();
}

// Rewrite
// %0 = cir.cast array_to_ptrdecay %base
// cir.ptr_stride(%0, %stride)
Expand All @@ -1647,20 +1720,9 @@ class CIRPtrStrideOpLowering
if (!baseDefiningOp)
return mlir::failure();

auto base = baseDefiningOp->getOperand(0);
auto ptrType = op.getType();
auto memrefType = llvm::cast<mlir::MemRefType>(convertTy(ptrType));
auto stride = adaptor.getStride();
auto indexType = rewriter.getIndexType();

// Generate casting if the stride is not index type.
if (stride.getType() != indexType)
stride = mlir::arith::IndexCastOp::create(rewriter, op.getLoc(),
indexType, stride);

rewriter.replaceOpWithNewOp<mlir::memref::ReinterpretCastOp>(
op, memrefType, base, stride, mlir::ValueRange{}, mlir::ValueRange{},
llvm::ArrayRef<mlir::NamedAttribute>{});
if (mlir::failed(rewritePtrStrideToReinterpret(
op, baseDefiningOp->getOperand(0), adaptor, rewriter)))
return mlir::failure();

rewriter.eraseOp(baseDefiningOp);
return mlir::success();
Expand All @@ -1669,8 +1731,13 @@ class CIRPtrStrideOpLowering
mlir::LogicalResult
matchAndRewrite(cir::PtrStrideOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
if (isCastArrayToPtrConsumer(op) && isLoadStoreOrCastArrayToPtrProduer(op))
return rewriteArrayDecay(op, adaptor, rewriter);
if (isLoadStoreOrCastArrayToPtrProduer(op)) {
if (isCastArrayToPtrConsumer(op))
return rewriteArrayDecay(op, adaptor, rewriter);
else
return rewritePtrStrideToReinterpret(op, adaptor.getBase(), adaptor,
rewriter);
}

auto base = adaptor.getBase();
auto stride = adaptor.getStride();
Expand Down Expand Up @@ -1801,7 +1868,7 @@ static mlir::TypeConverter prepareTypeConverter() {
return nullptr;
if (isa<cir::ArrayType>(type.getPointee()))
return ty;
return mlir::MemRefType::get({}, ty);
return mlir::MemRefType::get({mlir::ShapedType::kDynamic}, ty);
});
converter.addConversion(
[&](mlir::IntegerType type) -> mlir::Type { return type; });
Expand Down
6 changes: 3 additions & 3 deletions clang/test/CIR/Lowering/ThroughMLIR/array.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ int test_array1() {
// CIR: %{{.*}} = cir.get_element %[[ARRAY]][{{.*}}] : (!cir.ptr<!cir.array<!s32i x 3>>, !s32i) -> !cir.ptr<!s32i>

// MLIR-LABEL: func @test_array1
// MLIR: %{{.*}} = memref.alloca() {alignment = 4 : i64} : memref<i32>
// MLIR: %{{.*}} = memref.alloca() {alignment = 4 : i64} : memref<1xi32>
// MLIR: %[[ARRAY:.*]] = memref.alloca() {alignment = 4 : i64} : memref<3xi32>
// MLIR: %{{.*}} = memref.load %[[ARRAY]][%{{.*}}] : memref<3xi32>
int a[3];
Expand All @@ -23,7 +23,7 @@ int test_array2() {
// CIR: %{{.*}} = cir.get_element %{{.*}}[%{{.*}}] : (!cir.ptr<!cir.array<!s32i x 4>>, !s32i) -> !cir.ptr<!s32i>

// MLIR-LABEL: func @test_array2
// MLIR: %{{.*}} = memref.alloca() {alignment = 4 : i64} : memref<i32>
// MLIR: %{{.*}} = memref.alloca() {alignment = 4 : i64} : memref<1xi32>
// MLIR: %[[ARRAY:.*]] = memref.alloca() {alignment = 16 : i64} : memref<3x4xi32>
// MLIR: %{{.*}} = memref.load %[[ARRAY]][%{{.*}}, %{{.*}}] : memref<3x4xi32>
int a[3][4];
Expand All @@ -42,7 +42,7 @@ int test_array3() {
// CIR: %{{.*}} = cir.load align(4) %[[ELEM3]] : !cir.ptr<!s32i>, !s32i

// MLIR-LABEL: func @test_array3
// MLIR: %{{.*}} = memref.alloca() {alignment = 4 : i64} : memref<i32>
// MLIR: %{{.*}} = memref.alloca() {alignment = 4 : i64} : memref<1xi32>
// MLIR: %[[ARRAY:.*]] = memref.alloca() {alignment = 4 : i64} : memref<3xi32>
// MLIR: %[[IDX1:.*]] = arith.index_cast %{{.*}} : i32 to index
// MLIR: %{{.*}} = memref.load %[[ARRAY]][%[[IDX1]]] : memref<3xi32>
Expand Down
9 changes: 5 additions & 4 deletions clang/test/CIR/Lowering/ThroughMLIR/bool.cir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: cir-opt %s -cir-to-mlir -o - | FileCheck %s -check-prefix=MLIR
// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm -canonicalize -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM

#false = #cir.bool<false> : !cir.bool
#true = #cir.bool<true> : !cir.bool
Expand All @@ -13,12 +13,13 @@ module {
}

// MLIR: func @foo() {
// MLIR: [[Value:%[a-z0-9]+]] = memref.alloca() {alignment = 1 : i64} : memref<i8>
// MLIR: %[[VALUE:[a-z0-9]+]] = memref.alloca() {alignment = 1 : i64} : memref<1xi8>
// MLIR: %[[CONST:.*]] = arith.constant true
// MLIR: %[[BOOL_TO_MEM:.*]] = arith.extui %[[CONST]] : i1 to i8
// MLIR-NEXT: memref.store %[[BOOL_TO_MEM]], [[Value]][] : memref<i8>
// MLIR: %[[CONST0:[a-z0-9]+]] = arith.constant 0 : index
// MLIR-NEXT: memref.store %[[BOOL_TO_MEM]], %[[VALUE]][%[[CONST0]]] : memref<1xi8>
// return

// LLVM: = alloca i8, i64
// LLVM: store i8 1, ptr %5
// LLVM: store i8 1, ptr %1
// LLVM: ret
Loading