Skip to content

[SPIR-V] Add support for the SPIR-V extension SPV_INTEL_tensor_float32_conversion #150090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ struct ConvertBuiltin {
bool IsSaturated;
bool IsRounded;
bool IsBfloat16;
bool IsTF32;
FPRoundingMode::FPRoundingMode RoundingMode;
};

Expand Down Expand Up @@ -230,6 +231,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall,
// - "__spirv_SubgroupImageMediaBlockReadINTEL"
// - "__spirv_SubgroupImageMediaBlockWriteINTEL"
// - "__spirv_Convert"
// - "__spirv_Round"
// - "__spirv_UConvert"
// - "__spirv_SConvert"
// - "__spirv_FConvert"
Expand All @@ -242,7 +244,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall,
"SDotKHR|SUDotKHR|SDotAccSatKHR|UDotAccSatKHR|SUDotAccSatKHR|"
"ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|"
"SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|"
"Convert|"
"Convert|Round|"
"UConvert|SConvert|FConvert|SatConvert)[^_]*)(_R[^_]*_?(\\w+)?.*)?");
std::smatch Match;
if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 1) {
Expand Down Expand Up @@ -697,7 +699,8 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
if (Call->isSpirvOp())
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call,
Register(0));

Register ScopeRegister =
buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR);
Expand Down Expand Up @@ -2677,8 +2680,20 @@ static bool generateConvertInst(const StringRef DemangledCall,
}
} else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
SPIRV::OpTypeFloat)) {
// Float -> Float
Opcode = SPIRV::OpFConvert;
if (Builtin->IsTF32) {
const auto *ST = static_cast<const SPIRVSubtarget *>(
&MIRBuilder.getMF().getSubtarget());
if (!ST->canUseExtension(
SPIRV::Extension::SPV_INTEL_tensor_float32_conversion))
NeedExtMsg = "SPV_INTEL_tensor_float32_conversion";
IsRightComponentsNumber =
GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
Opcode = SPIRV::OpRoundFToTF32INTEL;
} else {
// Float -> Float
Opcode = SPIRV::OpFConvert;
}
}
}

Expand Down
23 changes: 22 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,8 @@ class ConvertBuiltin<string name, InstructionSet set> {
bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)),
!not(!eq(!find(name, "bfloat16"), -1)));
bit IsTF32 = !or(!not(!eq(!find(name, "TF32"), -1)),
!not(!eq(!find(name, "tensor_float32"), -1)));
FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
!not(!eq(!find(name, "_rtz"), -1)) : RTZ,
!not(!eq(!find(name, "_rtp"), -1)) : RTP,
Expand All @@ -1472,7 +1474,7 @@ class ConvertBuiltin<string name, InstructionSet set> {
def ConvertBuiltins : GenericTable {
let FilterClass = "ConvertBuiltin";
let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated",
"IsRounded", "IsBfloat16", "RoundingMode"];
"IsRounded", "IsBfloat16", "IsTF32", "RoundingMode"];
string TypeOf_Set = "InstructionSet";
string TypeOf_RoundingMode = "FPRoundingMode";
}
Expand Down Expand Up @@ -1556,6 +1558,25 @@ foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in {
def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
}

// cl_intel_tensor_float32_conversions / SPV_INTEL_tensor_float32_conversion
// Multiclass used to define at the same time both a demangled builtin records
// and a corresponding convert builtin records.
multiclass DemangledTF32RoundBuiltin<string name1, string name2> {
// Create records for scalar and vector conversions.
foreach i = ["", "2", "3", "4", "8", "16"] in {
def : DemangledBuiltin<!strconcat("intel_round_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>;
def : ConvertBuiltin<!strconcat("intel_round_", name1, i, name2, i), OpenCL_std>;
}
}

defm : DemangledTF32RoundBuiltin<"tensor_float32", "_as_float">;
defm : DemangledTF32RoundBuiltin<"as_tensor_float32", "_float">;

foreach conv = ["FToTF32INTEL"] in {
def : DemangledBuiltin<!strconcat("__spirv_Round", conv), OpenCL_std, Convert, 1, 1>;
def : ConvertBuiltin<!strconcat("__spirv_Round", conv), OpenCL_std>;
}

//===----------------------------------------------------------------------===//
// Class defining a vector data load/store builtin record used for lowering
// into OpExtInst instruction.
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
{"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4},
{"SPV_KHR_float_controls2",
SPIRV::Extension::Extension::SPV_KHR_float_controls2}};
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
{"SPV_INTEL_tensor_float32_conversion",
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};

bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,9 @@ def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938
def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>;
def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>;

// SPV_INTEL_tensor_float32_conversion
def OpRoundFToTF32INTEL : UnOp<"OpRoundFToTF32INTEL", 6426>;

// 3.42.12 Composite Instructions

def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,13 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
}
break;
case SPIRV::OpRoundFToTF32INTEL:
if (ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL);
}
break;
case SPIRV::OpVariableLengthArrayINTEL:
case SPIRV::OpSaveMemoryINTEL:
case SPIRV::OpRestoreMemoryINTEL:
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
defm SPV_INTEL_int4 : ExtensionOperand<123>;
defm SPV_KHR_float_controls2 : ExtensionOperand<124>;
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
Expand Down Expand Up @@ -529,6 +530,7 @@ defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d
defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;

//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
; CHECK-ERROR: result and argument must have the same number of components

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

define spir_func void @test(<8 x float> %in) {
%res = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
ret void
}

declare spir_func float @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
; CHECK-ERROR: result and argument must have the same number of components

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

define spir_func void @test(<8 x float> %in) {
%res = tail call spir_func <4 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
ret void
}

declare spir_func <4 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o - -filetype=obj | spirv-val %}

; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
; CHECK-ERROR: the builtin requires the following SPIR-V extension: SPV_INTEL_tensor_float32_conversion

; CHECK: OpCapability TensorFloat32RoundingINTEL
; CHECK: OpExtension "SPV_INTEL_tensor_float32_conversion"

; CHECK-DAG: %[[VoidTy:.*]] = OpTypeVoid
; CHECK-DAG: %[[FP32Ty:.*]] = OpTypeFloat 32
; CHECK-DAG: %[[VecFloat2:.*]] = OpTypeVector %[[FP32Ty]] 2
; CHECK-DAG: %[[VecFloat3:.*]] = OpTypeVector %[[FP32Ty]] 3
; CHECK-DAG: %[[VecFloat4:.*]] = OpTypeVector %[[FP32Ty]] 4
; CHECK-DAG: %[[VecFloat8:.*]] = OpTypeVector %[[FP32Ty]] 8
; CHECK-DAG: %[[VecFloat16:.*]] = OpTypeVector %[[FP32Ty]] 16
; CHECK-DAG: %[[FloatConstId:.*]] = OpConstant %[[FP32Ty]] 1.5

; CHECK: OpFunction %[[VoidTy]]
; CHECK: %[[FP32ValId:.*]] = OpFunctionParameter %[[FP32Ty]]
; CHECK: %[[FP32v8ValId:.*]] = OpFunctionParameter %[[VecFloat8]]
; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] %[[FP32ValId]]
; CHECK: OpRoundFToTF32INTEL %[[VecFloat8]] %[[FP32v8ValId]]
; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] %[[FloatConstId]]

; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]]
; CHECK: OpRoundFToTF32INTEL %[[VecFloat2]]
; CHECK: OpRoundFToTF32INTEL %[[VecFloat3]]
; CHECK: OpRoundFToTF32INTEL %[[VecFloat4]]
; CHECK: OpRoundFToTF32INTEL %[[VecFloat8]]
; CHECK: OpRoundFToTF32INTEL %[[VecFloat16]]

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

define spir_func void @test(float %a, <8 x float> %in) {
%res1 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float %a)
%res2 = tail call spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
%res3 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float 1.500000e+00)
ret void
}

declare spir_func float @_Z25__spirv_RoundFToTF32INTELf(float)
declare spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)

define dso_local spir_kernel void @test_ocl(float %a) {
entry:
%res4 = call spir_func float @_Z35intel_round_as_tensor_float32_floatt(float 0.000000e+00)
%res5 = call spir_func <2 x float> @_Z37intel_round_as_tensor_float322_float2Dv2_t(<2 x float> zeroinitializer)
%res6 = call spir_func <3 x float> @_Z37intel_round_as_tensor_float323_float3Dv3_t(<3 x float> zeroinitializer)
%res7 = call spir_func <4 x float> @_Z37intel_round_as_tensor_float324_float4Dv4_t(<4 x float> zeroinitializer)
%res8 = call spir_func <8 x float> @_Z37intel_round_as_tensor_float328_float8Dv8_t(<8 x float> zeroinitializer)
%res9 = call spir_func <16 x float> @_Z39intel_round_as_tensor_float3216_float16Dv16_t(<16 x float> zeroinitializer)
ret void
}

declare spir_func float @_Z35intel_round_as_tensor_float32_floatt(float)
declare spir_func <2 x float> @_Z37intel_round_as_tensor_float322_float2Dv2_t(<2 x float>)
declare spir_func <3 x float> @_Z37intel_round_as_tensor_float323_float3Dv3_t(<3 x float>)
declare spir_func <4 x float> @_Z37intel_round_as_tensor_float324_float4Dv4_t(<4 x float>)
declare spir_func <8 x float> @_Z37intel_round_as_tensor_float328_float8Dv8_t(<8 x float>)
declare spir_func <16 x float> @_Z39intel_round_as_tensor_float3216_float16Dv16_t(<16 x float>)
17 changes: 14 additions & 3 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_me
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
def SPV_INTEL_tensor_float32_conversion : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;

def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
Expand Down Expand Up @@ -474,7 +475,8 @@ def SPIRV_ExtensionAttr :
SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
SPV_NV_shader_subgroup_partitioned, SPV_NV_shading_rate,
SPV_NV_stereo_view_rendering, SPV_NV_viewport_array2, SPV_NV_bindless_texture,
SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes
SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes,
SPV_INTEL_tensor_float32_conversion
]>;

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1465,6 +1467,12 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B
];
}

def SPIRV_C_TensorFloat32RoundingINTEL : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> {
list<Availability> availability = [
Extension<[SPV_INTEL_tensor_float32_conversion]>
];
}

def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
list<Availability> availability = [
Extension<[SPV_INTEL_cache_controls]>
Expand Down Expand Up @@ -1567,7 +1575,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
SPIRV_C_TensorFloat32RoundingINTEL
]>;

def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
Expand Down Expand Up @@ -4586,6 +4595,7 @@ def SPIRV_OC_OpControlBarrierArriveINTEL : I32EnumAttrCase<"OpControlBarrie
def SPIRV_OC_OpControlBarrierWaitINTEL : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>;
def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
def SPIRV_OC_OpRoundFToTF32INTEL : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;

def SPIRV_OpcodeAttr :
SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
Expand Down Expand Up @@ -4690,7 +4700,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR
SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
SPIRV_OC_OpRoundFToTF32INTEL
]>;

// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
Expand Down
54 changes: 54 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// at (https://github.com/intel/llvm)
// Supported extensions
// * SPV_INTEL_bfloat16_conversion
// * SPV_INTEL_tensor_float32_conversion
//===----------------------------------------------------------------------===//


Expand Down Expand Up @@ -110,6 +111,59 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
let hasVerifier = 1;
}

// -----

def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", []> {
let summary = "See extension SPV_INTEL_tensor_float32_conversion";

let description = [{
Convert value numerically from a 32-bit floating point type to tensor float32,
with rounding to the nearest even.

Result Type must be a scalar or vector of 32-bit floating-point type.
The component width must be 32 bits. Bit pattern in the Result represents a tensor float32 value.

Float Value must be a scalar or vector of floating-point type.
It must have the same number of components as Result Type. The component width must be 32 bits.

Results are computed per component.


```
convert-f-to-tf32-op ::= ssa-id `=` `spirv.INTEL.RoundFToTF32` ssa-use
`:` operand-type `to` result-type
```

#### Example:

```mlir
%1 = spirv.RoundFToTF32 %0 : f32 to f32
%3 = spirv.RoundFToTF32 %2 : vector<3xf32> to vector<3xf32>
```

}];


let availability = [
MinVersion<SPIRV_V_1_0>,
MaxVersion<SPIRV_V_1_6>,
Extension<[SPV_INTEL_tensor_float32_conversion]>,
Capability<[SPIRV_C_TensorFloat32RoundingINTEL]>
];

let arguments = (ins
SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
);

let results = (outs
SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];

let hasVerifier = 1;
}

// -----

Expand Down
21 changes: 21 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,27 @@ LogicalResult INTELConvertFToBF16Op::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// spirv.INTELRoundFToTF32Op
//===----------------------------------------------------------------------===//

LogicalResult INTELRoundFToTF32Op::verify() {
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
// ODS checks that vector result type and vector operand type have the same
// shape.
if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
unsigned operandNumElements = vectorType.getNumElements();
unsigned resultNumElements =
llvm::cast<VectorType>(resultType).getNumElements();
if (operandNumElements != resultNumElements) {
return emitOpError(
"operand and result must have same number of elements");
}
}
return success();
}

//===----------------------------------------------------------------------===//
// spirv.FConvertOp
//===----------------------------------------------------------------------===//
Expand Down
Loading