@@ -80,6 +80,12 @@ struct __subgroup_radix_sort
80
80
{
81
81
return __dpl_sycl::__local_accessor<_KeyT>(__buf_size, __cgh);
82
82
}
83
+
84
+ inline static constexpr auto
85
+ get_fence ()
86
+ {
87
+ return __dpl_sycl::__fence_space_local;
88
+ }
83
89
};
84
90
85
91
template <typename _KeyT>
@@ -94,6 +100,12 @@ struct __subgroup_radix_sort
94
100
{
95
101
return sycl::accessor (__buf, __cgh, sycl::read_write, __dpl_sycl::__no_init{});
96
102
}
103
+
104
+ inline constexpr static auto
105
+ get_fence ()
106
+ {
107
+ return __dpl_sycl::__fence_space_global;
108
+ }
97
109
};
98
110
99
111
template <typename _ValueT, typename _Wi, typename _Src, typename _Values>
@@ -175,8 +187,9 @@ struct __subgroup_radix_sort
175
187
176
188
// copy(move) values construction
177
189
__block_load<_ValT>(__wi, __src, __values.__v , __n);
190
+ // TODO: check if the barrier can be removed
191
+ __dpl_sycl::__group_barrier (__it, decltype (__buf_val)::get_fence ());
178
192
179
- __dpl_sycl::__group_barrier (__it);
180
193
while (true )
181
194
{
182
195
uint16_t __indices[__block_size]; // indices for indirect access in the "re-order" phase
@@ -205,7 +218,7 @@ struct __subgroup_radix_sort
205
218
__indices[__i] = *__counters[__i];
206
219
*__counters[__i] = __indices[__i] + 1 ;
207
220
}
208
- __dpl_sycl::__group_barrier (__it);
221
+ __dpl_sycl::__group_barrier (__it, decltype (__buf_count):: get_fence () );
209
222
210
223
// 2. scan phase
211
224
{
@@ -218,8 +231,8 @@ struct __subgroup_radix_sort
218
231
_ONEDPL_PRAGMA_UNROLL
219
232
for (uint16_t __i = 1 ; __i < __bin_count; ++__i)
220
233
__bin_sum[__i] = __bin_sum[__i - 1 ] + __counter_lacc[__wi * __bin_count + __i];
234
+ __dpl_sycl::__group_barrier (__it, decltype (__buf_count)::get_fence ());
221
235
222
- __dpl_sycl::__group_barrier (__it);
223
236
// exclusive scan local sum
224
237
uint16_t __sum_scan = __dpl_sycl::__exclusive_scan_over_group (
225
238
__it.get_group (), __bin_sum[__bin_count - 1 ], __dpl_sycl::__plus<uint16_t >());
@@ -230,7 +243,7 @@ struct __subgroup_radix_sort
230
243
231
244
if (__wi == 0 )
232
245
__counter_lacc[0 ] = 0 ;
233
- __dpl_sycl::__group_barrier (__it);
246
+ __dpl_sycl::__group_barrier (__it, decltype (__buf_count):: get_fence () );
234
247
}
235
248
236
249
_ONEDPL_PRAGMA_UNROLL
@@ -244,7 +257,7 @@ struct __subgroup_radix_sort
244
257
__begin_bit += __radix;
245
258
246
259
// 3. "re-order" phase
247
- __dpl_sycl::__group_barrier (__it);
260
+ __dpl_sycl::__group_barrier (__it, decltype (__buf_val):: get_fence () );
248
261
if (__begin_bit >= __end_bit)
249
262
{
250
263
// the last iteration - writing out the result
@@ -268,7 +281,6 @@ struct __subgroup_radix_sort
268
281
if (__idx < __n)
269
282
__exchange_lacc[__idx].~_ValT ();
270
283
}
271
-
272
284
return ;
273
285
}
274
286
@@ -293,8 +305,7 @@ struct __subgroup_radix_sort
293
305
__exchange_lacc[__r] = ::std::move (__values.__v [__i]);
294
306
}
295
307
}
296
-
297
- __dpl_sycl::__group_barrier (__it);
308
+ __dpl_sycl::__group_barrier (__it, decltype (__buf_val)::get_fence ());
298
309
299
310
_ONEDPL_PRAGMA_UNROLL
300
311
for (uint16_t __i = 0 ; __i < __block_size; ++__i)
@@ -303,8 +314,7 @@ struct __subgroup_radix_sort
303
314
if (__idx < __n)
304
315
__values.__v [__i] = ::std::move (__exchange_lacc[__idx]);
305
316
}
306
-
307
- __dpl_sycl::__group_barrier (__it);
317
+ __dpl_sycl::__group_barrier (__it, decltype (__buf_val)::get_fence ());
308
318
}
309
319
}));
310
320
});
0 commit comments