Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 33 additions & 38 deletions libcudacxx/include/cuda/__mdspan/host_device_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@

#include <cuda/__driver/driver_api.h>
#include <cuda/__memory/address_space.h>
#include <cuda/__memory/is_pointer_accessible.h>
#include <cuda/std/__concepts/concept_macros.h>
#include <cuda/std/__cstddef/types.h>
#include <cuda/std/__iterator/concepts.h>
#include <cuda/std/__memory/pointer_traits.h>
#include <cuda/std/__type_traits/always_false.h>
#include <cuda/std/__type_traits/is_constant_evaluated.h>
#include <cuda/std/__type_traits/is_constructible.h>
#include <cuda/std/__type_traits/is_convertible.h>
#include <cuda/std/__type_traits/is_default_constructible.h>
Expand All @@ -35,8 +37,6 @@
#include <cuda/std/__type_traits/is_nothrow_default_constructible.h>
#include <cuda/std/__utility/declval.h>
#include <cuda/std/__utility/move.h>
#include <cuda/std/cassert>
#include <cuda/std/cstddef>

#include <cuda/std/__cccl/prologue.h>

Expand Down Expand Up @@ -105,21 +105,13 @@ class __host_accessor : public _Accessor
noexcept(::cuda::std::declval<_Accessor>().offset(::cuda::std::declval<__data_handle_type>(), 0));

#if !_CCCL_COMPILER(NVRTC)
[[nodiscard]] _CCCL_HOST_API static constexpr bool
__is_host_accessible_pointer([[maybe_unused]] __data_handle_type __p) noexcept
[[nodiscard]]
_CCCL_HOST_API static constexpr bool __is_host_accessible_pointer([[maybe_unused]] __data_handle_type __p) noexcept
{
# if _CCCL_HAS_CTK()
if constexpr (::cuda::std::contiguous_iterator<__data_handle_type>)
{
_CCCL_IF_NOT_CONSTEVAL_DEFAULT
{
auto __p1 = ::cuda::std::to_address(__p);
::CUmemorytype __type{};
const auto __status =
::cuda::__driver::__pointerGetAttributeNoThrow<::CU_POINTER_ATTRIBUTE_MEMORY_TYPE>(__type, __p1);
return (__status != ::cudaSuccess) || __type == ::CU_MEMORYTYPE_HOST;
}
return true;
return ::cuda::is_host_accessible(::cuda::std::to_address(__p));
}
else
# endif // _CCCL_HAS_CTK()
Expand Down Expand Up @@ -242,29 +234,26 @@ class __device_accessor : public _Accessor
static constexpr bool __is_offset_noexcept =
noexcept(::cuda::std::declval<_Accessor>().offset(::cuda::std::declval<__data_handle_type>(), 0));

[[nodiscard]] _CCCL_API static constexpr bool
[[nodiscard]] _CCCL_API static bool
__is_device_accessible_pointer_from_host([[maybe_unused]] __data_handle_type __p) noexcept
{
#if _CCCL_HAS_CTK()
#if _CCCL_HAS_CTK() && !_CCCL_COMPILER(NVRTC)
if constexpr (::cuda::std::contiguous_iterator<__data_handle_type>)
{
auto __p1 = ::cuda::std::to_address(__p);
::CUmemorytype __type{};
const auto __status =
::cuda::__driver::__pointerGetAttributeNoThrow<::CU_POINTER_ATTRIBUTE_MEMORY_TYPE>(__type, __p1);
return (__status != ::cudaSuccess) || __type == ::CU_MEMORYTYPE_DEVICE;
static const auto __dev_id = static_cast<int>(::cuda::__driver::__ctxGetDevice());
return ::cuda::is_device_accessible(::cuda::std::to_address(__p), ::cuda::device_ref{__dev_id});
}
else
#endif // _CCCL_HAS_CTK()
#endif // _CCCL_HAS_CTK() && !_CCCL_COMPILER(NVRTC)
{
return true; // cannot be verified
}
}

#if _CCCL_DEVICE_COMPILATION()

[[nodiscard]] _CCCL_HIDE_FROM_ABI _CCCL_DEVICE static constexpr bool
__is_device_accessible_pointer_from_device(__data_handle_type __p) noexcept
[[nodiscard]] _CCCL_HIDE_FROM_ABI
_CCCL_DEVICE static constexpr bool __is_device_accessible_pointer_from_device(__data_handle_type __p) noexcept
{
return ::cuda::device::is_address_from(__p, ::cuda::device::address_space::global)
|| ::cuda::device::is_address_from(__p, ::cuda::device::address_space::shared)
Expand All @@ -278,8 +267,12 @@ class __device_accessor : public _Accessor

_CCCL_API static constexpr void __check_device_pointer([[maybe_unused]] __data_handle_type __p) noexcept
{
NV_IF_TARGET(NV_IS_HOST,
(_CCCL_ASSERT(__is_device_accessible_pointer_from_host(__p), "The pointer is not device accessible");))
if (!::cuda::std::__cccl_default_is_constant_evaluated())
{
NV_IF_TARGET(
NV_IS_HOST,
(_CCCL_ASSERT(__is_device_accessible_pointer_from_host(__p), "The pointer is not device accessible");))
}
}

public:
Expand Down Expand Up @@ -357,10 +350,13 @@ class __device_accessor : public _Accessor

_CCCL_API constexpr reference access(data_handle_type __p, size_t __i) const noexcept(__is_access_noexcept)
{
NV_IF_ELSE_TARGET(
NV_IS_DEVICE,
(_CCCL_ASSERT(__is_device_accessible_pointer_from_device(__p), "The pointer is not device accessible");),
(_CCCL_VERIFY(false, "cuda::device_accessor cannot be used in HOST code");))
if (!::cuda::std::__cccl_default_is_constant_evaluated())
{
NV_IF_ELSE_TARGET(
NV_IS_DEVICE,
(_CCCL_ASSERT(__is_device_accessible_pointer_from_device(__p), "The pointer is not device accessible");),
(_CCCL_VERIFY(false, "cuda::device_accessor cannot be used in HOST code");))
}
return _Accessor::access(__p, __i);
}

Expand All @@ -372,7 +368,10 @@ class __device_accessor : public _Accessor

[[nodiscard]] _CCCL_API constexpr bool __detectably_invalid(data_handle_type __p, size_t) const noexcept
{
NV_IF_ELSE_TARGET(NV_IS_HOST, (return __is_device_accessible_pointer_from_host(__p);), (return false;))
if (!::cuda::std::__cccl_default_is_constant_evaluated())
{
NV_IF_ELSE_TARGET(NV_IS_HOST, (return __is_device_accessible_pointer_from_host(__p);), (return false;))
}
}
};

Expand All @@ -396,17 +395,13 @@ class __managed_accessor : public _Accessor

[[nodiscard]] _CCCL_API static constexpr bool __is_managed_pointer([[maybe_unused]] __data_handle_type __p) noexcept
{
#if _CCCL_HAS_CTK()
#if _CCCL_HAS_CTK() && !_CCCL_COMPILER(NVRTC)
if constexpr (::cuda::std::contiguous_iterator<__data_handle_type>)
{
const auto __p1 = ::cuda::std::to_address(__p);
bool __is_managed{};
const auto __status =
::cuda::__driver::__pointerGetAttributeNoThrow<::CU_POINTER_ATTRIBUTE_IS_MANAGED>(__is_managed, __p1);
return (__status != ::cudaSuccess) || __is_managed;
return ::cuda::is_managed(::cuda::std::to_address(__p));
}
else
#endif // _CCCL_HAS_CTK()
#endif // _CCCL_HAS_CTK() && !_CCCL_COMPILER(NVRTC)
{
return true; // cannot be verified
}
Expand Down
Loading