diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 8b618686ee7d..08a46dee4ff9 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -303,7 +303,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, // Generate a SPIR-V type for the function. auto MRI = MIRBuilder.getMRI(); Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); - MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); if (F.isDeclaration()) GR->add(&F, &MIRBuilder.getMF(), FuncVReg); SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder); @@ -313,22 +312,28 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, // Build the OpTypeFunction declaring it. uint32_t FuncControl = getFunctionControl(F); - MIRBuilder.buildInstr(SPIRV::OpFunction) - .addDef(FuncVReg) - .addUse(GR->getSPIRVTypeID(RetTy)) - .addImm(FuncControl) - .addUse(GR->getSPIRVTypeID(FuncTy)); + auto B = MIRBuilder.buildInstr(SPIRV::OpFunction) + .addDef(FuncVReg) + .addUse(GR->getSPIRVTypeID(RetTy)) + .addImm(FuncControl) + .addUse(GR->getSPIRVTypeID(FuncTy)); + const auto &ST = GR->CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*B.getInstr(), *ST.getInstrInfo(), + *ST.getRegisterInfo(), *ST.getRegBankInfo()); // Add OpFunctionParameters. int i = 0; for (const auto &Arg : F.args()) { assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs"); - MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass); - MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) - .addDef(VRegs[i][0]) - .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); + auto B = MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) + .addDef(VRegs[i][0]) + .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); if (F.isDeclaration()) GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]); + const auto &ST = GR->CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*B.getInstr(), *ST.getInstrInfo(), + *ST.getRegisterInfo(), + *ST.getRegBankInfo()); i++; } // Name the function. @@ -421,7 +426,8 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, // Make sure there's a valid return reg, even for functions returning void. if (!ResVReg.isValid()) - ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); + ResVReg = + MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32)); SPIRVType *RetType = GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 062188abbf5e..fb0dbb88c850 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -727,6 +727,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( return nullptr; TypesInProcessing.insert(Ty); SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); + const auto &ST = CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*const_cast(SpirvType), + *ST.getInstrInfo(), *ST.getRegisterInfo(), + *ST.getRegBankInfo()); TypesInProcessing.erase(Ty); VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; SPIRVToLLVMType[SpirvType] = Ty; @@ -972,6 +976,10 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy, VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; SPIRVToLLVMType[SpirvType] = LLVMTy; DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType)); + const auto &ST = CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*const_cast(SpirvType), + *ST.getInstrInfo(), *ST.getRegisterInfo(), + *ST.getRegBankInfo()); return SpirvType; } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp index 42317453a237..bd512a6690fc 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp @@ -245,9 +245,8 @@ void SPIRVInstrInfo::copyPhysReg(MachineBasicBlock &MBB, } bool SPIRVInstrInfo::expandPostRAPseudo(MachineInstr &MI) const { - if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_fID || - MI.getOpcode() == SPIRV::GET_pID || MI.getOpcode() == SPIRV::GET_vfID || - MI.getOpcode() == SPIRV::GET_vID) { + if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_SID || + MI.getOpcode() == SPIRV::GET_VID) { auto &MRI = MI.getMF()->getRegInfo(); MRI.replaceRegWith(MI.getOperand(0).getReg(), MI.getOperand(1).getReg()); MI.eraseFromParent(); diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index 44b5536becf7..14e4b6bcd374 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -15,13 +15,10 @@ include "SPIRVSymbolicOperands.td" // Codegen only metadata instructions let isCodeGenOnly=1 in { - def ASSIGN_TYPE: Pseudo<(outs ANYID:$dst_id), (ins ANYID:$src_id, TYPE:$src_ty)>; - def DECL_TYPE: Pseudo<(outs ANYID:$dst_id), (ins ANYID:$src_id, TYPE:$src_ty)>; + def ASSIGN_TYPE: Pseudo<(outs ANYID:$dst_id), (ins unknown:$src_id, TYPE:$src_ty)>; def GET_ID: Pseudo<(outs ID:$dst_id), (ins ANYID:$src)>; - def GET_fID: Pseudo<(outs fID:$dst_id), (ins ANYID:$src)>; - def GET_pID: Pseudo<(outs pID:$dst_id), (ins ANYID:$src)>; - def GET_vID: Pseudo<(outs vID:$dst_id), (ins ANYID:$src)>; - def GET_vfID: Pseudo<(outs vfID:$dst_id), (ins ANYID:$src)>; + def GET_SID: Pseudo<(outs SID:$dst_id), (ins ANYID:$src)>; + def GET_VID: Pseudo<(outs VID:$dst_id), (ins ANYID:$src)>; } def SPVTypeBin : SDTypeProfile<1, 2, []>; @@ -38,41 +35,10 @@ class BinOpTyped opCode, RegisterClass CID, SDNode node> : Op; -class TernOpTyped opCode, RegisterClass CCond, RegisterClass CID, SDNode node> - : Op; - -multiclass BinOpTypedGen opCode, SDNode node, bit genF = 0, bit genV = 0> { - if genF then - def S: BinOpTyped; - else - def S: BinOpTyped; +multiclass BinOpTypedGen opCode, SDNode node, bit genV = 0> { + def S: BinOpTyped; if genV then { - if genF then - def V: BinOpTyped; - else - def V: BinOpTyped; - } -} - -multiclass TernOpTypedGen opCode, SDNode node, bit genI = 1, bit genF = 0, bit genV = 0> { - if genF then { - def SFSCond: TernOpTyped; - def SFVCond: TernOpTyped; - } - if genI then { - def SISCond: TernOpTyped; - def SIVCond: TernOpTyped; - } - if genV then { - if genF then { - def VFSCond: TernOpTyped; - def VFVCond: TernOpTyped; - } - if genI then { - def VISCond: TernOpTyped; - def VIVCond: TernOpTyped; - } + def V: BinOpTyped; } } @@ -99,7 +65,7 @@ def OpSource: Op<3, (outs), (ins SourceLanguage:$lang, i32imm:$version, variable "OpSource $lang $version">; def OpSourceExtension: Op<4, (outs), (ins StringImm:$extension, variable_ops), "OpSourceExtension $extension">; -def OpName: Op<5, (outs), (ins ANY:$tar, StringImm:$name, variable_ops), "OpName $tar $name">; +def OpName: Op<5, (outs), (ins ANYID:$tar, StringImm:$name, variable_ops), "OpName $tar $name">; def OpMemberName: Op<6, (outs), (ins TYPE:$ty, i32imm:$mem, StringImm:$name, variable_ops), "OpMemberName $ty $mem $name">; def OpString: Op<7, (outs ID:$r), (ins StringImm:$s, variable_ops), "$r = OpString $s">; @@ -110,7 +76,7 @@ def OpModuleProcessed: Op<330, (outs), (ins StringImm:$process, variable_ops), // 3.42.3 Annotation Instructions -def OpDecorate: Op<71, (outs), (ins ANY:$target, Decoration:$dec, variable_ops), +def OpDecorate: Op<71, (outs), (ins ANYID:$target, Decoration:$dec, variable_ops), "OpDecorate $target $dec">; def OpMemberDecorate: Op<72, (outs), (ins TYPE:$t, i32imm:$m, Decoration:$d, variable_ops), "OpMemberDecorate $t $m $d">; @@ -118,9 +84,9 @@ def OpMemberDecorate: Op<72, (outs), (ins TYPE:$t, i32imm:$m, Decoration:$d, var // TODO Currently some deprecated opcodes are missing: OpDecorationGroup, // OpGroupDecorate and OpGroupMemberDecorate -def OpDecorateId: Op<332, (outs), (ins ANY:$target, Decoration:$dec, variable_ops), +def OpDecorateId: Op<332, (outs), (ins ANYID:$target, Decoration:$dec, variable_ops), "OpDecorateId $target $dec">; -def OpDecorateString: Op<5632, (outs), (ins ANY:$t, Decoration:$d, StringImm:$s, variable_ops), +def OpDecorateString: Op<5632, (outs), (ins ANYID:$t, Decoration:$d, StringImm:$s, variable_ops), "OpDecorateString $t $d $s">; def OpMemberDecorateString: Op<5633, (outs), (ins TYPE:$ty, i32imm:$mem, Decoration:$dec, StringImm:$str, variable_ops), @@ -131,7 +97,7 @@ def OpMemberDecorateString: Op<5633, (outs), def OpExtension: Op<10, (outs), (ins StringImm:$name, variable_ops), "OpExtension $name">; def OpExtInstImport: Op<11, (outs ID:$res), (ins StringImm:$extInstsName, variable_ops), "$res = OpExtInstImport $extInstsName">; -def OpExtInst: Op<12, (outs ID:$res), (ins TYPE:$ty, ID:$set, Extension:$inst, variable_ops), +def OpExtInst: Op<12, (outs ID:$res), (ins TYPE:$ty, i32imm:$set, Extension:$inst, variable_ops), "$res = OpExtInst $ty $set $inst">; // 3.42.5 Mode-Setting Instructions @@ -198,9 +164,9 @@ return CurDAG->getTargetConstant( N->getValueAP().bitcastToAPInt().getZExtValue(), SDLoc(N), MVT::i32); }]>; -def fimm_to_i32 : SDNodeXFormgetTargetConstant( - N->getValueAPF().bitcastToAPInt().getZExtValue(), SDLoc(N), MVT::i32); + N->getValueAPF().bitcastToAPInt().getZExtValue(), SDLoc(N), MVT::f32); }]>; def gi_bitcast_fimm_to_i32 : GICustomOperandRenderer<"renderFImm32">, @@ -210,7 +176,8 @@ def gi_bitcast_imm_to_i32 : GICustomOperandRenderer<"renderImm32">, GISDNodeXFormEquiv; def PseudoConstI: IntImmLeaf; -def PseudoConstF: FPImmLeaf; +def PseudoConstF: FPImmLeaf; + def ConstPseudoTrue: IntImmLeaf; def ConstPseudoFalse: IntImmLeaf; def ConstPseudoNull: IntImmLeaf; @@ -218,7 +185,7 @@ def ConstPseudoNull: IntImmLeaf; multiclass IntFPImm opCode, string name> { def I: Op; - def F: Op; } @@ -428,9 +395,8 @@ def OpBitcast : UnOp<"OpBitcast", 124>; // 3.42.12 Composite Instructions -def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx), - "$res = OpVectorExtractDynamic $type $vec $idx", [(set ID:$res, (assigntype (extractelt vID:$vec, ID:$idx), TYPE:$type))]>; - +def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$ty, ID:$vec, ID:$idx), + "$res = OpVectorExtractDynamic $ty $vec $idx">; def OpVectorInsertDynamic: Op<78, (outs ID:$res), (ins TYPE:$ty, ID:$vec, ID:$comp, ID:$idx), "$res = OpVectorInsertDynamic $ty $vec $comp $idx">; def OpVectorShuffle: Op<79, (outs ID:$res), (ins TYPE:$ty, ID:$v1, ID:$v2, variable_ops), @@ -448,27 +414,27 @@ def OpCopyLogical: UnOp<"OpCopyLogical", 400>; // 3.42.13 Arithmetic Instructions def OpSNegate: UnOp<"OpSNegate", 126>; -def OpFNegate: UnOpTyped<"OpFNegate", 127, fID, fneg>; -def OpFNegateV: UnOpTyped<"OpFNegate", 127, vfID, fneg>; -defm OpIAdd: BinOpTypedGen<"OpIAdd", 128, add, 0, 1>; -defm OpFAdd: BinOpTypedGen<"OpFAdd", 129, fadd, 1, 1>; +def OpFNegate: UnOpTyped<"OpFNegate", 127, SID, fneg>; +def OpFNegateV: UnOpTyped<"OpFNegate", 127, VID, fneg>; +defm OpIAdd: BinOpTypedGen<"OpIAdd", 128, add, 1>; +defm OpFAdd: BinOpTypedGen<"OpFAdd", 129, fadd, 1>; -defm OpISub: BinOpTypedGen<"OpISub", 130, sub, 0, 1>; -defm OpFSub: BinOpTypedGen<"OpFSub", 131, fsub, 1, 1>; +defm OpISub: BinOpTypedGen<"OpISub", 130, sub, 1>; +defm OpFSub: BinOpTypedGen<"OpFSub", 131, fsub, 1>; -defm OpIMul: BinOpTypedGen<"OpIMul", 132, mul, 0, 1>; -defm OpFMul: BinOpTypedGen<"OpFMul", 133, fmul, 1, 1>; +defm OpIMul: BinOpTypedGen<"OpIMul", 132, mul, 1>; +defm OpFMul: BinOpTypedGen<"OpFMul", 133, fmul, 1>; -defm OpUDiv: BinOpTypedGen<"OpUDiv", 134, udiv, 0, 1>; -defm OpSDiv: BinOpTypedGen<"OpSDiv", 135, sdiv, 0, 1>; -defm OpFDiv: BinOpTypedGen<"OpFDiv", 136, fdiv, 1, 1>; +defm OpUDiv: BinOpTypedGen<"OpUDiv", 134, udiv, 1>; +defm OpSDiv: BinOpTypedGen<"OpSDiv", 135, sdiv, 1>; +defm OpFDiv: BinOpTypedGen<"OpFDiv", 136, fdiv, 1>; -defm OpUMod: BinOpTypedGen<"OpUMod", 137, urem, 0, 1>; -defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 0, 1>; +defm OpUMod: BinOpTypedGen<"OpUMod", 137, urem, 1>; +defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 1>; def OpSMod: BinOp<"OpSMod", 139>; -defm OpFRem: BinOpTypedGen<"OpFRem", 140, frem, 1, 1>; +defm OpFRem: BinOpTypedGen<"OpFRem", 140, frem, 1>; def OpFMod: BinOp<"OpFMod", 141>; def OpVectorTimesScalar: BinOp<"OpVectorTimesScalar", 142>; @@ -487,13 +453,13 @@ def OpSMulExtended: BinOp<"OpSMulExtended", 152>; // 3.42.14 Bit Instructions -defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 0, 1>; -defm OpShiftRightArithmetic: BinOpTypedGen<"OpShiftRightArithmetic", 195, sra, 0, 1>; -defm OpShiftLeftLogical: BinOpTypedGen<"OpShiftLeftLogical", 196, shl, 0, 1>; +defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 1>; +defm OpShiftRightArithmetic: BinOpTypedGen<"OpShiftRightArithmetic", 195, sra, 1>; +defm OpShiftLeftLogical: BinOpTypedGen<"OpShiftLeftLogical", 196, shl, 1>; -defm OpBitwiseOr: BinOpTypedGen<"OpBitwiseOr", 197, or, 0, 1>; -defm OpBitwiseXor: BinOpTypedGen<"OpBitwiseXor", 198, xor, 0, 1>; -defm OpBitwiseAnd: BinOpTypedGen<"OpBitwiseAnd", 199, and, 0, 1>; +defm OpBitwiseOr: BinOpTypedGen<"OpBitwiseOr", 197, or, 1>; +defm OpBitwiseXor: BinOpTypedGen<"OpBitwiseXor", 198, xor, 1>; +defm OpBitwiseAnd: BinOpTypedGen<"OpBitwiseAnd", 199, and, 1>; def OpNot: UnOp<"OpNot", 200>; def OpBitFieldInsert: Op<201, (outs ID:$res), @@ -531,7 +497,8 @@ def OpLogicalOr: BinOp<"OpLogicalOr", 166>; def OpLogicalAnd: BinOp<"OpLogicalAnd", 167>; def OpLogicalNot: UnOp<"OpLogicalNot", 168>; -defm OpSelect: TernOpTypedGen<"OpSelect", 169, select, 1, 1, 1>; +def OpSelect: Op<169, (outs ID:$dst), (ins TYPE:$src_ty, ID:$cond, ID:$src1, ID:$src2), + "$dst = OpSelect $src_ty $cond $src1 $src2">; def OpIEqual: BinOp<"OpIEqual", 170>; def OpINotEqual: BinOp<"OpINotEqual", 171>; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 5507e9254fae..704ad4d3f8f3 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -127,8 +127,10 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectConst(Register ResVReg, const SPIRVType *ResType, const APInt &Imm, MachineInstr &I) const; - bool selectSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, - bool IsSigned) const; + bool selectBoolSelect(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I, bool IsSigned) const; + bool selectSelect(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; bool selectIToF(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, bool IsSigned, unsigned Opcode) const; bool selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, @@ -178,6 +180,8 @@ class SPIRVInstructionSelector : public InstructionSelector { Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const; Register buildOnesVal(bool AllOnes, const SPIRVType *ResType, MachineInstr &I) const; + bool buildSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + Register TrueValue, Register FalseValue) const; }; } // end anonymous namespace @@ -319,6 +323,9 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg, case TargetOpcode::G_PHI: return selectPhi(ResVReg, ResType, I); + case TargetOpcode::G_SELECT: + return selectSelect(ResVReg, ResType, I); + case TargetOpcode::G_FPTOSI: return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertFToS); case TargetOpcode::G_FPTOUI: @@ -1087,23 +1094,34 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes, return GR.getOrCreateConstInt(One.getZExtValue(), I, ResType, TII); } -bool SPIRVInstructionSelector::selectSelect(Register ResVReg, - const SPIRVType *ResType, - MachineInstr &I, - bool IsSigned) const { +bool SPIRVInstructionSelector::selectBoolSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + bool IsSigned) const { // To extend a bool, we need to use OpSelect between constants. Register ZeroReg = buildZerosVal(ResType, I); Register OneReg = buildOnesVal(IsSigned, ResType, I); - bool IsScalarBool = - GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool); - unsigned Opcode = - IsScalarBool ? SPIRV::OpSelectSISCond : SPIRV::OpSelectSIVCond; - return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode)) + return buildSelect(ResVReg, ResType, I, OneReg, ZeroReg); +} + +bool SPIRVInstructionSelector::selectSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + return buildSelect(ResVReg, ResType, I, I.getOperand(2).getReg(), + I.getOperand(3).getReg()); +} + +bool SPIRVInstructionSelector::buildSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, Register TrueValue, + Register FalseValue) const { + // To extend a bool, we need to use OpSelect between constants. + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpSelect)) .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)) .addUse(I.getOperand(1).getReg()) - .addUse(OneReg) - .addUse(ZeroReg) + .addUse(TrueValue) + .addUse(FalseValue) .constrainAllUses(TII, TRI, RBI); } @@ -1122,7 +1140,7 @@ bool SPIRVInstructionSelector::selectIToF(Register ResVReg, TmpType = GR.getOrCreateSPIRVVectorType(TmpType, NumElts, I, TII); } SrcReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); - selectSelect(SrcReg, TmpType, I, false); + selectBoolSelect(SrcReg, TmpType, I, false); } return selectUnOpWithSrc(ResVReg, ResType, I, SrcReg, Opcode); } @@ -1131,7 +1149,7 @@ bool SPIRVInstructionSelector::selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, bool IsSigned) const { if (GR.isScalarOrVectorOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool)) - return selectSelect(ResVReg, ResType, I, IsSigned); + return selectBoolSelect(ResVReg, ResType, I, IsSigned); unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert; return selectUnOp(ResVReg, ResType, I, Opcode); } diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index b0028f8c80a4..dd1b8a0feb86 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -46,7 +46,6 @@ static const std::set TypeFoldingSupportingOpcs = { TargetOpcode::G_SHL, TargetOpcode::G_ASHR, TargetOpcode::G_LSHR, - TargetOpcode::G_SELECT, TargetOpcode::G_EXTRACT_VECTOR_ELT, }; @@ -199,6 +198,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { all(typeInSet(0, allBoolScalarsAndVectors), typeInSet(1, allFloatScalarsAndVectors))); + getActionDefinitionsBuilder(G_SELECT).legalIf(all( + typeInSet(0, allScalarsAndVectors), typeInSet(1, allScalarsAndVectors))); + getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, G_ATOMICRMW_SUB, G_ATOMICRMW_XOR, diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index 1e664ca6cfcd..8394fc853ce6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -179,8 +179,6 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR, } if (SpirvTy) GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF()); - if (!MRI.getRegClassOrNull(Reg)) - MRI.setRegClass(Reg, &SPIRV::IDRegClass); } } return SpirvTy; @@ -201,8 +199,6 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, (Def->getNextNode() ? Def->getNextNode()->getIterator() : Def->getParent()->end())); Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg)); - if (auto *RC = MRI.getRegClassOrNull(Reg)) - MRI.setRegClass(NewReg, RC); SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB); GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF()); // This is to make it convenient for Legalizer to get the SPIRVType @@ -211,13 +207,15 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep // the flags after instruction selection. const uint16_t Flags = Def->getFlags(); - MIB.buildInstr(SPIRV::ASSIGN_TYPE) - .addDef(Reg) - .addUse(NewReg) - .addUse(GR->getSPIRVTypeID(SpirvTy)) - .setMIFlags(Flags); + auto B = MIB.buildInstr(SPIRV::ASSIGN_TYPE) + .addDef(Reg) + .addUse(NewReg) + .addUse(GR->getSPIRVTypeID(SpirvTy)) + .setMIFlags(Flags); Def->getOperand(0).setReg(NewReg); - MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass); + const auto &ST = GR->CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*B.getInstr(), *ST.getInstrInfo(), + *ST.getRegisterInfo(), *ST.getRegBankInfo()); return NewReg; } } // namespace llvm @@ -305,25 +303,12 @@ createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI, LLT NewT = LLT::scalar(32); SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg); assert(SpvType && "VReg is expected to have SPIRV type"); - bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat; - bool IsVectorFloat = - SpvType->getOpcode() == SPIRV::OpTypeVector && - GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() == - SPIRV::OpTypeFloat; - IsFloat |= IsVectorFloat; - auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID; - auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass; - if (MRI.getType(ValReg).isPointer()) { - NewT = LLT::pointer(0, 32); - GetIdOp = SPIRV::GET_pID; - DstClass = &SPIRV::pIDRegClass; - } else if (MRI.getType(ValReg).isVector()) { + auto GetIdOp = SPIRV::GET_SID; + if (MRI.getType(ValReg).isVector()) { NewT = LLT::fixed_vector(2, NewT); - GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID; - DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass; + GetIdOp = SPIRV::GET_VID; } Register IdReg = MRI.createGenericVirtualRegister(NewT); - MRI.setRegClass(IdReg, DstClass); return {IdReg, GetIdOp}; } @@ -336,15 +321,20 @@ static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB, auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first; AssignTypeInst.getOperand(1).setReg(NewReg); MI.getOperand(0).setReg(NewReg); - MIB.setInsertPt(*MI.getParent(), - (MI.getNextNode() ? MI.getNextNode()->getIterator() - : MI.getParent()->end())); + MIB.setInsertPt(*MI.getParent(), MI.getIterator()); for (auto &Op : MI.operands()) { if (!Op.isReg() || Op.isDef()) continue; auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR); - MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg()); + auto B = MIB.buildInstr(IdOpInfo.second) + .addDef(IdOpInfo.first) + .addUse(Op.getReg()); Op.setReg(IdOpInfo.first); + + const auto &ST = GR->CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*B.getInstr(), *ST.getInstrInfo(), + *ST.getRegisterInfo(), + *ST.getRegBankInfo()); } } @@ -373,8 +363,6 @@ static void processInstrsWithTypeFolding(MachineFunction &MF, if (!isTypeFoldingSupported(Opcode)) continue; Register DstReg = MI.getOperand(0).getReg(); - if (MRI.getType(DstReg).isVector()) - MRI.setRegClass(DstReg, &SPIRV::IDRegClass); // Don't need to reset type of register holding constant and used in // G_ADDRSPACE_CAST, since it braaks legalizer. if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) { diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp index 9bf9d7fe5b39..7352ac56847e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp @@ -27,21 +27,5 @@ using namespace llvm; const RegisterBank & SPIRVRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC, LLT Ty) const { - switch (RC.getID()) { - case SPIRV::TYPERegClassID: - return SPIRV::TYPERegBank; - case SPIRV::pIDRegClassID: - case SPIRV::IDRegClassID: - return SPIRV::IDRegBank; - case SPIRV::fIDRegClassID: - return SPIRV::fIDRegBank; - case SPIRV::vIDRegClassID: - return SPIRV::vIDRegBank; - case SPIRV::vfIDRegClassID: - return SPIRV::vfIDRegBank; - case SPIRV::ANYIDRegClassID: - case SPIRV::ANYRegClassID: - return SPIRV::IDRegBank; - } - llvm_unreachable("Unknown register class"); + return SPIRV::IDRegBank; } diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td b/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td index 90c7f3a6e672..e489af857c65 100644 --- a/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td @@ -6,10 +6,6 @@ // //===----------------------------------------------------------------------===// -// Although RegisterBankSelection is disabled we need to distinct the banks -// as InstructionSelector RegClass checking code relies on them -def IDRegBank : RegisterBank<"IDBank", [ID]>; -def fIDRegBank : RegisterBank<"fIDBank", [fID]>; -def vIDRegBank : RegisterBank<"vIDBank", [vID]>; -def vfIDRegBank : RegisterBank<"vfIDBank", [vfID]>; -def TYPERegBank : RegisterBank<"TYPEBank", [TYPE]>; +// Although RegBankSelect is disabled, we need to have +// at least one regbank to support instruction selector +def IDRegBank : RegisterBank<"IDBank", [ID]>; \ No newline at end of file diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td index d0b64b6895d0..152ee291d7e7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td @@ -12,7 +12,6 @@ let Namespace = "SPIRV" in { def p0 : PtrValueType ; - // All registers are for 32-bit identifiers, so have a single dummy register // Class for registers that are the result of OpTypeXXX instructions def TYPE0 : Register<"TYPE0">; @@ -21,19 +20,16 @@ let Namespace = "SPIRV" in { // Class for every other non-type ID def ID0 : Register<"ID0">; def ID : RegisterClass<"SPIRV", [i32], 32, (add ID0)>; - def fID0 : Register<"FID0">; - def fID : RegisterClass<"SPIRV", [f32], 32, (add fID0)>; - def pID0 : Register<"pID0">; - def pID : RegisterClass<"SPIRV", [p0], 32, (add pID0)>; - def vID0 : Register<"pID0">; - def vID : RegisterClass<"SPIRV", [v2i32], 32, (add vID0)>; - def vfID0 : Register<"pID0">; - def vfID : RegisterClass<"SPIRV", [v2f32], 32, (add vfID0)>; + def VID0 : Register<"VID0">; - def ANYID : RegisterClass<"SPIRV", [i32, f32, p0, v2i32, v2f32], 32, (add ID, fID, pID, vID, vfID)>; + // TODO: FID register class is only needed for constants selection, + // consider redesigning the pattern to get rid of this + def FID0 : Register<"FID0">; + def FID : RegisterClass<"SPIRV", [f32], 32, (add FID0)>; - // A few instructions like OpName can take ids from both type and non-type - // instructions, so we need a super-class to allow for both to count as valid - // arguments for these instructions. - def ANY : RegisterClass<"SPIRV", [i32], 32, (add TYPE, ID)>; + // Scalar ID + def SID : RegisterClass<"SPIRV", [i32, f32, p0], 32, (add ID0)>; + // Vector ID + def VID : RegisterClass<"SPIRV", [v2i32, v2f32], 32, (add VID0)>; + def ANYID : RegisterClass<"SPIRV", [i32, f32, p0, v2i32, v2f32], 32, (add SID, VID)>; }