diff --git a/clang/runtime/dpct-rt/include/dpct/util.hpp b/clang/runtime/dpct-rt/include/dpct/util.hpp index b5b775ba770e..6472a03f1368 100644 --- a/clang/runtime/dpct-rt/include/dpct/util.hpp +++ b/clang/runtime/dpct-rt/include/dpct/util.hpp @@ -532,6 +532,40 @@ T shift_sub_group_left(unsigned int member_mask, throw sycl::exception(sycl::errc::runtime, "Masked version of shift_sub_group_left not " "supported on host device."); #endif // __SYCL_DEVICE_ONLY__ + + constexpr unsigned int MAX_BARRIER_ID = 16; + + sycl::ext::oneapi::experimental::work_group_static + bar_counters; + + void barrier_arrive_aligned(unsigned int barrier_id, + unsigned int thread_count) { + sycl::atomic_ref + count_ref(bar_counters[barrier_id]); + uint32_t _unused = 0; + count_ref.compare_exchange_strong(_unused, thread_count); + --count_ref; + } + + void barrier_sync_aligned(unsigned int barrier_id, + unsigned int thread_count) { + barrier_arrive_aligned(barrier_id, thread_count); + + sycl::atomic_ref + count_ref(bar_counters[barrier_id]); + + auto it = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + auto sg = it.get_sub_group(); + + if (sg.leader()) { + while (count_ref.load(sycl::memory_order::seq_cst) != 0) + ; + } + + sycl::group_barrier(sg); + } } /// Masked version of shift_sub_group_right, which execute masked sub-group