Skip to content

Commit f1d5c63

Browse files
committed
[lumen] allow LLVMDialect to inherit LLVMContext
By allowing the LLVMDialect to be created with a reference to a non-owned LLVMContext, we can create one LLVMContext per thread and use it for everything. This prevents the issue where MLIR-generated modules were created with their own context and cannot be linked together or worked with using the thread-global LLVMContext we create for that purpose in Lumen.
1 parent 7fabecc commit f1d5c63

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def LLVM_Dialect : Dialect {
2121
let cppNamespace = "LLVM";
2222
let hasRegionArgAttrVerify = 1;
2323
let extraClassDeclaration = [{
24+
LLVMDialect(mlir::MLIRContext *, llvm::LLVMContext *);
2425
~LLVMDialect();
2526
llvm::LLVMContext &getLLVMContext();
2627
llvm::Module &getLLVMModule();

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,12 +1633,32 @@ static LogicalResult verify(FenceOp &op) {
16331633
namespace mlir {
16341634
namespace LLVM {
16351635
namespace detail {
1636+
struct LLVMContextHandle {
1637+
bool owned;
1638+
llvm::LLVMContext *context;
1639+
1640+
LLVMContextHandle() :
1641+
owned(true), context(new llvm::LLVMContext()) {}
1642+
LLVMContextHandle(llvm::LLVMContext *ctx) :
1643+
owned(false), context(ctx) {}
1644+
1645+
~LLVMContextHandle() {
1646+
if (owned)
1647+
delete context;
1648+
}
1649+
};
1650+
16361651
struct LLVMDialectImpl {
1637-
LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {}
1652+
LLVMDialectImpl()
1653+
: module("LLVMDialectModule", *llvmContext.context) {}
1654+
LLVMDialectImpl(llvm::LLVMContext *ctx)
1655+
: llvmContext(ctx), module("LLVMDialectModule", *ctx) {}
16381656

1639-
llvm::LLVMContext llvmContext;
1657+
LLVMContextHandle llvmContext;
16401658
llvm::Module module;
16411659

1660+
bool ownsContext;
1661+
16421662
/// A set of LLVMTypes that are cached on construction to avoid any lookups or
16431663
/// locking.
16441664
LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
@@ -1653,6 +1673,38 @@ struct LLVMDialectImpl {
16531673
} // end namespace LLVM
16541674
} // end namespace mlir
16551675

1676+
LLVMDialect::LLVMDialect(MLIRContext *context, llvm::LLVMContext *llvmCtx)
1677+
: Dialect(getDialectNamespace(), context),
1678+
impl(new detail::LLVMDialectImpl(llvmCtx)) {
1679+
addTypes<LLVMType>();
1680+
addOperations<
1681+
#define GET_OP_LIST
1682+
#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
1683+
>();
1684+
1685+
// Support unknown operations because not all LLVM operations are registered.
1686+
allowUnknownOperations();
1687+
1688+
// Cache some of the common LLVM types to avoid the need for lookups/locking.
1689+
auto &llvmContext = impl->module.getContext();
1690+
/// Integer Types.
1691+
impl->int1Ty = LLVMType::get(context, llvm::Type::getInt1Ty(llvmContext));
1692+
impl->int8Ty = LLVMType::get(context, llvm::Type::getInt8Ty(llvmContext));
1693+
impl->int16Ty = LLVMType::get(context, llvm::Type::getInt16Ty(llvmContext));
1694+
impl->int32Ty = LLVMType::get(context, llvm::Type::getInt32Ty(llvmContext));
1695+
impl->int64Ty = LLVMType::get(context, llvm::Type::getInt64Ty(llvmContext));
1696+
impl->int128Ty = LLVMType::get(context, llvm::Type::getInt128Ty(llvmContext));
1697+
/// Float Types.
1698+
impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext));
1699+
impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext));
1700+
impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext));
1701+
impl->fp128Ty = LLVMType::get(context, llvm::Type::getFP128Ty(llvmContext));
1702+
impl->x86_fp80Ty =
1703+
LLVMType::get(context, llvm::Type::getX86_FP80Ty(llvmContext));
1704+
/// Other Types.
1705+
impl->voidTy = LLVMType::get(context, llvm::Type::getVoidTy(llvmContext));
1706+
}
1707+
16561708
LLVMDialect::LLVMDialect(MLIRContext *context)
16571709
: Dialect(getDialectNamespace(), context),
16581710
impl(new detail::LLVMDialectImpl()) {
@@ -1666,7 +1718,7 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
16661718
allowUnknownOperations();
16671719

16681720
// Cache some of the common LLVM types to avoid the need for lookups/locking.
1669-
auto &llvmContext = impl->llvmContext;
1721+
auto &llvmContext = impl->module.getContext();
16701722
/// Integer Types.
16711723
impl->int1Ty = LLVMType::get(context, llvm::Type::getInt1Ty(llvmContext));
16721724
impl->int8Ty = LLVMType::get(context, llvm::Type::getInt8Ty(llvmContext));
@@ -1690,7 +1742,7 @@ LLVMDialect::~LLVMDialect() {}
16901742
#define GET_OP_CLASSES
16911743
#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
16921744

1693-
llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; }
1745+
llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->module.getContext(); }
16941746
llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
16951747

16961748
/// Parse a type registered to this dialect.

0 commit comments

Comments
 (0)