Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
92472c8
Implement extension SPV_KHR_float_controls2
jmmartinez Nov 18, 2025
8718010
[Review] set FPFastMathDefault equal to 0 for every entry-point
jmmartinez Dec 22, 2025
4530695
[Review] Add test showing reassoc->AllowTransform->reassoc contract
jmmartinez Dec 22, 2025
c079c61
[Review] add fcmp and ExtInst tests
jmmartinez Dec 22, 2025
4ad2388
Ignore FPFastMathMode decorations attached as metadata
jmmartinez Dec 22, 2025
515b153
Add test with multiple floating point types
jmmartinez Dec 22, 2025
8a4f37f
[Review] reword FPFastMathMode decoration metadata comment
jmmartinez Jan 5, 2026
5c36ab5
[Review] fix typo in execution_mode_default.ll
jmmartinez Jan 5, 2026
c476837
[Review] rename CHECK->SPIRV in execution_mode_default.ll
jmmartinez Jan 5, 2026
cf9063e
[Review] add reverse translation tests for execution_mode_default.ll
jmmartinez Jan 5, 2026
73ece56
[Review] remove word-count from SPIRV checks in execution_mode_defaul…
jmmartinez Jan 5, 2026
6e99715
[Review] remove word-count from SPIRV checks in fp-decorate-twice.ll
jmmartinez Jan 5, 2026
5bba51f
[Review] Pre-commit tests before changing the logic of transFPFastMat…
jmmartinez Jan 5, 2026
be484ad
[Review] Update transFPFastMathDefault
jmmartinez Jan 5, 2026
100b998
[Review][NFC] Comments and renames
jmmartinez Jan 9, 2026
b9f0ccb
[Review][NFC] Add assertions
jmmartinez Jan 9, 2026
39e4141
[Review] Propagate FP mode to vector types
jmmartinez Jan 9, 2026
6de8e77
[Review] missing space
jmmartinez Jan 13, 2026
e0bd0eb
[Review] missing .
jmmartinez Jan 13, 2026
a90ccf7
[Review] operaiton -> operation
jmmartinez Jan 13, 2026
1108325
[Review] Do not set FPFastMathDefault to 0 by default / translate Con…
jmmartinez Jan 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,4 @@ EXT(SPV_INTEL_predicated_io)
EXT(SPV_INTEL_sigmoid)
EXT(SPV_INTEL_float4)
EXT(SPV_INTEL_fp_conversions)
EXT(SPV_KHR_float_controls2)
10 changes: 7 additions & 3 deletions lib/SPIRV/SPIRVMDWalker.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ class SPIRVMDWalker {

bool atEnd() const { return !(M && I < E); }

template <typename T> MDWrapper &get(T &V) {
template <typename IntegerType,
typename = std::enable_if_t<std::is_integral_v<IntegerType>>>
MDWrapper &get(IntegerType &V) {
if (!Q)
assert(I < E && "out of bound");
if (atEnd())
Expand All @@ -115,12 +117,14 @@ class SPIRVMDWalker {
return *this;
}

MDWrapper &get(Function *&F) {
template <typename ValueType,
typename = std::enable_if_t<std::is_base_of_v<Value, ValueType>>>
MDWrapper &get(ValueType *&V) {
if (!Q)
assert(I < E && "out of bound");
if (atEnd())
return *this;
F = mdconst::dyn_extract<Function>(M->getOperand(I++));
V = mdconst::dyn_extract<ValueType>(M->getOperand(I++));
return *this;
}

Expand Down
103 changes: 85 additions & 18 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1218,26 +1218,24 @@ static void applyNoIntegerWrapDecorations(const SPIRVValue *BV,
}
}

static void applyFPFastMathModeDecorations(const SPIRVValue *BV,
Instruction *Inst) {
SPIRVWord V;
FastMathFlags FMF;
void SPIRVToLLVM::applyFPFastMathModeDecorations(const SPIRVValue *BV,
Instruction *Inst) {
if (!isa<FPMathOperator>(Inst))
return;

SPIRVWord V{0};
if (BV->hasDecorate(DecorationFPFastMathMode, 0, &V)) {
if (V & FPFastMathModeNotNaNMask)
FMF.setNoNaNs();
if (V & FPFastMathModeNotInfMask)
FMF.setNoInfs();
if (V & FPFastMathModeNSZMask)
FMF.setNoSignedZeros();
if (V & FPFastMathModeAllowRecipMask)
FMF.setAllowReciprocal();
if (V & FPFastMathModeAllowContractFastINTELMask)
FMF.setAllowContract();
if (V & FPFastMathModeAllowReassocINTELMask)
FMF.setAllowReassoc();
if (V & FPFastMathModeFastMask)
FMF.setFast();
FastMathFlags FMF = translateFastMathFlags(V);
Inst->setFastMathFlags(FMF);
return;
}

// Get the scalar type to handle vector operands. And get the first operand
// type (instead of the result) due to fcmp instructions.
Type *FloatType = Inst->getOperand(0)->getType()->getScalarType();
auto Func2FMF = FuncToFastMathFlags.find({Inst->getFunction(), FloatType});
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm tempted to remove this FuncToFastMathFlags stuff.

It's used to set the FPFastMathFlags that are attached to the execution mode to the individual instructions of a kernel.

But since we're preserving the FPFastMathFlags in the metadata; I'm thinking that this is not needed anymore.

@maarquitos14 should I remove this ?

if (Func2FMF != FuncToFastMathFlags.end()) {
Inst->setFastMathFlags(Func2FMF->second);
}
}

Expand Down Expand Up @@ -3443,6 +3441,73 @@ static void validatePhiPredecessors(Function *F) {
}
} // namespace

FastMathFlags SPIRVToLLVM::translateFastMathFlags(SPIRVWord V) const {
FastMathFlags FMF;
if (V & FPFastMathModeNotNaNMask)
FMF.setNoNaNs();
if (V & FPFastMathModeNotInfMask)
FMF.setNoInfs();
if (V & FPFastMathModeNSZMask)
FMF.setNoSignedZeros();
if (V & FPFastMathModeAllowRecipMask)
FMF.setAllowReciprocal();
static_assert(FPFastMathModeAllowContractFastINTELMask ==
FPFastMathModeAllowContractMask);
if (V & FPFastMathModeAllowContractFastINTELMask)
FMF.setAllowContract();
static_assert(FPFastMathModeAllowReassocINTELMask ==
FPFastMathModeAllowReassocMask);
if (V & FPFastMathModeAllowReassocINTELMask)
FMF.setAllowReassoc();
// There is no FPFastMathMode flag that represents LLVM approximate functions
// flag `afn`. Even the FPFastMathMode Fast flag should not imply it, but to
// avoid changing the previous behaviour we make it equivalent to LLVM's.
if (V & FPFastMathModeFastMask)
FMF.setFast();
if (V & FPFastMathModeAllowTransformMask) {
// AllowTransform requires the AllowContract and AllowReassoc bits to be
// set.
assert(FMF.allowContract() && FMF.allowReassoc() &&
"The FPFastMathMode AllowTransform requires AllowContract and "
"AllowReassoc to be set");
}

return FMF;
}

void SPIRVToLLVM::parseFloatControls2ExecutionModeId(SPIRVFunction *BF,
Function *F) {

auto [Begin, End] =
BF->getExecutionModeRange(spv::ExecutionModeFPFastMathDefault);
if (Begin == End)
return;

LLVMContext &C = F->getContext();
NamedMDNode *ExecModeMD =
M->getOrInsertNamedMetadata(kSPIRVMD::ExecutionMode);

Metadata *FPFastMathMode[4] = {ConstantAsMetadata::get(F),
ConstantAsMetadata::get(getUInt32(
M, spv::ExecutionModeFPFastMathDefault)),
nullptr, nullptr};

for (auto [_, EM] : make_range(Begin, End)) {
const auto &Literals = EM->getLiterals();
assert(Literals.size() == 2);
SPIRVWord FloatTyId = Literals[0];
SPIRVType *FloatSPIRVType = BM->get<SPIRVType>(FloatTyId);
Type *FloatType = transFPType(FloatSPIRVType);
SPIRVWord Flags = *transIdAsConstant(Literals[1]);
FuncToFastMathFlags.try_emplace({F, FloatType},
translateFastMathFlags(Flags));

FPFastMathMode[2] = ConstantAsMetadata::get(PoisonValue::get(FloatType));
FPFastMathMode[3] = ConstantAsMetadata::get(getUInt32(M, Flags));
ExecModeMD->addOperand(MDNode::get(C, FPFastMathMode));
}
}

Function *SPIRVToLLVM::transFunction(SPIRVFunction *BF, unsigned AS) {
auto Loc = FuncMap.find(BF);
if (Loc != FuncMap.end())
Expand Down Expand Up @@ -3515,6 +3580,8 @@ Function *SPIRVToLLVM::transFunction(SPIRVFunction *BF, unsigned AS) {
: CallingConv::SPIR_FUNC);
transFunctionAttrs(BF, F);

parseFloatControls2ExecutionModeId(BF, F);

// Creating all basic blocks before creating instructions.
for (size_t I = 0, E = BF->getNumBasicBlock(); I != E; ++I) {
transValue(BF->getBasicBlock(I), F, nullptr);
Expand Down
8 changes: 8 additions & 0 deletions lib/SPIRV/SPIRVReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,14 @@ class SPIRVToLLVM : private BuiltinCallHelper {
void
transFunctionPointerCallArgumentAttributes(SPIRVValue *BV, CallInst *CI,
SPIRVTypeFunction *CalledFnTy);

using FunctionAndTypeIdPair = std::pair<Function *, Type *>;
using FunctionToFastMathFlagsMap =
DenseMap<FunctionAndTypeIdPair, FastMathFlags>;
FunctionToFastMathFlagsMap FuncToFastMathFlags;
FastMathFlags translateFastMathFlags(SPIRVWord V) const;
void parseFloatControls2ExecutionModeId(SPIRVFunction *BF, Function *F);
void applyFPFastMathModeDecorations(const SPIRVValue *BV, Instruction *Inst);
}; // class SPIRVToLLVM

} // namespace SPIRV
Expand Down
128 changes: 110 additions & 18 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3122,6 +3122,13 @@ static void transMetadataDecorations(Metadata *MD, SPIRVValue *Target) {
static_cast<StoreCacheControl>(CacheControl->getZExtValue())));
break;
}
case DecorationFPFastMathMode: {
// Ignore this decoration. FPFastMathMode is set through the LLVM-IR
// fast-math flags (e.g. reassoc, contract) associated with the
// instruction. It should not be set through metadata, since LLVM passes
// are free to ignore it.
break;
}
default: {
if (NumOperands == 1) {
Target->addDecorate(new SPIRVDecorate(DecoKind, Target));
Expand Down Expand Up @@ -3183,11 +3190,26 @@ bool LLVMToSPIRVBase::transDecoration(Value *V, SPIRVValue *BV) {
((Opcode == Instruction::FNeg || Opcode == Instruction::FCmp ||
BV->isExtInst()) &&
BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_6))) {
bool AllowFloatControls2 =
BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_float_controls2);
bool AllowIntelFpFastMathMode =
BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_fp_fast_math_mode);
FastMathFlags FMF = BVF->getFastMathFlags();
SPIRVWord M{0};
if (FMF.isFast())
M |= FPFastMathModeFastMask;
else {
if (FMF.isFast()) {
if (AllowFloatControls2) {
// When SPV_KHR_float_controls2 is used, setting the fast math flag
// bit is deprecated. Set the rest of the bits instead.
M |= FPFastMathModeNotNaNMask | FPFastMathModeNotInfMask |
FPFastMathModeNSZMask | FPFastMathModeAllowRecipMask |
FPFastMathModeAllowTransformMask |
FPFastMathModeAllowReassocMask | FPFastMathModeAllowContractMask;
BM->addCapability(CapabilityFloatControls2);
BM->addExtension(ExtensionID::SPV_KHR_float_controls2);
} else {
M |= FPFastMathModeFastMask;
}
} else {
if (FMF.noNaNs())
M |= FPFastMathModeNotNaNMask;
if (FMF.noInfs())
Expand All @@ -3196,14 +3218,31 @@ bool LLVMToSPIRVBase::transDecoration(Value *V, SPIRVValue *BV) {
M |= FPFastMathModeNSZMask;
if (FMF.allowReciprocal())
M |= FPFastMathModeAllowRecipMask;
if (BM->isAllowedToUseExtension(
ExtensionID::SPV_INTEL_fp_fast_math_mode)) {
if (FMF.allowContract()) {
M |= FPFastMathModeAllowContractFastINTELMask;
BM->addCapability(CapabilityFPFastMathModeINTEL);
BM->addExtension(ExtensionID::SPV_INTEL_fp_fast_math_mode);
if (FMF.allowContract()) {
if (AllowFloatControls2 || AllowIntelFpFastMathMode) {
static_assert(FPFastMathModeAllowContractFastINTELMask ==
FPFastMathModeAllowContractMask);
M |= FPFastMathModeAllowContractMask;
BM->addCapability(AllowFloatControls2
? CapabilityFloatControls2
: CapabilityFPFastMathModeINTEL);
BM->addExtension(AllowFloatControls2
? ExtensionID::SPV_KHR_float_controls2
: ExtensionID::SPV_INTEL_fp_fast_math_mode);
}
if (FMF.allowReassoc()) {
}
if (FMF.allowReassoc()) {
if (AllowFloatControls2) {
// LLVM reassoc maps to SPIRV transform, see
// https://github.com/KhronosGroup/SPIRV-Registry/issues/326 for
// details. Because we are enabling AllowTransform, we must enable
// AllowReassoc and AllowContract too, as required by SPIRV spec.
M |= FPFastMathModeAllowTransformMask |
FPFastMathModeAllowReassocMask |
FPFastMathModeAllowContractMask;
BM->addCapability(CapabilityFloatControls2);
BM->addExtension(ExtensionID::SPV_KHR_float_controls2);
} else if (AllowIntelFpFastMathMode) {
M |= FPFastMathModeAllowReassocINTELMask;
BM->addCapability(CapabilityFPFastMathModeINTEL);
BM->addExtension(ExtensionID::SPV_INTEL_fp_fast_math_mode);
Expand Down Expand Up @@ -6403,7 +6442,22 @@ SPIRVInstruction *LLVMToSPIRVBase::transBuiltinToInst(StringRef DemangledName,
return Inst;
}

void LLVMToSPIRVBase::setExecutionModeFPFastMathDefault(
SPIRVFunction *BF, SPIRVType *FloatSPIRVType, SPIRVWord FlagsLiteral) {
assert(BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_float_controls2));

SPIRVConstant *Flags = BM->getLiteralAsConstant(FlagsLiteral);
BF->addExecutionMode(
new SPIRVExecutionModeId(BF, spv::ExecutionModeFPFastMathDefault,
FloatSPIRVType->getId(), Flags->getId()));

BM->addCapability(CapabilityFloatControls2);
BM->addExtension(ExtensionID::SPV_KHR_float_controls2);
}

bool LLVMToSPIRVBase::transExecutionMode() {
transFPContract();

if (auto NMD = SPIRVMDWalker(*M).getNamedMD(kSPIRVMD::ExecutionMode)) {
while (!NMD.atEnd()) {
unsigned EMode = ~0U;
Expand All @@ -6425,8 +6479,18 @@ bool LLVMToSPIRVBase::transExecutionMode() {

switch (EMode) {
case spv::ExecutionModeContractionOff:
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
// With SPV_KHR_float_controls2 this execution mode is deprecated.
// We cannot set only the contract flag to 0, so we set all flags to 0.
if (BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_float_controls2)) {
for (auto [_, FloatSPIRVType] : TypeMap) {
if (FloatSPIRVType->isTypeFloat()) {
setExecutionModeFPFastMathDefault(BF, FloatSPIRVType, 0);
}
}
} else {
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
}
break;
case spv::ExecutionModeInitializer:
case spv::ExecutionModeFinalizer:
Expand Down Expand Up @@ -6513,9 +6577,21 @@ bool LLVMToSPIRVBase::transExecutionMode() {
BM->addCapability(CapabilityVectorComputeINTEL);
} break;

case spv::ExecutionModeSignedZeroInfNanPreserve:
// With SPV_KHR_float_controls2 this execution mode is deprecated.
// Map this execution mode to the FPFastMathDefault with all flags set
// to 0.
if (BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_float_controls2)) {
unsigned BitWidth;
N.get(BitWidth);

SPIRVType *FloatSPIRVType = BM->addFloatType(BitWidth);
setExecutionModeFPFastMathDefault(BF, FloatSPIRVType, 0);
break;
}
[[fallthrough]];
case spv::ExecutionModeDenormPreserve:
case spv::ExecutionModeDenormFlushToZero:
case spv::ExecutionModeSignedZeroInfNanPreserve:
case spv::ExecutionModeRoundingModeRTE:
case spv::ExecutionModeRoundingModeRTZ: {
if (BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_4)) {
Expand All @@ -6542,20 +6618,32 @@ bool LLVMToSPIRVBase::transExecutionMode() {
break;
AddSingleArgExecutionMode(static_cast<ExecutionMode>(EMode));
} break;
case spv::ExecutionModeFPFastMathDefault: {
if (!BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_float_controls2))
break;
assert(F);
BM->addCapability(CapabilityFloatControls2);
BM->addExtension(ExtensionID::SPV_KHR_float_controls2);
PoisonValue *V;
unsigned FlagsLiteral;
N.get(V).get(FlagsLiteral);
SPIRVType *FloatSPIRVType = transType(V->getType());
setExecutionModeFPFastMathDefault(BF, FloatSPIRVType, FlagsLiteral);
break;
}
default:
llvm_unreachable("invalid execution mode");
}
}
}

transFPContract();

return true;
}

void LLVMToSPIRVBase::transFPContract() {
FPContractMode Mode = BM->getFPContractMode();

LLVMContext &C = M->getContext();
for (Function &F : *M) {
SPIRVValue *TranslatedF = getTranslatedValue(&F);
if (!TranslatedF) {
Expand All @@ -6564,7 +6652,7 @@ void LLVMToSPIRVBase::transFPContract() {
SPIRVFunction *BF = static_cast<SPIRVFunction *>(TranslatedF);

bool IsKernelEntryPoint =
BF->getModule()->isEntryPoint(spv::ExecutionModelKernel, BF->getId());
BM->isEntryPoint(spv::ExecutionModelKernel, BF->getId());
if (!IsKernelEntryPoint)
continue;

Expand All @@ -6585,8 +6673,12 @@ void LLVMToSPIRVBase::transFPContract() {
}

if (DisableContraction) {
BF->addExecutionMode(BF->getModule()->add(new SPIRVExecutionMode(
OpExecutionMode, BF, spv::ExecutionModeContractionOff)));
NamedMDNode *ExecModeMD =
M->getOrInsertNamedMetadata(kSPIRVMD::ExecutionMode);
Metadata *ContractionOff[2] = {ConstantAsMetadata::get(&F),
ConstantAsMetadata::get(getUInt32(
M, spv::ExecutionModeContractionOff))};
ExecModeMD->addOperand(MDNode::get(C, ContractionOff));
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/SPIRVWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ class LLVMToSPIRVBase : protected BuiltinCallHelper {
bool translate();
bool transExecutionMode();
void transFPContract();
void setExecutionModeFPFastMathDefault(SPIRVFunction *BF,
SPIRVType *FloatSPIRVType,
SPIRVWord FlagsLiteral);
SPIRVValue *transConstant(Value *V);
/// Translate a reference to a constant in a constant expression. This may
/// involve inserting extra bitcasts to correct type issues.
Expand Down
Loading
Loading