@@ -103,7 +103,8 @@ template<typename T NBL_STRUCT_CONSTRAINABLE>
103
103
struct nMax_helper;
104
104
template<typename T NBL_STRUCT_CONSTRAINABLE>
105
105
struct nClamp_helper;
106
-
106
+ template<typename T NBL_STRUCT_CONSTRAINABLE>
107
+ struct fma_helper;
107
108
108
109
#ifdef __HLSL_VERSION // HLSL only specializations
109
110
@@ -134,6 +135,7 @@ template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(find_lsb_helper, findIL
134
135
#undef FIND_MSB_LSB_RETURN_TYPE
135
136
136
137
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (bitReverse_helper, bitReverse, (T), (T), T)
138
+ template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (dot_helper, dot, (T), (T)(T), typename vector_traits<T>::scalar_type)
137
139
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (transpose_helper, transpose, (T), (T), T)
138
140
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (length_helper, length, (T), (T), typename vector_traits<T>::scalar_type)
139
141
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (normalize_helper, normalize, (T), (T), T)
@@ -162,6 +164,7 @@ template<typename T, typename U> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(refract_hel
162
164
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (nMax_helper, nMax, (T), (T)(T), T)
163
165
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (nMin_helper, nMin, (T), (T)(T), T)
164
166
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (nClamp_helper, nClamp, (T), (T)(T), T)
167
+ template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (fma_helper, fma, (T), (T)(T)(T), T)
165
168
166
169
#define BITCOUNT_HELPER_RETRUN_TYPE conditional_t<is_vector_v<T>, vector <int32_t, vector_traits<T>::Dimension>, int32_t>
167
170
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (bitCount_helper, bitCount, (T), (T), BITCOUNT_HELPER_RETRUN_TYPE)
@@ -599,6 +602,16 @@ struct nClamp_helper<T>
599
602
}
600
603
};
601
604
605
+ template<typename FloatingPoint>
606
+ requires concepts::FloatingPointScalar<FloatingPoint>
607
+ struct fma_helper<FloatingPoint>
608
+ {
609
+ static FloatingPoint __call (NBL_CONST_REF_ARG (FloatingPoint) x, NBL_CONST_REF_ARG (FloatingPoint) y, NBL_CONST_REF_ARG (FloatingPoint) z)
610
+ {
611
+ return std::fma (x, y, z);
612
+ }
613
+ };
614
+
602
615
#endif // C++ only specializations
603
616
604
617
// C++ and HLSL specializations
@@ -613,25 +626,6 @@ struct bitReverseAs_helper<T NBL_PARTIAL_REQ_BOT(concepts::UnsignedIntegralScala
613
626
}
614
627
};
615
628
616
- template<typename Vectorial>
617
- NBL_PARTIAL_REQ_TOP (concepts::Vectorial<Vectorial>)
618
- struct dot_helper<Vectorial NBL_PARTIAL_REQ_BOT (concepts::Vectorial<Vectorial>) >
619
- {
620
- using scalar_type = typename vector_traits<Vectorial>::scalar_type;
621
-
622
- static inline scalar_type __call (NBL_CONST_REF_ARG (Vectorial) lhs, NBL_CONST_REF_ARG (Vectorial) rhs)
623
- {
624
- static const uint32_t ArrayDim = vector_traits<Vectorial>::Dimension;
625
- static array_get<Vectorial, scalar_type> getter;
626
-
627
- scalar_type retval = getter (lhs, 0 ) * getter (rhs, 0 );
628
- for (uint32_t i = 1 ; i < ArrayDim; ++i)
629
- retval = retval + getter (lhs, i) * getter (rhs, i);
630
-
631
- return retval;
632
- }
633
- };
634
-
635
629
#ifdef __HLSL_VERSION
636
630
// SPIR-V already defines specializations for builtin vector types
637
631
#define VECTOR_SPECIALIZATION_CONCEPT concepts::Vectorial<T> && !is_vector_v<T>
@@ -888,8 +882,54 @@ struct mix_helper<T, U NBL_PARTIAL_REQ_BOT(concepts::Vectorial<T> && concepts::B
888
882
}
889
883
};
890
884
885
+ template<typename T>
886
+ NBL_PARTIAL_REQ_TOP (VECTOR_SPECIALIZATION_CONCEPT)
887
+ struct fma_helper<T NBL_PARTIAL_REQ_BOT (VECTOR_SPECIALIZATION_CONCEPT) >
888
+ {
889
+ using return_t = T;
890
+ static return_t __call (NBL_CONST_REF_ARG (T) x, NBL_CONST_REF_ARG (T) y, NBL_CONST_REF_ARG (T) z)
891
+ {
892
+ using traits = hlsl::vector_traits<T>;
893
+ array_get<T, typename traits::scalar_type> getter;
894
+ array_set<T, typename traits::scalar_type> setter;
895
+
896
+ return_t output;
897
+ for (uint32_t i = 0 ; i < traits::Dimension; ++i)
898
+ setter (output, i, fma_helper<typename traits::scalar_type>::__call (getter (x, i), getter (y, i), getter (z, i)));
899
+
900
+ return output;
901
+ }
902
+ };
903
+
904
+ #ifdef __HLSL_VERSION
905
+ #define DOT_HELPER_REQUIREMENT (concepts::Vectorial<Vectorial> && !is_vector_v<Vectorial>)
906
+ #else
907
+ #define DOT_HELPER_REQUIREMENT concepts::Vectorial<Vectorial>
908
+ #endif
909
+
910
+ template<typename Vectorial>
911
+ NBL_PARTIAL_REQ_TOP (DOT_HELPER_REQUIREMENT)
912
+ struct dot_helper<Vectorial NBL_PARTIAL_REQ_BOT (DOT_HELPER_REQUIREMENT) >
913
+ {
914
+ using scalar_type = typename vector_traits<Vectorial>::scalar_type;
915
+
916
+ static inline scalar_type __call (NBL_CONST_REF_ARG (Vectorial) lhs, NBL_CONST_REF_ARG (Vectorial) rhs)
917
+ {
918
+ static const uint32_t ArrayDim = vector_traits<Vectorial>::Dimension;
919
+ static array_get<Vectorial, scalar_type> getter;
920
+
921
+ scalar_type retval = getter (lhs, 0 ) * getter (rhs, 0 );
922
+ for (uint32_t i = 1 ; i < ArrayDim; ++i)
923
+ retval = fma_helper<scalar_type>::__call (getter (lhs, i), getter (rhs, i), retval);
924
+
925
+ return retval;
926
+ }
927
+ };
928
+
929
+ #undef DOT_HELPER_REQUIREMENT
930
+
891
931
}
892
932
}
893
933
}
894
934
895
- #endif
935
+ #endif
0 commit comments