Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,9 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
}
for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
const Function &F = *FI;
if (F.isDeclaration())
if (F.isDeclaration() || !isKernel(&F))
continue;
Register FReg = MAI->getFuncReg(F.getGlobalIdentifier());
Register FReg = MAI->getFuncReg(F.getName().str());
assert(FReg.isValid());
if (MDNode *Node = F.getMetadata("reqd_work_group_size"))
outputExecutionModeFromMDNode(FReg, Node, ExecutionMode::LocalSize);
Expand Down Expand Up @@ -462,7 +462,7 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
Register Reg;
if (isa<Function>(AnnotatedVar)) {
auto *Func = cast<Function>(AnnotatedVar);
Reg = MAI->getFuncReg(Func->getGlobalIdentifier());
Reg = MAI->getFuncReg(Func->getName().str());
} else {
llvm_unreachable("Unsupported value in llvm.global.annotations");
}
Expand Down
12 changes: 7 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,17 +231,19 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
i++;
}

// Name the function.
if (F.hasName())
// Name the function but skip the kernels that are wrappers at this point.
if (F.hasName() && !isKernel(&F))
buildOpName(FuncVReg, F.getName(), MIRBuilder);

// Handle entry points and function linkage.
if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
if (isKernel(&F)) {
auto ExecModel = ExecutionModel::Kernel;
auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint)
.addImm(ExecModel)
.addUse(FuncVReg);
addStringImm(F.getName(), MIB);
StringRef Name = F.getName();
Name.consume_front("__spirv_entry_");
addStringImm(Name, MIB);
} else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage ||
F.getLinkage() == GlobalValue::LinkOnceODRLinkage) {
auto LnkTy = F.isDeclaration() ? LinkageType::Import : LinkageType::Export;
Expand Down Expand Up @@ -273,7 +275,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
std::string FuncName = Info.Callee.getGlobal()->getGlobalIdentifier();
std::string DemangledName = isOclOrSpirvBuiltin(FuncName);
if (!DemangledName.empty()) {
if (!DemangledName.empty() && CF && CF->isDeclaration()) {
// TODO: check that it's OCL builtin, then apply OpenCL_std.
const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
if (ST->canUseExtInstSet(ExtInstSet::OpenCL_std)) {
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void SPIRVMCInstLower::lower(const MachineInstr *MI, MCInst &OutMI,
MI->print(errs());
llvm_unreachable("unknown operand type");
case MachineOperand::MO_GlobalAddress: {
Register FuncReg = MAI->getFuncReg(MO.getGlobal()->getGlobalIdentifier());
Register FuncReg = MAI->getFuncReg(MO.getGlobal()->getName().str());
assert(FuncReg.isValid() && "Cannot find function Id");
MCOp = MCOperand::createReg(FuncReg);
break;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
assert(GlobalReg.isValid());
// TODO: check that it does not conflict with existing entries.
MAI.FuncNameMap[F.getGlobalIdentifier()] = GlobalReg;
MAI.FuncNameMap[F.getName()] = GlobalReg;
}
}

Expand Down
73 changes: 73 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVPreTranslationLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "SPIRV.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Transforms/Utils/Cloning.h"

Expand Down Expand Up @@ -114,13 +115,85 @@ Function *SPIRVPreTranslationLegalizer::processFunctionSignature(Function *F) {
return NewF;
}

static Function *getOrCreateFunction(Module *M, Type *RetTy,
ArrayRef<Type *> ArgTypes,
StringRef Name) {
FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
Function *F = M->getFunction(Name);
if (!F || F->getFunctionType() != FT) {
auto NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
if (F)
NewF->setDSOLocal(F->isDSOLocal());
F = NewF;
F->setCallingConv(CallingConv::SPIR_FUNC);
}
return F;
}

// Add a wrapper around the kernel function to act as an entry point.
static void addKernelEntryPointWrapper(Module *M, Function *F) {
F->setCallingConv(CallingConv::SPIR_FUNC);
FunctionType *FType = F->getFunctionType();
std::string WrapName =
"__spirv_entry_" + static_cast<std::string>(F->getName());
Function *WrapFn =
getOrCreateFunction(M, F->getReturnType(), FType->params(), WrapName);

auto *CallBB = BasicBlock::Create(M->getContext(), "", WrapFn);
IRBuilder<> Builder(CallBB);

Function::arg_iterator DestI = WrapFn->arg_begin();
for (const Argument &I : F->args()) {
DestI->setName(I.getName());
DestI++;
}
SmallVector<Value *, 1> Args;
for (Argument &I : WrapFn->args())
Args.emplace_back(&I);

auto *CI = CallInst::Create(F, ArrayRef<Value *>(Args), "", CallBB);
CI->setCallingConv(F->getCallingConv());
CI->setAttributes(F->getAttributes());

// Copy over all the metadata excepting debug info.
// TODO: removed some metadata from F if it's necessery.
SmallVector<std::pair<unsigned, MDNode *>> MDs;
F->getAllMetadata(MDs);
WrapFn->setAttributes(F->getAttributes());
for (auto MD = MDs.begin(), End = MDs.end(); MD != End; ++MD) {
if (MD->first != LLVMContext::MD_dbg)
WrapFn->addMetadata(MD->first, *MD->second);
}
WrapFn->setCallingConv(CallingConv::SPIR_KERNEL);
WrapFn->setLinkage(llvm::GlobalValue::InternalLinkage);

Builder.CreateRet(F->getReturnType()->isVoidTy() ? nullptr : CI);

// Have to find the spir-v metadata for execution mode and transfer it to
// the wrapper.
auto Node = M->getNamedMetadata("spirv.ExecutionMode");
if (Node) {
for (unsigned i = 0; i < Node->getNumOperands(); i++) {
MDNode *MDN = cast<MDNode>(Node->getOperand(i));
const MDOperand &MDOp = MDN->getOperand(0);
auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp);
Function *MDF = dyn_cast<Function>(CMeta->getValue());
if (MDF == F)
MDN->replaceOperandWith(0, ValueAsMetadata::get(WrapFn));
}
}
}

bool SPIRVPreTranslationLegalizer::runOnModule(Module &M) {
std::vector<Function *> FuncsWorklist;
bool Changed = false;
for (auto &F : M)
FuncsWorklist.push_back(&F);

for (auto *Func : FuncsWorklist) {
if (isKernel(Func))
addKernelEntryPointWrapper(&M, Func);

auto *F = processFunctionSignature(Func);

bool CreatedNewF = F != Func;
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,7 @@ std::string isOclOrSpirvBuiltin(StringRef Name) {
.getAsInteger(10, Len);
return Name.substr(Start, Len).str();
}

bool isKernel(const Function *F) {
return F->getCallingConv() == CallingConv::SPIR_KERNEL;
}
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,7 @@ llvm::Type *getMDOperandAsType(const llvm::MDNode *N, unsigned I);
// Return a demangled name with arg type info by itaniumDemangle().
// If the parser fails, return only function name.
std::string isOclOrSpirvBuiltin(llvm::StringRef Name);

// Check if it is a kernel function.
bool isKernel(const llvm::Function *F);
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/SPIRV/entry_point_func.ll
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ define spir_kernel void @testfunction() {
; CHECK-SPIRV: OpDecorate %[[FUNC]] LinkageAttributes "testfunction" Export
; CHECK-SPIRV: %[[FUNC]] = OpFunction %2 None %3
; CHECK-SPIRV: %[[EP]] = OpFunction %2 None %3
; CHECK-SPIRV: %8 = OpFunctionCall %2 %[[FUNC]]
; CHECK-SPIRV: OpFunctionCall %2 %[[FUNC]]
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/SPIRV/opencl/basic/get_global_offset.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

target triple = "spirv64-unknown-unknown"

; CHECK: OpEntryPoint Kernel %[[test_func:[0-9]+]] "test"
; CHECK: OpEntryPoint Kernel %[[test_func_wrap:[0-9]+]] "test"
; CHECK: OpName %[[outOffsets:[0-9]+]] "outOffsets"
; CHECK: OpName %[[test_func]] "test"
; CHECK: OpName %[[test_func:[0-9]+]] "test"
; CHECK: OpName %[[f2_decl:[0-9]+]] "BuiltInGlobalOffset"
; CHECK: OpDecorate %[[f2_decl]] LinkageAttributes "BuiltInGlobalOffset" Import
; CHECK: %[[int_ty:[0-9]+]] = OpTypeInt 32 0
Expand Down