Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,5 @@ EXT(SPV_INTEL_ternary_bitwise_function)
EXT(SPV_INTEL_int4)
EXT(SPV_INTEL_function_variants)
EXT(SPV_INTEL_shader_atomic_bfloat16)
EXT(SPV_EXT_float8)
EXT(SPV_INTEL_predicated_io)
84 changes: 83 additions & 1 deletion lib/SPIRV/SPIRVInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ const static char ConvertHandleToImageINTEL[] = "ConvertHandleToImageINTEL";
const static char ConvertHandleToSamplerINTEL[] = "ConvertHandleToSamplerINTEL";
const static char ConvertHandleToSampledImageINTEL[] =
"ConvertHandleToSampledImageINTEL";
const static char InternalBuiltinPrefix[] = "__builtin_spirv_";
} // namespace kSPIRVName

namespace kSPIRVPostfix {
Expand Down Expand Up @@ -665,7 +666,7 @@ Op getSPIRVFuncOC(StringRef Name, SmallVectorImpl<std::string> *Dec = nullptr);
bool getSPIRVBuiltin(const std::string &Name, spv::BuiltIn &Builtin);

/// \param Name LLVM function name
/// \param DemangledName demanged name of the OpenCL built-in function
/// \param DemangledName demangled name of the OpenCL built-in function
/// \returns true if Name is the name of the OpenCL built-in function,
/// false for other functions
bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp = false);
Expand Down Expand Up @@ -728,6 +729,9 @@ CallInst *addCallInst(Module *M, StringRef FuncName, Type *RetTy,
StringRef InstName = SPIR_TEMP_NAME_PREFIX_CALL,
bool TakeFuncName = true);

/// Check if an LLVM type is spirv.CooperativeMatrixKHR.
bool isLLVMCooperativeMatrixType(llvm::Type *Ty);

/// Add a call instruction for SPIR-V builtin function.
CallInst *addCallInstSPIRV(Module *M, StringRef FuncName, Type *RetTy,
ArrayRef<Value *> Args, AttributeList *Attrs,
Expand Down Expand Up @@ -1029,6 +1033,84 @@ bool postProcessBuiltinsReturningStruct(Module *M, bool IsCpp = false);

bool postProcessBuiltinsWithArrayArguments(Module *M, bool IsCpp = false);

/// \param MangledName LLVM function name.
/// \param DemangledName demangled name of the input function if it is the
/// translator's internal built-in function.
/// \returns true if MangledName is the name of the translator's internal
/// built-in function, false for other functions.
/// Used for 'mini'-floats conversion functions
bool isInternalSPIRVBuiltin(StringRef MangledName, StringRef &DemangledName);

// Wrapper around SPIR-V 1.6.4 FP Encoding to be used in the conversion
// descriptor
enum FPEncodingWrap {
Integer = FPEncoding::FPEncodingMax - 1,
IEEE754 = FPEncoding::FPEncodingMax,
BF16 = FPEncoding::FPEncodingBFloat16KHR,
E4M3 = FPEncoding::FPEncodingFloat8E4M3EXT,
E5M2 = FPEncoding::FPEncodingFloat8E5M2EXT,
};

// Structure describing non-trivial conversions (FP8 and int4)
struct FPConversionDesc {
FPEncodingWrap SrcEncoding;
FPEncodingWrap DstEncoding;
SPIRVWord ConvOpCode;

// To use as a key in std::map
bool operator==(const FPConversionDesc &Other) const {
return SrcEncoding == Other.SrcEncoding &&
DstEncoding == Other.DstEncoding && ConvOpCode == Other.ConvOpCode;
}

bool operator<(const FPConversionDesc &Other) const {
if (ConvOpCode != Other.ConvOpCode)
return ConvOpCode < Other.ConvOpCode;
if (SrcEncoding != Other.SrcEncoding)
return SrcEncoding < Other.SrcEncoding;
return DstEncoding < Other.DstEncoding;
}
};

// Maps internal builtin name to conversion descriptor
typedef SPIRVMap<llvm::StringRef, FPConversionDesc> FPConvertToEncodingMap;

// clang-format off
template <> inline void FPConvertToEncodingMap::init() {
// 8-bit conversions
add("ConvertE4M3ToFP16EXT",
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
add("ConvertE5M2ToFP16EXT",
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
add("ConvertE4M3ToBF16EXT",
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
add("ConvertE5M2ToBF16EXT",
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
add("ConvertFP16ToE4M3EXT",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
add("ConvertFP16ToE5M2EXT",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
add("ConvertBF16ToE4M3EXT",
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
add("ConvertBF16ToE5M2EXT",
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});

add("ConvertInt4ToE4M3INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
add("ConvertInt4ToE5M2INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
add("ConvertInt4ToFP16INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
add("ConvertInt4ToBF16INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
add("ConvertFP16ToInt4INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
add("ConvertBF16ToInt4INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
}

// clang-format on

} // namespace SPIRV

#endif // SPIRV_SPIRVINTERNAL_H
95 changes: 90 additions & 5 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ std::optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {

Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
switch (T->getFloatBitWidth()) {
case 8:
// No LLVM IR counter part for FP8 - map it on i8
return Type::getIntNTy(*Context, 8);
case 16:
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
return Type::getBFloatTy(*Context);
Expand Down Expand Up @@ -1049,6 +1052,22 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
CastInst::CastOps CO = Instruction::BitCast;
bool IsExt =
Dst->getScalarSizeInBits() > Src->getType()->getScalarSizeInBits();

auto GetFPEncoding = [](SPIRVType *Ty) -> FPEncodingWrap {
if (Ty->isTypeFloat()) {
unsigned Enc =
static_cast<SPIRVTypeFloat *>(Ty)->getFloatingPointEncoding();
return static_cast<FPEncodingWrap>(Enc);
}
if (Ty->isTypeInt())
return FPEncodingWrap::Integer;
return FPEncodingWrap::IEEE754;
};

auto IsFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
return Encoding == FPEncodingWrap::E4M3 || Encoding == FPEncodingWrap::E5M2;
};

switch (BC->getOpCode()) {
case OpPtrCastToGeneric:
case OpGenericCastToPtr:
Expand All @@ -1070,10 +1089,58 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
case OpUConvert:
CO = IsExt ? Instruction::ZExt : Instruction::Trunc;
break;
case OpFConvert:
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
case OpConvertSToF:
case OpConvertFToS:
case OpConvertUToF:
case OpConvertFToU:
case OpFConvert: {
const auto OC = BC->getOpCode();
{
auto SPVOps = BC->getOperands();
auto *SPVSrcTy = SPVOps[0]->getType();
auto *SPVDstTy = BC->getType();

auto GetEncodingAndUpdateType =
[GetFPEncoding](SPIRVType *&SPVTy) -> FPEncodingWrap {
if (SPVTy->isTypeVector()) {
SPVTy = SPVTy->getVectorComponentType();
} else if (SPVTy->isTypeCooperativeMatrixKHR()) {
auto *MT = static_cast<SPIRVTypeCooperativeMatrixKHR *>(SPVTy);
SPVTy = MT->getCompType();
}
return GetFPEncoding(SPVTy);
};

FPEncodingWrap SrcEnc = GetEncodingAndUpdateType(SPVSrcTy);
FPEncodingWrap DstEnc = GetEncodingAndUpdateType(SPVDstTy);
if (IsFP8Encoding(SrcEnc) || IsFP8Encoding(DstEnc) ||
SPVSrcTy->isTypeInt(4) || SPVDstTy->isTypeInt(4)) {
FPConversionDesc FPDesc = {SrcEnc, DstEnc, BC->getOpCode()};
auto Conv = SPIRV::FPConvertToEncodingMap::rmap(FPDesc);
std::vector<Value *> Ops = {Src};
std::vector<Type *> OpsTys = {Src->getType()};

std::string BuiltinName =
kSPIRVName::InternalBuiltinPrefix + std::string(Conv);
BuiltinFuncMangleInfo Info;
std::string MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);

FunctionType *FTy = FunctionType::get(Dst, OpsTys, false);
FunctionCallee Func = M->getOrInsertFunction(MangledName, FTy);
return CallInst::Create(Func, Ops, "", BB);
}
}

if (OC == OpFConvert) {
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
break;
}
CO = static_cast<CastInst::CastOps>(OpCodeMap::rmap(OC));
break;
}
case OpBitcast:
if (!Dst->isPointerTy() && Dst == Src->getType())
return Src;
// OpBitcast need to be handled as a special-case when the source is a
// pointer and the destination is not a pointer, and where the source is not
// a pointer and the destination is a pointer. This is supported by the
Expand Down Expand Up @@ -2970,11 +3037,29 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) {
auto *BI = static_cast<SPIRVInstruction *>(BV);
Value *Inst = nullptr;
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion() ||
BI->getType()->isTypeCooperativeMatrixKHR())
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion()) {
Inst = transSPIRVBuiltinFromInst(BI, BB);
else
} else if (BI->getType()->isTypeCooperativeMatrixKHR()) {
// For cooperative matrix conversions generate __builtin_spirv
// conversions instead of __spirv_FConvert in case of mini-float
// type element type.
auto *OutMatrixElementTy =
static_cast<SPIRVTypeCooperativeMatrixKHR *>(BI->getType())
->getCompType();
auto *InMatrixElementTy =
static_cast<SPIRVTypeCooperativeMatrixKHR *>(
static_cast<SPIRVUnary *>(BI)->getOperand(0)->getType())
->getCompType();
if (OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT))
Inst = transConvertInst(BV, F, BB);
else
Inst = transSPIRVBuiltinFromInst(BI, BB);
} else {
Inst = transConvertInst(BV, F, BB);
}
return mapValue(BV, Inst);
}
return mapValue(
Expand Down
24 changes: 23 additions & 1 deletion lib/SPIRV/SPIRVUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

// This file needs to be included before anything that declares
// llvm::PointerType to avoid a compilation bug on MSVC.
#include "llvm/Demangle/Demangle.h"
#include "llvm/Demangle/ItaniumDemangle.h"

#include "FunctionDescriptor.h"
Expand Down Expand Up @@ -267,6 +268,12 @@ bool isSYCLBfloat16Type(llvm::Type *Ty) {
return false;
}

bool isLLVMCooperativeMatrixType(llvm::Type *Ty) {
if (auto *TargetTy = dyn_cast<TargetExtType>(Ty))
return TargetTy->getName() == "spirv.CooperativeMatrixKHR";
return false;
}

Function *getOrCreateFunction(Module *M, Type *RetTy, ArrayRef<Type *> ArgTypes,
StringRef Name, BuiltinFuncMangleInfo *Mangle,
AttributeList *Attrs, bool TakeName) {
Expand Down Expand Up @@ -439,7 +446,7 @@ bool getSPIRVBuiltin(const std::string &OrigName, spv::BuiltIn &B) {
return getByName(R.str(), B);
}

// Demangled name is a substring of the name. The DemangledName is updated only
// DemangledName is a substring of Name. The DemangledName is updated only
// if true is returned
bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) {
if (Name == "printf") {
Expand Down Expand Up @@ -484,6 +491,21 @@ bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) {
return false;
}

// DemangledName is a substring of Name. The DemangledName is updated only
// if true is returned.
bool isInternalSPIRVBuiltin(StringRef Name, StringRef &DemangledName) {
if (!Name.starts_with("_Z"))
return false;
constexpr unsigned DemangledNameLenStart = 2;
size_t Start = Name.find_first_not_of("0123456789", DemangledNameLenStart);
if (!Name.substr(Start, Name.size() - 1)
.starts_with(kSPIRVName::InternalBuiltinPrefix))
return false;
DemangledName = llvm::itaniumDemangle(Name.data(), false);
DemangledName.consume_front(kSPIRVName::InternalBuiltinPrefix);
return true;
}

// Check if a mangled type Name is unsigned
bool isMangledTypeUnsigned(char Mangled) {
return Mangled == 'h' /* uchar */
Expand Down
Loading
Loading