From cd671d39b073a50bf5ea5f2aecacd24e908e1008 Mon Sep 17 00:00:00 2001 From: Aleksandr Bezzubikov Date: Wed, 12 Apr 2023 19:08:37 -0700 Subject: [PATCH] [SPIR-V] Start cleaning up register classes. In order to fix EXPENSIVE_CHECKS failures the hierarchy of register classes for SPIRV target needs to be adjusted. This patch removes redundant FP regclasses since they effectively mimic type inference for existing generic FP instructions. Also this change begins replacing explicit setRegClass calls with simply constraining instruction operands. --- llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 28 +++-- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 8 ++ llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp | 5 +- llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 113 +++++++----------- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 48 +++++--- llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 4 +- llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 52 ++++---- .../Target/SPIRV/SPIRVRegisterBankInfo.cpp | 18 +-- llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td | 10 +- llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td | 24 ++-- 10 files changed, 137 insertions(+), 173 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 8b618686ee7d..08a46dee4ff9 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -303,7 +303,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, // Generate a SPIR-V type for the function. auto MRI = MIRBuilder.getMRI(); Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); - MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); if (F.isDeclaration()) GR->add(&F, &MIRBuilder.getMF(), FuncVReg); SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder); @@ -313,22 +312,28 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, // Build the OpTypeFunction declaring it. uint32_t FuncControl = getFunctionControl(F); - MIRBuilder.buildInstr(SPIRV::OpFunction) - .addDef(FuncVReg) - .addUse(GR->getSPIRVTypeID(RetTy)) - .addImm(FuncControl) - .addUse(GR->getSPIRVTypeID(FuncTy)); + auto B = MIRBuilder.buildInstr(SPIRV::OpFunction) + .addDef(FuncVReg) + .addUse(GR->getSPIRVTypeID(RetTy)) + .addImm(FuncControl) + .addUse(GR->getSPIRVTypeID(FuncTy)); + const auto &ST = GR->CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*B.getInstr(), *ST.getInstrInfo(), + *ST.getRegisterInfo(), *ST.getRegBankInfo()); // Add OpFunctionParameters. int i = 0; for (const auto &Arg : F.args()) { assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs"); - MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass); - MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) - .addDef(VRegs[i][0]) - .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); + auto B = MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) + .addDef(VRegs[i][0]) + .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); if (F.isDeclaration()) GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]); + const auto &ST = GR->CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*B.getInstr(), *ST.getInstrInfo(), + *ST.getRegisterInfo(), + *ST.getRegBankInfo()); i++; } // Name the function. @@ -421,7 +426,8 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, // Make sure there's a valid return reg, even for functions returning void. if (!ResVReg.isValid()) - ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); + ResVReg = + MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32)); SPIRVType *RetType = GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 062188abbf5e..fb0dbb88c850 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -727,6 +727,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( return nullptr; TypesInProcessing.insert(Ty); SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); + const auto &ST = CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*const_cast(SpirvType), + *ST.getInstrInfo(), *ST.getRegisterInfo(), + *ST.getRegBankInfo()); TypesInProcessing.erase(Ty); VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; SPIRVToLLVMType[SpirvType] = Ty; @@ -972,6 +976,10 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy, VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; SPIRVToLLVMType[SpirvType] = LLVMTy; DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType)); + const auto &ST = CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*const_cast(SpirvType), + *ST.getInstrInfo(), *ST.getRegisterInfo(), + *ST.getRegBankInfo()); return SpirvType; } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp index 42317453a237..bd512a6690fc 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp @@ -245,9 +245,8 @@ void SPIRVInstrInfo::copyPhysReg(MachineBasicBlock &MBB, } bool SPIRVInstrInfo::expandPostRAPseudo(MachineInstr &MI) const { - if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_fID || - MI.getOpcode() == SPIRV::GET_pID || MI.getOpcode() == SPIRV::GET_vfID || - MI.getOpcode() == SPIRV::GET_vID) { + if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_SID || + MI.getOpcode() == SPIRV::GET_VID) { auto &MRI = MI.getMF()->getRegInfo(); MRI.replaceRegWith(MI.getOperand(0).getReg(), MI.getOperand(1).getReg()); MI.eraseFromParent(); diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index 44b5536becf7..14e4b6bcd374 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -15,13 +15,10 @@ include "SPIRVSymbolicOperands.td" // Codegen only metadata instructions let isCodeGenOnly=1 in { - def ASSIGN_TYPE: Pseudo<(outs ANYID:$dst_id), (ins ANYID:$src_id, TYPE:$src_ty)>; - def DECL_TYPE: Pseudo<(outs ANYID:$dst_id), (ins ANYID:$src_id, TYPE:$src_ty)>; + def ASSIGN_TYPE: Pseudo<(outs ANYID:$dst_id), (ins unknown:$src_id, TYPE:$src_ty)>; def GET_ID: Pseudo<(outs ID:$dst_id), (ins ANYID:$src)>; - def GET_fID: Pseudo<(outs fID:$dst_id), (ins ANYID:$src)>; - def GET_pID: Pseudo<(outs pID:$dst_id), (ins ANYID:$src)>; - def GET_vID: Pseudo<(outs vID:$dst_id), (ins ANYID:$src)>; - def GET_vfID: Pseudo<(outs vfID:$dst_id), (ins ANYID:$src)>; + def GET_SID: Pseudo<(outs SID:$dst_id), (ins ANYID:$src)>; + def GET_VID: Pseudo<(outs VID:$dst_id), (ins ANYID:$src)>; } def SPVTypeBin : SDTypeProfile<1, 2, []>; @@ -38,41 +35,10 @@ class BinOpTyped opCode, RegisterClass CID, SDNode node> : Op; -class TernOpTyped opCode, RegisterClass CCond, RegisterClass CID, SDNode node> - : Op; - -multiclass BinOpTypedGen opCode, SDNode node, bit genF = 0, bit genV = 0> { - if genF then - def S: BinOpTyped; - else - def S: BinOpTyped; +multiclass BinOpTypedGen opCode, SDNode node, bit genV = 0> { + def S: BinOpTyped; if genV then { - if genF then - def V: BinOpTyped; - else - def V: BinOpTyped; - } -} - -multiclass TernOpTypedGen opCode, SDNode node, bit genI = 1, bit genF = 0, bit genV = 0> { - if genF then { - def SFSCond: TernOpTyped; - def SFVCond: TernOpTyped; - } - if genI then { - def SISCond: TernOpTyped; - def SIVCond: TernOpTyped; - } - if genV then { - if genF then { - def VFSCond: TernOpTyped; - def VFVCond: TernOpTyped; - } - if genI then { - def VISCond: TernOpTyped; - def VIVCond: TernOpTyped; - } + def V: BinOpTyped; } } @@ -99,7 +65,7 @@ def OpSource: Op<3, (outs), (ins SourceLanguage:$lang, i32imm:$version, variable "OpSource $lang $version">; def OpSourceExtension: Op<4, (outs), (ins StringImm:$extension, variable_ops), "OpSourceExtension $extension">; -def OpName: Op<5, (outs), (ins ANY:$tar, StringImm:$name, variable_ops), "OpName $tar $name">; +def OpName: Op<5, (outs), (ins ANYID:$tar, StringImm:$name, variable_ops), "OpName $tar $name">; def OpMemberName: Op<6, (outs), (ins TYPE:$ty, i32imm:$mem, StringImm:$name, variable_ops), "OpMemberName $ty $mem $name">; def OpString: Op<7, (outs ID:$r), (ins StringImm:$s, variable_ops), "$r = OpString $s">; @@ -110,7 +76,7 @@ def OpModuleProcessed: Op<330, (outs), (ins StringImm:$process, variable_ops), // 3.42.3 Annotation Instructions -def OpDecorate: Op<71, (outs), (ins ANY:$target, Decoration:$dec, variable_ops), +def OpDecorate: Op<71, (outs), (ins ANYID:$target, Decoration:$dec, variable_ops), "OpDecorate $target $dec">; def OpMemberDecorate: Op<72, (outs), (ins TYPE:$t, i32imm:$m, Decoration:$d, variable_ops), "OpMemberDecorate $t $m $d">; @@ -118,9 +84,9 @@ def OpMemberDecorate: Op<72, (outs), (ins TYPE:$t, i32imm:$m, Decoration:$d, var // TODO Currently some deprecated opcodes are missing: OpDecorationGroup, // OpGroupDecorate and OpGroupMemberDecorate -def OpDecorateId: Op<332, (outs), (ins ANY:$target, Decoration:$dec, variable_ops), +def OpDecorateId: Op<332, (outs), (ins ANYID:$target, Decoration:$dec, variable_ops), "OpDecorateId $target $dec">; -def OpDecorateString: Op<5632, (outs), (ins ANY:$t, Decoration:$d, StringImm:$s, variable_ops), +def OpDecorateString: Op<5632, (outs), (ins ANYID:$t, Decoration:$d, StringImm:$s, variable_ops), "OpDecorateString $t $d $s">; def OpMemberDecorateString: Op<5633, (outs), (ins TYPE:$ty, i32imm:$mem, Decoration:$dec, StringImm:$str, variable_ops), @@ -131,7 +97,7 @@ def OpMemberDecorateString: Op<5633, (outs), def OpExtension: Op<10, (outs), (ins StringImm:$name, variable_ops), "OpExtension $name">; def OpExtInstImport: Op<11, (outs ID:$res), (ins StringImm:$extInstsName, variable_ops), "$res = OpExtInstImport $extInstsName">; -def OpExtInst: Op<12, (outs ID:$res), (ins TYPE:$ty, ID:$set, Extension:$inst, variable_ops), +def OpExtInst: Op<12, (outs ID:$res), (ins TYPE:$ty, i32imm:$set, Extension:$inst, variable_ops), "$res = OpExtInst $ty $set $inst">; // 3.42.5 Mode-Setting Instructions @@ -198,9 +164,9 @@ return CurDAG->getTargetConstant( N->getValueAP().bitcastToAPInt().getZExtValue(), SDLoc(N), MVT::i32); }]>; -def fimm_to_i32 : SDNodeXFormgetTargetConstant( - N->getValueAPF().bitcastToAPInt().getZExtValue(), SDLoc(N), MVT::i32); + N->getValueAPF().bitcastToAPInt().getZExtValue(), SDLoc(N), MVT::f32); }]>; def gi_bitcast_fimm_to_i32 : GICustomOperandRenderer<"renderFImm32">, @@ -210,7 +176,8 @@ def gi_bitcast_imm_to_i32 : GICustomOperandRenderer<"renderImm32">, GISDNodeXFormEquiv; def PseudoConstI: IntImmLeaf; -def PseudoConstF: FPImmLeaf; +def PseudoConstF: FPImmLeaf; + def ConstPseudoTrue: IntImmLeaf; def ConstPseudoFalse: IntImmLeaf; def ConstPseudoNull: IntImmLeaf; @@ -218,7 +185,7 @@ def ConstPseudoNull: IntImmLeaf; multiclass IntFPImm opCode, string name> { def I: Op; - def F: Op; } @@ -428,9 +395,8 @@ def OpBitcast : UnOp<"OpBitcast", 124>; // 3.42.12 Composite Instructions -def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx), - "$res = OpVectorExtractDynamic $type $vec $idx", [(set ID:$res, (assigntype (extractelt vID:$vec, ID:$idx), TYPE:$type))]>; - +def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$ty, ID:$vec, ID:$idx), + "$res = OpVectorExtractDynamic $ty $vec $idx">; def OpVectorInsertDynamic: Op<78, (outs ID:$res), (ins TYPE:$ty, ID:$vec, ID:$comp, ID:$idx), "$res = OpVectorInsertDynamic $ty $vec $comp $idx">; def OpVectorShuffle: Op<79, (outs ID:$res), (ins TYPE:$ty, ID:$v1, ID:$v2, variable_ops), @@ -448,27 +414,27 @@ def OpCopyLogical: UnOp<"OpCopyLogical", 400>; // 3.42.13 Arithmetic Instructions def OpSNegate: UnOp<"OpSNegate", 126>; -def OpFNegate: UnOpTyped<"OpFNegate", 127, fID, fneg>; -def OpFNegateV: UnOpTyped<"OpFNegate", 127, vfID, fneg>; -defm OpIAdd: BinOpTypedGen<"OpIAdd", 128, add, 0, 1>; -defm OpFAdd: BinOpTypedGen<"OpFAdd", 129, fadd, 1, 1>; +def OpFNegate: UnOpTyped<"OpFNegate", 127, SID, fneg>; +def OpFNegateV: UnOpTyped<"OpFNegate", 127, VID, fneg>; +defm OpIAdd: BinOpTypedGen<"OpIAdd", 128, add, 1>; +defm OpFAdd: BinOpTypedGen<"OpFAdd", 129, fadd, 1>; -defm OpISub: BinOpTypedGen<"OpISub", 130, sub, 0, 1>; -defm OpFSub: BinOpTypedGen<"OpFSub", 131, fsub, 1, 1>; +defm OpISub: BinOpTypedGen<"OpISub", 130, sub, 1>; +defm OpFSub: BinOpTypedGen<"OpFSub", 131, fsub, 1>; -defm OpIMul: BinOpTypedGen<"OpIMul", 132, mul, 0, 1>; -defm OpFMul: BinOpTypedGen<"OpFMul", 133, fmul, 1, 1>; +defm OpIMul: BinOpTypedGen<"OpIMul", 132, mul, 1>; +defm OpFMul: BinOpTypedGen<"OpFMul", 133, fmul, 1>; -defm OpUDiv: BinOpTypedGen<"OpUDiv", 134, udiv, 0, 1>; -defm OpSDiv: BinOpTypedGen<"OpSDiv", 135, sdiv, 0, 1>; -defm OpFDiv: BinOpTypedGen<"OpFDiv", 136, fdiv, 1, 1>; +defm OpUDiv: BinOpTypedGen<"OpUDiv", 134, udiv, 1>; +defm OpSDiv: BinOpTypedGen<"OpSDiv", 135, sdiv, 1>; +defm OpFDiv: BinOpTypedGen<"OpFDiv", 136, fdiv, 1>; -defm OpUMod: BinOpTypedGen<"OpUMod", 137, urem, 0, 1>; -defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 0, 1>; +defm OpUMod: BinOpTypedGen<"OpUMod", 137, urem, 1>; +defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 1>; def OpSMod: BinOp<"OpSMod", 139>; -defm OpFRem: BinOpTypedGen<"OpFRem", 140, frem, 1, 1>; +defm OpFRem: BinOpTypedGen<"OpFRem", 140, frem, 1>; def OpFMod: BinOp<"OpFMod", 141>; def OpVectorTimesScalar: BinOp<"OpVectorTimesScalar", 142>; @@ -487,13 +453,13 @@ def OpSMulExtended: BinOp<"OpSMulExtended", 152>; // 3.42.14 Bit Instructions -defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 0, 1>; -defm OpShiftRightArithmetic: BinOpTypedGen<"OpShiftRightArithmetic", 195, sra, 0, 1>; -defm OpShiftLeftLogical: BinOpTypedGen<"OpShiftLeftLogical", 196, shl, 0, 1>; +defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 1>; +defm OpShiftRightArithmetic: BinOpTypedGen<"OpShiftRightArithmetic", 195, sra, 1>; +defm OpShiftLeftLogical: BinOpTypedGen<"OpShiftLeftLogical", 196, shl, 1>; -defm OpBitwiseOr: BinOpTypedGen<"OpBitwiseOr", 197, or, 0, 1>; -defm OpBitwiseXor: BinOpTypedGen<"OpBitwiseXor", 198, xor, 0, 1>; -defm OpBitwiseAnd: BinOpTypedGen<"OpBitwiseAnd", 199, and, 0, 1>; +defm OpBitwiseOr: BinOpTypedGen<"OpBitwiseOr", 197, or, 1>; +defm OpBitwiseXor: BinOpTypedGen<"OpBitwiseXor", 198, xor, 1>; +defm OpBitwiseAnd: BinOpTypedGen<"OpBitwiseAnd", 199, and, 1>; def OpNot: UnOp<"OpNot", 200>; def OpBitFieldInsert: Op<201, (outs ID:$res), @@ -531,7 +497,8 @@ def OpLogicalOr: BinOp<"OpLogicalOr", 166>; def OpLogicalAnd: BinOp<"OpLogicalAnd", 167>; def OpLogicalNot: UnOp<"OpLogicalNot", 168>; -defm OpSelect: TernOpTypedGen<"OpSelect", 169, select, 1, 1, 1>; +def OpSelect: Op<169, (outs ID:$dst), (ins TYPE:$src_ty, ID:$cond, ID:$src1, ID:$src2), + "$dst = OpSelect $src_ty $cond $src1 $src2">; def OpIEqual: BinOp<"OpIEqual", 170>; def OpINotEqual: BinOp<"OpINotEqual", 171>; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 5507e9254fae..704ad4d3f8f3 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -127,8 +127,10 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectConst(Register ResVReg, const SPIRVType *ResType, const APInt &Imm, MachineInstr &I) const; - bool selectSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, - bool IsSigned) const; + bool selectBoolSelect(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I, bool IsSigned) const; + bool selectSelect(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; bool selectIToF(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, bool IsSigned, unsigned Opcode) const; bool selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, @@ -178,6 +180,8 @@ class SPIRVInstructionSelector : public InstructionSelector { Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const; Register buildOnesVal(bool AllOnes, const SPIRVType *ResType, MachineInstr &I) const; + bool buildSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + Register TrueValue, Register FalseValue) const; }; } // end anonymous namespace @@ -319,6 +323,9 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg, case TargetOpcode::G_PHI: return selectPhi(ResVReg, ResType, I); + case TargetOpcode::G_SELECT: + return selectSelect(ResVReg, ResType, I); + case TargetOpcode::G_FPTOSI: return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertFToS); case TargetOpcode::G_FPTOUI: @@ -1087,23 +1094,34 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes, return GR.getOrCreateConstInt(One.getZExtValue(), I, ResType, TII); } -bool SPIRVInstructionSelector::selectSelect(Register ResVReg, - const SPIRVType *ResType, - MachineInstr &I, - bool IsSigned) const { +bool SPIRVInstructionSelector::selectBoolSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + bool IsSigned) const { // To extend a bool, we need to use OpSelect between constants. Register ZeroReg = buildZerosVal(ResType, I); Register OneReg = buildOnesVal(IsSigned, ResType, I); - bool IsScalarBool = - GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool); - unsigned Opcode = - IsScalarBool ? SPIRV::OpSelectSISCond : SPIRV::OpSelectSIVCond; - return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode)) + return buildSelect(ResVReg, ResType, I, OneReg, ZeroReg); +} + +bool SPIRVInstructionSelector::selectSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + return buildSelect(ResVReg, ResType, I, I.getOperand(2).getReg(), + I.getOperand(3).getReg()); +} + +bool SPIRVInstructionSelector::buildSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, Register TrueValue, + Register FalseValue) const { + // To extend a bool, we need to use OpSelect between constants. + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpSelect)) .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)) .addUse(I.getOperand(1).getReg()) - .addUse(OneReg) - .addUse(ZeroReg) + .addUse(TrueValue) + .addUse(FalseValue) .constrainAllUses(TII, TRI, RBI); } @@ -1122,7 +1140,7 @@ bool SPIRVInstructionSelector::selectIToF(Register ResVReg, TmpType = GR.getOrCreateSPIRVVectorType(TmpType, NumElts, I, TII); } SrcReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); - selectSelect(SrcReg, TmpType, I, false); + selectBoolSelect(SrcReg, TmpType, I, false); } return selectUnOpWithSrc(ResVReg, ResType, I, SrcReg, Opcode); } @@ -1131,7 +1149,7 @@ bool SPIRVInstructionSelector::selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, bool IsSigned) const { if (GR.isScalarOrVectorOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool)) - return selectSelect(ResVReg, ResType, I, IsSigned); + return selectBoolSelect(ResVReg, ResType, I, IsSigned); unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert; return selectUnOp(ResVReg, ResType, I, Opcode); } diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index b0028f8c80a4..dd1b8a0feb86 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -46,7 +46,6 @@ static const std::set TypeFoldingSupportingOpcs = { TargetOpcode::G_SHL, TargetOpcode::G_ASHR, TargetOpcode::G_LSHR, - TargetOpcode::G_SELECT, TargetOpcode::G_EXTRACT_VECTOR_ELT, }; @@ -199,6 +198,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { all(typeInSet(0, allBoolScalarsAndVectors), typeInSet(1, allFloatScalarsAndVectors))); + getActionDefinitionsBuilder(G_SELECT).legalIf(all( + typeInSet(0, allScalarsAndVectors), typeInSet(1, allScalarsAndVectors))); + getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, G_ATOMICRMW_SUB, G_ATOMICRMW_XOR, diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index 1e664ca6cfcd..8394fc853ce6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -179,8 +179,6 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR, } if (SpirvTy) GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF()); - if (!MRI.getRegClassOrNull(Reg)) - MRI.setRegClass(Reg, &SPIRV::IDRegClass); } } return SpirvTy; @@ -201,8 +199,6 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, (Def->getNextNode() ? Def->getNextNode()->getIterator() : Def->getParent()->end())); Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg)); - if (auto *RC = MRI.getRegClassOrNull(Reg)) - MRI.setRegClass(NewReg, RC); SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB); GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF()); // This is to make it convenient for Legalizer to get the SPIRVType @@ -211,13 +207,15 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep // the flags after instruction selection. const uint16_t Flags = Def->getFlags(); - MIB.buildInstr(SPIRV::ASSIGN_TYPE) - .addDef(Reg) - .addUse(NewReg) - .addUse(GR->getSPIRVTypeID(SpirvTy)) - .setMIFlags(Flags); + auto B = MIB.buildInstr(SPIRV::ASSIGN_TYPE) + .addDef(Reg) + .addUse(NewReg) + .addUse(GR->getSPIRVTypeID(SpirvTy)) + .setMIFlags(Flags); Def->getOperand(0).setReg(NewReg); - MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass); + const auto &ST = GR->CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*B.getInstr(), *ST.getInstrInfo(), + *ST.getRegisterInfo(), *ST.getRegBankInfo()); return NewReg; } } // namespace llvm @@ -305,25 +303,12 @@ createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI, LLT NewT = LLT::scalar(32); SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg); assert(SpvType && "VReg is expected to have SPIRV type"); - bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat; - bool IsVectorFloat = - SpvType->getOpcode() == SPIRV::OpTypeVector && - GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() == - SPIRV::OpTypeFloat; - IsFloat |= IsVectorFloat; - auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID; - auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass; - if (MRI.getType(ValReg).isPointer()) { - NewT = LLT::pointer(0, 32); - GetIdOp = SPIRV::GET_pID; - DstClass = &SPIRV::pIDRegClass; - } else if (MRI.getType(ValReg).isVector()) { + auto GetIdOp = SPIRV::GET_SID; + if (MRI.getType(ValReg).isVector()) { NewT = LLT::fixed_vector(2, NewT); - GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID; - DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass; + GetIdOp = SPIRV::GET_VID; } Register IdReg = MRI.createGenericVirtualRegister(NewT); - MRI.setRegClass(IdReg, DstClass); return {IdReg, GetIdOp}; } @@ -336,15 +321,20 @@ static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB, auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first; AssignTypeInst.getOperand(1).setReg(NewReg); MI.getOperand(0).setReg(NewReg); - MIB.setInsertPt(*MI.getParent(), - (MI.getNextNode() ? MI.getNextNode()->getIterator() - : MI.getParent()->end())); + MIB.setInsertPt(*MI.getParent(), MI.getIterator()); for (auto &Op : MI.operands()) { if (!Op.isReg() || Op.isDef()) continue; auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR); - MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg()); + auto B = MIB.buildInstr(IdOpInfo.second) + .addDef(IdOpInfo.first) + .addUse(Op.getReg()); Op.setReg(IdOpInfo.first); + + const auto &ST = GR->CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*B.getInstr(), *ST.getInstrInfo(), + *ST.getRegisterInfo(), + *ST.getRegBankInfo()); } } @@ -373,8 +363,6 @@ static void processInstrsWithTypeFolding(MachineFunction &MF, if (!isTypeFoldingSupported(Opcode)) continue; Register DstReg = MI.getOperand(0).getReg(); - if (MRI.getType(DstReg).isVector()) - MRI.setRegClass(DstReg, &SPIRV::IDRegClass); // Don't need to reset type of register holding constant and used in // G_ADDRSPACE_CAST, since it braaks legalizer. if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) { diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp index 9bf9d7fe5b39..7352ac56847e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp @@ -27,21 +27,5 @@ using namespace llvm; const RegisterBank & SPIRVRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC, LLT Ty) const { - switch (RC.getID()) { - case SPIRV::TYPERegClassID: - return SPIRV::TYPERegBank; - case SPIRV::pIDRegClassID: - case SPIRV::IDRegClassID: - return SPIRV::IDRegBank; - case SPIRV::fIDRegClassID: - return SPIRV::fIDRegBank; - case SPIRV::vIDRegClassID: - return SPIRV::vIDRegBank; - case SPIRV::vfIDRegClassID: - return SPIRV::vfIDRegBank; - case SPIRV::ANYIDRegClassID: - case SPIRV::ANYRegClassID: - return SPIRV::IDRegBank; - } - llvm_unreachable("Unknown register class"); + return SPIRV::IDRegBank; } diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td b/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td index 90c7f3a6e672..e489af857c65 100644 --- a/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td @@ -6,10 +6,6 @@ // //===----------------------------------------------------------------------===// -// Although RegisterBankSelection is disabled we need to distinct the banks -// as InstructionSelector RegClass checking code relies on them -def IDRegBank : RegisterBank<"IDBank", [ID]>; -def fIDRegBank : RegisterBank<"fIDBank", [fID]>; -def vIDRegBank : RegisterBank<"vIDBank", [vID]>; -def vfIDRegBank : RegisterBank<"vfIDBank", [vfID]>; -def TYPERegBank : RegisterBank<"TYPEBank", [TYPE]>; +// Although RegBankSelect is disabled, we need to have +// at least one regbank to support instruction selector +def IDRegBank : RegisterBank<"IDBank", [ID]>; \ No newline at end of file diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td index d0b64b6895d0..152ee291d7e7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td @@ -12,7 +12,6 @@ let Namespace = "SPIRV" in { def p0 : PtrValueType ; - // All registers are for 32-bit identifiers, so have a single dummy register // Class for registers that are the result of OpTypeXXX instructions def TYPE0 : Register<"TYPE0">; @@ -21,19 +20,16 @@ let Namespace = "SPIRV" in { // Class for every other non-type ID def ID0 : Register<"ID0">; def ID : RegisterClass<"SPIRV", [i32], 32, (add ID0)>; - def fID0 : Register<"FID0">; - def fID : RegisterClass<"SPIRV", [f32], 32, (add fID0)>; - def pID0 : Register<"pID0">; - def pID : RegisterClass<"SPIRV", [p0], 32, (add pID0)>; - def vID0 : Register<"pID0">; - def vID : RegisterClass<"SPIRV", [v2i32], 32, (add vID0)>; - def vfID0 : Register<"pID0">; - def vfID : RegisterClass<"SPIRV", [v2f32], 32, (add vfID0)>; + def VID0 : Register<"VID0">; - def ANYID : RegisterClass<"SPIRV", [i32, f32, p0, v2i32, v2f32], 32, (add ID, fID, pID, vID, vfID)>; + // TODO: FID register class is only needed for constants selection, + // consider redesigning the pattern to get rid of this + def FID0 : Register<"FID0">; + def FID : RegisterClass<"SPIRV", [f32], 32, (add FID0)>; - // A few instructions like OpName can take ids from both type and non-type - // instructions, so we need a super-class to allow for both to count as valid - // arguments for these instructions. - def ANY : RegisterClass<"SPIRV", [i32], 32, (add TYPE, ID)>; + // Scalar ID + def SID : RegisterClass<"SPIRV", [i32, f32, p0], 32, (add ID0)>; + // Vector ID + def VID : RegisterClass<"SPIRV", [v2i32, v2f32], 32, (add VID0)>; + def ANYID : RegisterClass<"SPIRV", [i32, f32, p0, v2i32, v2f32], 32, (add SID, VID)>; }