-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[SPIR-V] Add support for the SPIR-V extension SPV_INTEL_tensor_float32_conversion #150090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[SPIR-V] Add support for the SPIR-V extension SPV_INTEL_tensor_float32_conversion #150090
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-backend-spir-v Author: None (YixingZhang007) ChangesThis PR is to add support for the SPIR-V extension SPV_INTEL_tensor_float32_conversion (https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_tensor_float32_conversion.asciidoc) Full diff: https://github.com/llvm/llvm-project/pull/150090.diff 5 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 6ec7544767c52..1c7c1750af1c9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -148,6 +148,7 @@ struct ConvertBuiltin {
bool IsSaturated;
bool IsRounded;
bool IsBfloat16;
+ bool IsTF32;
FPRoundingMode::FPRoundingMode RoundingMode;
};
@@ -2677,8 +2678,20 @@ static bool generateConvertInst(const StringRef DemangledCall,
}
} else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
SPIRV::OpTypeFloat)) {
- // Float -> Float
- Opcode = SPIRV::OpFConvert;
+ if(Builtin->IsTF32){
+ const auto *ST = static_cast<const SPIRVSubtarget *>(
+ &MIRBuilder.getMF().getSubtarget());
+ if (!ST->canUseExtension(
+ SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
+ NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
+ IsRightComponentsNumber =
+ GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
+ GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
+ Opcode = SPIRV::OpRoundFToTF32INTEL;
+ } else {
+ Float -> Float
+ Opcode = SPIRV::OpFConvert;
+ }
}
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index ea78dcd135267..326109c9fdff4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -1461,6 +1461,7 @@ class ConvertBuiltin<string name, InstructionSet set> {
bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)),
!not(!eq(!find(name, "bfloat16"), -1)));
+ bit IsTF32 = !not(!eq(!find(name, "TF32"), -1));
FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
!not(!eq(!find(name, "_rtz"), -1)) : RTZ,
!not(!eq(!find(name, "_rtp"), -1)) : RTP,
@@ -1472,7 +1473,7 @@ class ConvertBuiltin<string name, InstructionSet set> {
def ConvertBuiltins : GenericTable {
let FilterClass = "ConvertBuiltin";
let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated",
- "IsRounded", "IsBfloat16", "RoundingMode"];
+ "IsRounded", "IsBfloat16", "IsTF32", "RoundingMode"];
string TypeOf_Set = "InstructionSet";
string TypeOf_RoundingMode = "FPRoundingMode";
}
@@ -1556,6 +1557,24 @@ foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in {
def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
}
+// SPV_INTEL_tensor_float32_conversion
+// Multiclass used to define at the same time both a demangled builtin records
+// and a corresponding convert builtin records.
+multiclass DemangledTF32ConvertBuiltin<string name1, string name2> {
+ // Create records for scalar and vector conversions.
+ foreach i = ["", "2", "3", "4", "8", "16"] in {
+ def : DemangledBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>;
+ def : ConvertBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std>;
+ }
+}
+
+defm : DemangledTF32ConvertBuiltin<"ConvertFToTF32INTEL">;
+
+foreach conv = ["FToTF32INTEL"] in {
+ def : DemangledBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std, Convert, 1, 1>;
+ def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
+}
+
//===----------------------------------------------------------------------===//
// Class defining a vector data load/store builtin record used for lowering
// into OpExtInst instruction.
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 049ba0275f223..a04ed6a42c868 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -441,10 +441,13 @@ def OpBitcast : UnOp<"OpBitcast", 124>;
def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;
-// SPV_INTEL_bfloat16_conversion
+// SPV_INTEL_tensor_float32_conversion
def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>;
def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>;
+// SPV_INTEL_bfloat16_conversion
+def OpRoundFToTF32INTEL : UnOp<"OpRoundFToTF32INTEL", 6426>;
+
// 3.42.12 Composite Instructions
def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index ad976e5288927..c252fc5897518 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1564,6 +1564,12 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
}
break;
+ case SPIRV::OpRoundFToTF32INTEL:
+ if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
+ Reqs.addCapability(SPIRV::Capability::TF32ConversionINTEL);
+ }
+ break;
case SPIRV::OpVariableLengthArrayINTEL:
case SPIRV::OpSaveMemoryINTEL:
case SPIRV::OpRestoreMemoryINTEL:
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 548e9b717c161..7b2139a1c84a8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -320,6 +320,7 @@ defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
defm SPV_INTEL_int4 : ExtensionOperand<123>;
defm SPV_KHR_float_controls2 : ExtensionOperand<124>;
+defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -502,6 +503,7 @@ defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variabl
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>;
+defm TF32ConversionINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>;
|
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions cpp -- llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp View the diff from clang-format here.diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 8e8c34e2b..469c04018 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2678,13 +2678,13 @@ static bool generateConvertInst(const StringRef DemangledCall,
}
} else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
SPIRV::OpTypeFloat)) {
- if(Builtin->IsTF32){
+ if (Builtin->IsTF32) {
const auto *ST = static_cast<const SPIRVSubtarget *>(
- &MIRBuilder.getMF().getSubtarget());
+ &MIRBuilder.getMF().getSubtarget());
if (!ST->canUseExtension(
SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
- IsRightComponentsNumber =
+ IsRightComponentsNumber =
GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
Opcode = SPIRV::OpRoundFToTF32INTEL;
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index c252fc589..228f2227b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1565,7 +1565,8 @@ void addInstrRequirements(const MachineInstr &MI,
}
break;
case SPIRV::OpRoundFToTF32INTEL:
- if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
+ if (ST.canUseExtension(
+ SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
Reqs.addCapability(SPIRV::Capability::TF32ConversionINTEL);
}
|
This PR is to add support for the SPIR-V extension SPV_INTEL_tensor_float32_conversion (https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_tensor_float32_conversion.asciidoc)