Skip to content

Commit 3604d2c

Browse files
Merge pull request #856 from Devsh-Graphics-Programming/intrinsics_adjustments
Intrinsics adjustments
2 parents 4e43183 + f6a69fe commit 3604d2c

File tree

5 files changed

+73
-59
lines changed

5 files changed

+73
-59
lines changed

Diff for: include/nbl/builtin/hlsl/cpp_compat/impl/intrinsics_impl.hlsl

+61-21
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ template<typename T NBL_STRUCT_CONSTRAINABLE>
103103
struct nMax_helper;
104104
template<typename T NBL_STRUCT_CONSTRAINABLE>
105105
struct nClamp_helper;
106-
106+
template<typename T NBL_STRUCT_CONSTRAINABLE>
107+
struct fma_helper;
107108

108109
#ifdef __HLSL_VERSION // HLSL only specializations
109110

@@ -134,6 +135,7 @@ template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(find_lsb_helper, findIL
134135
#undef FIND_MSB_LSB_RETURN_TYPE
135136

136137
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(bitReverse_helper, bitReverse, (T), (T), T)
138+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(dot_helper, dot, (T), (T)(T), typename vector_traits<T>::scalar_type)
137139
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(transpose_helper, transpose, (T), (T), T)
138140
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(length_helper, length, (T), (T), typename vector_traits<T>::scalar_type)
139141
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(normalize_helper, normalize, (T), (T), T)
@@ -162,6 +164,7 @@ template<typename T, typename U> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(refract_hel
162164
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(nMax_helper, nMax, (T), (T)(T), T)
163165
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(nMin_helper, nMin, (T), (T)(T), T)
164166
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(nClamp_helper, nClamp, (T), (T)(T), T)
167+
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(fma_helper, fma, (T), (T)(T)(T), T)
165168

166169
#define BITCOUNT_HELPER_RETRUN_TYPE conditional_t<is_vector_v<T>, vector<int32_t, vector_traits<T>::Dimension>, int32_t>
167170
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(bitCount_helper, bitCount, (T), (T), BITCOUNT_HELPER_RETRUN_TYPE)
@@ -599,6 +602,16 @@ struct nClamp_helper<T>
599602
}
600603
};
601604

605+
template<typename FloatingPoint>
606+
requires concepts::FloatingPointScalar<FloatingPoint>
607+
struct fma_helper<FloatingPoint>
608+
{
609+
static FloatingPoint __call(NBL_CONST_REF_ARG(FloatingPoint) x, NBL_CONST_REF_ARG(FloatingPoint) y, NBL_CONST_REF_ARG(FloatingPoint) z)
610+
{
611+
return std::fma(x, y, z);
612+
}
613+
};
614+
602615
#endif // C++ only specializations
603616

604617
// C++ and HLSL specializations
@@ -613,25 +626,6 @@ struct bitReverseAs_helper<T NBL_PARTIAL_REQ_BOT(concepts::UnsignedIntegralScala
613626
}
614627
};
615628

616-
template<typename Vectorial>
617-
NBL_PARTIAL_REQ_TOP(concepts::Vectorial<Vectorial>)
618-
struct dot_helper<Vectorial NBL_PARTIAL_REQ_BOT(concepts::Vectorial<Vectorial>) >
619-
{
620-
using scalar_type = typename vector_traits<Vectorial>::scalar_type;
621-
622-
static inline scalar_type __call(NBL_CONST_REF_ARG(Vectorial) lhs, NBL_CONST_REF_ARG(Vectorial) rhs)
623-
{
624-
static const uint32_t ArrayDim = vector_traits<Vectorial>::Dimension;
625-
static array_get<Vectorial, scalar_type> getter;
626-
627-
scalar_type retval = getter(lhs, 0) * getter(rhs, 0);
628-
for (uint32_t i = 1; i < ArrayDim; ++i)
629-
retval = retval + getter(lhs, i) * getter(rhs, i);
630-
631-
return retval;
632-
}
633-
};
634-
635629
#ifdef __HLSL_VERSION
636630
// SPIR-V already defines specializations for builtin vector types
637631
#define VECTOR_SPECIALIZATION_CONCEPT concepts::Vectorial<T> && !is_vector_v<T>
@@ -888,8 +882,54 @@ struct mix_helper<T, U NBL_PARTIAL_REQ_BOT(concepts::Vectorial<T> && concepts::B
888882
}
889883
};
890884

885+
template<typename T>
886+
NBL_PARTIAL_REQ_TOP(VECTOR_SPECIALIZATION_CONCEPT)
887+
struct fma_helper<T NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT) >
888+
{
889+
using return_t = T;
890+
static return_t __call(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(T) z)
891+
{
892+
using traits = hlsl::vector_traits<T>;
893+
array_get<T, typename traits::scalar_type> getter;
894+
array_set<T, typename traits::scalar_type> setter;
895+
896+
return_t output;
897+
for (uint32_t i = 0; i < traits::Dimension; ++i)
898+
setter(output, i, fma_helper<typename traits::scalar_type>::__call(getter(x, i), getter(y, i), getter(z, i)));
899+
900+
return output;
901+
}
902+
};
903+
904+
#ifdef __HLSL_VERSION
905+
#define DOT_HELPER_REQUIREMENT (concepts::Vectorial<Vectorial> && !is_vector_v<Vectorial>)
906+
#else
907+
#define DOT_HELPER_REQUIREMENT concepts::Vectorial<Vectorial>
908+
#endif
909+
910+
template<typename Vectorial>
911+
NBL_PARTIAL_REQ_TOP(DOT_HELPER_REQUIREMENT)
912+
struct dot_helper<Vectorial NBL_PARTIAL_REQ_BOT(DOT_HELPER_REQUIREMENT) >
913+
{
914+
using scalar_type = typename vector_traits<Vectorial>::scalar_type;
915+
916+
static inline scalar_type __call(NBL_CONST_REF_ARG(Vectorial) lhs, NBL_CONST_REF_ARG(Vectorial) rhs)
917+
{
918+
static const uint32_t ArrayDim = vector_traits<Vectorial>::Dimension;
919+
static array_get<Vectorial, scalar_type> getter;
920+
921+
scalar_type retval = getter(lhs, 0) * getter(rhs, 0);
922+
for (uint32_t i = 1; i < ArrayDim; ++i)
923+
retval = fma_helper<scalar_type>::__call(getter(lhs, i), getter(rhs, i), retval);
924+
925+
return retval;
926+
}
927+
};
928+
929+
#undef DOT_HELPER_REQUIREMENT
930+
891931
}
892932
}
893933
}
894934

895-
#endif
935+
#endif

Diff for: include/nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl

+6
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,12 @@ inline int32_t2 unpackDouble2x32(T val)
295295
return NAMESPACE::unpackDouble2x32(val);
296296
}
297297

298+
template<typename T>
299+
inline T fma(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(T) z)
300+
{
301+
return cpp_compat_intrinsics_impl::fma_helper<T>::__call(x, y, z);
302+
}
303+
298304
#undef NAMESPACE
299305

300306
}

Diff for: include/nbl/builtin/hlsl/spirv_intrinsics/core.hlsl

+4
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,10 @@ template<typename T NBL_FUNC_REQUIRES(is_floating_point_v<T> && is_vector_v<T>)
318318
[[vk::ext_instruction(spv::OpIsInf)]]
319319
vector<bool, vector_traits<T>::Dimension> isInf(T val);
320320

321+
template<typename Vector NBL_FUNC_REQUIRES(is_vector_v<Vector>)
322+
[[vk::ext_instruction( spv::OpDot )]]
323+
typename vector_traits<Vector>::scalar_type dot(Vector lhs, Vector rhs);
324+
321325
template<typename Matrix>
322326
[[vk::ext_instruction( spv::OpTranspose )]]
323327
Matrix transpose(Matrix mat);

Diff for: include/nbl/builtin/hlsl/tgmath.hlsl

+2-6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include <nbl/builtin/hlsl/spirv_intrinsics/core.hlsl>
1414
#include <nbl/builtin/hlsl/concepts/core.hlsl>
1515
#include <nbl/builtin/hlsl/concepts/vector.hlsl>
16+
#include <nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl>
17+
1618
// C++ headers
1719
#ifndef __HLSL_VERSION
1820
#include <algorithm>
@@ -211,12 +213,6 @@ inline T ceil(NBL_CONST_REF_ARG(T) val)
211213
return tgmath_impl::ceil_helper<T>::__call(val);
212214
}
213215

214-
template<typename T>
215-
inline T fma(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(T) z)
216-
{
217-
return tgmath_impl::fma_helper<T>::__call(x, y, z);
218-
}
219-
220216
template<typename T, typename U>
221217
inline T ldexp(NBL_CONST_REF_ARG(T) arg, NBL_CONST_REF_ARG(U) exp)
222218
{

Diff for: include/nbl/builtin/hlsl/tgmath/impl.hlsl

-32
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@ template<typename T NBL_STRUCT_CONSTRAINABLE>
8383
struct trunc_helper;
8484
template<typename T NBL_STRUCT_CONSTRAINABLE>
8585
struct ceil_helper;
86-
template<typename T NBL_STRUCT_CONSTRAINABLE>
87-
struct fma_helper;
8886
template<typename T, typename U NBL_STRUCT_CONSTRAINABLE>
8987
struct ldexp_helper;
9088
template<typename T NBL_STRUCT_CONSTRAINABLE>
@@ -138,7 +136,6 @@ template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(roundEven_helper, round
138136
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(trunc_helper, trunc, (T), (T), T)
139137
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(ceil_helper, ceil, (T), (T), T)
140138
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(pow_helper, pow, (T), (T)(T), T)
141-
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(fma_helper, fma, (T), (T)(T)(T), T)
142139
template<typename T, typename U> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(ldexp_helper, ldexp, (T)(U), (T)(U), T)
143140
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(modfStruct_helper, modfStruct, (T), (T), ModfOutput<T>)
144141
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(frexpStruct_helper, frexpStruct, (T), (T), FrexpOutput<T>)
@@ -338,16 +335,6 @@ struct roundEven_helper<FloatingPoint NBL_PARTIAL_REQ_BOT(concepts::FloatingPoin
338335
}
339336
};
340337

341-
template<typename FloatingPoint>
342-
NBL_PARTIAL_REQ_TOP(concepts::FloatingPointScalar<FloatingPoint>)
343-
struct fma_helper<FloatingPoint NBL_PARTIAL_REQ_BOT(concepts::FloatingPointScalar<FloatingPoint>) >
344-
{
345-
static FloatingPoint __call(NBL_CONST_REF_ARG(FloatingPoint) x, NBL_CONST_REF_ARG(FloatingPoint) y, NBL_CONST_REF_ARG(FloatingPoint) z)
346-
{
347-
return std::fma(x, y, z);
348-
}
349-
};
350-
351338
template<typename T, typename U>
352339
NBL_PARTIAL_REQ_TOP(concepts::FloatingPointScalar<T> && concepts::IntegralScalar<U>)
353340
struct ldexp_helper<T, U NBL_PARTIAL_REQ_BOT(concepts::FloatingPointScalar<T> && concepts::IntegralScalar<U>) >
@@ -618,25 +605,6 @@ struct pow_helper<T NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT) >
618605
}
619606
};
620607

621-
template<typename T>
622-
NBL_PARTIAL_REQ_TOP(VECTOR_SPECIALIZATION_CONCEPT)
623-
struct fma_helper<T NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT) >
624-
{
625-
using return_t = T;
626-
static return_t __call(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(T) z)
627-
{
628-
using traits = hlsl::vector_traits<T>;
629-
array_get<T, typename traits::scalar_type> getter;
630-
array_set<T, typename traits::scalar_type> setter;
631-
632-
return_t output;
633-
for (uint32_t i = 0; i < traits::Dimension; ++i)
634-
setter(output, i, fma_helper<typename traits::scalar_type>::__call(getter(x, i), getter(y, i), getter(z, i)));
635-
636-
return output;
637-
}
638-
};
639-
640608
template<typename T, typename U>
641609
NBL_PARTIAL_REQ_TOP(VECTOR_SPECIALIZATION_CONCEPT && (vector_traits<T>::Dimension == vector_traits<U>::Dimension))
642610
struct ldexp_helper<T, U NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT && (vector_traits<T>::Dimension == vector_traits<U>::Dimension)) >

0 commit comments

Comments
 (0)