diff --git a/libs/core/algorithms/include/hpx/parallel/algorithms/partial_sort.hpp b/libs/core/algorithms/include/hpx/parallel/algorithms/partial_sort.hpp index e19710ba910..aac2d6ac779 100644 --- a/libs/core/algorithms/include/hpx/parallel/algorithms/partial_sort.hpp +++ b/libs/core/algorithms/include/hpx/parallel/algorithms/partial_sort.hpp @@ -113,6 +113,7 @@ namespace hpx { #include #include #include +#include #include #include #include @@ -531,26 +532,43 @@ namespace hpx::parallel { template - static util::detail::algorithm_result_t parallel( - ExPolicy&& policy, Iter first, Iter middle, Sent last, Comp&& comp, - Proj&& proj) + static decltype(auto) parallel(ExPolicy&& policy, Iter first, + Iter middle, Sent last, Comp&& comp, Proj&& proj) { - using algorithm_result = - util::detail::algorithm_result; + constexpr bool has_scheduler_executor = + hpx::execution_policy_has_scheduler_executor_v; - try + if constexpr (has_scheduler_executor) { - // call the sort routine and return the right type, - // depending on execution policy - return algorithm_result::get(parallel_partial_sort( - HPX_FORWARD(ExPolicy, policy), first, middle, last, - util::compare_projected(comp, proj))); + namespace ex = hpx::execution::experimental; + return ex::just(first, middle, last) | + ex::then([comp = HPX_FORWARD(Comp, comp), + proj = HPX_FORWARD(Proj, proj)]( + Iter first, Iter middle, Iter last) -> Iter { + return sequential_partial_sort(first, middle, last, + util::compare_projected, + std::decay_t>(comp, proj)); + }); } - catch (...) + else { - return algorithm_result::get( - detail::handle_exception::call( - std::current_exception())); + using algorithm_result = + util::detail::algorithm_result; + + try + { + // call the sort routine and return the right type, + // depending on execution policy + return algorithm_result::get(parallel_partial_sort( + HPX_FORWARD(ExPolicy, policy), first, middle, last, + util::compare_projected(comp, proj))); + } + catch (...) + { + return algorithm_result::get( + detail::handle_exception::call( + std::current_exception())); + } } } }; @@ -595,9 +613,9 @@ namespace hpx { > )> // clang-format on - friend parallel::util::detail::algorithm_result_t - tag_fallback_invoke(hpx::partial_sort_t, ExPolicy&& policy, - RandIter first, RandIter middle, RandIter last, Comp comp = Comp()) + friend decltype(auto) tag_fallback_invoke(hpx::partial_sort_t, + ExPolicy&& policy, RandIter first, RandIter middle, RandIter last, + Comp comp = Comp()) { static_assert(hpx::traits::is_random_access_iterator_v, "Requires at least random access iterator."); diff --git a/libs/core/algorithms/tests/unit/algorithms/CMakeLists.txt b/libs/core/algorithms/tests/unit/algorithms/CMakeLists.txt index 559ee830030..22173e38800 100644 --- a/libs/core/algorithms/tests/unit/algorithms/CMakeLists.txt +++ b/libs/core/algorithms/tests/unit/algorithms/CMakeLists.txt @@ -197,6 +197,7 @@ if(HPX_WITH_STDEXEC) mismatch_binary_sender move_sender none_of_sender + partial_sort_sender reduce_sender remove_sender remove_if_sender diff --git a/libs/core/algorithms/tests/unit/algorithms/partial_sort_sender.cpp b/libs/core/algorithms/tests/unit/algorithms/partial_sort_sender.cpp new file mode 100644 index 00000000000..97942665ed9 --- /dev/null +++ b/libs/core/algorithms/tests/unit/algorithms/partial_sort_sender.cpp @@ -0,0 +1,117 @@ +// Copyright (c) 2024 Tobias Wukovitsch +// +// SPDX-License-Identifier: BSL-1.0 +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "test_utils.hpp" + +//////////////////////////////////////////////////////////////////////////// +unsigned int seed = std::random_device{}(); +std::mt19937 gen(seed); +constexpr std::uint64_t SIZE{1007}; + +template +void test_partial_sort_sender( + LnPolicy ln_policy, ExPolicy&& ex_policy, IteratorTag) +{ + static_assert(hpx::is_async_execution_policy_v, + "hpx::is_async_execution_policy_v"); + + using compare_t = std::less; + using base_iterator = std::vector::iterator; + using iterator = test::test_iterator; + + namespace ex = hpx::execution::experimental; + namespace tt = hpx::this_thread::experimental; + using scheduler_t = ex::thread_pool_policy_scheduler; + + std::vector A, B; + A.reserve(SIZE); + B.reserve(SIZE); + + for (std::uint64_t i = 0; i < SIZE; ++i) + { + A.emplace_back(i); + } + std::shuffle(A.begin(), A.end(), gen); + + for (std::uint64_t i = 1; i < SIZE; ++i) + { + B = A; + + auto exec = ex::explicit_scheduler_executor(scheduler_t(ln_policy)); + + tt::sync_wait( + ex::just(iterator(std::begin(B)), iterator(std::begin(B) + i), + iterator(std::end(B)), compare_t{}) | + hpx::partial_sort(ex_policy.on(exec))); + + for (std::uint64_t j = 0; j < i; ++j) + { + HPX_TEST(B[j] == j); + } + } +} + +template +void partial_sort_sender_test() +{ + using namespace hpx::execution; + test_partial_sort_sender(hpx::launch::sync, seq(task), IteratorTag()); + test_partial_sort_sender(hpx::launch::sync, unseq(task), IteratorTag()); + + test_partial_sort_sender(hpx::launch::async, par(task), IteratorTag()); + test_partial_sort_sender( + hpx::launch::async, par_unseq(task), IteratorTag()); +} + +int hpx_main(hpx::program_options::variables_map& vm) +{ + unsigned int seed = (unsigned int) std::time(nullptr); + if (vm.count("seed")) + seed = vm["seed"].as(); + + std::cout << "using seed: " << seed << std::endl; + std::srand(seed); + + partial_sort_sender_test(); + + return hpx::local::finalize(); +} + +int main(int argc, char* argv[]) +{ + // add command line option which controls the random number generator seed + using namespace hpx::program_options; + options_description desc_commandline( + "Usage: " HPX_APPLICATION_STRING " [options]"); + + desc_commandline.add_options()("seed,s", value(), + "the random number generator seed to use for this run"); + + // By default this test should run on all available cores + std::vector const cfg = {"hpx.os_threads=all"}; + + // Initialize and run HPX + hpx::local::init_params init_args; + init_args.desc_cmdline = desc_commandline; + init_args.cfg = cfg; + + HPX_TEST_EQ_MSG(hpx::local::init(hpx_main, argc, argv, init_args), 0, + "HPX main exited with non-zero status"); + + return hpx::util::report_errors(); +}