Skip to content

Commit

Permalink
Include FMA in manual type legalization.
Browse files Browse the repository at this point in the history
It too is promoted incorrectly by current LLVM versions.
  • Loading branch information
hvdijk committed Feb 1, 2024
1 parent 57c8972 commit 84d055b
Showing 1 changed file with 84 additions and 41 deletions.
125 changes: 84 additions & 41 deletions modules/compiler/utils/source/manual_type_legalization_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/InstrTypes.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/IntrinsicInst.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IR/Type.h>
Expand All @@ -31,25 +32,39 @@ using namespace llvm;

PreservedAnalyses compiler::utils::ManualTypeLegalizationPass::run(
Function &F, FunctionAnalysisManager &FAM) {
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);

auto *HalfT = Type::getHalfTy(F.getContext());
auto *FloatT = Type::getFloatTy(F.getContext());

// Targets where half is a legal type do not need this pass. Targets where
// half is promoted using "soft promotion" rules also do not need this pass.
// We cannot reliably determine which targets these are, but that is okay, on
// targets where this pass is not needed it does no harm, it merely wastes
// time.
auto *DoubleT = Type::getDoubleTy(F.getContext());

// Targets where half is a legal type, and targets where half is promoted
// using "soft promotion" rules, are assumed to implement basic operators
// correctly. We cannot reliably determine which targets use "soft promotion"
// rules so we hardcode the list here.
//
// FMA is promoted incorrectly on all targets without hardware support, even
// when using "soft promotion" rules; only targets that have native support
// implement it correctly at the moment.
//
// Both for operators and FMA, whether the target implements the operation
// correctly may depend on the target feature string. We ignore that here for
// simplicity.
const llvm::Triple TT(F.getParent()->getTargetTriple());
if (TTI.isTypeLegal(HalfT) || TT.isX86() || TT.isRISCV()) {

auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
const bool HaveCorrectHalfOps =
TTI.isTypeLegal(HalfT) || TT.isX86() || TT.isRISCV();
const bool HaveCorrectHalfFMA = TT.isRISCV();
if (HaveCorrectHalfOps && HaveCorrectHalfFMA) {
return PreservedAnalyses::all();
}

DenseMap<Value *, Value *> FPExtVals;
IRBuilder<> B(F.getContext());

auto CreateFPExt = [&](Value *V, Type *ExtTy) {
auto CreateFPExt = [&](Value *V, Type *Ty, Type *ExtTy) {
(void)Ty;
assert(V->getType() == Ty &&
"Expected matching types for floating point operation");
auto *&FPExt = FPExtVals[V];
if (!FPExt) {
if (auto *I = dyn_cast<Instruction>(V)) {
Expand Down Expand Up @@ -78,43 +93,71 @@ PreservedAnalyses compiler::utils::ManualTypeLegalizationPass::run(

for (auto &BB : F) {
for (auto &I : make_early_inc_range(BB)) {
auto *BO = dyn_cast<BinaryOperator>(&I);
if (!BO) continue;

auto *T = BO->getType();
auto *T = I.getType();
auto *VecT = dyn_cast<VectorType>(T);
auto *ElT = VecT ? VecT->getElementType() : T;

if (ElT != HalfT) continue;

auto *LHS = BO->getOperand(0);
auto *RHS = BO->getOperand(1);
assert(LHS->getType() == T &&
"Expected matching types for floating point operation");
assert(RHS->getType() == T &&
"Expected matching types for floating point operation");

auto *ExtElT = FloatT;
auto *ExtT =
VecT ? VectorType::get(ExtElT, VecT->getElementCount()) : ExtElT;

auto *LHSExt = CreateFPExt(LHS, ExtT);
auto *RHSExt = CreateFPExt(RHS, ExtT);

B.SetInsertPoint(BO);

B.setFastMathFlags(BO->getFastMathFlags());
auto *OpExt = B.CreateBinOp(BO->getOpcode(), LHSExt, RHSExt,
BO->getName() + ".fpext");
B.clearFastMathFlags();

auto *Trunc = B.CreateFPTrunc(OpExt, T);
Trunc->takeName(BO);

BO->replaceAllUsesWith(Trunc);
BO->eraseFromParent();
if (!HaveCorrectHalfOps) {
if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
Type *const ExtElT = FloatT;
Type *const ExtT =
VecT ? VectorType::get(ExtElT, VecT->getElementCount()) : ExtElT;
Value *const PromotedOperands[] = {
CreateFPExt(BO->getOperand(0), T, ExtT),
CreateFPExt(BO->getOperand(1), T, ExtT),
};
B.SetInsertPoint(BO);
B.setFastMathFlags(BO->getFastMathFlags());
auto *const PromotedOperation =
B.CreateBinOp(BO->getOpcode(), PromotedOperands[0],
PromotedOperands[1], BO->getName() + ".fpext");
B.clearFastMathFlags();

auto *const Trunc = B.CreateFPTrunc(PromotedOperation, T);
Trunc->takeName(BO);

BO->replaceAllUsesWith(Trunc);
BO->eraseFromParent();

Changed = true;
continue;
}
}

Changed = true;
if (!HaveCorrectHalfFMA) {
if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
if (II->getIntrinsicID() == Intrinsic::fma) {
Type *const ExtElT = DoubleT;
Type *const ExtT =
VecT ? VectorType::get(ExtElT, VecT->getElementCount())
: ExtElT;
Value *const PromotedArguments[] = {
CreateFPExt(II->getArgOperand(0), T, ExtT),
CreateFPExt(II->getArgOperand(1), T, ExtT),
CreateFPExt(II->getArgOperand(2), T, ExtT),
};
B.SetInsertPoint(II);
// Because the arguments are promoted halfs, the multiplication in
// type double is exact and the result is the same even if multiply
// and add are kept as separate operations, so use FMulAdd rather
// than FMA.
auto *const PromotedOperation =
B.CreateIntrinsic(ExtT, Intrinsic::fmuladd, PromotedArguments,
II, II->getName() + ".fpext");

auto *const Trunc = B.CreateFPTrunc(PromotedOperation, T);
Trunc->takeName(II);

II->replaceAllUsesWith(Trunc);
II->eraseFromParent();

Changed = true;
continue;
}
}
}
}
}

Expand Down

0 comments on commit 84d055b

Please sign in to comment.