Skip to content

Commit 8416a68

Browse files
wsmosespengmai
andauthored
LLVM integrate for mlir changes (#2118)
* LLVM integrate for mlir changes * fmt * fixup * fmt * minor * fixup * fmt * fix opq cast * fmt * fix * fix * Fix programPoint before/after semantics --------- Co-authored-by: Jacob Peng <[email protected]>
1 parent 192d923 commit 8416a68

19 files changed

+166
-161
lines changed

.github/workflows/enzyme-mlir.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
- uses: actions/checkout@v4
3737
with:
3838
repository: 'llvm/llvm-project'
39-
ref: '54a49658990e827173f3a3198331df7cbe50b0c0'
39+
ref: '36a405519bf54c7b9bc1247286c59beca0d8eff8'
4040
path: 'llvm-project'
4141

4242
- name: Get MLIR commit hash

enzyme/Enzyme/AdjointGenerator.h

+9-10
Original file line numberDiff line numberDiff line change
@@ -1137,8 +1137,8 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
11371137
} else {
11381138
maskL = lookup(mask, Builder2);
11391139
Type *tys[] = {valType, orig_ptr->getType()};
1140-
auto F = Intrinsic::getDeclaration(gutils->oldFunc->getParent(),
1141-
Intrinsic::masked_load, tys);
1140+
auto F = getIntrinsicDeclaration(gutils->oldFunc->getParent(),
1141+
Intrinsic::masked_load, tys);
11421142
Value *alignv =
11431143
ConstantInt::get(Type::getInt32Ty(mask->getContext()),
11441144
align ? align->value() : 0);
@@ -3789,10 +3789,9 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
37893789
case Intrinsic::nvvm_barrier0_or: {
37903790
SmallVector<Value *, 1> args = {};
37913791
auto cal = cast<CallInst>(Builder2.CreateCall(
3792-
Intrinsic::getDeclaration(M, Intrinsic::nvvm_barrier0), args));
3793-
cal->setCallingConv(
3794-
Intrinsic::getDeclaration(M, Intrinsic::nvvm_barrier0)
3795-
->getCallingConv());
3792+
getIntrinsicDeclaration(M, Intrinsic::nvvm_barrier0), args));
3793+
cal->setCallingConv(getIntrinsicDeclaration(M, Intrinsic::nvvm_barrier0)
3794+
->getCallingConv());
37963795
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
37973796
return false;
37983797
}
@@ -3804,8 +3803,8 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
38043803
case Intrinsic::nvvm_membar_sys: {
38053804
SmallVector<Value *, 1> args = {};
38063805
auto cal = cast<CallInst>(
3807-
Builder2.CreateCall(Intrinsic::getDeclaration(M, ID), args));
3808-
cal->setCallingConv(Intrinsic::getDeclaration(M, ID)->getCallingConv());
3806+
Builder2.CreateCall(getIntrinsicDeclaration(M, ID), args));
3807+
cal->setCallingConv(getIntrinsicDeclaration(M, ID)->getCallingConv());
38093808
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
38103809
return false;
38113810
}
@@ -3818,9 +3817,9 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
38183817
lookup(gutils->getNewFromOriginal(orig_ops[1]), Builder2)};
38193818
Type *tys[] = {args[1]->getType()};
38203819
auto cal = Builder2.CreateCall(
3821-
Intrinsic::getDeclaration(M, Intrinsic::lifetime_end, tys), args);
3820+
getIntrinsicDeclaration(M, Intrinsic::lifetime_end, tys), args);
38223821
cal->setCallingConv(
3823-
Intrinsic::getDeclaration(M, Intrinsic::lifetime_end, tys)
3822+
getIntrinsicDeclaration(M, Intrinsic::lifetime_end, tys)
38243823
->getCallingConv());
38253824
return false;
38263825
}

enzyme/Enzyme/CallDerivatives.cpp

+21-34
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
277277
Type *tys[] = {dbuf->getType(), len_arg->getType()};
278278

279279
auto memset = cast<CallInst>(Builder2.CreateCall(
280-
Intrinsic::getDeclaration(called->getParent(), Intrinsic::memset,
281-
tys),
280+
getIntrinsicDeclaration(called->getParent(), Intrinsic::memset,
281+
tys),
282282
nargs, BufferDefs));
283283
memset->addParamAttr(0, Attribute::NonNull);
284284
} else if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") {
@@ -887,8 +887,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
887887
ValueType::None, ValueType::None, ValueType::None},
888888
Builder2, /*lookup*/ true);
889889
auto memset = cast<CallInst>(Builder2.CreateCall(
890-
Intrinsic::getDeclaration(gutils->newFunc->getParent(),
891-
Intrinsic::memset, tys),
890+
getIntrinsicDeclaration(gutils->newFunc->getParent(),
891+
Intrinsic::memset, tys),
892892
nargs));
893893
memset->addParamAttr(0, Attribute::NonNull);
894894
}
@@ -1057,8 +1057,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
10571057

10581058
Type *tys[] = {shadow->getType(), buf->getType(), len_arg->getType()};
10591059

1060-
auto memcpyF = Intrinsic::getDeclaration(gutils->newFunc->getParent(),
1061-
Intrinsic::memcpy, tys);
1060+
auto memcpyF = getIntrinsicDeclaration(gutils->newFunc->getParent(),
1061+
Intrinsic::memcpy, tys);
10621062

10631063
auto mem =
10641064
cast<CallInst>(Builder2.CreateCall(memcpyF, nargs, BufferDefs));
@@ -1080,8 +1080,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
10801080
Value *args[] = {shadow, val_arg, len_arg, volatile_arg};
10811081
Type *tys[] = {args[0]->getType(), args[2]->getType()};
10821082
auto memset = cast<CallInst>(Builder2.CreateCall(
1083-
Intrinsic::getDeclaration(gutils->newFunc->getParent(),
1084-
Intrinsic::memset, tys),
1083+
getIntrinsicDeclaration(gutils->newFunc->getParent(),
1084+
Intrinsic::memset, tys),
10851085
args, BufferDefs));
10861086
memset->addParamAttr(0, Attribute::NonNull);
10871087
Builder2.CreateBr(mergeBlock);
@@ -1262,8 +1262,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
12621262
Type *tys[] = {nargs[0]->getType(), nargs[1]->getType(),
12631263
len_arg->getType()};
12641264

1265-
auto memcpyF = Intrinsic::getDeclaration(gutils->newFunc->getParent(),
1266-
Intrinsic::memcpy, tys);
1265+
auto memcpyF = getIntrinsicDeclaration(gutils->newFunc->getParent(),
1266+
Intrinsic::memcpy, tys);
12671267

12681268
auto mem =
12691269
cast<CallInst>(Builder2.CreateCall(memcpyF, nargs, BufferDefs));
@@ -1314,8 +1314,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
13141314
Value *args[] = {shadow_recvbuf, val_arg, len_arg, volatile_arg};
13151315
Type *tys[] = {args[0]->getType(), args[2]->getType()};
13161316
auto memset = cast<CallInst>(Builder2.CreateCall(
1317-
Intrinsic::getDeclaration(gutils->newFunc->getParent(),
1318-
Intrinsic::memset, tys),
1317+
getIntrinsicDeclaration(gutils->newFunc->getParent(),
1318+
Intrinsic::memset, tys),
13191319
args, BufferDefs));
13201320
memset->addParamAttr(0, Attribute::NonNull);
13211321

@@ -1496,8 +1496,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
14961496
Value *args[] = {shadow_recvbuf, val_arg, len_arg, volatile_arg};
14971497
Type *tys[] = {args[0]->getType(), args[2]->getType()};
14981498
auto memset = cast<CallInst>(Builder2.CreateCall(
1499-
Intrinsic::getDeclaration(gutils->newFunc->getParent(),
1500-
Intrinsic::memset, tys),
1499+
getIntrinsicDeclaration(gutils->newFunc->getParent(),
1500+
Intrinsic::memset, tys),
15011501
args, BufferDefs));
15021502
memset->addParamAttr(0, Attribute::NonNull);
15031503

@@ -1696,8 +1696,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
16961696
Value *args[] = {shadow_recvbuf, val_arg, recvlen_arg, volatile_arg};
16971697
Type *tys[] = {args[0]->getType(), args[2]->getType()};
16981698
auto memset = cast<CallInst>(Builder2.CreateCall(
1699-
Intrinsic::getDeclaration(gutils->newFunc->getParent(),
1700-
Intrinsic::memset, tys),
1699+
getIntrinsicDeclaration(gutils->newFunc->getParent(),
1700+
Intrinsic::memset, tys),
17011701
args, BufferDefs));
17021702
memset->addParamAttr(0, Attribute::NonNull);
17031703

@@ -1917,8 +1917,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
19171917
Value *args[] = {shadow_recvbuf, val_arg, recvlen_arg, volatile_arg};
19181918
Type *tys[] = {args[0]->getType(), args[2]->getType()};
19191919
auto memset = cast<CallInst>(Builder2.CreateCall(
1920-
Intrinsic::getDeclaration(gutils->newFunc->getParent(),
1921-
Intrinsic::memset, tys),
1920+
getIntrinsicDeclaration(gutils->newFunc->getParent(),
1921+
Intrinsic::memset, tys),
19221922
args, BufferDefs));
19231923
memset->addParamAttr(0, Attribute::NonNull);
19241924
}
@@ -2129,8 +2129,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
21292129
Value *args[] = {shadow_recvbuf, val_arg, recvlen_arg, volatile_arg};
21302130
Type *tys[] = {args[0]->getType(), args[2]->getType()};
21312131
auto memset = cast<CallInst>(Builder2.CreateCall(
2132-
Intrinsic::getDeclaration(gutils->newFunc->getParent(),
2133-
Intrinsic::memset, tys),
2132+
getIntrinsicDeclaration(gutils->newFunc->getParent(),
2133+
Intrinsic::memset, tys),
21342134
args, BufferDefs));
21352135
memset->addParamAttr(0, Attribute::NonNull);
21362136
}
@@ -3822,20 +3822,7 @@ bool AdjointGenerator::handleKnownCallDerivatives(
38223822

38233823
if (funcName == "posix_memalign" ||
38243824
funcName == "cudaMallocHost") {
3825-
auto volatile_arg = ConstantInt::getFalse(call.getContext());
3826-
3827-
Value *nargs[] = {dst_arg, val_arg, len_arg, volatile_arg};
3828-
3829-
Type *tys[] = {dst_arg->getType(), len_arg->getType()};
3830-
3831-
auto memset = cast<CallInst>(BuilderZ.CreateCall(
3832-
Intrinsic::getDeclaration(gutils->newFunc->getParent(),
3833-
Intrinsic::memset, tys),
3834-
nargs));
3835-
// memset->addParamAttr(0,
3836-
// Attribute::getWithAlignment(Context,
3837-
// inst->getAlignment()));
3838-
memset->addParamAttr(0, Attribute::NonNull);
3825+
BuilderZ.CreateMemSet(dst_arg, val_arg, len_arg, MaybeAlign());
38393826
} else if (funcName == "cudaMalloc") {
38403827
Type *tys[] = {PT, val_arg->getType(), len_arg->getType()};
38413828
auto F = M->getOrInsertFunction(

enzyme/Enzyme/DiffeGradientUtils.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -589,8 +589,8 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM,
589589
// store->setAlignment(align);
590590
} else {
591591
Type *tys[] = {res->getType(), ptr->getType()};
592-
auto F = Intrinsic::getDeclaration(oldFunc->getParent(),
593-
Intrinsic::masked_store, tys);
592+
auto F = getIntrinsicDeclaration(oldFunc->getParent(),
593+
Intrinsic::masked_store, tys);
594594
auto align = cast<AllocaInst>(ptr)->getAlign().value();
595595
assert(align);
596596
Value *alignv =
@@ -608,8 +608,8 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM,
608608
// store->setAlignment(align);
609609
} else {
610610
Type *tys[] = {res->getType(), ptr->getType()};
611-
auto F = Intrinsic::getDeclaration(oldFunc->getParent(),
612-
Intrinsic::masked_store, tys);
611+
auto F = getIntrinsicDeclaration(oldFunc->getParent(),
612+
Intrinsic::masked_store, tys);
613613
auto align = cast<AllocaInst>(ptr)->getAlign().value();
614614
assert(align);
615615
Value *alignv =
@@ -1076,10 +1076,10 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
10761076
applyChainRule(BuilderM, rule, ptr, dif);
10771077
} else {
10781078
Type *tys[] = {addingType, origptr->getType()};
1079-
auto LF = Intrinsic::getDeclaration(oldFunc->getParent(),
1080-
Intrinsic::masked_load, tys);
1081-
auto SF = Intrinsic::getDeclaration(oldFunc->getParent(),
1082-
Intrinsic::masked_store, tys);
1079+
auto LF = getIntrinsicDeclaration(oldFunc->getParent(),
1080+
Intrinsic::masked_load, tys);
1081+
auto SF = getIntrinsicDeclaration(oldFunc->getParent(),
1082+
Intrinsic::masked_store, tys);
10831083
unsigned aligni = align ? align->value() : 0;
10841084

10851085
if (aligni != 0)

enzyme/Enzyme/Enzyme.cpp

+13-14
Original file line numberDiff line numberDiff line change
@@ -2853,12 +2853,17 @@ class EnzymeBase {
28532853
return Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F);
28542854
};
28552855

2856+
TargetTransformInfo TTI(F->getParent()->getDataLayout());
28562857
auto GetInlineCost = [&](CallBase &CB) {
2857-
TargetTransformInfo TTI(F->getParent()->getDataLayout());
28582858
auto cst = llvm::getInlineCost(CB, Params, TTI, getAC, GetTLI);
28592859
return cst;
28602860
};
2861-
if (llvm::shouldInline(*cur, GetInlineCost, ORE)) {
2861+
#if LLVM_VERSION_MAJOR >= 20
2862+
if (llvm::shouldInline(*cur, TTI, GetInlineCost, ORE))
2863+
#else
2864+
if (llvm::shouldInline(*cur, GetInlineCost, ORE))
2865+
#endif
2866+
{
28622867
InlineFunctionInfo IFI;
28632868
InlineResult IR = InlineFunction(*cur, IFI);
28642869
if (IR.isSuccess()) {
@@ -2964,19 +2969,13 @@ class EnzymeBase {
29642969
if (F && F->getName() == "f90_mzero8") {
29652970
IRBuilder<> B(CI);
29662971

2967-
SmallVector<Value *, 4> args;
2968-
args.push_back(CI->getArgOperand(0));
2969-
args.push_back(
2970-
ConstantInt::get(Type::getInt8Ty(M.getContext()), 0));
2971-
args.push_back(B.CreateMul(
2972+
Value *args[3];
2973+
args[0] = CI->getArgOperand(0);
2974+
args[1] = ConstantInt::get(Type::getInt8Ty(M.getContext()), 0);
2975+
args[2] = B.CreateMul(
29722976
CI->getArgOperand(1),
2973-
ConstantInt::get(CI->getArgOperand(1)->getType(), 8)));
2974-
args.push_back(ConstantInt::getFalse(M.getContext()));
2975-
2976-
Type *tys[] = {args[0]->getType(), args[2]->getType()};
2977-
auto memsetIntr =
2978-
Intrinsic::getDeclaration(&M, Intrinsic::memset, tys);
2979-
B.CreateCall(memsetIntr, args);
2977+
ConstantInt::get(CI->getArgOperand(1)->getType(), 8));
2978+
B.CreateMemSet(args[0], args[1], args[2], MaybeAlign());
29802979

29812980
CI->eraseFromParent();
29822981
}

enzyme/Enzyme/EnzymeLogic.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -4349,18 +4349,18 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
43494349

43504350
Value *tx, *ty, *tz;
43514351
if (Arch == Triple::nvptx || Arch == Triple::nvptx64) {
4352-
tx = ebuilder.CreateCall(Intrinsic::getDeclaration(
4352+
tx = ebuilder.CreateCall(getIntrinsicDeclaration(
43534353
gutils->newFunc->getParent(), Intrinsic::nvvm_read_ptx_sreg_tid_x));
4354-
ty = ebuilder.CreateCall(Intrinsic::getDeclaration(
4354+
ty = ebuilder.CreateCall(getIntrinsicDeclaration(
43554355
gutils->newFunc->getParent(), Intrinsic::nvvm_read_ptx_sreg_tid_y));
4356-
tz = ebuilder.CreateCall(Intrinsic::getDeclaration(
4356+
tz = ebuilder.CreateCall(getIntrinsicDeclaration(
43574357
gutils->newFunc->getParent(), Intrinsic::nvvm_read_ptx_sreg_tid_z));
43584358
} else if (Arch == Triple::amdgcn) {
4359-
tx = ebuilder.CreateCall(Intrinsic::getDeclaration(
4359+
tx = ebuilder.CreateCall(getIntrinsicDeclaration(
43604360
gutils->newFunc->getParent(), Intrinsic::amdgcn_workitem_id_x));
4361-
ty = ebuilder.CreateCall(Intrinsic::getDeclaration(
4361+
ty = ebuilder.CreateCall(getIntrinsicDeclaration(
43624362
gutils->newFunc->getParent(), Intrinsic::amdgcn_workitem_id_y));
4363-
tz = ebuilder.CreateCall(Intrinsic::getDeclaration(
4363+
tz = ebuilder.CreateCall(getIntrinsicDeclaration(
43644364
gutils->newFunc->getParent(), Intrinsic::amdgcn_workitem_id_z));
43654365
} else {
43664366
llvm_unreachable("unknown gpu architecture");
@@ -4377,7 +4377,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
43774377
? (llvm::Intrinsic::ID)Intrinsic::amdgcn_s_barrier
43784378
: (llvm::Intrinsic::ID)Intrinsic::nvvm_barrier0;
43794379
instbuilder.CreateCall(
4380-
Intrinsic::getDeclaration(gutils->newFunc->getParent(), BarrierInst),
4380+
getIntrinsicDeclaration(gutils->newFunc->getParent(), BarrierInst),
43814381
{});
43824382
OldEntryInsts->moveAfter(entry);
43834383
sharedBlock->moveAfter(entry);

enzyme/Enzyme/FunctionUtils.cpp

+11-12
Original file line numberDiff line numberDiff line change
@@ -313,19 +313,20 @@ void RecursivelyReplaceAddressSpace(Value *AI, Value *rep, bool legal) {
313313
continue;
314314
}
315315
IRBuilder<> B(CI);
316-
auto nCI = cast<CastInst>(B.CreateCast(
316+
auto nCI0 = B.CreateCast(
317317
CI->getOpcode(), rep,
318318
#if LLVM_VERSION_MAJOR < 17
319319
PointerType::get(CI->getType()->getPointerElementType(),
320320
cast<PointerType>(rep->getType())->getAddressSpace())
321321
#else
322322
rep->getType()
323323
#endif
324-
));
325-
nCI->takeName(CI);
324+
);
325+
if (auto nCI = dyn_cast<CastInst>(nCI0))
326+
nCI->takeName(CI);
326327
for (auto U : CI->users()) {
327328
Todo.push_back(
328-
std::make_tuple((Value *)nCI, (Value *)CI, cast<Instruction>(U)));
329+
std::make_tuple((Value *)nCI0, (Value *)CI, cast<Instruction>(U)));
329330
}
330331
toErase.push_back(CI);
331332
continue;
@@ -387,12 +388,10 @@ void RecursivelyReplaceAddressSpace(Value *AI, Value *rep, bool legal) {
387388

388389
Value *nargs[] = {rep, MS->getArgOperand(1), MS->getArgOperand(2),
389390
MS->getArgOperand(3)};
390-
391391
Type *tys[] = {nargs[0]->getType(), nargs[2]->getType()};
392-
393392
auto nMS = cast<CallInst>(B.CreateCall(
394-
Intrinsic::getDeclaration(MS->getParent()->getParent()->getParent(),
395-
Intrinsic::memset, tys),
393+
getIntrinsicDeclaration(MS->getParent()->getParent()->getParent(),
394+
Intrinsic::memset, tys),
396395
nargs));
397396
nMS->copyMetadata(*MS);
398397
nMS->setAttributes(MS->getAttributes());
@@ -415,8 +414,8 @@ void RecursivelyReplaceAddressSpace(Value *AI, Value *rep, bool legal) {
415414
nargs[2]->getType()};
416415

417416
auto nMTI = cast<CallInst>(B.CreateCall(
418-
Intrinsic::getDeclaration(MTI->getParent()->getParent()->getParent(),
419-
MTI->getIntrinsicID(), tys),
417+
getIntrinsicDeclaration(MTI->getParent()->getParent()->getParent(),
418+
MTI->getIntrinsicID(), tys),
420419
nargs));
421420
nMTI->copyMetadata(*MTI);
422421
nMTI->setAttributes(MTI->getAttributes());
@@ -914,7 +913,7 @@ void PreProcessCache::ReplaceReallocs(Function *NewF, bool mem2reg) {
914913
Type *tys[] = {next->getType(), p->getType(), old->getType()};
915914

916915
auto memcpyF =
917-
Intrinsic::getDeclaration(NewF->getParent(), Intrinsic::memcpy, tys);
916+
getIntrinsicDeclaration(NewF->getParent(), Intrinsic::memcpy, tys);
918917

919918
auto mem = cast<CallInst>(B.CreateCall(memcpyF, nargs));
920919
mem->setCallingConv(memcpyF->getCallingConv());
@@ -1733,7 +1732,7 @@ Function *PreProcessCache::preprocessForClone(Function *F,
17331732
Type *tys[] = {args[0]->getType(), args[1]->getType(),
17341733
args[2]->getType()};
17351734
auto intr =
1736-
Intrinsic::getDeclaration(g.getParent(), Intrinsic::memcpy, tys);
1735+
getIntrinsicDeclaration(g.getParent(), Intrinsic::memcpy, tys);
17371736
{
17381737

17391738
auto cal = bb.CreateCall(intr, args);

0 commit comments

Comments
 (0)