Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -395,4 +395,4 @@ namespace hpx {
} fill_n{};
} // namespace hpx

#endif // DOXYGEN
#endif // DOXYGEN
3 changes: 3 additions & 0 deletions libs/core/async_cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ set(async_cuda_headers
hpx/async_cuda/get_targets.hpp
hpx/async_cuda/target.hpp
hpx/async_cuda/transform_stream.hpp
hpx/async_cuda/thrust/policy.hpp
hpx/async_cuda/thrust/algorithms.hpp
hpx/async_cuda/thrust/detail/algorithm_map.hpp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a CMake option that is enabled only if the user has enabled support for Thrust? We probably also need support for finding Thrust/CCCL in CMake to begin with (we can't assume that Thrust/CCCL is installed in the system directories).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we would need to only check for cccl since they are distributed together but yes this is a valid point . I will do that

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thrust is distributed with cccl for quite some time now

)

# Default location is $HPX_ROOT/libs/async_cuda/include_compatibility
Expand Down
39 changes: 39 additions & 0 deletions libs/core/async_cuda/include/hpx/async_cuda/thrust/algorithms.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/// \file
/// Universal algorithm dispatch for Thrust integration with HPX

#pragma once

#include <hpx/async_cuda/thrust/policy.hpp>
#include <hpx/parallel/algorithms/fill.hpp>
#include <hpx/parallel/algorithms/for_each.hpp>
#include <hpx/functional/tag_invoke.hpp>
#include <hpx/async_cuda/thrust/detail/algorithm_map.hpp>
#include <hpx/concepts/concepts.hpp>
#include <cuda_runtime.h>

#include <iostream>
#include <type_traits>

namespace hpx {
namespace async_cuda {
namespace thrust {

template<typename HPXTag, typename ThrustPolicy, typename... Args,
HPX_CONCEPT_REQUIRES_(
is_thrust_execution_policy_v<std::decay_t<ThrustPolicy>>
)>
auto tag_invoke(HPXTag tag, ThrustPolicy&& policy, Args&&... args)
-> decltype(detail::algorithm_map<HPXTag>::invoke(std::forward<ThrustPolicy>(policy), std::forward<Args>(args)...)) {


// Universal dispatch to the mapped Thrust function
// This calls detail::algorithm_map<HPXTag>::invoke(policy, args...)
// which in turn calls the appropriate ::thrust::algorithm(policy.get(), args...)
return detail::algorithm_map<HPXTag>::invoke(
std::forward<ThrustPolicy>(policy),
std::forward<Args>(args)...
);
}
} // namespace thrust
} // namespace async_cuda
} // namespace hpx
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/// \file
/// Template specialization mapping from HPX algorithm tags to Thrust functions
/// This file contains all the HPX tag -> Thrust function mappings using

#pragma once

// HPX algorithm headers for tag types
#include <hpx/parallel/algorithms/fill.hpp>
#include <hpx/parallel/algorithms/copy.hpp>
#include <hpx/parallel/algorithms/transform.hpp>
#include <hpx/parallel/algorithms/for_each.hpp>
#include <hpx/parallel/algorithms/reduce.hpp>
#include <hpx/parallel/algorithms/sort.hpp>
#include <hpx/parallel/algorithms/find.hpp>
#include <hpx/parallel/algorithms/count.hpp>

// Centralized Thrust algorithm headers
#include <thrust/fill.h>
#include <thrust/copy.h>
#include <thrust/transform.h>
#include <thrust/for_each.h>
#include <thrust/reduce.h>
#include <thrust/sort.h>
#include <thrust/find.h>
#include <thrust/count.h>
#include <thrust/unique.h>
#include <thrust/reverse.h>
#include <thrust/scan.h>

#include <type_traits>
#include <utility>

namespace hpx {
namespace async_cuda {
namespace thrust {
namespace detail {


template<typename HPXTag>
struct algorithm_map; // No definition = compilation error for unmapped algorithms


// Each specialization maps one HPX algorithm tag to its Thrust equivalent
// Pattern: template<> struct algorithm_map<hpx::tag_t> { static invoke(...) }


template<>
struct algorithm_map<hpx::fill_t> {
template<typename Policy, typename... Args>
static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) {
return ::thrust::fill(policy.get(), std::forward<Args>(args)...);
}

static constexpr char const* name() { return "thrust::fill"; }
};

template<>
struct algorithm_map<hpx::copy_t> {
template<typename Policy, typename... Args>
static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) {
return ::thrust::copy(policy.get(), std::forward<Args>(args)...);
}

static constexpr char const* name() { return "thrust::copy"; }
};

template<>
struct algorithm_map<hpx::transform_t> {
template<typename Policy, typename... Args>
static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) {
return ::thrust::transform(policy.get(), std::forward<Args>(args)...);
}

static constexpr char const* name() { return "thrust::transform"; }
};

template<>
struct algorithm_map<hpx::for_each_t> {
template<typename Policy, typename... Args>
static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) {
return ::thrust::for_each(policy.get(), std::forward<Args>(args)...);
}

static constexpr char const* name() { return "thrust::for_each"; }
};

template<>
struct algorithm_map<hpx::reduce_t> {
template<typename Policy, typename... Args>
static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) {
return ::thrust::reduce(policy.get(), std::forward<Args>(args)...);
}

static constexpr char const* name() { return "thrust::reduce"; }
};

template<>
struct algorithm_map<hpx::sort_t> {
template<typename Policy, typename... Args>
static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) {
return ::thrust::sort(policy.get(), std::forward<Args>(args)...);
}

static constexpr char const* name() { return "thrust::sort"; }
};

template<>
struct algorithm_map<hpx::find_t> {
template<typename Policy, typename... Args>
static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) {
return ::thrust::find(policy.get(), std::forward<Args>(args)...);
}

static constexpr char const* name() { return "thrust::find"; }
};

template<>
struct algorithm_map<hpx::count_t> {
template<typename Policy, typename... Args>
static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) {
return ::thrust::count(policy.get(), std::forward<Args>(args)...);
}

static constexpr char const* name() { return "thrust::count"; }
};


template<>
struct algorithm_map<hpx::unique_t> {
template<typename Policy, typename... Args>
static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) {
return ::thrust::unique(policy.get(), std::forward<Args>(args)...);
}

static constexpr char const* name() { return "thrust::unique"; }
};

template<>
struct algorithm_map<hpx::reverse_t> {
template<typename Policy, typename... Args>
static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) {
return ::thrust::reverse(policy.get(), std::forward<Args>(args)...);
}

static constexpr char const* name() { return "thrust::reverse"; }
};

template<>
struct algorithm_map<hpx::inclusive_scan_t> {
template<typename Policy, typename... Args>
static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) {
return ::thrust::inclusive_scan(policy.get(), std::forward<Args>(args)...);
}

static constexpr char const* name() { return "thrust::inclusive_scan"; }
};

template<>
struct algorithm_map<hpx::exclusive_scan_t> {
template<typename Policy, typename... Args>
static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) {
return ::thrust::exclusive_scan(policy.get(), std::forward<Args>(args)...);
}

static constexpr char const* name() { return "thrust::exclusive_scan"; }
};


// SFINAE HELPER - Check if algorithm is mapped at compile time
// This is used in the universal tag_invoke to enable/disable the overload
template<typename HPXTag, typename Policy, typename... Args>
using is_algorithm_mapped = std::void_t<
decltype(algorithm_map<HPXTag>::invoke(std::declval<Policy>(), std::declval<Args>()...))
>;

} // namespace detail
} // namespace thrust
} // namespace async_cuda
} // namespace hpx
Loading
Loading