Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Remove underspecified vec::vector_t #17867

Merged
merged 2 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sycl/include/sycl/detail/builtins/helper_macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@
[](NUM_ARGS##_AUTO_ARG) { return (NS::NAME)(NUM_ARGS##_ARG); }, \
NUM_ARGS##_ARG); \
} else { \
return __VA_ARGS__(NUM_ARGS##_CONVERTED_ARG); \
return bit_cast<detail::ENABLER<NUM_ARGS##_TEMPLATE_TYPE>>( \
__VA_ARGS__(NUM_ARGS##_CONVERTED_ARG)); \
} \
}

Expand Down
4 changes: 3 additions & 1 deletion sycl/include/sycl/detail/builtins/math_functions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,9 @@ auto builtin_delegate_ptr_impl(FuncTy F, PtrTy p, Ts... xs) {
detail::NON_SCALAR_ENABLER<SYCL_CONCAT(LESS_ONE(NUM_ARGS), _TEMPLATE_TYPE), \
PtrTy> \
NAME(SYCL_CONCAT(LESS_ONE(NUM_ARGS), _TEMPLATE_TYPE_ARG), PtrTy p) { \
return detail::NAME##_impl(SYCL_CONCAT(LESS_ONE(NUM_ARGS), _ARG), p); \
return bit_cast<detail::NON_SCALAR_ENABLER< \
SYCL_CONCAT(LESS_ONE(NUM_ARGS), _TEMPLATE_TYPE), PtrTy>>( \
detail::NAME##_impl(SYCL_CONCAT(LESS_ONE(NUM_ARGS), _ARG), p)); \
}

#if __SYCL_DEVICE_ONLY__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ auto builtin_device_rel_impl(FuncTy F, const Ts &...xs) {
// the relation builtin (vector of int16_t/int32_t/int64_t depending on the
// arguments' element type).
auto ret = F(builtins::convert_arg(xs)...);
vec<signed char, num_elements<T>::value> tmp{ret};
auto tmp = bit_cast<vec<signed char, num_elements<T>::value>>(ret);
using res_elem_type = fixed_width_signed<sizeof(get_elem_type_t<T>)>;
static_assert(is_scalar_arithmetic_v<res_elem_type>);
return tmp.template convert<res_elem_type>();
Expand Down
19 changes: 10 additions & 9 deletions sycl/include/sycl/detail/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -946,11 +946,12 @@ EnableIfNativeShuffle<T> Shuffle(GroupT g, T x, id<1> local_id) {
return result;
} else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
return __spirv_GroupNonUniformShuffle(group_scope<GroupT>::value,
convertToOpenCLType(x), LocalId);
return convertFromOpenCLTypeFor<T>(__spirv_GroupNonUniformShuffle(
group_scope<GroupT>::value, convertToOpenCLType(x), LocalId));
} else {
// Subgroup.
return __spirv_SubgroupShuffleINTEL(convertToOpenCLType(x), LocalId);
return convertFromOpenCLTypeFor<T>(
__spirv_SubgroupShuffleINTEL(convertToOpenCLType(x), LocalId));
}
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
Expand Down Expand Up @@ -987,8 +988,8 @@ EnableIfNativeShuffle<T> ShuffleXor(GroupT g, T x, id<1> mask) {
convertToOpenCLType(x), TargetId);
} else {
// Subgroup.
return __spirv_SubgroupShuffleXorINTEL(convertToOpenCLType(x),
static_cast<uint32_t>(mask.get(0)));
return convertFromOpenCLTypeFor<T>(__spirv_SubgroupShuffleXorINTEL(
convertToOpenCLType(x), static_cast<uint32_t>(mask.get(0))));
}
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
Expand Down Expand Up @@ -1035,8 +1036,8 @@ EnableIfNativeShuffle<T> ShuffleDown(GroupT g, T x, uint32_t delta) {
convertToOpenCLType(x), TargetId);
} else {
// Subgroup.
return __spirv_SubgroupShuffleDownINTEL(convertToOpenCLType(x),
convertToOpenCLType(x), delta);
return convertFromOpenCLTypeFor<T>(__spirv_SubgroupShuffleDownINTEL(
convertToOpenCLType(x), convertToOpenCLType(x), delta));
}
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
Expand Down Expand Up @@ -1079,8 +1080,8 @@ EnableIfNativeShuffle<T> ShuffleUp(GroupT g, T x, uint32_t delta) {
convertToOpenCLType(x), TargetId);
} else {
// Subgroup.
return __spirv_SubgroupShuffleUpINTEL(convertToOpenCLType(x),
convertToOpenCLType(x), delta);
return convertFromOpenCLTypeFor<T>(__spirv_SubgroupShuffleUpINTEL(
convertToOpenCLType(x), convertToOpenCLType(x), delta));
}
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
Expand Down
30 changes: 30 additions & 0 deletions sycl/include/sycl/detail/vector_convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,36 @@ using ConvertBoolAndByteT =
template <typename DataT, int NumElements>
template <typename convertT, rounding_mode roundingMode>
vec<convertT, NumElements> vec<DataT, NumElements>::convert() const {
#if !__SYCL_USE_LIBSYCL8_VEC_IMPL
auto getValue = [this](int Index) {
using RetType = typename std::conditional_t<
detail::is_byte_v<DataT>, int8_t,
#ifdef __SYCL_DEVICE_ONLY__
typename detail::map_type<
DataT,
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
std::byte, /*->*/ std::uint8_t, //
#endif
bool, /*->*/ std::uint8_t, //
sycl::half, /*->*/ sycl::detail::half_impl::StorageT, //
sycl::ext::oneapi::bfloat16,
/*->*/ sycl::ext::oneapi::bfloat16::Bfloat16StorageT, //
char, /*->*/ detail::ConvertToOpenCLType_t<char>, //
DataT, /*->*/ DataT //
>::type
#else
DataT
#endif
>;

#ifdef __SYCL_DEVICE_ONLY__
if constexpr (std::is_same_v<DataT, sycl::ext::oneapi::bfloat16>)
return sycl::bit_cast<RetType>(this->m_Data[Index]);
else
#endif
return static_cast<RetType>(this->m_Data[Index]);
};
#endif
using T = detail::ConvertBoolAndByteT<DataT>;
using R = detail::ConvertBoolAndByteT<convertT>;
using bfloat16 = sycl::ext::oneapi::bfloat16;
Expand Down
7 changes: 7 additions & 0 deletions sycl/include/sycl/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,11 @@ class __SYCL_EBO Swizzle
using element_type = DataT;
using value_type = DataT;

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
#ifdef __SYCL_DEVICE_ONLY__
using vector_t = typename vec<DataT, NumElements>::vector_t;
#endif // __SYCL_DEVICE_ONLY__
#endif

Swizzle() = delete;
Swizzle(const Swizzle &) = delete;
Expand Down Expand Up @@ -497,6 +499,7 @@ class __SYCL_EBO vec :

using Base = detail::vec_base<DataT, NumElements>;

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
#ifdef __SYCL_DEVICE_ONLY__
using element_type_for_vector_t = typename detail::map_type<
DataT,
Expand Down Expand Up @@ -541,6 +544,7 @@ class __SYCL_EBO vec :

private:
#endif // __SYCL_DEVICE_ONLY__
#endif

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
template <int... Indexes>
Expand Down Expand Up @@ -618,6 +622,7 @@ class __SYCL_EBO vec :
static constexpr size_t get_size() { return byte_size(); }
static constexpr size_t byte_size() noexcept { return sizeof(Base); }

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
private:
// getValue should be able to operate on different underlying
// types: enum cl_float#N , builtin vector float#N, builtin type float.
Expand All @@ -640,6 +645,8 @@ class __SYCL_EBO vec :
}

public:
#endif

// Out-of-class definition is in `sycl/detail/vector_convert.hpp`
template <typename convertT,
rounding_mode roundingMode = rounding_mode::automatic>
Expand Down
5 changes: 3 additions & 2 deletions sycl/test-e2e/DeviceLib/built-ins/printf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ int main() {
sycl::vec<int, 4> v4{5, 6, 7, 8};
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
// On SPIRV devices, vectors can be printed via native OpenCL types:
using ocl_int4 = sycl::vec<int, 4>::vector_t;
using ocl_int4 = int __attribute__((ext_vector_type(4)));
{
static const CONSTANT char format[] = "%v4hld\n";
ext::oneapi::experimental::printf(format, (ocl_int4)v4);
ext::oneapi::experimental::printf(format,
sycl::bit_cast<ocl_int4>(v4));
}

// However, you are still able to print them by-element:
Expand Down
19 changes: 19 additions & 0 deletions sycl/test/basic_tests/generic_type_traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,36 @@ int main() {
#endif

#ifdef __SYCL_DEVICE_ONLY__
static_assert(
std::is_same_v<d::ConvertToOpenCLType_t<s::vec<s::opencl::cl_int, 2>>,
s::opencl::cl_int __attribute__((ext_vector_type(2)))>);
static_assert(
std::is_same_v<d::ConvertToOpenCLType_t<s::vec<long long, 2>>,
s::opencl::cl_long __attribute__((ext_vector_type(2)))>);
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
static_assert(
std::is_same_v<d::ConvertToOpenCLType_t<s::vec<s::opencl::cl_int, 2>>,
s::vec<s::opencl::cl_int, 2>::vector_t>);
static_assert(std::is_same_v<d::ConvertToOpenCLType_t<s::vec<long long, 2>>,
s::vec<s::opencl::cl_long, 2>::vector_t>);
#endif
static_assert(std::is_same_v<
d::ConvertToOpenCLType_t<s::multi_ptr<
s::opencl::cl_int, s::access::address_space::global_space,
s::access::decorated::yes>>,
s::multi_ptr<s::opencl::cl_int,
s::access::address_space::global_space,
s::access::decorated::yes>::pointer>);
static_assert(
std::is_same_v<
d::ConvertToOpenCLType_t<
s::multi_ptr<s::vec<s::opencl::cl_int, 4>,
s::access::address_space::global_space,
s::access::decorated::yes>>,
s::multi_ptr<s::opencl::cl_int __attribute__((ext_vector_type(4))),
s::access::address_space::global_space,
s::access::decorated::yes>::pointer>);
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
static_assert(
std::is_same_v<d::ConvertToOpenCLType_t<
s::multi_ptr<s::vec<s::opencl::cl_int, 4>,
Expand All @@ -154,6 +172,7 @@ int main() {
s::multi_ptr<s::vec<s::opencl::cl_int, 4>::vector_t,
s::access::address_space::global_space,
s::access::decorated::yes>::pointer>);
#endif
#endif
static_assert(std::is_same_v<d::ConvertToOpenCLType_t<s::half>,
d::half_impl::BIsRepresentationT>);
Expand Down
4 changes: 2 additions & 2 deletions sycl/test/basic_tests/vectors/assign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ static_assert( std::is_assignable_v<vec<float, 1>, half>);
static_assert( std::is_assignable_v<vec<float, 1>, float>);
static_assert( std::is_assignable_v<vec<float, 1>, double>);
#if __SYCL_DEVICE_ONLY__
static_assert( std::is_assignable_v<vec<float, 1>, vec<half, 1>>);
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<float, 1>, vec<half, 1>>);
#else
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<float, 1>, vec<half, 1>>);
#endif
Expand All @@ -73,7 +73,7 @@ static_assert( std::is_assignable_v<vec<float, 2>, half>);
static_assert( std::is_assignable_v<vec<float, 2>, float>);
static_assert( std::is_assignable_v<vec<float, 2>, double>);
#if __SYCL_DEVICE_ONLY__
static_assert( std::is_assignable_v<vec<float, 2>, vec<half, 1>>);
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<float, 2>, vec<half, 1>>);
#else
static_assert( !std::is_assignable_v<vec<float, 2>, vec<half, 1>>);
#endif
Expand Down
2 changes: 2 additions & 0 deletions sycl/test/basic_tests/vectors/swizzle_aliases.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ int main() {
sycl::vec<int, 4> X{1};
static_assert(std::is_same_v<decltype(X.swizzle<0>())::element_type, int>);
static_assert(std::is_same_v<decltype(X.swizzle<0>())::value_type, int>);
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
#ifdef __SYCL_DEVICE_ONLY__
static_assert(std::is_same_v<decltype(X.swizzle<0>())::vector_t,
sycl::vec<int, 1>::vector_t>);
#endif // __SYCL_DEVICE_ONLY__
#endif
});
return 0;
}
Loading