-
-
Notifications
You must be signed in to change notification settings - Fork 491
Integrating thrust into HPX #6744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
9cbfc1b
4885a83
6b6ab66
c81e6db
a5b04f9
9cecd6e
dcf58e2
dfb9673
3eb496e
cf89016
9c853f0
c93a97a
a591b3c
8e2b2a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -395,4 +395,4 @@ namespace hpx { | |
| } fill_n{}; | ||
| } // namespace hpx | ||
|
|
||
| #endif // DOXYGEN | ||
| #endif // DOXYGEN | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,9 @@ set(async_cuda_headers | |
| hpx/async_cuda/get_targets.hpp | ||
| hpx/async_cuda/target.hpp | ||
| hpx/async_cuda/transform_stream.hpp | ||
| hpx/async_cuda/thrust/policy.hpp | ||
| hpx/async_cuda/thrust/algorithms.hpp | ||
| hpx/async_cuda/thrust/detail/algorithm_map.hpp | ||
|
||
| ) | ||
|
|
||
| # Default location is $HPX_ROOT/libs/async_cuda/include_compatibility | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| /// \file | ||
hkaiser marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| /// Universal algorithm dispatch for Thrust integration with HPX | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <hpx/async_cuda/thrust/policy.hpp> | ||
| #include <hpx/parallel/algorithms/fill.hpp> | ||
| #include <hpx/parallel/algorithms/for_each.hpp> | ||
| #include <hpx/functional/tag_invoke.hpp> | ||
| #include <hpx/async_cuda/thrust/detail/algorithm_map.hpp> | ||
| #include <hpx/concepts/concepts.hpp> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| #include <iostream> | ||
adityacodes30 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| #include <type_traits> | ||
|
|
||
| namespace hpx { | ||
| namespace async_cuda { | ||
| namespace thrust { | ||
adityacodes30 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| template<typename HPXTag, typename ThrustPolicy, typename... Args, | ||
| HPX_CONCEPT_REQUIRES_( | ||
| is_thrust_execution_policy_v<std::decay_t<ThrustPolicy>> | ||
adityacodes30 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| )> | ||
| auto tag_invoke(HPXTag tag, ThrustPolicy&& policy, Args&&... args) | ||
| -> decltype(detail::algorithm_map<HPXTag>::invoke(std::forward<ThrustPolicy>(policy), std::forward<Args>(args)...)) { | ||
|
|
||
|
|
||
| // Universal dispatch to the mapped Thrust function | ||
| // This calls detail::algorithm_map<HPXTag>::invoke(policy, args...) | ||
| // which in turn calls the appropriate ::thrust::algorithm(policy.get(), args...) | ||
| return detail::algorithm_map<HPXTag>::invoke( | ||
| std::forward<ThrustPolicy>(policy), | ||
| std::forward<Args>(args)... | ||
| ); | ||
| } | ||
| } // namespace thrust | ||
| } // namespace async_cuda | ||
| } // namespace hpx | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| /// \file | ||
adityacodes30 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| /// Template specialization mapping from HPX algorithm tags to Thrust functions | ||
| /// This file contains all the HPX tag -> Thrust function mappings using | ||
|
|
||
| #pragma once | ||
|
|
||
| // HPX algorithm headers for tag types | ||
| #include <hpx/parallel/algorithms/fill.hpp> | ||
| #include <hpx/parallel/algorithms/copy.hpp> | ||
| #include <hpx/parallel/algorithms/transform.hpp> | ||
| #include <hpx/parallel/algorithms/for_each.hpp> | ||
| #include <hpx/parallel/algorithms/reduce.hpp> | ||
| #include <hpx/parallel/algorithms/sort.hpp> | ||
| #include <hpx/parallel/algorithms/find.hpp> | ||
| #include <hpx/parallel/algorithms/count.hpp> | ||
|
|
||
| // Centralized Thrust algorithm headers | ||
| #include <thrust/fill.h> | ||
| #include <thrust/copy.h> | ||
| #include <thrust/transform.h> | ||
| #include <thrust/for_each.h> | ||
| #include <thrust/reduce.h> | ||
| #include <thrust/sort.h> | ||
| #include <thrust/find.h> | ||
| #include <thrust/count.h> | ||
| #include <thrust/unique.h> | ||
| #include <thrust/reverse.h> | ||
| #include <thrust/scan.h> | ||
|
|
||
| #include <type_traits> | ||
| #include <utility> | ||
|
|
||
| namespace hpx { | ||
| namespace async_cuda { | ||
| namespace thrust { | ||
| namespace detail { | ||
|
|
||
|
|
||
| template<typename HPXTag> | ||
| struct algorithm_map; // No definition = compilation error for unmapped algorithms | ||
|
|
||
|
|
||
| // Each specialization maps one HPX algorithm tag to its Thrust equivalent | ||
| // Pattern: template<> struct algorithm_map<hpx::tag_t> { static invoke(...) } | ||
|
|
||
|
|
||
| template<> | ||
| struct algorithm_map<hpx::fill_t> { | ||
| template<typename Policy, typename... Args> | ||
| static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) { | ||
| return ::thrust::fill(policy.get(), std::forward<Args>(args)...); | ||
adityacodes30 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| static constexpr char const* name() { return "thrust::fill"; } | ||
adityacodes30 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| }; | ||
|
|
||
| template<> | ||
| struct algorithm_map<hpx::copy_t> { | ||
| template<typename Policy, typename... Args> | ||
| static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) { | ||
| return ::thrust::copy(policy.get(), std::forward<Args>(args)...); | ||
| } | ||
|
|
||
| static constexpr char const* name() { return "thrust::copy"; } | ||
| }; | ||
|
|
||
| template<> | ||
| struct algorithm_map<hpx::transform_t> { | ||
| template<typename Policy, typename... Args> | ||
| static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) { | ||
| return ::thrust::transform(policy.get(), std::forward<Args>(args)...); | ||
| } | ||
|
|
||
| static constexpr char const* name() { return "thrust::transform"; } | ||
| }; | ||
|
|
||
| template<> | ||
| struct algorithm_map<hpx::for_each_t> { | ||
| template<typename Policy, typename... Args> | ||
| static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) { | ||
| return ::thrust::for_each(policy.get(), std::forward<Args>(args)...); | ||
| } | ||
|
|
||
| static constexpr char const* name() { return "thrust::for_each"; } | ||
| }; | ||
|
|
||
| template<> | ||
| struct algorithm_map<hpx::reduce_t> { | ||
| template<typename Policy, typename... Args> | ||
| static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) { | ||
| return ::thrust::reduce(policy.get(), std::forward<Args>(args)...); | ||
| } | ||
|
|
||
| static constexpr char const* name() { return "thrust::reduce"; } | ||
| }; | ||
|
|
||
| template<> | ||
| struct algorithm_map<hpx::sort_t> { | ||
| template<typename Policy, typename... Args> | ||
| static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) { | ||
| return ::thrust::sort(policy.get(), std::forward<Args>(args)...); | ||
| } | ||
|
|
||
| static constexpr char const* name() { return "thrust::sort"; } | ||
| }; | ||
|
|
||
| template<> | ||
| struct algorithm_map<hpx::find_t> { | ||
| template<typename Policy, typename... Args> | ||
| static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) { | ||
| return ::thrust::find(policy.get(), std::forward<Args>(args)...); | ||
| } | ||
|
|
||
| static constexpr char const* name() { return "thrust::find"; } | ||
| }; | ||
|
|
||
| template<> | ||
| struct algorithm_map<hpx::count_t> { | ||
| template<typename Policy, typename... Args> | ||
| static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) { | ||
| return ::thrust::count(policy.get(), std::forward<Args>(args)...); | ||
| } | ||
|
|
||
| static constexpr char const* name() { return "thrust::count"; } | ||
| }; | ||
|
|
||
|
|
||
| template<> | ||
| struct algorithm_map<hpx::unique_t> { | ||
| template<typename Policy, typename... Args> | ||
| static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) { | ||
| return ::thrust::unique(policy.get(), std::forward<Args>(args)...); | ||
| } | ||
|
|
||
| static constexpr char const* name() { return "thrust::unique"; } | ||
| }; | ||
|
|
||
| template<> | ||
| struct algorithm_map<hpx::reverse_t> { | ||
| template<typename Policy, typename... Args> | ||
| static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) { | ||
| return ::thrust::reverse(policy.get(), std::forward<Args>(args)...); | ||
| } | ||
|
|
||
| static constexpr char const* name() { return "thrust::reverse"; } | ||
| }; | ||
|
|
||
| template<> | ||
| struct algorithm_map<hpx::inclusive_scan_t> { | ||
| template<typename Policy, typename... Args> | ||
| static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) { | ||
| return ::thrust::inclusive_scan(policy.get(), std::forward<Args>(args)...); | ||
| } | ||
|
|
||
| static constexpr char const* name() { return "thrust::inclusive_scan"; } | ||
| }; | ||
|
|
||
| template<> | ||
| struct algorithm_map<hpx::exclusive_scan_t> { | ||
| template<typename Policy, typename... Args> | ||
| static constexpr decltype(auto) invoke(Policy&& policy, Args&&... args) { | ||
| return ::thrust::exclusive_scan(policy.get(), std::forward<Args>(args)...); | ||
| } | ||
|
|
||
| static constexpr char const* name() { return "thrust::exclusive_scan"; } | ||
| }; | ||
|
|
||
|
|
||
| // SFINAE HELPER - Check if algorithm is mapped at compile time | ||
| // This is used in the universal tag_invoke to enable/disable the overload | ||
| template<typename HPXTag, typename Policy, typename... Args> | ||
| using is_algorithm_mapped = std::void_t< | ||
| decltype(algorithm_map<HPXTag>::invoke(std::declval<Policy>(), std::declval<Args>()...)) | ||
| >; | ||
|
|
||
| } // namespace detail | ||
| } // namespace thrust | ||
| } // namespace async_cuda | ||
| } // namespace hpx | ||
Uh oh!
There was an error while loading. Please reload this page.