diff --git a/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h b/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h index 77a4875ccde..94fb8455ec8 100644 --- a/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h +++ b/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h @@ -48,6 +48,53 @@ namespace dpl namespace ranges { +namespace __internal +{ + +template +concept __is_subscriptable = requires(_T&& __a) { __a[0]; }; + +template +concept __is_not_subscriptable = !__is_subscriptable<_T>; + +template +concept __is_sizeable = requires(_T&& __a) { __a.size(); }; + +template +concept __is_empty_method = requires(_T&& __a) { __a.empty(); }; + +template +struct _WrapperRAR: public _R +{ + template + _WrapperRAR(_Base&& __r): _R(std::forward<_Base>(__r)) {} + decltype(auto) operator[](auto __i) { return _R::begin()[__i]; } + decltype(auto) operator[](auto __i) const { return _R::begin()[__i]; } + + std::enable_if_t, std::ranges::range_size_t<_R>> + size() const { return this->_R::end() - this->_R::begin(); } + + std::enable_if_t, bool> + empty() const { return this->_R::end() - this->_R::begin() <= 0; } +}; + +template <__is_not_subscriptable _R> +constexpr decltype (auto) +__get_r(_R&& __r) +{ + using _T = std::remove_reference_t<_R>; + return _WrapperRAR<_T>(std::forward<_R>(__r)); +} + +template <__is_subscriptable _R> +constexpr decltype (auto) +__get_r(_R&& __r) +{ + return std::forward<_R>(__r); +} + +} //__internal + // [alg.foreach] namespace __internal @@ -64,7 +111,7 @@ struct __for_each_fn { const auto __dispatch_tag = oneapi::dpl::__ranges::__select_backend(__exec); oneapi::dpl::__internal::__ranges::__pattern_for_each( - __dispatch_tag, std::forward<_ExecutionPolicy>(__exec), __r, __f, __proj); + __dispatch_tag, std::forward<_ExecutionPolicy>(__exec), __get_r(__r), __f, __proj); return {std::ranges::begin(__r) + std::ranges::size(__r)}; } @@ -149,8 +196,11 @@ struct __find_if_fn operator()(_ExecutionPolicy&& __exec, _R&& __r, _Pred __pred, _Proj __proj = {}) const { const auto __dispatch_tag = oneapi::dpl::__ranges::__select_backend(__exec); - return oneapi::dpl::__internal::__ranges::__pattern_find_if(__dispatch_tag, - std::forward<_ExecutionPolicy>(__exec), std::forward<_R>(__r), __pred, __proj); + + auto __ra = __get_r(__r); + auto __res = oneapi::dpl::__internal::__ranges::__pattern_find_if(__dispatch_tag, + std::forward<_ExecutionPolicy>(__exec), __ra, __pred, __proj) - __ra.begin(); + return __r.begin() + __res; } }; //__find_if_fn } //__internal @@ -213,7 +263,7 @@ struct __any_of_fn { const auto __dispatch_tag = oneapi::dpl::__ranges::__select_backend(__exec); return oneapi::dpl::__internal::__ranges::__pattern_any_of(__dispatch_tag, - std::forward<_ExecutionPolicy>(__exec), std::forward<_R>(__r), __pred, __proj); + std::forward<_ExecutionPolicy>(__exec), __get_r(std::forward<_R>(__r)), __pred, __proj); } }; //__any_of_fn } //__internal diff --git a/test/parallel_api/ranges/std_ranges_for_each.pass.cpp b/test/parallel_api/ranges/std_ranges_for_each.pass.cpp index 91f7f17c791..a998d365896 100644 --- a/test/parallel_api/ranges/std_ranges_for_each.pass.cpp +++ b/test/parallel_api/ranges/std_ranges_for_each.pass.cpp @@ -32,6 +32,7 @@ main() test_range_algo<1>{}(dpl_ranges::for_each, for_each_checker, f_mutuable, proj_mutuable); test_range_algo<2, P2>{}(dpl_ranges::for_each, for_each_checker, f_mutuable, &P2::x); test_range_algo<3, P2>{}(dpl_ranges::for_each, for_each_checker, f_mutuable, &P2::proj); + #endif //_ENABLE_STD_RANGES_TESTING return TestUtils::done(_ENABLE_STD_RANGES_TESTING); diff --git a/test/parallel_api/ranges/std_ranges_test.h b/test/parallel_api/ranges/std_ranges_test.h index b022dea294c..9f5b35acedc 100644 --- a/test/parallel_api/ranges/std_ranges_test.h +++ b/test/parallel_api/ranges/std_ranges_test.h @@ -144,6 +144,28 @@ template static constexpr bool is_range().begin())>> = true; +//a random access range, but without operator[] and size() method +template +struct RangeRA +{ + RangeRA(R r): m_r(r) {} + R m_r; + auto begin() { return m_r.begin(); } + auto end() { return m_r.end(); } + auto begin() const { return m_r.begin(); } + auto end() const { return m_r.end(); } +}; + +struct __range_ra_fn +{ + template + RangeRA + operator()(R r) + { + return RangeRA(r); + } +} __range_ra_wr; + template struct test { @@ -536,6 +558,7 @@ struct test_range_algo test, mode>{max_n}(host_policies(), algo, checker, subrange_view, std::identity{}, args...); test, mode>{max_n}(host_policies(), algo, checker, std::views::all, std::identity{}, args...); test, mode>{max_n}(host_policies(), algo, checker, std::views::all, std::identity{}, args...); + test, mode>{max_n}(host_policies(), algo, checker, __range_ra_wr, __range_ra_wr, args...); #if TEST_CPP20_SPAN_PRESENT test, mode>{max_n}(host_policies(), algo, checker, span_view, std::identity{}, args...); test, mode>{max_n}(host_policies(), algo, checker, std::views::all, std::identity{}, args...); @@ -551,6 +574,7 @@ struct test_range_algo { test, mode>{max_n}(dpcpp_policy(), algo, checker, subrange_view, subrange_view, args...); test, mode>{max_n}(dpcpp_policy(), algo, checker, std::identity{}, std::identity{}, args...); + test, mode>{max_n}(dpcpp_policy(), algo, checker, __range_ra_wr, __range_ra_wr, args...); #if TEST_CPP20_SPAN_PRESENT test, mode>{max_n}(dpcpp_policy(), algo, checker, span_view, subrange_view, args...); test, mode>{max_n}(dpcpp_policy(), algo, checker, std::identity{}, std::identity{}, args...);