-
-
Notifications
You must be signed in to change notification settings - Fork 491
Feature/hierarchical collectives #6668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
f9e165f
7bab42e
38d3b49
ef6f3d3
1a83e5d
390273b
9d0e89e
a20a467
fc2c307
633e52d
a0e8b58
bed46b0
dc16e12
0cbc5b7
0d63554
3bad0f0
eda96af
ea58a7b
0443894
4a887d9
1b4027e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -392,6 +392,79 @@ namespace hpx::collectives { | |||||||
| this_site, generation, root_site), | ||||||||
| this_site); | ||||||||
| } | ||||||||
|
|
||||||||
| template <typename T> | ||||||||
| hpx::future<std::decay_t<T>> broadcast_to_hierarchically( | ||||||||
| std::vector<std::tuple<communicator,int>> communicators, | ||||||||
| T&& local_result, | ||||||||
| this_site_arg this_site = this_site_arg(), | ||||||||
| generation_arg generation = generation_arg(), | ||||||||
| root_site_arg root_site = root_site_arg(), | ||||||||
| int arity = 2) | ||||||||
| { | ||||||||
| if (this_site == static_cast<std::size_t>(-1)) | ||||||||
| { | ||||||||
| this_site = agas::get_locality_id(); | ||||||||
| } | ||||||||
| if (generation == 0) | ||||||||
| { | ||||||||
| return hpx::make_exceptional_future<T>(HPX_GET_EXCEPTION( | ||||||||
| hpx::error::bad_parameter, "hpx::collectives::scatter_to", | ||||||||
| "the generation number shouldn't be zero")); | ||||||||
| } | ||||||||
|
|
||||||||
| communicator current_communicator = std::get<0>(communicators[0]); | ||||||||
| int current_site = std::get<1>(communicators[0]); | ||||||||
| if (this_site == root_site) | ||||||||
| { | ||||||||
| T current_local_result = std::move(local_result); | ||||||||
| for (int i = 0; i < communicators.size()-1;i++) | ||||||||
| { | ||||||||
| current_communicator = std::get<0>(communicators[i]); | ||||||||
| current_local_result = broadcast_to(current_communicator, std::move(current_local_result), generation, this_site_arg(0)).get(); | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do have a synchronous overload of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use |
||||||||
| } | ||||||||
| current_communicator = std::get<0>(communicators[communicators.size()-1]); | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| return broadcast_to(current_communicator, std::move(current_local_result), generation, this_site_arg(0)); | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| template <typename T> | ||||||||
| hpx::future<T> broadcast_from_hierarchically( | ||||||||
| std::vector<std::tuple<communicator,int>> communicators, | ||||||||
| this_site_arg this_site = this_site_arg(), | ||||||||
| generation_arg generation = generation_arg(), | ||||||||
| root_site_arg root_site = root_site_arg(), | ||||||||
| int arity = 2) | ||||||||
| { | ||||||||
| if (this_site == static_cast<std::size_t>(-1)) | ||||||||
| { | ||||||||
| this_site = agas::get_locality_id(); | ||||||||
| } | ||||||||
| if (generation == 0) | ||||||||
| { | ||||||||
| return hpx::make_exceptional_future<T>(HPX_GET_EXCEPTION( | ||||||||
| hpx::error::bad_parameter, "hpx::collectives::scatter_to", | ||||||||
| "the generation number shouldn't be zero")); | ||||||||
| } | ||||||||
|
|
||||||||
| communicator current_communicator = std::get<0>(communicators[0]); | ||||||||
| int current_site = std::get<1>(communicators[0]); | ||||||||
|
Comment on lines
+450
to
+451
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| if (this_site != root_site && communicators.size()>1) | ||||||||
| { | ||||||||
| T current_local_result = broadcast_from<T>(current_communicator, generation, this_site_arg(current_site)).get(); | ||||||||
| for (int i = 1; i < communicators.size()-1; i++) | ||||||||
| { | ||||||||
| current_communicator = std::get<0>(communicators[i]); | ||||||||
| current_local_result = broadcast_to(current_communicator, std::move(current_local_result), generation, this_site_arg(0)).get(); | ||||||||
| } | ||||||||
| current_communicator = std::get<0>(communicators[communicators.size()-1]); | ||||||||
| return broadcast_to(current_communicator, std::move(current_local_result), generation, this_site_arg(0)); | ||||||||
| } | ||||||||
| else if (this_site != root_site ) | ||||||||
| { | ||||||||
| return broadcast_from<T>(current_communicator, generation, this_site_arg(current_site)); | ||||||||
| } | ||||||||
| } | ||||||||
| } // namespace hpx::collectives | ||||||||
|
|
||||||||
| //////////////////////////////////////////////////////////////////////////////// | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -100,6 +100,39 @@ namespace hpx { namespace collectives { | |
| num_sites_arg num_sites, this_site_arg this_site, | ||
| generation_arg generation = generation_arg(), | ||
| root_site_arg root_site = root_site_arg()); | ||
|
|
||
| /// Create a new communicator object usable with any collective operation | ||
| /// | ||
| /// This functions creates a new communicator object that can be called in | ||
| /// order to pre-allocate a communicator object usable with multiple | ||
| /// invocations of any of the collective operations (such as \a all_gather, | ||
| /// \a all_reduce, \a all_to_all, \a broadcast, etc.). | ||
| /// | ||
| /// \param basename The base name identifying the collective operation | ||
| /// \param num_sites The number of participating sites (default: all | ||
| /// localities). | ||
| /// \param this_site The sequence number of this invocation (usually | ||
| /// the locality id). This value is optional and | ||
| /// defaults to whatever hpx::get_locality_id() returns. | ||
| /// \param generation The generational counter identifying the sequence | ||
| /// number of the collective operation performed on the | ||
| /// given base name. This is optional and needs to be | ||
| /// supplied only if the collective operation on the | ||
| /// given base name has to be performed more than once. | ||
| /// \param root_site The site that is responsible for creating the | ||
| /// collective support object. This value is optional | ||
| /// and defaults to '0' (zero). | ||
| /// | ||
| /// \returns This function returns a new communicator object usable | ||
| /// with the collective operation. | ||
| /// | ||
| communicator create_hierarchical_communicator(char const* basename, | ||
| num_sites_arg num_sites = num_sites_arg(), | ||
| this_site_arg this_site = this_site_arg(), | ||
| generation_arg generation = generation_arg(), | ||
| root_site_arg root_site = root_site_arg(), | ||
| arity_arg arity = arity_arg()); | ||
|
|
||
| }} | ||
| // clang-format on | ||
|
|
||
|
|
@@ -197,6 +230,14 @@ namespace hpx::collectives { | |
| generation_arg generation = generation_arg(), | ||
| root_site_arg root_site = root_site_arg()); | ||
|
|
||
| HPX_EXPORT std::vector<std::tuple<communicator,int>> create_hierarchical_communicator(char const* basename, | ||
| num_sites_arg num_sites = num_sites_arg(), | ||
| this_site_arg this_site = this_site_arg(), | ||
| generation_arg generation = generation_arg(), | ||
| root_site_arg root_site = root_site_arg(), | ||
| int arity = 4); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please introduce a new |
||
| std::vector<std::tuple<communicator,int>> recursively_fill_communicators(std::vector<std::tuple<communicator,int>> communicators, int left, int right, std::string basename, int arity, int max_depth, int this_site, int num_sites, generation_arg generation); | ||
|
|
||
| } // namespace hpx::collectives | ||
|
|
||
| #endif // !HPX_COMPUTE_DEVICE_CODE | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -408,6 +408,114 @@ namespace hpx::collectives { | |
| this_site, generation, root_site), | ||
| HPX_FORWARD(T, local_result), this_site); | ||
| } | ||
|
|
||
| template <typename T> | ||
| std::vector<T> flatten_vector(communicator fid, std::vector<std::vector<T>>&& dimensional_vector) | ||
| { | ||
| std::vector<std::vector<T>> non_flat_vector = std::move(dimensional_vector); | ||
| std::vector<T> current_local_result; | ||
| size_t totalSize = 0; | ||
| for (const auto& row : non_flat_vector) { | ||
| totalSize += row.size(); | ||
| } | ||
| current_local_result.reserve(totalSize); | ||
| for (auto& row : non_flat_vector) { | ||
| current_local_result.insert(current_local_result.end(), std::make_move_iterator(row.begin()), std::make_move_iterator(row.end())); | ||
| } | ||
| return std::move(current_local_result); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Never use |
||
| } | ||
|
|
||
|
|
||
| template <typename T> | ||
| std::vector<T> gather_here_hierarchically( | ||
| std::vector<std::tuple<communicator,int>> communicators, | ||
| T&& local_result, | ||
| this_site_arg this_site = this_site_arg(), | ||
| generation_arg generation = generation_arg(), | ||
| root_site_arg root_site = root_site_arg(), | ||
| int arity = 2) | ||
| { | ||
| if (this_site == static_cast<std::size_t>(-1)) | ||
| { | ||
| this_site = agas::get_locality_id(); | ||
| } | ||
| /* if (generation == 0) | ||
| { | ||
| return hpx::make_exceptional_future<T>(HPX_GET_EXCEPTION( | ||
| hpx::error::bad_parameter, "hpx::collectives::scatter_to", | ||
| "the generation number shouldn't be zero")); | ||
| } */ | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove code that has been commented out. |
||
|
|
||
| communicator current_communicator = std::get<0>(communicators[0]); | ||
| int current_site = std::get<1>(communicators[0]); | ||
| if (this_site == root_site) | ||
| { | ||
| std::vector<T> current_local_result; | ||
| current_local_result.push_back(std::move(local_result)); | ||
| for (int i = communicators.size()-1; i > 0;i--) | ||
| { | ||
| current_communicator = std::get<0>(communicators[i]); | ||
| std::vector<std::vector<T>> in_between_result = gather_here(current_communicator, std::move(current_local_result), generation, this_site_arg(0)).get(); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do have a synchronous version of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving from |
||
| size_t totalSize = 0; | ||
| for (const auto& row : in_between_result) { | ||
| totalSize += row.size(); | ||
| } | ||
| current_local_result.reserve(totalSize); | ||
| for (auto& row : in_between_result) { | ||
| current_local_result.insert(current_local_result.end(), std::make_move_iterator(row.begin()), std::make_move_iterator(row.end())); | ||
| } | ||
| } | ||
| current_communicator = std::get<0>(communicators[0]); | ||
| hpx::future<std::vector<std::vector<std::decay_t<T>>>> dimensional_vector = gather_here(current_communicator, std::move(current_local_result), generation, this_site_arg(0)); | ||
| return flatten_vector<T>(current_communicator, std::move(dimensional_vector.get())); | ||
| } | ||
| } | ||
|
|
||
|
|
||
|
|
||
| template <typename T> | ||
| hpx::future<void> gather_there_hierarchically( | ||
| std::vector<std::tuple<communicator,int>> communicators, | ||
| T&& local_result, | ||
| this_site_arg this_site = this_site_arg(), | ||
| generation_arg generation = generation_arg(), | ||
| root_site_arg root_site = root_site_arg(), | ||
| int arity = 2) | ||
| { | ||
| if (this_site == static_cast<std::size_t>(-1)) | ||
| { | ||
| this_site = agas::get_locality_id(); | ||
| } | ||
| if (generation == 0) | ||
| { | ||
| return hpx::make_exceptional_future<T>(HPX_GET_EXCEPTION( | ||
| hpx::error::bad_parameter, "hpx::collectives::scatter_to", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function name needs to be corrected. |
||
| "the generation number shouldn't be zero")); | ||
| } | ||
|
|
||
| communicator current_communicator = std::get<0>(communicators[0]); | ||
| int current_site = std::get<1>(communicators[0]); | ||
| if (this_site != root_site) | ||
| { | ||
| std::vector<T> current_local_result; | ||
| current_local_result.push_back(std::move(local_result)); | ||
| for (int i = communicators.size()-1; i > 0;i--) | ||
| { | ||
| current_communicator = std::get<0>(communicators[i]); | ||
| std::vector<std::vector<T>> in_between_result = gather_here(current_communicator, std::move(current_local_result), generation, this_site_arg(0)).get(); | ||
| size_t totalSize = 0; | ||
| for (const auto& row : in_between_result) { | ||
| totalSize += row.size(); | ||
| } | ||
| current_local_result.reserve(totalSize); | ||
| for (auto& row : in_between_result) { | ||
| current_local_result.insert(current_local_result.end(), std::make_move_iterator(row.begin()), std::make_move_iterator(row.end())); | ||
| } | ||
| } | ||
| current_communicator = std::get<0>(communicators[0]); | ||
| return gather_there(current_communicator, std::move(current_local_result), generation, this_site_arg(current_site)); | ||
| } | ||
| } | ||
| } // namespace hpx::collectives | ||
|
|
||
| /////////////////////////////////////////////////////////////////////////////// | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function name needs to be corrected