From 532f83fc24d63b9aef3fe77a862dfb689bfc9991 Mon Sep 17 00:00:00 2001 From: Ilia Diachkov Date: Wed, 29 Jun 2022 10:57:00 +0300 Subject: [PATCH] [SPIR-V] implement EntryPoint wrappers --- llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp | 6 +- llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 12 +-- llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp | 2 +- llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 2 +- .../SPIRV/SPIRVPreTranslationLegalizer.cpp | 73 +++++++++++++++++++ llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 4 + llvm/lib/Target/SPIRV/SPIRVUtils.h | 3 + llvm/test/CodeGen/SPIRV/entry_point_func.ll | 2 +- .../SPIRV/opencl/basic/get_global_offset.ll | 4 +- 9 files changed, 95 insertions(+), 13 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index c6c1dd93fd41..2bdd068a69a5 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -425,9 +425,9 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { } for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { const Function &F = *FI; - if (F.isDeclaration()) + if (F.isDeclaration() || !isKernel(&F)) continue; - Register FReg = MAI->getFuncReg(F.getGlobalIdentifier()); + Register FReg = MAI->getFuncReg(F.getName().str()); assert(FReg.isValid()); if (MDNode *Node = F.getMetadata("reqd_work_group_size")) outputExecutionModeFromMDNode(FReg, Node, ExecutionMode::LocalSize); @@ -462,7 +462,7 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) { Register Reg; if (isa(AnnotatedVar)) { auto *Func = cast(AnnotatedVar); - Reg = MAI->getFuncReg(Func->getGlobalIdentifier()); + Reg = MAI->getFuncReg(Func->getName().str()); } else { llvm_unreachable("Unsupported value in llvm.global.annotations"); } diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 9aced9c6c81f..0b2411c74df7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -231,17 +231,19 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, i++; } - // Name the function. - if (F.hasName()) + // Name the function but skip the kernels that are wrappers at this point. + if (F.hasName() && !isKernel(&F)) buildOpName(FuncVReg, F.getName(), MIRBuilder); // Handle entry points and function linkage. - if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { + if (isKernel(&F)) { auto ExecModel = ExecutionModel::Kernel; auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint) .addImm(ExecModel) .addUse(FuncVReg); - addStringImm(F.getName(), MIB); + StringRef Name = F.getName(); + Name.consume_front("__spirv_entry_"); + addStringImm(Name, MIB); } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage || F.getLinkage() == GlobalValue::LinkOnceODRLinkage) { auto LnkTy = F.isDeclaration() ? LinkageType::Import : LinkageType::Export; @@ -273,7 +275,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; std::string FuncName = Info.Callee.getGlobal()->getGlobalIdentifier(); std::string DemangledName = isOclOrSpirvBuiltin(FuncName); - if (!DemangledName.empty()) { + if (!DemangledName.empty() && CF && CF->isDeclaration()) { // TODO: check that it's OCL builtin, then apply OpenCL_std. const auto *ST = static_cast(&MF.getSubtarget()); if (ST->canUseExtInstSet(ExtInstSet::OpenCL_std)) { diff --git a/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp b/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp index 59fd37ee6c7a..6df4f31cd6ea 100644 --- a/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp @@ -34,7 +34,7 @@ void SPIRVMCInstLower::lower(const MachineInstr *MI, MCInst &OutMI, MI->print(errs()); llvm_unreachable("unknown operand type"); case MachineOperand::MO_GlobalAddress: { - Register FuncReg = MAI->getFuncReg(MO.getGlobal()->getGlobalIdentifier()); + Register FuncReg = MAI->getFuncReg(MO.getGlobal()->getName().str()); assert(FuncReg.isValid() && "Cannot find function Id"); MCOp = MCOperand::createReg(FuncReg); break; diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 0bb51d3d46cd..7fe49a2b4f77 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -271,7 +271,7 @@ void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg); assert(GlobalReg.isValid()); // TODO: check that it does not conflict with existing entries. - MAI.FuncNameMap[F.getGlobalIdentifier()] = GlobalReg; + MAI.FuncNameMap[F.getName()] = GlobalReg; } } diff --git a/llvm/lib/Target/SPIRV/SPIRVPreTranslationLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreTranslationLegalizer.cpp index 9196a771cf69..49317f7ff9e2 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreTranslationLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreTranslationLegalizer.cpp @@ -16,6 +16,7 @@ #include "SPIRV.h" #include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -114,6 +115,75 @@ Function *SPIRVPreTranslationLegalizer::processFunctionSignature(Function *F) { return NewF; } +static Function *getOrCreateFunction(Module *M, Type *RetTy, + ArrayRef ArgTypes, + StringRef Name) { + FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false); + Function *F = M->getFunction(Name); + if (!F || F->getFunctionType() != FT) { + auto NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M); + if (F) + NewF->setDSOLocal(F->isDSOLocal()); + F = NewF; + F->setCallingConv(CallingConv::SPIR_FUNC); + } + return F; +} + +// Add a wrapper around the kernel function to act as an entry point. +static void addKernelEntryPointWrapper(Module *M, Function *F) { + F->setCallingConv(CallingConv::SPIR_FUNC); + FunctionType *FType = F->getFunctionType(); + std::string WrapName = + "__spirv_entry_" + static_cast(F->getName()); + Function *WrapFn = + getOrCreateFunction(M, F->getReturnType(), FType->params(), WrapName); + + auto *CallBB = BasicBlock::Create(M->getContext(), "", WrapFn); + IRBuilder<> Builder(CallBB); + + Function::arg_iterator DestI = WrapFn->arg_begin(); + for (const Argument &I : F->args()) { + DestI->setName(I.getName()); + DestI++; + } + SmallVector Args; + for (Argument &I : WrapFn->args()) + Args.emplace_back(&I); + + auto *CI = CallInst::Create(F, ArrayRef(Args), "", CallBB); + CI->setCallingConv(F->getCallingConv()); + CI->setAttributes(F->getAttributes()); + + // Copy over all the metadata excepting debug info. + // TODO: removed some metadata from F if it's necessery. + SmallVector> MDs; + F->getAllMetadata(MDs); + WrapFn->setAttributes(F->getAttributes()); + for (auto MD = MDs.begin(), End = MDs.end(); MD != End; ++MD) { + if (MD->first != LLVMContext::MD_dbg) + WrapFn->addMetadata(MD->first, *MD->second); + } + WrapFn->setCallingConv(CallingConv::SPIR_KERNEL); + WrapFn->setLinkage(llvm::GlobalValue::InternalLinkage); + + Builder.CreateRet(F->getReturnType()->isVoidTy() ? nullptr : CI); + + // Have to find the spir-v metadata for execution mode and transfer it to + // the wrapper. + auto Node = M->getNamedMetadata("spirv.ExecutionMode"); + if (Node) { + for (unsigned i = 0; i < Node->getNumOperands(); i++) { + MDNode *MDN = cast(Node->getOperand(i)); + const MDOperand &MDOp = MDN->getOperand(0); + auto *CMeta = dyn_cast(MDOp); + Function *MDF = dyn_cast(CMeta->getValue()); + if (MDF == F) + MDN->replaceOperandWith(0, ValueAsMetadata::get(WrapFn)); + } + } +} + bool SPIRVPreTranslationLegalizer::runOnModule(Module &M) { std::vector FuncsWorklist; bool Changed = false; @@ -121,6 +191,9 @@ bool SPIRVPreTranslationLegalizer::runOnModule(Module &M) { FuncsWorklist.push_back(&F); for (auto *Func : FuncsWorklist) { + if (isKernel(Func)) + addKernelEntryPointWrapper(&M, Func); + auto *F = processFunctionSignature(Func); bool CreatedNewF = F != Func; diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 7b33ee41b446..7dad8bedd7af 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -338,3 +338,7 @@ std::string isOclOrSpirvBuiltin(StringRef Name) { .getAsInteger(10, Len); return Name.substr(Start, Len).str(); } + +bool isKernel(const Function *F) { + return F->getCallingConv() == CallingConv::SPIR_KERNEL; +} diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index 7b2d9a5f2ce3..249635e4f49f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -96,4 +96,7 @@ llvm::Type *getMDOperandAsType(const llvm::MDNode *N, unsigned I); // Return a demangled name with arg type info by itaniumDemangle(). // If the parser fails, return only function name. std::string isOclOrSpirvBuiltin(llvm::StringRef Name); + +// Check if it is a kernel function. +bool isKernel(const llvm::Function *F); #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H diff --git a/llvm/test/CodeGen/SPIRV/entry_point_func.ll b/llvm/test/CodeGen/SPIRV/entry_point_func.ll index da9d5b33e327..629c3728a0c6 100644 --- a/llvm/test/CodeGen/SPIRV/entry_point_func.ll +++ b/llvm/test/CodeGen/SPIRV/entry_point_func.ll @@ -15,4 +15,4 @@ define spir_kernel void @testfunction() { ; CHECK-SPIRV: OpDecorate %[[FUNC]] LinkageAttributes "testfunction" Export ; CHECK-SPIRV: %[[FUNC]] = OpFunction %2 None %3 ; CHECK-SPIRV: %[[EP]] = OpFunction %2 None %3 -; CHECK-SPIRV: %8 = OpFunctionCall %2 %[[FUNC]] +; CHECK-SPIRV: OpFunctionCall %2 %[[FUNC]] diff --git a/llvm/test/CodeGen/SPIRV/opencl/basic/get_global_offset.ll b/llvm/test/CodeGen/SPIRV/opencl/basic/get_global_offset.ll index 518c843f7c84..fac5ba624abf 100644 --- a/llvm/test/CodeGen/SPIRV/opencl/basic/get_global_offset.ll +++ b/llvm/test/CodeGen/SPIRV/opencl/basic/get_global_offset.ll @@ -2,9 +2,9 @@ target triple = "spirv64-unknown-unknown" -; CHECK: OpEntryPoint Kernel %[[test_func:[0-9]+]] "test" +; CHECK: OpEntryPoint Kernel %[[test_func_wrap:[0-9]+]] "test" ; CHECK: OpName %[[outOffsets:[0-9]+]] "outOffsets" -; CHECK: OpName %[[test_func]] "test" +; CHECK: OpName %[[test_func:[0-9]+]] "test" ; CHECK: OpName %[[f2_decl:[0-9]+]] "BuiltInGlobalOffset" ; CHECK: OpDecorate %[[f2_decl]] LinkageAttributes "BuiltInGlobalOffset" Import ; CHECK: %[[int_ty:[0-9]+]] = OpTypeInt 32 0