diff --git a/include/LLVMSPIRVExtensions.inc b/include/LLVMSPIRVExtensions.inc index e052efb1cf..278b31acf9 100644 --- a/include/LLVMSPIRVExtensions.inc +++ b/include/LLVMSPIRVExtensions.inc @@ -82,4 +82,5 @@ EXT(SPV_INTEL_ternary_bitwise_function) EXT(SPV_INTEL_int4) EXT(SPV_INTEL_function_variants) EXT(SPV_INTEL_shader_atomic_bfloat16) +EXT(SPV_EXT_float8) EXT(SPV_INTEL_predicated_io) diff --git a/lib/SPIRV/SPIRVInternal.h b/lib/SPIRV/SPIRVInternal.h index 207658fc1c..20a53254e8 100644 --- a/lib/SPIRV/SPIRVInternal.h +++ b/lib/SPIRV/SPIRVInternal.h @@ -373,6 +373,7 @@ const static char ConvertHandleToImageINTEL[] = "ConvertHandleToImageINTEL"; const static char ConvertHandleToSamplerINTEL[] = "ConvertHandleToSamplerINTEL"; const static char ConvertHandleToSampledImageINTEL[] = "ConvertHandleToSampledImageINTEL"; +const static char InternalBuiltinPrefix[] = "__builtin_spirv_"; } // namespace kSPIRVName namespace kSPIRVPostfix { @@ -665,7 +666,7 @@ Op getSPIRVFuncOC(StringRef Name, SmallVectorImpl *Dec = nullptr); bool getSPIRVBuiltin(const std::string &Name, spv::BuiltIn &Builtin); /// \param Name LLVM function name -/// \param DemangledName demanged name of the OpenCL built-in function +/// \param DemangledName demangled name of the OpenCL built-in function /// \returns true if Name is the name of the OpenCL built-in function, /// false for other functions bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp = false); @@ -728,6 +729,9 @@ CallInst *addCallInst(Module *M, StringRef FuncName, Type *RetTy, StringRef InstName = SPIR_TEMP_NAME_PREFIX_CALL, bool TakeFuncName = true); +/// Check if an LLVM type is spirv.CooperativeMatrixKHR. +bool isLLVMCooperativeMatrixType(llvm::Type *Ty); + /// Add a call instruction for SPIR-V builtin function. CallInst *addCallInstSPIRV(Module *M, StringRef FuncName, Type *RetTy, ArrayRef Args, AttributeList *Attrs, @@ -1029,6 +1033,84 @@ bool postProcessBuiltinsReturningStruct(Module *M, bool IsCpp = false); bool postProcessBuiltinsWithArrayArguments(Module *M, bool IsCpp = false); +/// \param MangledName LLVM function name. +/// \param DemangledName demangled name of the input function if it is the +/// translator's internal built-in function. +/// \returns true if MangledName is the name of the translator's internal +/// built-in function, false for other functions. +/// Used for 'mini'-floats conversion functions +bool isInternalSPIRVBuiltin(StringRef MangledName, StringRef &DemangledName); + +// Wrapper around SPIR-V 1.6.4 FP Encoding to be used in the conversion +// descriptor +enum FPEncodingWrap { + Integer = FPEncoding::FPEncodingMax - 1, + IEEE754 = FPEncoding::FPEncodingMax, + BF16 = FPEncoding::FPEncodingBFloat16KHR, + E4M3 = FPEncoding::FPEncodingFloat8E4M3EXT, + E5M2 = FPEncoding::FPEncodingFloat8E5M2EXT, +}; + +// Structure describing non-trivial conversions (FP8 and int4) +struct FPConversionDesc { + FPEncodingWrap SrcEncoding; + FPEncodingWrap DstEncoding; + SPIRVWord ConvOpCode; + + // To use as a key in std::map + bool operator==(const FPConversionDesc &Other) const { + return SrcEncoding == Other.SrcEncoding && + DstEncoding == Other.DstEncoding && ConvOpCode == Other.ConvOpCode; + } + + bool operator<(const FPConversionDesc &Other) const { + if (ConvOpCode != Other.ConvOpCode) + return ConvOpCode < Other.ConvOpCode; + if (SrcEncoding != Other.SrcEncoding) + return SrcEncoding < Other.SrcEncoding; + return DstEncoding < Other.DstEncoding; + } +}; + +// Maps internal builtin name to conversion descriptor +typedef SPIRVMap FPConvertToEncodingMap; + +// clang-format off +template <> inline void FPConvertToEncodingMap::init() { + // 8-bit conversions + add("ConvertE4M3ToFP16EXT", + {FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert}); + add("ConvertE5M2ToFP16EXT", + {FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert}); + add("ConvertE4M3ToBF16EXT", + {FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert}); + add("ConvertE5M2ToBF16EXT", + {FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert}); + add("ConvertFP16ToE4M3EXT", + {FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert}); + add("ConvertFP16ToE5M2EXT", + {FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert}); + add("ConvertBF16ToE4M3EXT", + {FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert}); + add("ConvertBF16ToE5M2EXT", + {FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert}); + + add("ConvertInt4ToE4M3INTEL", + {FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF}); + add("ConvertInt4ToE5M2INTEL", + {FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF}); + add("ConvertInt4ToFP16INTEL", + {FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF}); + add("ConvertInt4ToBF16INTEL", + {FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF}); + add("ConvertFP16ToInt4INTEL", + {FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS}); + add("ConvertBF16ToInt4INTEL", + {FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS}); +} + +// clang-format on + } // namespace SPIRV #endif // SPIRV_SPIRVINTERNAL_H diff --git a/lib/SPIRV/SPIRVReader.cpp b/lib/SPIRV/SPIRVReader.cpp index a9c8270601..70cfbef303 100644 --- a/lib/SPIRV/SPIRVReader.cpp +++ b/lib/SPIRV/SPIRVReader.cpp @@ -297,6 +297,9 @@ std::optional SPIRVToLLVM::getAlignment(SPIRVValue *V) { Type *SPIRVToLLVM::transFPType(SPIRVType *T) { switch (T->getFloatBitWidth()) { + case 8: + // No LLVM IR counter part for FP8 - map it on i8 + return Type::getIntNTy(*Context, 8); case 16: if (T->isTypeFloat(16, FPEncodingBFloat16KHR)) return Type::getBFloatTy(*Context); @@ -1049,6 +1052,22 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F, CastInst::CastOps CO = Instruction::BitCast; bool IsExt = Dst->getScalarSizeInBits() > Src->getType()->getScalarSizeInBits(); + + auto GetFPEncoding = [](SPIRVType *Ty) -> FPEncodingWrap { + if (Ty->isTypeFloat()) { + unsigned Enc = + static_cast(Ty)->getFloatingPointEncoding(); + return static_cast(Enc); + } + if (Ty->isTypeInt()) + return FPEncodingWrap::Integer; + return FPEncodingWrap::IEEE754; + }; + + auto IsFP8Encoding = [](FPEncodingWrap Encoding) -> bool { + return Encoding == FPEncodingWrap::E4M3 || Encoding == FPEncodingWrap::E5M2; + }; + switch (BC->getOpCode()) { case OpPtrCastToGeneric: case OpGenericCastToPtr: @@ -1070,10 +1089,58 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F, case OpUConvert: CO = IsExt ? Instruction::ZExt : Instruction::Trunc; break; - case OpFConvert: - CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc; + case OpConvertSToF: + case OpConvertFToS: + case OpConvertUToF: + case OpConvertFToU: + case OpFConvert: { + const auto OC = BC->getOpCode(); + { + auto SPVOps = BC->getOperands(); + auto *SPVSrcTy = SPVOps[0]->getType(); + auto *SPVDstTy = BC->getType(); + + auto GetEncodingAndUpdateType = + [GetFPEncoding](SPIRVType *&SPVTy) -> FPEncodingWrap { + if (SPVTy->isTypeVector()) { + SPVTy = SPVTy->getVectorComponentType(); + } else if (SPVTy->isTypeCooperativeMatrixKHR()) { + auto *MT = static_cast(SPVTy); + SPVTy = MT->getCompType(); + } + return GetFPEncoding(SPVTy); + }; + + FPEncodingWrap SrcEnc = GetEncodingAndUpdateType(SPVSrcTy); + FPEncodingWrap DstEnc = GetEncodingAndUpdateType(SPVDstTy); + if (IsFP8Encoding(SrcEnc) || IsFP8Encoding(DstEnc) || + SPVSrcTy->isTypeInt(4) || SPVDstTy->isTypeInt(4)) { + FPConversionDesc FPDesc = {SrcEnc, DstEnc, BC->getOpCode()}; + auto Conv = SPIRV::FPConvertToEncodingMap::rmap(FPDesc); + std::vector Ops = {Src}; + std::vector OpsTys = {Src->getType()}; + + std::string BuiltinName = + kSPIRVName::InternalBuiltinPrefix + std::string(Conv); + BuiltinFuncMangleInfo Info; + std::string MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info); + + FunctionType *FTy = FunctionType::get(Dst, OpsTys, false); + FunctionCallee Func = M->getOrInsertFunction(MangledName, FTy); + return CallInst::Create(Func, Ops, "", BB); + } + } + + if (OC == OpFConvert) { + CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc; + break; + } + CO = static_cast(OpCodeMap::rmap(OC)); break; + } case OpBitcast: + if (!Dst->isPointerTy() && Dst == Src->getType()) + return Src; // OpBitcast need to be handled as a special-case when the source is a // pointer and the destination is not a pointer, and where the source is not // a pointer and the destination is a pointer. This is supported by the @@ -2970,11 +3037,29 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F, if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) { auto *BI = static_cast(BV); Value *Inst = nullptr; - if (BI->hasFPRoundingMode() || BI->isSaturatedConversion() || - BI->getType()->isTypeCooperativeMatrixKHR()) + if (BI->hasFPRoundingMode() || BI->isSaturatedConversion()) { Inst = transSPIRVBuiltinFromInst(BI, BB); - else + } else if (BI->getType()->isTypeCooperativeMatrixKHR()) { + // For cooperative matrix conversions generate __builtin_spirv + // conversions instead of __spirv_FConvert in case of mini-float + // type element type. + auto *OutMatrixElementTy = + static_cast(BI->getType()) + ->getCompType(); + auto *InMatrixElementTy = + static_cast( + static_cast(BI)->getOperand(0)->getType()) + ->getCompType(); + if (OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) || + OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) || + InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) || + InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT)) + Inst = transConvertInst(BV, F, BB); + else + Inst = transSPIRVBuiltinFromInst(BI, BB); + } else { Inst = transConvertInst(BV, F, BB); + } return mapValue(BV, Inst); } return mapValue( diff --git a/lib/SPIRV/SPIRVUtil.cpp b/lib/SPIRV/SPIRVUtil.cpp index 2fbd505ab7..14fc6d0523 100644 --- a/lib/SPIRV/SPIRVUtil.cpp +++ b/lib/SPIRV/SPIRVUtil.cpp @@ -40,6 +40,7 @@ // This file needs to be included before anything that declares // llvm::PointerType to avoid a compilation bug on MSVC. +#include "llvm/Demangle/Demangle.h" #include "llvm/Demangle/ItaniumDemangle.h" #include "FunctionDescriptor.h" @@ -267,6 +268,12 @@ bool isSYCLBfloat16Type(llvm::Type *Ty) { return false; } +bool isLLVMCooperativeMatrixType(llvm::Type *Ty) { + if (auto *TargetTy = dyn_cast(Ty)) + return TargetTy->getName() == "spirv.CooperativeMatrixKHR"; + return false; +} + Function *getOrCreateFunction(Module *M, Type *RetTy, ArrayRef ArgTypes, StringRef Name, BuiltinFuncMangleInfo *Mangle, AttributeList *Attrs, bool TakeName) { @@ -439,7 +446,7 @@ bool getSPIRVBuiltin(const std::string &OrigName, spv::BuiltIn &B) { return getByName(R.str(), B); } -// Demangled name is a substring of the name. The DemangledName is updated only +// DemangledName is a substring of Name. The DemangledName is updated only // if true is returned bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) { if (Name == "printf") { @@ -484,6 +491,21 @@ bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) { return false; } +// DemangledName is a substring of Name. The DemangledName is updated only +// if true is returned. +bool isInternalSPIRVBuiltin(StringRef Name, StringRef &DemangledName) { + if (!Name.starts_with("_Z")) + return false; + constexpr unsigned DemangledNameLenStart = 2; + size_t Start = Name.find_first_not_of("0123456789", DemangledNameLenStart); + if (!Name.substr(Start, Name.size() - 1) + .starts_with(kSPIRVName::InternalBuiltinPrefix)) + return false; + DemangledName = llvm::itaniumDemangle(Name.data(), false); + DemangledName.consume_front(kSPIRVName::InternalBuiltinPrefix); + return true; +} + // Check if a mangled type Name is unsigned bool isMangledTypeUnsigned(char Mangled) { return Mangled == 'h' /* uchar */ diff --git a/lib/SPIRV/SPIRVWriter.cpp b/lib/SPIRV/SPIRVWriter.cpp index d363f060a3..673f2139a5 100644 --- a/lib/SPIRV/SPIRVWriter.cpp +++ b/lib/SPIRV/SPIRVWriter.cpp @@ -909,6 +909,26 @@ SPIRVFunction *LLVMToSPIRVBase::transFunctionDecl(Function *F) { return nullptr; } + // Don't translate FP conversion translator builtins as function declarations. + auto MangledName = F->getName(); + StringRef DemangledName; + if (isInternalSPIRVBuiltin(MangledName, DemangledName)) { + if (SPIRV::FPConvertToEncodingMap::find(DemangledName)) { + // Create an early exit here if none of the extensions are enabled. + // Proper checks for the required extensions will be done during TypeFloat + // generation. + if (!BM->isAllowedToUseExtension(ExtensionID::SPV_EXT_float8) && + !BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_int4)) { + std::string ErrorStr = + "One of the following extensions: SPV_EXT_float8, " + "SPV_INTEL_int4 should be enabled to process " + "conversion builtins"; + getErrorLog().checkError(false, SPIRVEC_RequiresExtension, F, ErrorStr); + } + return nullptr; + } + } + SPIRVTypeFunction *BFT = static_cast(transScavengedType(F)); SPIRVFunction *BF = @@ -5505,6 +5525,163 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI, return BV; } + if (isInternalSPIRVBuiltin(MangledName, DemangledName)) { + // Logic of the code below is described in + // docs/SPIRVMiniFloatsRepresentationInLLVM.rst + // A quick recap of the document: + // For FP8 types (which don't have appropriate counterparts in LLVM) + // the translator expect to see external function calls with __builtin_spirv + // prefix, names of the functions encode the used in the conversion + // FP types and will be used by the translator for proper TypeFloat values + // generation. Since in LLVM IR FP8 types are represented by + // integer type, we have to insert bitcasts for dependent values that expect + // integer inputs or produces integer output that is used in the conversion. + if (SPIRV::FPConvertToEncodingMap::find(DemangledName)) { + FPConversionDesc FPDesc = + SPIRV::FPConvertToEncodingMap::map(DemangledName); + Value *Src = CI->getOperand(0); + Type *LLVMSrcTy = Src->getType(); + Type *LLVMDstTy = CI->getType(); + SPIRVType *SrcTy = nullptr; + SPIRVType *DstTy = nullptr; + + auto GetScalarTy = [&](Type *Ty) -> Type * { + if (isLLVMCooperativeMatrixType(Ty)) + return cast(Ty)->getTypeParameter(0); + return Ty->getScalarType(); + }; + + // FP types representable in LLVM IR, no need for special handling + // Also, int4 matrices remain the same. + if (FPDesc.SrcEncoding == FPEncodingWrap::IEEE754 || + FPDesc.SrcEncoding == FPEncodingWrap::BF16 || + (FPDesc.SrcEncoding == FPEncodingWrap::Integer && + isLLVMCooperativeMatrixType(LLVMDstTy))) + SrcTy = transType(LLVMSrcTy); + if (FPDesc.DstEncoding == FPEncodingWrap::IEEE754 || + FPDesc.DstEncoding == FPEncodingWrap::BF16 || + (FPDesc.DstEncoding == FPEncodingWrap::Integer && + (isLLVMCooperativeMatrixType(LLVMDstTy) || + FPDesc.SrcEncoding == FPEncodingWrap::Integer))) + DstTy = transType(LLVMDstTy); + + SPIRVValue *SrcOp = transValue(Src, BB); + + // TODO: unify SrcTy and DstTy processing into a single routine. + if (!SrcTy) { + // Src type is 'mini' float or int4 + Type *SrcScalarTy = GetScalarTy(LLVMSrcTy); + unsigned SrcTyWidth = cast(SrcScalarTy)->getBitWidth(); + unsigned SrcVecSize = 0; + if (SrcTyWidth == 32) { + // Int4 packed in 32-bit integer, change Src type and vector size + assert(FPDesc.SrcEncoding == FPEncodingWrap::Integer && + "Unknown FP encoding"); + assert(!isLLVMCooperativeMatrixType(LLVMSrcTy) && + "Int4 matrices must not be packed"); + SrcVecSize = 8; + SrcTyWidth = 4; + } else if (SrcTyWidth == 8 && + FPDesc.SrcEncoding == FPEncodingWrap::Integer) { + assert(!isLLVMCooperativeMatrixType(LLVMSrcTy) && + "Int4 matrices must not be packed"); + // Int4 packed in 8-bit integer, change Src type and vector size + SrcVecSize = 2; + SrcTyWidth = 4; + } else { + if (LLVMSrcTy->isVectorTy()) + SrcVecSize = + cast(LLVMSrcTy)->getElementCount().getFixedValue(); + } + if (FPDesc.SrcEncoding == FPEncodingWrap::Integer) { + SrcTy = BM->addIntegerType(SrcTyWidth); + } else { + SrcTy = BM->addFloatType(SrcTyWidth, FPDesc.SrcEncoding); + } + if (SrcVecSize > 0) + SrcTy = BM->addVectorType(SrcTy, SrcVecSize); + + if (isLLVMCooperativeMatrixType(LLVMSrcTy)) { + // Create FP8 matrix with a new type and insert a bitcast. + SrcTy = BM->addCooperativeMatrixKHRType( + SrcTy, + static_cast(transType(LLVMSrcTy)) + ->getArgs()); + SrcOp = BM->addUnaryInst(OpBitcast, SrcTy, SrcOp, BB); + } else if (FPDesc.SrcEncoding != FPEncodingWrap::Integer || + (SrcTy->isTypeVector() && !LLVMSrcTy->isVectorTy())) { + // Create bitcast for FP8 and packed Int4. + SrcOp = BM->addUnaryInst(OpBitcast, SrcTy, SrcOp, BB); + } + } + if (!DstTy) { + // Dst type is 'mini' float or int4 + Type *DstScalarTy = GetScalarTy(LLVMDstTy); + unsigned DstTyWidth = cast(DstScalarTy)->getBitWidth(); + unsigned DstVecSize = 0; + if (DstTyWidth == 32) { + // Int4 packed in 32-bit integer, change Dst type and vector size + assert(FPDesc.DstEncoding == FPEncodingWrap::Integer && + "Unknown FP encoding"); + assert(!isLLVMCooperativeMatrixType(LLVMDstTy) && + "Int4 matrices must not be packed"); + DstVecSize = 8; + DstTyWidth = 4; + } else if (DstTyWidth == 8 && + FPDesc.DstEncoding == FPEncodingWrap::Integer) { + assert(!isLLVMCooperativeMatrixType(LLVMDstTy) && + "Int4 matrices must not be packed"); + // Int4 packed in 8-bit integer, change Dst type and vector size + DstVecSize = 2; + DstTyWidth = 4; + } else { + // Currently unused in SYCL + if (LLVMDstTy->isVectorTy()) + DstVecSize = + cast(LLVMDstTy)->getElementCount().getFixedValue(); + } + if (FPDesc.DstEncoding == FPEncodingWrap::Integer) { + DstTy = BM->addIntegerType(DstTyWidth); + } else { + DstTy = BM->addFloatType(DstTyWidth, FPDesc.DstEncoding); + } + + if (isLLVMCooperativeMatrixType(LLVMDstTy)) + // Create FP8 matrix with a new type. + DstTy = BM->addCooperativeMatrixKHRType( + DstTy, + static_cast(transType(LLVMDstTy)) + ->getArgs()); + + if (DstVecSize > 0) + DstTy = BM->addVectorType(DstTy, DstVecSize); + } + + std::vector Ops = {SrcOp}; + const auto OC = static_cast(FPDesc.ConvOpCode); + + // Translate operands for stochastic roundings. + for (size_t I = 1; I != CI->arg_size(); ++I) + Ops.push_back(transValue(CI->getOperand(I), BB)); + + SPIRVValue *Conv = BM->addInstTemplate(OC, BM->getIds(Ops), BB, DstTy); + + // Representable in LLVM FP types: bitcast is not needed. + if (FPDesc.DstEncoding == FPEncodingWrap::IEEE754 || + FPDesc.DstEncoding == FPEncodingWrap::BF16) + return Conv; + // Originally not-packed integer. + if (FPDesc.DstEncoding == FPEncodingWrap::Integer && + (DstTy->isTypeVector() == LLVMDstTy->isVectorTy() || + isLLVMCooperativeMatrixType(LLVMDstTy))) + return Conv; + // Need to adjust types: create bitcast for FP8 and packed Int4. + SPIRVValue *BitCast = + BM->addUnaryInst(OpBitcast, transType(CI->getType()), Conv, BB); + return BitCast; + } + } + SmallVector Dec; if (isBuiltinTransToExtInst(CI->getCalledFunction(), &ExtSetKind, &ExtOp, &Dec)) { diff --git a/lib/SPIRV/libSPIRV/SPIRVModule.cpp b/lib/SPIRV/libSPIRV/SPIRVModule.cpp index a680574ecc..6012b1723f 100644 --- a/lib/SPIRV/libSPIRV/SPIRVModule.cpp +++ b/lib/SPIRV/libSPIRV/SPIRVModule.cpp @@ -644,6 +644,8 @@ class SPIRVModuleImpl : public SPIRVModule { FloatTypeMap; SmallDenseMap, SPIRVTypePointer *, 4> PointerTypeMap; + SmallDenseMap, SPIRVTypeVector *, 4> + VectorTypeMap; std::unordered_map LiteralMap; std::vector DebugInstVec; std::vector AuxDataInstVec; @@ -1163,7 +1165,13 @@ void SPIRVModuleImpl::closeStructType(SPIRVTypeStruct *T, bool Packed) { SPIRVTypeVector *SPIRVModuleImpl::addVectorType(SPIRVType *CompType, SPIRVWord CompCount) { - return addType(new SPIRVTypeVector(this, getId(), CompType, CompCount)); + auto Desc = std::make_pair(CompType, CompCount); + auto Loc = VectorTypeMap.find(Desc); + if (Loc != VectorTypeMap.end()) + return Loc->second; + auto *Ty = new SPIRVTypeVector(this, getId(), CompType, CompCount); + VectorTypeMap[Desc] = Ty; + return addType(Ty); } SPIRVTypeJointMatrixINTEL * diff --git a/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h b/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h index bfa1f5b19a..323fd5af68 100644 --- a/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h +++ b/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h @@ -691,6 +691,8 @@ template <> inline void SPIRVMap::init() { add(CapabilityFunctionVariantsINTEL, "FunctionVariantsINTEL"); add(CapabilitySpecConditionalINTEL, "SpecConditionalINTEL"); add(internal::CapabilityBFloat16ArithmeticINTEL, "BFloat16ArithmeticINTEL"); + add(CapabilityFloat8EXT, "Float8EXT"); + add(CapabilityFloat8CooperativeMatrixEXT, "Float8CooperativeMatrixEXT"); add(internal::CapabilityPredicatedIOINTEL, "PredicatedIOINTEL"); } SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap) diff --git a/lib/SPIRV/libSPIRV/SPIRVType.h b/lib/SPIRV/libSPIRV/SPIRVType.h index 34d97d904a..d7a520a345 100644 --- a/lib/SPIRV/libSPIRV/SPIRVType.h +++ b/lib/SPIRV/libSPIRV/SPIRVType.h @@ -237,6 +237,9 @@ class SPIRVTypeFloat : public SPIRVType { std::optional getRequiredExtension() const override { if (isTypeFloat(16, FPEncodingBFloat16KHR)) return ExtensionID::SPV_KHR_bfloat16; + if (isTypeFloat(8, FPEncodingFloat8E4M3EXT) || + isTypeFloat(8, FPEncodingFloat8E5M2EXT)) + return ExtensionID::SPV_EXT_float8; return {}; } @@ -250,8 +253,12 @@ class SPIRVTypeFloat : public SPIRVType { if (std::any_of(Extensions.begin(), Extensions.end(), [](const std::string &I) { return I == "cl_khr_fp16"; })) CV.push_back(CapabilityFloat16); - } else if (isTypeFloat(64)) + } else if (isTypeFloat(64)) { CV.push_back(CapabilityFloat64); + } else if (isTypeFloat(8, FPEncodingFloat8E4M3EXT) || + isTypeFloat(8, FPEncodingFloat8E5M2EXT)) { + CV.push_back(CapabilityFloat8EXT); + } return CV; } @@ -274,10 +281,14 @@ class SPIRVTypeFloat : public SPIRVType { void validate() const override { SPIRVEntry::validate(); - assert(BitWidth >= 16 && BitWidth <= 64 && "Invalid bit width"); + assert( + (BitWidth == 8 || BitWidth == 16 || BitWidth == 32 || BitWidth == 64) && + "Invalid bit width"); assert( (FloatingPointEncoding == FPEncodingMax || - (BitWidth == 16 && FloatingPointEncoding == FPEncodingBFloat16KHR)) && + (BitWidth == 16 && FloatingPointEncoding == FPEncodingBFloat16KHR) || + (BitWidth == 8 && FloatingPointEncoding == FPEncodingFloat8E4M3EXT) || + (BitWidth == 8 && FloatingPointEncoding == FPEncodingFloat8E5M2EXT)) && "Invalid floating point encoding"); } @@ -1249,9 +1260,13 @@ class SPIRVTypeCooperativeMatrixKHR : public SPIRVType { else if (CompType->isTypeInt() && static_cast(CompType)->getBitWidth() == 4) CV.push_back(CapabilityInt4CooperativeMatrixINTEL); + else if (CompType->isTypeFloat(8, FPEncodingFloat8E4M3EXT) || + CompType->isTypeFloat(8, FPEncodingFloat8E5M2EXT)) + CV.push_back(CapabilityFloat8CooperativeMatrixEXT); return CV; } + std::vector getArgs() const { return Args; } SPIRVType *getCompType() const { return CompType; } SPIRVValue *getScope() const { return Args[0]; } SPIRVValue *getRows() const { return Args[1]; } diff --git a/spirv-headers-tag.conf b/spirv-headers-tag.conf index 9e127c1b87..ece42c8712 100644 --- a/spirv-headers-tag.conf +++ b/spirv-headers-tag.conf @@ -1 +1 @@ -9e3836d7d6023843a72ecd3fbf3f09b1b6747a9e +01e0577914a75a2569c846778c2f93aa8e6feddd diff --git a/test/extensions/EXT/SPV_EXT_float8/conversions_matrix.ll b/test/extensions/EXT/SPV_EXT_float8/conversions_matrix.ll new file mode 100644 index 0000000000..5a6642214e --- /dev/null +++ b/test/extensions/EXT/SPV_EXT_float8/conversions_matrix.ll @@ -0,0 +1,63 @@ +; This tests checks if FP8 matrix conversions work fine. + +; RUN: llvm-as %s -o %t.bc +; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_EXT_float8,+SPV_KHR_cooperative_matrix +; RUN: llvm-spirv %t.spv -o %t.spt --to-text +; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV +; RUN: llvm-spirv %t.spv -o %t.rev.bc -r --spirv-target-env=SPV-IR +; RUN: llvm-dis %t.rev.bc -o %t.rev.ll +; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM + +; TODO: RUNx: spirv-val + +; CHECK-SPIRV-DAG: CooperativeMatrixKHR +; CHECK-SPIRV-DAG: Float8EXT +; CHECK-SPIRV-DAG: Float8CooperativeMatrixEXT +; CHECK-SPIRV-DAG: "SPV_EXT_float8" +; CHECK-SPIRV-DAG: "SPV_KHR_cooperative_matrix" + +; CHECK-SPIRV-DAG: TypeInt [[#Int8Ty:]] 8 0 +; CHECK-SPIRV-DAG: TypeFloat [[#FP8Ty:]] 8 4214 +; CHECK-SPIRV-DAG: TypeFloat [[#FP16Ty:]] 16 +; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#Int8MatrixTy:]] [[#Int8Ty]] +; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#FP8MatrixTy:]] [[#FP8Ty]] +; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#FP16MatrixTy:]] [[#FP16Ty]] + +; CHECK-SPIRV: CompositeConstruct [[#FP16MatrixTy]] [[#M:]] [[#]] +; CHECK-SPIRV: FConvert [[#FP8MatrixTy]] [[#Conv:]] [[#M]] +; CHECK-SPIRV: Bitcast [[#Int8MatrixTy]] [[#]] [[#Conv]] + +; CHECK-LLVM: %[[#M:]] = call spir_func target("spirv.CooperativeMatrixKHR", half, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructDh(half 0xH0000) +; CHECK-LLVM: call target("spirv.CooperativeMatrixKHR", i8, 3, 12, 12, 2) @_Z36__builtin_spirv_ConvertFP16ToE4M3EXTPU3AS144__spirv_CooperativeMatrixKHR__half_3_12_12_2(target("spirv.CooperativeMatrixKHR", half, 3, 12, 12, 2) %[[#M]]) + +; ModuleID = 'test.bc' +target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-G1" +target triple = "spir-unknown-unknown" + +; Function Attrs: nounwind +define spir_func void @int4_hf8() #0 { +entry: + %0 = call spir_func target("spirv.CooperativeMatrixKHR", half, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructDh(half 0.0) #0 + %1 = call target("spirv.CooperativeMatrixKHR", i8, 3, 12, 12, 2) @_Z36__builtin_spirv_ConvertFP16ToE4M3EXTPU3AS144__spirv_CooperativeMatrixKHR__half_3_12_12_2(target("spirv.CooperativeMatrixKHR", half, 3, 12, 12, 2) %0) + ret void +} + +; Function Attrs: nounwind +declare spir_func target("spirv.CooperativeMatrixKHR", half, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructDh(half) #0 + +declare target("spirv.CooperativeMatrixKHR", i8, 3, 12, 12, 2) @_Z36__builtin_spirv_ConvertFP16ToE4M3EXTPU3AS144__spirv_CooperativeMatrixKHR__half_3_12_12_2(target("spirv.CooperativeMatrixKHR", half, 3, 12, 12, 2)) + +attributes #0 = { nounwind } + +!spirv.MemoryModel = !{!0} +!opencl.enable.FP_CONTRACT = !{} +!spirv.Source = !{!1} +!opencl.spir.version = !{!0} +!opencl.used.extensions = !{!2} +!opencl.used.optional.core.features = !{!2} +!spirv.Generator = !{!3} + +!0 = !{i32 1, i32 2} +!1 = !{i32 0, i32 0} +!2 = !{} +!3 = !{i16 6, i16 14} diff --git a/test/extensions/EXT/SPV_EXT_float8/conversions_scalar_vector.ll b/test/extensions/EXT/SPV_EXT_float8/conversions_scalar_vector.ll new file mode 100644 index 0000000000..b0f08b92cd --- /dev/null +++ b/test/extensions/EXT/SPV_EXT_float8/conversions_scalar_vector.ll @@ -0,0 +1,330 @@ +; This tests checks if FP8 scalar and vector conversions specified by +; __builtin_spirv_* external function calls translated correctly. + +; RUN: llvm-as %s -o %t.bc +; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_EXT_float8,+SPV_KHR_bfloat16 +; RUN: llvm-spirv %t.spv -o %t.spt --to-text +; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV +; RUN: llvm-spirv %t.spv -o %t.rev.bc -r --spirv-target-env=SPV-IR +; RUN: llvm-dis %t.rev.bc -o %t.rev.ll +; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM + +; TODO: RUNx: spirv-val + +; CHECK-SPIRV-DAG: Capability Float8EXT + +; CHECK-SPIRV-DAG: Extension "SPV_EXT_float8" + +; CHECK-SPIRV-DAG: Name [[#e4m3_hf16_scalar:]] "e4m3_hf16_scalar" +; CHECK-SPIRV-DAG: Name [[#e4m3_hf16_vector:]] "e4m3_hf16_vector" +; CHECK-SPIRV-DAG: Name [[#e5m2_hf16_scalar:]] "e5m2_hf16_scalar" +; CHECK-SPIRV-DAG: Name [[#e5m2_hf16_vector:]] "e5m2_hf16_vector" +; CHECK-SPIRV-DAG: Name [[#e4m3_bf16_scalar:]] "e4m3_bf16_scalar" +; CHECK-SPIRV-DAG: Name [[#e4m3_bf16_vector:]] "e4m3_bf16_vector" +; CHECK-SPIRV-DAG: Name [[#e5m2_bf16_scalar:]] "e5m2_bf16_scalar" +; CHECK-SPIRV-DAG: Name [[#e5m2_bf16_vector:]] "e5m2_bf16_vector" +; CHECK-SPIRV-DAG: Name [[#hf16_e4m3_scalar:]] "hf16_e4m3_scalar" +; CHECK-SPIRV-DAG: Name [[#hf16_e4m3_vector:]] "hf16_e4m3_vector" +; CHECK-SPIRV-DAG: Name [[#hf16_e5m2_scalar:]] "hf16_e5m2_scalar" +; CHECK-SPIRV-DAG: Name [[#hf16_e5m2_vector:]] "hf16_e5m2_vector" +; CHECK-SPIRV-DAG: Name [[#bf16_e4m3_scalar:]] "bf16_e4m3_scalar" +; CHECK-SPIRV-DAG: Name [[#bf16_e4m3_vector:]] "bf16_e4m3_vector" +; CHECK-SPIRV-DAG: Name [[#bf16_e5m2_scalar:]] "bf16_e5m2_scalar" +; CHECK-SPIRV-DAG: Name [[#bf16_e5m2_vector:]] "bf16_e5m2_vector" + +; CHECK-SPIRV-DAG: TypeInt [[#Int8Ty:]] 8 0 +; CHECK-SPIRV-DAG: TypeVector [[#Int8VecTy:]] [[#Int8Ty]] 8 +; CHECK-SPIRV-DAG: Constant [[#Int8Ty]] [[#Int8Const:]] 1 +; CHECK-SPIRV-DAG: ConstantComposite [[#Int8VecTy]] [[#Int8VecConst:]] [[#Int8Const]] [[#Int8Const]] [[#Int8Const]] [[#Int8Const]] [[#Int8Const]] [[#Int8Const]] [[#Int8Const]] [[#Int8Const]] + +; CHECK-SPIRV-DAG: TypeFloat [[#E4M3Ty:]] 8 4214 +; CHECK-SPIRV-DAG: TypeVector [[#E4M3VecTy:]] [[#E4M3Ty]] 8 + +; CHECK-SPIRV-DAG: TypeFloat [[#E5M2Ty:]] 8 4215 +; CHECK-SPIRV-DAG: TypeVector [[#E5M2VecTy:]] [[#E5M2Ty]] 8 + +; CHECK-SPIRV-DAG: TypeFloat [[#HFloat16Ty:]] 16 {{$}} +; CHECK-SPIRV-DAG: TypeVector [[#HFloat16VecTy:]] [[#HFloat16Ty]] 8 +; CHECK-SPIRV-DAG: Constant [[#HFloat16Ty]] [[#HalfConst:]] 15360 +; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16VecTy]] [[#HalfVecConst:]] [[#HalfConst]] [[#HalfConst]] [[#HalfConst]] [[#HalfConst]] [[#HalfConst]] [[#HalfConst]] [[#HalfConst]] [[#HalfConst]] + +; CHECK-SPIRV-DAG: TypeFloat [[#BFloat16Ty:]] 16 0 +; CHECK-SPIRV-DAG: TypeVector [[#BFloat16VecTy:]] [[#BFloat16Ty]] 8 +; CHECK-SPIRV-DAG: Constant [[#BFloat16Ty]] [[#BfloatConst:]] 16256 +; CHECK-SPIRV-DAG: ConstantComposite [[#BFloat16VecTy]] [[#BfloatVecConst:]] [[#BfloatConst]] [[#BfloatConst]] [[#BfloatConst]] [[#BfloatConst]] [[#BfloatConst]] [[#BfloatConst]] [[#BfloatConst]] [[#BfloatConst]] + +target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" +target triple = "spir-unknown-unknown" + +; CHECK-SPIRV: Function [[#]] [[#e4m3_hf16_scalar]] [[#]] +; CHECK-SPIRV: Bitcast [[#E4M3Ty]] [[#Cast1:]] [[#Int8Const]] +; CHECK-SPIRV: FConvert [[#HFloat16Ty]] [[#Conv:]] [[#Cast1]] +; CHECK-SPIRV: ReturnValue [[#Conv]] + +; CHECK-LLVM-LABEL: e4m3_hf16_scalar +; CHECK-LLVM: %[[#Call:]] = call half @_Z36__builtin_spirv_ConvertE4M3ToFP16EXTc(i8 1) +; CHECK-LLVM: ret half %[[#Call]] + +define spir_func half @e4m3_hf16_scalar() { +entry: + %0 = call half @_Z36__builtin_spirv_ConvertE4M3ToFP16EXTc(i8 1) + ret half %0 +} + +declare dso_local spir_func half @_Z36__builtin_spirv_ConvertE4M3ToFP16EXTc(i8) + +; CHECK-SPIRV: Function [[#]] [[#e4m3_hf16_vector]] [[#]] +; CHECK-SPIRV: Bitcast [[#E4M3VecTy]] [[#Cast1:]] [[#Int8VecConst]] +; CHECK-SPIRV: FConvert [[#HFloat16VecTy]] [[#Conv:]] [[#Cast1]] +; CHECK-SPIRV: ReturnValue [[#Conv]] + +; CHECK-LLVM-LABEL: e4m3_hf16_vector +; CHECK-LLVM: %[[#Call:]] = call <8 x half> @_Z36__builtin_spirv_ConvertE4M3ToFP16EXTDv8_c(<8 x i8> splat (i8 1)) +; CHECK-LLVM: ret <8 x half> %[[#Call]] + +define spir_func <8 x half> @e4m3_hf16_vector() { +entry: + %0 = call <8 x half> @_Z36__builtin_spirv_ConvertE4M3ToFP16EXTDv8_i(<8 x i8> ) + ret <8 x half> %0 +} + +declare dso_local spir_func <8 x half> @_Z36__builtin_spirv_ConvertE4M3ToFP16EXTDv8_i(<8 x i8>) + +; CHECK-SPIRV: Function [[#]] [[#e5m2_hf16_scalar]] [[#]] +; CHECK-SPIRV: Bitcast [[#E5M2Ty]] [[#Cast1:]] [[#Int8Const]] +; CHECK-SPIRV: FConvert [[#HFloat16Ty]] [[#Conv:]] [[#Cast1]] +; CHECK-SPIRV: ReturnValue [[#Conv]] + +; CHECK-LLVM-LABEL: e5m2_hf16_scalar +; CHECK-LLVM: %[[#Call:]] = call half @_Z36__builtin_spirv_ConvertE5M2ToFP16EXTc(i8 1) +; CHECK-LLVM: ret half %[[#Call]] + +define spir_func half @e5m2_hf16_scalar() { +entry: + %0 = call half @_Z36__builtin_spirv_ConvertE5M2ToFP16EXTc(i8 1) + ret half %0 +} + +declare dso_local spir_func half @_Z36__builtin_spirv_ConvertE5M2ToFP16EXTc(i8) + +; CHECK-SPIRV: Function [[#]] [[#e5m2_hf16_vector]] [[#]] +; CHECK-SPIRV: Bitcast [[#E5M2VecTy]] [[#Cast1:]] [[#Int8VecConst]] +; CHECK-SPIRV: FConvert [[#HFloat16VecTy]] [[#Conv:]] [[#Cast1]] +; CHECK-SPIRV: ReturnValue [[#Conv]] + +; CHECK-LLVM-LABEL: e5m2_hf16_vector +; CHECK-LLVM: %[[#Call:]] = call <8 x half> @_Z36__builtin_spirv_ConvertE5M2ToFP16EXTDv8_c(<8 x i8> splat (i8 1)) +; CHECK-LLVM: ret <8 x half> %[[#Call]] + +define spir_func <8 x half> @e5m2_hf16_vector() { +entry: + %0 = call <8 x half> @_Z36__builtin_spirv_ConvertE5M2ToFP16EXTDv8_i(<8 x i8> ) + ret <8 x half> %0 +} + +declare dso_local spir_func <8 x half> @_Z36__builtin_spirv_ConvertE5M2ToFP16EXTDv8_i(<8 x i8>) + +; CHECK-SPIRV: Function [[#]] [[#e4m3_bf16_scalar]] [[#]] +; CHECK-SPIRV: Bitcast [[#E4M3Ty]] [[#Cast1:]] [[#Int8Const]] +; CHECK-SPIRV: FConvert [[#BFloat16Ty]] [[#Conv:]] [[#Cast1]] +; CHECK-SPIRV: ReturnValue [[#Conv]] + +; CHECK-LLVM-LABEL: e4m3_bf16_scalar +; CHECK-LLVM: %[[#Call:]] = call bfloat @_Z36__builtin_spirv_ConvertE4M3ToBF16EXTc(i8 1) +; CHECK-LLVM: ret bfloat %[[#Call]] + +define spir_func bfloat @e4m3_bf16_scalar() { +entry: + %0 = call bfloat @_Z36__builtin_spirv_ConvertE4M3ToBF16EXTc(i8 1) + ret bfloat %0 +} + +declare dso_local spir_func bfloat @_Z36__builtin_spirv_ConvertE4M3ToBF16EXTc(i8) + +; CHECK-SPIRV: Function [[#]] [[#e4m3_bf16_vector]] [[#]] +; CHECK-SPIRV: Bitcast [[#E4M3VecTy]] [[#Cast1:]] [[#Int8VecConst]] +; CHECK-SPIRV: FConvert [[#BFloat16VecTy]] [[#Conv:]] [[#Cast1]] +; CHECK-SPIRV: ReturnValue [[#Conv]] + +; CHECK-LLVM-LABEL: e4m3_bf16_vector +; CHECK-LLVM: %[[#Call:]] = call <8 x bfloat> @_Z36__builtin_spirv_ConvertE4M3ToBF16EXTDv8_c(<8 x i8> splat (i8 1)) +; CHECK-LLVM: ret <8 x bfloat> %[[#Call]] + +define spir_func <8 x bfloat> @e4m3_bf16_vector() { +entry: + %0 = call <8 x bfloat> @_Z36__builtin_spirv_ConvertE4M3ToBF16EXTDv8_i(<8 x i8> ) + ret <8 x bfloat> %0 +} + +declare dso_local spir_func <8 x bfloat> @_Z36__builtin_spirv_ConvertE4M3ToBF16EXTDv8_i(<8 x i8>) + +; CHECK-SPIRV: Function [[#]] [[#e5m2_bf16_scalar]] [[#]] +; CHECK-SPIRV: Bitcast [[#E5M2Ty]] [[#Cast1:]] [[#Int8Const]] +; CHECK-SPIRV: FConvert [[#BFloat16Ty]] [[#Conv:]] [[#Cast1]] +; CHECK-SPIRV: ReturnValue [[#Conv]] + +; CHECK-LLVM-LABEL: e5m2_bf16_scalar +; CHECK-LLVM: %[[#Call:]] = call bfloat @_Z36__builtin_spirv_ConvertE5M2ToBF16EXTc(i8 1) +; CHECK-LLVM: ret bfloat %[[#Call]] + +define spir_func bfloat @e5m2_bf16_scalar() { +entry: + %0 = call bfloat @_Z36__builtin_spirv_ConvertE5M2ToBF16EXTc(i8 1) + ret bfloat %0 +} + +declare dso_local spir_func bfloat @_Z36__builtin_spirv_ConvertE5M2ToBF16EXTc(i8) + +; CHECK-SPIRV: Function [[#]] [[#e5m2_bf16_vector]] [[#]] +; CHECK-SPIRV: Bitcast [[#E5M2VecTy]] [[#Cast1:]] [[#Int8VecConst]] +; CHECK-SPIRV: FConvert [[#BFloat16VecTy]] [[#Conv:]] [[#Cast1]] +; CHECK-SPIRV: ReturnValue [[#Conv]] + +; CHECK-LLVM-LABEL: e5m2_bf16_vector +; CHECK-LLVM: %[[#Call:]] = call <8 x bfloat> @_Z36__builtin_spirv_ConvertE5M2ToBF16EXTDv8_c(<8 x i8> splat (i8 1)) +; CHECK-LLVM: ret <8 x bfloat> %[[#Call]] + +define spir_func <8 x bfloat> @e5m2_bf16_vector() { +entry: + %0 = call <8 x bfloat> @_Z36__builtin_spirv_ConvertE5M2ToBF16EXTDv8_i(<8 x i8> ) + ret <8 x bfloat> %0 +} + +declare dso_local spir_func <8 x bfloat> @_Z36__builtin_spirv_ConvertE5M2ToBF16EXTDv8_i(<8 x i8>) + +; CHECK-SPIRV: Function [[#]] [[#hf16_e4m3_scalar]] [[#]] +; CHECK-SPIRV: FConvert [[#E4M3Ty]] [[#Conv:]] [[#HalfConst]] +; CHECK-SPIRV: Bitcast [[#Int8Ty]] [[#Cast1:]] [[#Conv]] +; CHECK-SPIRV: ReturnValue [[#Cast1]] + +; CHECK-LLVM-LABEL: hf16_e4m3_scalar +; CHECK-LLVM: %[[#Call:]] = call i8 @_Z36__builtin_spirv_ConvertFP16ToE4M3EXTDh(half 0xH3C00) +; CHECK-LLVM: ret i8 %[[#Call]] + +define spir_func i8 @hf16_e4m3_scalar() { +entry: + %0 = call i8 @_Z36__builtin_spirv_ConvertFP16ToE4M3EXTDh(half 0xH3C00) + ret i8 %0 +} + +declare dso_local spir_func i8 @_Z36__builtin_spirv_ConvertFP16ToE4M3EXTDh(half) + +; CHECK-SPIRV: Function [[#]] [[#hf16_e4m3_vector]] [[#]] +; CHECK-SPIRV: FConvert [[#E4M3VecTy]] [[#Conv:]] [[#HalfVecConst]] +; CHECK-SPIRV: Bitcast [[#Int8VecTy]] [[#Cast1:]] [[#Conv]] +; CHECK-SPIRV: ReturnValue [[#Cast1]] + +; CHECK-LLVM-LABEL: hf16_e4m3_vector +; CHECK-LLVM: %[[#Call:]] = call <8 x i8> @_Z36__builtin_spirv_ConvertFP16ToE4M3EXTDv8_Dh(<8 x half> splat (half 0xH3C00)) +; CHECK-LLVM: ret <8 x i8> %[[#Call]] + +define spir_func <8 x i8> @hf16_e4m3_vector() { +entry: + %0 = call <8 x i8> @_Z36__builtin_spirv_ConvertFP16ToE4M3EXTDv8_Dh(<8 x half> ) + ret <8 x i8> %0 +} + +declare dso_local spir_func <8 x i8> @_Z36__builtin_spirv_ConvertFP16ToE4M3EXTDv8_Dh(<8 x half>) + +; CHECK-SPIRV: Function [[#]] [[#hf16_e5m2_scalar]] [[#]] +; CHECK-SPIRV: FConvert [[#E5M2Ty]] [[#Conv:]] [[#HalfConst]] +; CHECK-SPIRV: Bitcast [[#Int8Ty]] [[#Cast1:]] [[#Conv]] +; CHECK-SPIRV: ReturnValue [[#Cast1]] + +; CHECK-LLVM-LABEL: hf16_e5m2_scalar +; CHECK-LLVM: %[[#Call:]] = call i8 @_Z36__builtin_spirv_ConvertFP16ToE5M2EXTDh(half 0xH3C00) +; CHECK-LLVM: ret i8 %[[#Call]] + +define spir_func i8 @hf16_e5m2_scalar() { +entry: + %0 = call i8 @_Z36__builtin_spirv_ConvertFP16ToE5M2EXTDh(half 0xH3C00) + ret i8 %0 +} + +declare dso_local spir_func i8 @_Z36__builtin_spirv_ConvertFP16ToE5M2EXTDh(half) + +; CHECK-SPIRV: Function [[#]] [[#hf16_e5m2_vector]] [[#]] +; CHECK-SPIRV: FConvert [[#E5M2VecTy]] [[#Conv:]] [[#HalfVecConst]] +; CHECK-SPIRV: Bitcast [[#Int8VecTy]] [[#Cast1:]] [[#Conv]] +; CHECK-SPIRV: ReturnValue [[#Cast1]] + +; CHECK-LLVM-LABEL: hf16_e5m2_vector +; CHECK-LLVM: %[[#Call:]] = call <8 x i8> @_Z36__builtin_spirv_ConvertFP16ToE5M2EXTDv8_Dh(<8 x half> splat (half 0xH3C00)) +; CHECK-LLVM: ret <8 x i8> %[[#Call]] + +define spir_func <8 x i8> @hf16_e5m2_vector() { +entry: + %0 = call <8 x i8> @_Z36__builtin_spirv_ConvertFP16ToE5M2EXTDv8_Dh(<8 x half> ) + ret <8 x i8> %0 +} + +declare dso_local spir_func <8 x i8> @_Z36__builtin_spirv_ConvertFP16ToE5M2EXTDv8_Dh(<8 x half>) + +; CHECK-SPIRV: Function [[#]] [[#bf16_e4m3_scalar]] [[#]] +; CHECK-SPIRV: FConvert [[#E4M3Ty]] [[#Conv:]] [[#BfloatConst]] +; CHECK-SPIRV: Bitcast [[#Int8Ty]] [[#Cast1:]] [[#Conv]] +; CHECK-SPIRV: ReturnValue [[#Cast1]] + +; CHECK-LLVM-LABEL: bf16_e4m3_scalar +; CHECK-LLVM: %[[#Call:]] = call i8 @_Z36__builtin_spirv_ConvertBF16ToE4M3EXTDF16b(bfloat 0xR3F80) +; CHECK-LLVM: ret i8 %[[#Call]] + +define spir_func i8 @bf16_e4m3_scalar() { +entry: + %0 = call i8 @_Z36__builtin_spirv_ConvertBF16ToE4M3EXTDF16b(bfloat 0xR3F80) + ret i8 %0 +} + +declare dso_local spir_func i8 @_Z36__builtin_spirv_ConvertBF16ToE4M3EXTDF16b(bfloat) + +; CHECK-SPIRV: Function [[#]] [[#bf16_e4m3_vector]] [[#]] +; CHECK-SPIRV: FConvert [[#E4M3VecTy]] [[#Conv:]] [[#BfloatVecConst]] +; CHECK-SPIRV: Bitcast [[#Int8VecTy]] [[#Cast1:]] [[#Conv]] +; CHECK-SPIRV: ReturnValue [[#Cast1]] + +; CHECK-LLVM-LABEL: bf16_e4m3_vector +; CHECK-LLVM: %[[#Call:]] = call <8 x i8> @_Z36__builtin_spirv_ConvertBF16ToE4M3EXTDv8_DF16b(<8 x bfloat> splat (bfloat 0xR3F80)) +; CHECK-LLVM: ret <8 x i8> %[[#Call]] + +define spir_func <8 x i8> @bf16_e4m3_vector() { +entry: + %0 = call <8 x i8> @_Z36__builtin_spirv_ConvertBF16ToE4M3EXTDv8_DF16b(<8 x bfloat> ) + ret <8 x i8> %0 +} + +declare dso_local spir_func <8 x i8> @_Z36__builtin_spirv_ConvertBF16ToE4M3EXTDv8_DF16b(<8 x bfloat>) + +; CHECK-SPIRV: Function [[#]] [[#bf16_e5m2_scalar]] [[#]] +; CHECK-SPIRV: FConvert [[#E5M2Ty]] [[#Conv:]] [[#BfloatConst]] +; CHECK-SPIRV: Bitcast [[#Int8Ty]] [[#Cast1:]] [[#Conv]] +; CHECK-SPIRV: ReturnValue [[#Cast1]] + +; CHECK-LLVM-LABEL: bf16_e5m2_scalar +; CHECK-LLVM: %[[#Call:]] = call i8 @_Z36__builtin_spirv_ConvertBF16ToE5M2EXTDF16b(bfloat 0xR3F80) +; CHECK-LLVM: ret i8 %[[#Call]] + +define spir_func i8 @bf16_e5m2_scalar() { +entry: + %0 = call i8 @_Z36__builtin_spirv_ConvertBF16ToE5M2EXTDF16b(bfloat 0xR3F80) + ret i8 %0 +} + +declare dso_local spir_func i8 @_Z36__builtin_spirv_ConvertBF16ToE5M2EXTDF16b(bfloat) + +; CHECK-SPIRV: Function [[#]] [[#bf16_e5m2_vector]] [[#]] +; CHECK-SPIRV: FConvert [[#E5M2VecTy]] [[#Conv:]] [[#BfloatVecConst]] +; CHECK-SPIRV: Bitcast [[#Int8VecTy]] [[#Cast1:]] [[#Conv]] +; CHECK-SPIRV: ReturnValue [[#Cast1]] + +; CHECK-LLVM-LABEL: bf16_e5m2_vector +; CHECK-LLVM: %[[#Call:]] = call <8 x i8> @_Z36__builtin_spirv_ConvertBF16ToE5M2EXTDv8_DF16b(<8 x bfloat> splat (bfloat 0xR3F80)) +; CHECK-LLVM: ret <8 x i8> %[[#Call]] + +define spir_func <8 x i8> @bf16_e5m2_vector() { +entry: + %0 = call <8 x i8> @_Z36__builtin_spirv_ConvertBF16ToE5M2EXTDv8_DF16b(<8 x bfloat> ) + ret <8 x i8> %0 +} + +declare dso_local spir_func <8 x i8> @_Z36__builtin_spirv_ConvertBF16ToE5M2EXTDv8_DF16b(<8 x bfloat>) + diff --git a/test/extensions/EXT/SPV_EXT_float8/misc/check_spirv_builtin.ll b/test/extensions/EXT/SPV_EXT_float8/misc/check_spirv_builtin.ll new file mode 100644 index 0000000000..86cf31b1ea --- /dev/null +++ b/test/extensions/EXT/SPV_EXT_float8/misc/check_spirv_builtin.ll @@ -0,0 +1,25 @@ +; This test checks, that function with __builtin_spirv placed in the middle of +; the name is not translated as internal builtin. + +; RUN: llvm-as %s -o %t.bc +; RUN: llvm-spirv %t.bc -o %t.spv +; RUN: llvm-spirv %t.spv -o %t.spt --to-text +; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV +; RUN: llvm-spirv %t.spv -o %t.rev.bc -r --spirv-target-env=SPV-IR +; RUN: llvm-dis %t.rev.bc -o %t.rev.ll +; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM + +; CHECK-SPIRV: _Z19boo__builtin_spirv_fs +; CHECK-LLVM: _Z19boo__builtin_spirv_fs + +target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" +target triple = "spir-unknown-unknown" + +; Function Attrs: nounwind readnone +define spir_func void @foo() { +entry: + %0 = call spir_func half @_Z19boo__builtin_spirv_fs(float 1.0, i16 4) + ret void +} + +declare dso_local spir_func half @_Z19boo__builtin_spirv_fs(float, i16) diff --git a/test/extensions/EXT/SPV_EXT_float8/negative/no-enabled-extensions.ll b/test/extensions/EXT/SPV_EXT_float8/negative/no-enabled-extensions.ll new file mode 100644 index 0000000000..dd9ecb899c --- /dev/null +++ b/test/extensions/EXT/SPV_EXT_float8/negative/no-enabled-extensions.ll @@ -0,0 +1,19 @@ +; RUN: llvm-as %s -o %t.bc +; RUN: not llvm-spirv %t.bc 2>&1 \ +; RUN: | FileCheck %s --check-prefix=CHECK-ERROR + +; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension: +; CHECK-ERROR-NEXT: One of the following extensions: SPV_EXT_float8, +; CHECK-ERROR-SAME: SPV_INTEL_int4 should be enabled to process conversion builtins +; CHECK-ERROR-NEXT: declare dso_local spir_func i8 @_Z36__builtin_spirv_ConvertFP16ToE4M3EXTDh(half) + +target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" +target triple = "spir-unknown-unknown" + +define spir_func i8 @fp16_hf8() { +entry: + %0 = call spir_func i8 @_Z36__builtin_spirv_ConvertFP16ToE4M3EXTDh(half 0.0) + ret i8 %0 +} + +declare dso_local spir_func i8 @_Z36__builtin_spirv_ConvertFP16ToE4M3EXTDh(half) diff --git a/test/extensions/INTEL/SPV_INTEL_int4/conversions_packed.ll b/test/extensions/INTEL/SPV_INTEL_int4/conversions_packed.ll new file mode 100644 index 0000000000..653d63e08a --- /dev/null +++ b/test/extensions/INTEL/SPV_INTEL_int4/conversions_packed.ll @@ -0,0 +1,138 @@ +; This test checks if Int4 packed conversions specified by +; __builtin_spirv_* external function calls are translated correctly. +; Not all of the instructions are tested here, only one per the following +; case: +; 1. from packed Int4 to ... : +; a. packed in 32-bit +; b. packed in 8-bit +; 2. to packed Int4 from ... : +; a. packed in 32-bit +; b. packed in 8-bit + +; RUN: llvm-as %s -o %t.bc +; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_EXT_float8,+SPV_INTEL_int4,+SPV_KHR_bfloat16 +; RUN: llvm-spirv %t.spv -o %t.spt --to-text +; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV +; RUN: llvm-spirv %t.spv -o %t.rev.bc -r --spirv-target-env=SPV-IR +; RUN: llvm-dis %t.rev.bc -o %t.rev.ll +; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM + +; TODO: RUNx: spirv-val + +; CHECK-SPIRV-DAG: Capability Float8EXT +; CHECK-SPIRV-DAG: Capability Int4TypeINTEL +; CHECK-SPIRV-DAG: Extension "SPV_EXT_float8" +; CHECK-SPIRV-DAG: Extension "SPV_INTEL_int4" + +; CHECK-SPIRV-DAG: Name [[#int4_e4m3_32:]] "int4_e4m3_32" +; CHECK-SPIRV-DAG: Name [[#int4_e4m3_8:]] "int4_e4m3_8" +; CHECK-SPIRV-DAG: Name [[#hf16_int4_32:]] "hf16_int4_32" +; CHECK-SPIRV-DAG: Name [[#hf16_int4_8:]] "hf16_int4_8" + +; CHECK-SPIRV-DAG: TypeInt [[#Int32Ty:]] 32 0 +; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Int32Const:]] 1 + +; CHECK-SPIRV-DAG: TypeInt [[#Int8Ty:]] 8 0 +; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec8Ty:]] [[#Int8Ty]] 8 +; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec2Ty:]] [[#Int8Ty]] 2 +; CHECK-SPIRV-DAG: Constant [[#Int8Ty]] [[#Int8Const:]] 1 + +; CHECK-SPIRV-DAG: TypeInt [[#Int4Ty:]] 4 0 +; CHECK-SPIRV-DAG: TypeVector [[#Int4Vec8Ty:]] [[#Int4Ty]] 8 +; CHECK-SPIRV-DAG: TypeVector [[#Int4Vec2Ty:]] [[#Int4Ty]] 2 + +; CHECK-SPIRV-DAG: TypeFloat [[#Float8E4M3Ty:]] 8 4214 +; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec8Ty:]] [[#Float8E4M3Ty]] 8 +; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec2Ty:]] [[#Float8E4M3Ty]] 2 + +; CHECK-SPIRV-DAG: TypeFloat [[#HFloat16Ty:]] 16 {{$}} +; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec8Ty:]] [[#HFloat16Ty]] 8 +; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec2Ty:]] [[#HFloat16Ty]] 2 +; CHECK-SPIRV-DAG: Constant [[#HFloat16Ty]] [[#HFloat16Const:]] 15360 +; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec8Ty]] [[#HFloat16Vec8Const:]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] +; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec2Ty]] [[#HFloat16Vec2Const:]] [[#HFloat16Const]] [[#HFloat16Const]] + +target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" +target triple = "spir-unknown-unknown" + +; Packed in 32-bit integer + +; CHECK-SPIRV: Function [[#]] [[#int4_e4m3_32]] [[#]] +; CHECK-SPIRV: Bitcast [[#Int4Vec8Ty]] [[#Cast1:]] [[#Int32Const]] +; CHECK-SPIRV: ConvertSToF [[#Float8E4M3Vec8Ty]] [[#Conv:]] [[#Const1:]] +; CHECK-SPIRV: Bitcast [[#Int8Vec8Ty]] [[#Cast2:]] [[#Conv]] +; CHECK-SPIRV: ReturnValue [[#Cast2]] + +; CHECK-LLVM-LABEL: int4_e4m3_32 +; CHECK-LLVM: %[[#Cast:]] = bitcast i32 1 to <8 x i4> +; CHECK-LLVM: %[[#Conv:]] = call <8 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv8_i(<8 x i4> %[[#Cast]]) +; CHECK-LLVM: ret <8 x i8> %[[#Conv]] + +; Function Attrs: nounwind readnone +define spir_func <8 x i8> @int4_e4m3_32() { +entry: + %0 = call spir_func <8 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELi(i32 1) + ret <8 x i8> %0 +} + +declare dso_local spir_func <8 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELi(i32) + +; CHECK-SPIRV: Function [[#]] [[#hf16_int4_32]] [[#]] +; CHECK-SPIRV: ConvertFToS [[#Int4Vec8Ty]] [[#Conv:]] [[#HFloat16Vec8Const:]] +; CHECK-SPIRV: Bitcast [[#Int32Ty]] [[#Cast2:]] [[#Conv]] +; CHECK-SPIRV: ReturnValue [[#Cast2]] + +; CHECK-LLVM-LABEL: hf16_int4_32 +; CHECK-LLVM: %[[#Conv:]] = call <8 x i4> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv8_Dh(<8 x half> splat (half 0xH3C00)) +; CHECK-LLVM: [[#Cast:]] = bitcast <8 x i4> %[[#Conv]] to i32 +; CHECK-LLVM: ret i32 %[[#Cast]] + +; Function Attrs: nounwind readnone +define spir_func i32 @hf16_int4_32() { +entry: + %0 = call spir_func i32 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELi(<8 x half> ) + ret i32 %0 +} + +declare dso_local spir_func i32 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELi(<8 x half>) + +; Packed in 8-bit integer + +; CHECK-SPIRV: Function [[#]] [[#int4_e4m3_8]] [[#]] +; CHECK-SPIRV: Bitcast [[#Int4Vec2Ty]] [[#Cast1:]] [[#Int8Const]] +; CHECK-SPIRV: ConvertSToF [[#Float8E4M3Vec2Ty]] [[#Conv:]] [[#Const1:]] +; CHECK-SPIRV: Bitcast [[#Int8Vec2Ty]] [[#Cast2:]] [[#Conv]] +; CHECK-SPIRV: ReturnValue [[#Cast2]] + +; CHECK-LLVM-LABEL: int4_e4m3_8 +; CHECK-LLVM: %[[#Cast:]] = bitcast i8 1 to <2 x i4> +; CHECK-LLVM: %[[#Conv:]] = call <2 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv2_i(<2 x i4> %[[#Cast]]) +; CHECK-LLVM: ret <2 x i8> %[[#Conv]] + +; Function Attrs: nounwind readnone +define spir_func <2 x i8> @int4_e4m3_8() { +entry: + %0 = call spir_func <2 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELc(i8 1) + ret <2 x i8> %0 +} + +declare dso_local spir_func <2 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELc(i8) + +; CHECK-SPIRV: Function [[#]] [[#hf16_int4_8]] [[#]] +; CHECK-SPIRV: ConvertFToS [[#Int4Vec2Ty]] [[#Conv:]] [[#HFloat16Vec2Const:]] +; CHECK-SPIRV: Bitcast [[#Int8Ty]] [[#Cast2:]] [[#Conv]] +; CHECK-SPIRV: ReturnValue [[#Cast2]] + +; CHECK-LLVM-LABEL: hf16_int4_8 +; CHECK-LLVM: %[[#Conv:]] = call <2 x i4> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv2_Dh(<2 x half> splat (half 0xH3C00)) +; CHECK-LLVM: [[#Cast:]] = bitcast <2 x i4> %[[#Conv]] to i8 +; CHECK-LLVM: ret i8 %[[#Cast]] + +; Function Attrs: nounwind readnone +define spir_func i8 @hf16_int4_8() { +entry: + %0 = call spir_func i8 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELc(<2 x half> ) + ret i8 %0 +} + +declare dso_local spir_func i8 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELc(<2 x half>)