diff --git a/libs/core/algorithms/include/hpx/parallel/algorithms/nth_element.hpp b/libs/core/algorithms/include/hpx/parallel/algorithms/nth_element.hpp index 6ecb9ae73e6..27f9dbf97c6 100644 --- a/libs/core/algorithms/include/hpx/parallel/algorithms/nth_element.hpp +++ b/libs/core/algorithms/include/hpx/parallel/algorithms/nth_element.hpp @@ -136,6 +136,7 @@ namespace hpx { #include #include #include +#include #include #include #include @@ -262,84 +263,149 @@ namespace hpx::parallel { template - static util::detail::algorithm_result_t - parallel(ExPolicy&& policy, RandomIt first, RandomIt nth, Sent last, - Pred&& pred, Proj&& proj) + static decltype(auto) parallel(ExPolicy&& policy, RandomIt first, + RandomIt nth, Sent last, Pred&& pred, Proj&& proj) { using value_type = typename std::iterator_traits::value_type; - RandomIt partition_iter, return_last; + constexpr bool has_scheduler_executor = + hpx::execution_policy_has_scheduler_executor_v; - if (first == last) + if constexpr (has_scheduler_executor) { - return util::detail::algorithm_result::get(HPX_MOVE(first)); - } + namespace ex = hpx::execution::experimental; + return ex::just(first, nth, last) | + ex::then([policy = HPX_FORWARD(ExPolicy, policy), + pred = HPX_FORWARD(Pred, pred), + proj = HPX_FORWARD(Proj, proj)]( + RandomIt first, RandomIt nth, + RandomIt last) mutable -> RandomIt { + auto last_iter = + detail::advance_to_sentinel(first, last); + + while (first != last_iter) + { + detail::pivot9(first, last_iter, pred); + + RandomIt partition_iter = + hpx::parallel::detail::partition() + .sequential( + hpx::execution::seq, first + 1, + last_iter, + [val = HPX_INVOKE(proj, *first), + &pred](value_type const& elem) { + return HPX_INVOKE( + pred, elem, val); + }, + proj); + + --partition_iter; + + // swap first element and partitionIter + // (ending element of first group) +#if defined(HPX_HAVE_CXX20_STD_RANGES_ITER_SWAP) + std::ranges::iter_swap(first, partition_iter); +#else + std::iter_swap(first, partition_iter); +#endif - if (nth == last) - { - return util::detail::algorithm_result::get(HPX_MOVE(nth)); + // if nth element < partitioned index, + // it lies in [first, partitionIter) + if (partition_iter < nth) + { + first = partition_iter + 1; + } + // else it lies in [partitionIter + 1, last) + else if (partition_iter > nth) + { + last_iter = partition_iter; + } + // partitionIter == nth + else + { + break; + } + } + + return last_iter; + }); } - - try + else { - RandomIt last_iter = - detail::advance_to_sentinel(first, last); - return_last = last_iter; + RandomIt partition_iter, return_last; + + if (first == last) + { + return util::detail::algorithm_result::get(HPX_MOVE(first)); + } + + if (nth == last) + { + return util::detail::algorithm_result::get(HPX_MOVE(nth)); + } - while (first != last_iter) + try { - detail::pivot9(first, last_iter, pred); - - partition_iter = - hpx::parallel::detail::partition().call( - policy(hpx::execution::non_task), first + 1, - last_iter, - [val = HPX_INVOKE(proj, *first), &pred]( - value_type const& elem) { - return HPX_INVOKE(pred, elem, val); - }, - proj); - - --partition_iter; - - // swap first element and partitionIter - // (ending element of first group) + RandomIt last_iter = + detail::advance_to_sentinel(first, last); + return_last = last_iter; + + while (first != last_iter) + { + detail::pivot9(first, last_iter, pred); + + partition_iter = + hpx::parallel::detail::partition() + .call( + policy(hpx::execution::non_task), + first + 1, last_iter, + [val = HPX_INVOKE(proj, *first), &pred]( + value_type const& elem) { + return HPX_INVOKE(pred, elem, val); + }, + proj); + + --partition_iter; + + // swap first element and partitionIter + // (ending element of first group) #if defined(HPX_HAVE_CXX20_STD_RANGES_ITER_SWAP) - std::ranges::iter_swap(first, partition_iter); + std::ranges::iter_swap(first, partition_iter); #else - std::iter_swap(first, partition_iter); + std::iter_swap(first, partition_iter); #endif - // if nth element < partitioned index, - // it lies in [first, partitionIter) - if (partition_iter < nth) - { - first = partition_iter + 1; - } - // else it lies in [partitionIter + 1, last) - else if (partition_iter > nth) - { - last_iter = partition_iter; - } - // partitionIter == nth - else - { - break; + // if nth element < partitioned index, + // it lies in [first, partitionIter) + if (partition_iter < nth) + { + first = partition_iter + 1; + } + // else it lies in [partitionIter + 1, last) + else if (partition_iter > nth) + { + last_iter = partition_iter; + } + // partitionIter == nth + else + { + break; + } } } - } - catch (...) - { + catch (...) + { + return util::detail::algorithm_result::get(detail::handle_exception::call(std::current_exception())); + } + return util::detail::algorithm_result::get(detail::handle_exception::call(std::current_exception())); + RandomIt>::get(HPX_MOVE(return_last)); } - - return util::detail::algorithm_result::get( - HPX_MOVE(return_last)); } }; /// \endcond @@ -387,9 +453,9 @@ namespace hpx { > )> // clang-format on - friend parallel::util::detail::algorithm_result_t - tag_fallback_invoke(hpx::nth_element_t, ExPolicy&& policy, - RandomIt first, RandomIt nth, RandomIt last, Pred pred = Pred()) + friend decltype(auto) tag_fallback_invoke(hpx::nth_element_t, + ExPolicy&& policy, RandomIt first, RandomIt nth, RandomIt last, + Pred pred = Pred()) { static_assert(hpx::traits::is_random_access_iterator_v, "Requires at least random iterator."); diff --git a/libs/core/algorithms/tests/unit/algorithms/CMakeLists.txt b/libs/core/algorithms/tests/unit/algorithms/CMakeLists.txt index 559ee830030..cb42d17b039 100644 --- a/libs/core/algorithms/tests/unit/algorithms/CMakeLists.txt +++ b/libs/core/algorithms/tests/unit/algorithms/CMakeLists.txt @@ -196,6 +196,7 @@ if(HPX_WITH_STDEXEC) mismatch_sender mismatch_binary_sender move_sender + nth_element_sender none_of_sender reduce_sender remove_sender diff --git a/libs/core/algorithms/tests/unit/algorithms/nth_element_sender.cpp b/libs/core/algorithms/tests/unit/algorithms/nth_element_sender.cpp new file mode 100644 index 00000000000..b35361689da --- /dev/null +++ b/libs/core/algorithms/tests/unit/algorithms/nth_element_sender.cpp @@ -0,0 +1,117 @@ +// Copyright (c) 2024 Tobias Wukovitsch +// +// SPDX-License-Identifier: BSL-1.0 +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "test_utils.hpp" + +//////////////////////////////////////////////////////////////////////////// +int seed = std::random_device{}(); +std::mt19937 gen(seed); + +constexpr std::size_t SIZE{10007}; + +template +void test_nth_element_sender( + LnPolicy ln_policy, ExPolicy&& ex_policy, IteratorTag) +{ + static_assert(hpx::is_async_execution_policy_v, + "hpx::is_async_execution_policy_v"); + + using base_iterator = std::vector::iterator; + using iterator = test::test_iterator; + + namespace ex = hpx::execution::experimental; + namespace tt = hpx::this_thread::experimental; + using scheduler_t = ex::thread_pool_policy_scheduler; + + std::vector c(SIZE); + std::generate( + std::begin(c), std::end(c), []() { return std::rand() % SIZE; }); + std::vector d = c; + + auto rand_index = std::rand() % SIZE; + + auto exec = ex::explicit_scheduler_executor(scheduler_t(ln_policy)); + + tt::sync_wait( + ex::just(iterator(std::begin(c)), iterator(std::begin(c) + rand_index), + iterator(std::end(c))) | + hpx::nth_element(ex_policy.on(exec))); + + std::nth_element(std::begin(d), std::begin(d) + rand_index, std::end(d)); + + HPX_TEST(*(std::begin(c) + rand_index) == *(std::begin(d) + rand_index)); + + for (size_t k = 0; k < rand_index; k++) + { + HPX_TEST(c[k] <= c[rand_index]); + } + + for (size_t k = rand_index + 1; k < SIZE; k++) + { + HPX_TEST(c[k] >= c[rand_index]); + } +} + +template +void nth_element_sender_test() +{ + using namespace hpx::execution; + test_nth_element_sender(hpx::launch::sync, seq(task), IteratorTag()); + test_nth_element_sender(hpx::launch::sync, unseq(task), IteratorTag()); + + test_nth_element_sender(hpx::launch::async, par(task), IteratorTag()); + test_nth_element_sender(hpx::launch::async, par_unseq(task), IteratorTag()); +} + +int hpx_main(hpx::program_options::variables_map& vm) +{ + unsigned int seed = (unsigned int) std::time(nullptr); + if (vm.count("seed")) + seed = vm["seed"].as(); + + std::cout << "using seed: " << seed << std::endl; + std::srand(seed); + + nth_element_sender_test(); + + return hpx::local::finalize(); +} + +int main(int argc, char* argv[]) +{ + // add command line option which controls the random number generator seed + using namespace hpx::program_options; + options_description desc_commandline( + "Usage: " HPX_APPLICATION_STRING " [options]"); + + desc_commandline.add_options()("seed,s", value(), + "the random number generator seed to use for this run"); + + // By default this test should run on all available cores + std::vector const cfg = {"hpx.os_threads=all"}; + + // Initialize and run HPX + hpx::local::init_params init_args; + init_args.desc_cmdline = desc_commandline; + init_args.cfg = cfg; + + HPX_TEST_EQ_MSG(hpx::local::init(hpx_main, argc, argv, init_args), 0, + "HPX main exited with non-zero status"); + + return hpx::util::report_errors(); +}