Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
// Generate a SPIR-V type for the function.
auto MRI = MIRBuilder.getMRI();
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
if (F.isDeclaration())
GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
Expand All @@ -313,22 +312,28 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
// Build the OpTypeFunction declaring it.
uint32_t FuncControl = getFunctionControl(F);

MIRBuilder.buildInstr(SPIRV::OpFunction)
.addDef(FuncVReg)
.addUse(GR->getSPIRVTypeID(RetTy))
.addImm(FuncControl)
.addUse(GR->getSPIRVTypeID(FuncTy));
auto B = MIRBuilder.buildInstr(SPIRV::OpFunction)
.addDef(FuncVReg)
.addUse(GR->getSPIRVTypeID(RetTy))
.addImm(FuncControl)
.addUse(GR->getSPIRVTypeID(FuncTy));
const auto &ST = GR->CurMF->getSubtarget();
constrainSelectedInstRegOperands(*B.getInstr(), *ST.getInstrInfo(),
*ST.getRegisterInfo(), *ST.getRegBankInfo());

// Add OpFunctionParameters.
int i = 0;
for (const auto &Arg : F.args()) {
assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
.addDef(VRegs[i][0])
.addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
auto B = MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
.addDef(VRegs[i][0])
.addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
if (F.isDeclaration())
GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
const auto &ST = GR->CurMF->getSubtarget();
constrainSelectedInstRegOperands(*B.getInstr(), *ST.getInstrInfo(),
*ST.getRegisterInfo(),
*ST.getRegBankInfo());
i++;
}
// Name the function.
Expand Down Expand Up @@ -421,7 +426,8 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,

// Make sure there's a valid return reg, even for functions returning void.
if (!ResVReg.isValid())
ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
ResVReg =
MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32));
SPIRVType *RetType =
GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder);

Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
return nullptr;
TypesInProcessing.insert(Ty);
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
const auto &ST = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*const_cast<MachineInstr *>(SpirvType),
*ST.getInstrInfo(), *ST.getRegisterInfo(),
*ST.getRegBankInfo());
TypesInProcessing.erase(Ty);
VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
SPIRVToLLVMType[SpirvType] = Ty;
Expand Down Expand Up @@ -972,6 +976,10 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
SPIRVToLLVMType[SpirvType] = LLVMTy;
DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType));
const auto &ST = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*const_cast<MachineInstr *>(SpirvType),
*ST.getInstrInfo(), *ST.getRegisterInfo(),
*ST.getRegBankInfo());
return SpirvType;
}

Expand Down
5 changes: 2 additions & 3 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,8 @@ void SPIRVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
}

bool SPIRVInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_fID ||
MI.getOpcode() == SPIRV::GET_pID || MI.getOpcode() == SPIRV::GET_vfID ||
MI.getOpcode() == SPIRV::GET_vID) {
if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_SID ||
MI.getOpcode() == SPIRV::GET_VID) {
auto &MRI = MI.getMF()->getRegInfo();
MRI.replaceRegWith(MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
MI.eraseFromParent();
Expand Down
113 changes: 40 additions & 73 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@ include "SPIRVSymbolicOperands.td"

// Codegen only metadata instructions
let isCodeGenOnly=1 in {
def ASSIGN_TYPE: Pseudo<(outs ANYID:$dst_id), (ins ANYID:$src_id, TYPE:$src_ty)>;
def DECL_TYPE: Pseudo<(outs ANYID:$dst_id), (ins ANYID:$src_id, TYPE:$src_ty)>;
def ASSIGN_TYPE: Pseudo<(outs ANYID:$dst_id), (ins unknown:$src_id, TYPE:$src_ty)>;
def GET_ID: Pseudo<(outs ID:$dst_id), (ins ANYID:$src)>;
def GET_fID: Pseudo<(outs fID:$dst_id), (ins ANYID:$src)>;
def GET_pID: Pseudo<(outs pID:$dst_id), (ins ANYID:$src)>;
def GET_vID: Pseudo<(outs vID:$dst_id), (ins ANYID:$src)>;
def GET_vfID: Pseudo<(outs vfID:$dst_id), (ins ANYID:$src)>;
def GET_SID: Pseudo<(outs SID:$dst_id), (ins ANYID:$src)>;
def GET_VID: Pseudo<(outs VID:$dst_id), (ins ANYID:$src)>;
}

def SPVTypeBin : SDTypeProfile<1, 2, []>;
Expand All @@ -38,41 +35,10 @@ class BinOpTyped<string name, bits<16> opCode, RegisterClass CID, SDNode node>
: Op<opCode, (outs ID:$dst), (ins TYPE:$src_ty, CID:$src, CID:$src2),
"$dst = "#name#" $src_ty $src $src2", [(set ID:$dst, (assigntype (node CID:$src, CID:$src2), TYPE:$src_ty))]>;

class TernOpTyped<string name, bits<16> opCode, RegisterClass CCond, RegisterClass CID, SDNode node>
: Op<opCode, (outs ID:$dst), (ins TYPE:$src_ty, CCond:$cond, CID:$src1, CID:$src2),
"$dst = "#name#" $src_ty $cond $src1 $src2", [(set ID:$dst, (assigntype (node CCond:$cond, CID:$src1, CID:$src2), TYPE:$src_ty))]>;

multiclass BinOpTypedGen<string name, bits<16> opCode, SDNode node, bit genF = 0, bit genV = 0> {
if genF then
def S: BinOpTyped<name, opCode, fID, node>;
else
def S: BinOpTyped<name, opCode, ID, node>;
multiclass BinOpTypedGen<string name, bits<16> opCode, SDNode node, bit genV = 0> {
def S: BinOpTyped<name, opCode, SID, node>;
if genV then {
if genF then
def V: BinOpTyped<name, opCode, vfID, node>;
else
def V: BinOpTyped<name, opCode, vID, node>;
}
}

multiclass TernOpTypedGen<string name, bits<16> opCode, SDNode node, bit genI = 1, bit genF = 0, bit genV = 0> {
if genF then {
def SFSCond: TernOpTyped<name, opCode, ID, fID, node>;
def SFVCond: TernOpTyped<name, opCode, vID, fID, node>;
}
if genI then {
def SISCond: TernOpTyped<name, opCode, ID, ID, node>;
def SIVCond: TernOpTyped<name, opCode, vID, ID, node>;
}
if genV then {
if genF then {
def VFSCond: TernOpTyped<name, opCode, ID, vfID, node>;
def VFVCond: TernOpTyped<name, opCode, vID, vfID, node>;
}
if genI then {
def VISCond: TernOpTyped<name, opCode, ID, vID, node>;
def VIVCond: TernOpTyped<name, opCode, vID, vID, node>;
}
def V: BinOpTyped<name, opCode, VID, node>;
}
}

Expand All @@ -99,7 +65,7 @@ def OpSource: Op<3, (outs), (ins SourceLanguage:$lang, i32imm:$version, variable
"OpSource $lang $version">;
def OpSourceExtension: Op<4, (outs), (ins StringImm:$extension, variable_ops),
"OpSourceExtension $extension">;
def OpName: Op<5, (outs), (ins ANY:$tar, StringImm:$name, variable_ops), "OpName $tar $name">;
def OpName: Op<5, (outs), (ins ANYID:$tar, StringImm:$name, variable_ops), "OpName $tar $name">;
def OpMemberName: Op<6, (outs), (ins TYPE:$ty, i32imm:$mem, StringImm:$name, variable_ops),
"OpMemberName $ty $mem $name">;
def OpString: Op<7, (outs ID:$r), (ins StringImm:$s, variable_ops), "$r = OpString $s">;
Expand All @@ -110,17 +76,17 @@ def OpModuleProcessed: Op<330, (outs), (ins StringImm:$process, variable_ops),

// 3.42.3 Annotation Instructions

def OpDecorate: Op<71, (outs), (ins ANY:$target, Decoration:$dec, variable_ops),
def OpDecorate: Op<71, (outs), (ins ANYID:$target, Decoration:$dec, variable_ops),
"OpDecorate $target $dec">;
def OpMemberDecorate: Op<72, (outs), (ins TYPE:$t, i32imm:$m, Decoration:$d, variable_ops),
"OpMemberDecorate $t $m $d">;

// TODO Currently some deprecated opcodes are missing: OpDecorationGroup,
// OpGroupDecorate and OpGroupMemberDecorate

def OpDecorateId: Op<332, (outs), (ins ANY:$target, Decoration:$dec, variable_ops),
def OpDecorateId: Op<332, (outs), (ins ANYID:$target, Decoration:$dec, variable_ops),
"OpDecorateId $target $dec">;
def OpDecorateString: Op<5632, (outs), (ins ANY:$t, Decoration:$d, StringImm:$s, variable_ops),
def OpDecorateString: Op<5632, (outs), (ins ANYID:$t, Decoration:$d, StringImm:$s, variable_ops),
"OpDecorateString $t $d $s">;
def OpMemberDecorateString: Op<5633, (outs),
(ins TYPE:$ty, i32imm:$mem, Decoration:$dec, StringImm:$str, variable_ops),
Expand All @@ -131,7 +97,7 @@ def OpMemberDecorateString: Op<5633, (outs),
def OpExtension: Op<10, (outs), (ins StringImm:$name, variable_ops), "OpExtension $name">;
def OpExtInstImport: Op<11, (outs ID:$res), (ins StringImm:$extInstsName, variable_ops),
"$res = OpExtInstImport $extInstsName">;
def OpExtInst: Op<12, (outs ID:$res), (ins TYPE:$ty, ID:$set, Extension:$inst, variable_ops),
def OpExtInst: Op<12, (outs ID:$res), (ins TYPE:$ty, i32imm:$set, Extension:$inst, variable_ops),
"$res = OpExtInst $ty $set $inst">;

// 3.42.5 Mode-Setting Instructions
Expand Down Expand Up @@ -198,9 +164,9 @@ return CurDAG->getTargetConstant(
N->getValueAP().bitcastToAPInt().getZExtValue(), SDLoc(N), MVT::i32);
}]>;

def fimm_to_i32 : SDNodeXForm<imm, [{
def fimm_to_i32 : SDNodeXForm<fpimm, [{
return CurDAG->getTargetConstant(
N->getValueAPF().bitcastToAPInt().getZExtValue(), SDLoc(N), MVT::i32);
N->getValueAPF().bitcastToAPInt().getZExtValue(), SDLoc(N), MVT::f32);
}]>;

def gi_bitcast_fimm_to_i32 : GICustomOperandRenderer<"renderFImm32">,
Expand All @@ -210,15 +176,16 @@ def gi_bitcast_imm_to_i32 : GICustomOperandRenderer<"renderImm32">,
GISDNodeXFormEquiv<imm_to_i32>;

def PseudoConstI: IntImmLeaf<i32, [{ return Imm.getBitWidth() <= 32; }], imm_to_i32>;
def PseudoConstF: FPImmLeaf<f32, [{ return true; }], fimm_to_i32>;
def PseudoConstF: FPImmLeaf<f32, [{ return true; }], fimm_to_i32>;

def ConstPseudoTrue: IntImmLeaf<i32, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 1; }]>;
def ConstPseudoFalse: IntImmLeaf<i32, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 0; }]>;
def ConstPseudoNull: IntImmLeaf<i64, [{ return Imm.isZero(); }]>;

multiclass IntFPImm<bits<16> opCode, string name> {
def I: Op<opCode, (outs ID:$dst), (ins TYPE:$type, ID:$src, variable_ops),
"$dst = "#name#" $type $src", [(set ID:$dst, (assigntype PseudoConstI:$src, TYPE:$type))]>;
def F: Op<opCode, (outs ID:$dst), (ins TYPE:$type, fID:$src, variable_ops),
def F: Op<opCode, (outs ID:$dst), (ins TYPE:$type, FID:$src, variable_ops),
"$dst = "#name#" $type $src", [(set ID:$dst, (assigntype PseudoConstF:$src, TYPE:$type))]>;
}

Expand Down Expand Up @@ -428,9 +395,8 @@ def OpBitcast : UnOp<"OpBitcast", 124>;

// 3.42.12 Composite Instructions

def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),
"$res = OpVectorExtractDynamic $type $vec $idx", [(set ID:$res, (assigntype (extractelt vID:$vec, ID:$idx), TYPE:$type))]>;

def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$ty, ID:$vec, ID:$idx),
"$res = OpVectorExtractDynamic $ty $vec $idx">;
def OpVectorInsertDynamic: Op<78, (outs ID:$res), (ins TYPE:$ty, ID:$vec, ID:$comp, ID:$idx),
"$res = OpVectorInsertDynamic $ty $vec $comp $idx">;
def OpVectorShuffle: Op<79, (outs ID:$res), (ins TYPE:$ty, ID:$v1, ID:$v2, variable_ops),
Expand All @@ -448,27 +414,27 @@ def OpCopyLogical: UnOp<"OpCopyLogical", 400>;
// 3.42.13 Arithmetic Instructions

def OpSNegate: UnOp<"OpSNegate", 126>;
def OpFNegate: UnOpTyped<"OpFNegate", 127, fID, fneg>;
def OpFNegateV: UnOpTyped<"OpFNegate", 127, vfID, fneg>;
defm OpIAdd: BinOpTypedGen<"OpIAdd", 128, add, 0, 1>;
defm OpFAdd: BinOpTypedGen<"OpFAdd", 129, fadd, 1, 1>;
def OpFNegate: UnOpTyped<"OpFNegate", 127, SID, fneg>;
def OpFNegateV: UnOpTyped<"OpFNegate", 127, VID, fneg>;
defm OpIAdd: BinOpTypedGen<"OpIAdd", 128, add, 1>;
defm OpFAdd: BinOpTypedGen<"OpFAdd", 129, fadd, 1>;

defm OpISub: BinOpTypedGen<"OpISub", 130, sub, 0, 1>;
defm OpFSub: BinOpTypedGen<"OpFSub", 131, fsub, 1, 1>;
defm OpISub: BinOpTypedGen<"OpISub", 130, sub, 1>;
defm OpFSub: BinOpTypedGen<"OpFSub", 131, fsub, 1>;

defm OpIMul: BinOpTypedGen<"OpIMul", 132, mul, 0, 1>;
defm OpFMul: BinOpTypedGen<"OpFMul", 133, fmul, 1, 1>;
defm OpIMul: BinOpTypedGen<"OpIMul", 132, mul, 1>;
defm OpFMul: BinOpTypedGen<"OpFMul", 133, fmul, 1>;

defm OpUDiv: BinOpTypedGen<"OpUDiv", 134, udiv, 0, 1>;
defm OpSDiv: BinOpTypedGen<"OpSDiv", 135, sdiv, 0, 1>;
defm OpFDiv: BinOpTypedGen<"OpFDiv", 136, fdiv, 1, 1>;
defm OpUDiv: BinOpTypedGen<"OpUDiv", 134, udiv, 1>;
defm OpSDiv: BinOpTypedGen<"OpSDiv", 135, sdiv, 1>;
defm OpFDiv: BinOpTypedGen<"OpFDiv", 136, fdiv, 1>;

defm OpUMod: BinOpTypedGen<"OpUMod", 137, urem, 0, 1>;
defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 0, 1>;
defm OpUMod: BinOpTypedGen<"OpUMod", 137, urem, 1>;
defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 1>;

def OpSMod: BinOp<"OpSMod", 139>;

defm OpFRem: BinOpTypedGen<"OpFRem", 140, frem, 1, 1>;
defm OpFRem: BinOpTypedGen<"OpFRem", 140, frem, 1>;
def OpFMod: BinOp<"OpFMod", 141>;

def OpVectorTimesScalar: BinOp<"OpVectorTimesScalar", 142>;
Expand All @@ -487,13 +453,13 @@ def OpSMulExtended: BinOp<"OpSMulExtended", 152>;

// 3.42.14 Bit Instructions

defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 0, 1>;
defm OpShiftRightArithmetic: BinOpTypedGen<"OpShiftRightArithmetic", 195, sra, 0, 1>;
defm OpShiftLeftLogical: BinOpTypedGen<"OpShiftLeftLogical", 196, shl, 0, 1>;
defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 1>;
defm OpShiftRightArithmetic: BinOpTypedGen<"OpShiftRightArithmetic", 195, sra, 1>;
defm OpShiftLeftLogical: BinOpTypedGen<"OpShiftLeftLogical", 196, shl, 1>;

defm OpBitwiseOr: BinOpTypedGen<"OpBitwiseOr", 197, or, 0, 1>;
defm OpBitwiseXor: BinOpTypedGen<"OpBitwiseXor", 198, xor, 0, 1>;
defm OpBitwiseAnd: BinOpTypedGen<"OpBitwiseAnd", 199, and, 0, 1>;
defm OpBitwiseOr: BinOpTypedGen<"OpBitwiseOr", 197, or, 1>;
defm OpBitwiseXor: BinOpTypedGen<"OpBitwiseXor", 198, xor, 1>;
defm OpBitwiseAnd: BinOpTypedGen<"OpBitwiseAnd", 199, and, 1>;
def OpNot: UnOp<"OpNot", 200>;

def OpBitFieldInsert: Op<201, (outs ID:$res),
Expand Down Expand Up @@ -531,7 +497,8 @@ def OpLogicalOr: BinOp<"OpLogicalOr", 166>;
def OpLogicalAnd: BinOp<"OpLogicalAnd", 167>;
def OpLogicalNot: UnOp<"OpLogicalNot", 168>;

defm OpSelect: TernOpTypedGen<"OpSelect", 169, select, 1, 1, 1>;
def OpSelect: Op<169, (outs ID:$dst), (ins TYPE:$src_ty, ID:$cond, ID:$src1, ID:$src2),
"$dst = OpSelect $src_ty $cond $src1 $src2">;

def OpIEqual: BinOp<"OpIEqual", 170>;
def OpINotEqual: BinOp<"OpINotEqual", 171>;
Expand Down
Loading