Skip to content

Commit 07b1bf5

Browse files
committed
Improve can_consume check
1 parent 87cee67 commit 07b1bf5

File tree

1 file changed

+130
-104
lines changed

1 file changed

+130
-104
lines changed

bwosqueue/src/lib.rs

Lines changed: 130 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,17 @@ pub struct Stealer<E, const NUM_BLOCKS: usize, const ENTRIES_PER_BLOCK: usize> {
8787
queue: Pin<Arc<BwsQueue<E, NUM_BLOCKS, ENTRIES_PER_BLOCK>>>,
8888
}
8989

90-
/// An iterator over elements of one Block
90+
/// An iterator over elements of one Block.
91+
///
92+
/// The iterator borrows all elements up to `committed` to allows batched
93+
/// operations on the elements. When the iterator is dropped the entries
94+
/// are marked as consumed in one atomic operation.
9195
pub struct BlockIter<'a, E, const ENTRIES_PER_BLOCK: usize> {
9296
buffer: &'a [UnsafeCell<MaybeUninit<E>>; ENTRIES_PER_BLOCK],
97+
/// Index if the next to be consumed entry in the buffer.
9398
i: usize,
99+
/// Number of committed entries in the buffer.
100+
committed: usize,
94101
}
95102

96103
/// An iterator over elements of one Block of a stealer
@@ -263,107 +270,61 @@ impl<E, const NUM_BLOCKS: usize, const ENTRIES_PER_BLOCK: usize>
263270
/// Try to dequeue the oldest element in the queue.
264271
#[inline(always)]
265272
pub fn dequeue(&mut self) -> Option<E> {
266-
loop {
267-
// SAFETY: `ccache` always points to a valid `Block` in the queue. We never create a mutable reference
268-
// to a Block, so it is safe to construct a shared reference here.
269-
let blk = unsafe { &**self.ccache };
270273

271-
// check if the block is fully consumed already
272-
let consumed = blk.consumed.load(Relaxed);
273-
let consumed_idx = consumed.raw_index();
274+
let (blk, consumed_opt) = self.next_dequeable_entry();
275+
*self.ccache = blk;
276+
let consumed = consumed_opt?;
274277

275-
// Fastpath (Block is not fully consumed yet)
276-
if let Some(entry_cell) = blk.entries.get(consumed_idx) {
277-
// we know the block is not full, but most first check if there is an entry to
278-
// dequeue.
279-
let committed_idx = blk.committed.load(Relaxed).raw_index();
280-
if consumed_idx == committed_idx {
281-
return None;
282-
}
278+
// To assign to `ccache` we need to end the borrow of the previous `blk` variable.
279+
let blk = unsafe { &**self.ccache };
283280

284-
/* There is an entry to dequeue */
281+
// We trust that the correct index is passed to us here.
282+
let entry_cell = &blk.entries[consumed.raw_index()];
283+
// SAFETY: We know there is an entry to dequeue, so we know the entry is a valid initialized `E`.
284+
let item = unsafe { entry_cell.with(|entry| entry.read().assume_init()) };
285+
// SAFETY: We already checked that `consumed_idx < ENTRIES_PER_BLOCK`.
286+
let new_consumed = unsafe { consumed.index_add_unchecked(1) };
287+
blk.consumed.store(new_consumed, Relaxed);
288+
#[cfg(feature = "stats")]
289+
self.queue.stats.increment_dequeued(1);
290+
return Some(item);
285291

286-
// SAFETY: We know there is an entry to dequeue, so we know the entry is a valid initialized `E`.
287-
let item = unsafe { entry_cell.with(|entry| entry.read().assume_init()) };
288-
// SAFETY: We already checked that `consumed_idx < ENTRIES_PER_BLOCK`.
289-
let new_consumed = unsafe { consumed.index_add_unchecked(1) };
290-
blk.consumed.store(new_consumed, Relaxed);
291-
#[cfg(feature = "stats")]
292-
self.queue.stats.increment_dequeued(1);
293-
return Some(item);
294-
}
295-
296-
/* Slow-path */
297-
298-
/* Consumer head may never pass the Producer head and Consumer/Stealer tail */
299-
let nblk = unsafe { &*blk.next() };
300-
if self.try_advance_consumer_block(nblk, consumed).is_err() {
301-
return None;
302-
}
303-
/* We advanced to the next block - loop around and try again */
304-
}
305292
}
306293

307-
/// Tru to dequeue a whole block
294+
/// Try to dequeue all remaining committed entries in the current block.
308295
pub fn dequeue_block(&mut self) -> Option<BlockIter<'_, E, ENTRIES_PER_BLOCK>> {
309-
loop {
310-
// SAFETY: `ccache` always points to a valid `Block` in the queue. We never create a mutable reference
311-
// to a Block, so it is safe to construct a shared reference here.
312-
let blk = unsafe { &**self.ccache };
313-
314-
// check if the block is fully consumed already
315-
let consumed = blk.consumed.load(Relaxed);
316-
let consumed_idx = consumed.raw_index();
317-
318-
if consumed_idx < ENTRIES_PER_BLOCK {
319-
// for now just return none. We could also return consumed_idx..committed_idx
320-
if !(blk.committed.load(Relaxed).index().is_full()) {
321-
return None;
322-
}
323-
324-
// We are claiming the tasks **before** reading them out of the buffer.
325-
// This is safe because only the **current** thread is able to push new
326-
// tasks.
327-
//
328-
// There isn't really any need for memory ordering... Relaxed would
329-
// work. This is because all tasks are pushed into the queue from the
330-
// current thread (or memory has been acquired if the local queue handle
331-
// moved).
332-
let new_consumed = consumed.set_full();
333-
blk.consumed.store(new_consumed, Relaxed);
334-
#[cfg(feature = "stats")]
335-
self.queue
336-
.stats
337-
.increment_dequeued(new_consumed.raw_index() - consumed_idx);
338-
339-
// Pre-advance ccache for the next time
340-
let nblk = unsafe { &*blk.next() };
341-
// We don't care if this fails. The consumer can try again next time.
342-
let _ = self.try_advance_consumer_block(nblk, new_consumed);
343-
344-
return Some(BlockIter {
345-
buffer: &blk.entries,
346-
i: consumed_idx,
347-
});
348-
}
349-
350-
/* Slow-path */
351-
352-
/* Consumer head may never pass the Producer head and Consumer/Stealer tail */
353-
let nblk = unsafe { &*blk.next() };
354-
if self.try_advance_consumer_block(nblk, consumed).is_err() {
355-
return None;
356-
}
357-
358-
/* We advanced to the next block - loop around and try again */
359-
}
296+
let (blk, consumed_opt) = self.next_dequeable_entry();
297+
*self.ccache = blk;
298+
let consumed = consumed_opt?;
299+
300+
// To assign to `ccache` we need to end the borrow of the previous `blk` variable.
301+
let blk = unsafe { &**self.ccache };
302+
303+
let committed = blk.committed.load(Relaxed);
304+
305+
// We are claiming the tasks **before** reading them out of the buffer.
306+
// This is safe because only the **current** thread is able to push new
307+
// tasks.
308+
//
309+
// There isn't really any need for memory ordering... Relaxed would
310+
// work. This is because all tasks are pushed into the queue from the
311+
// current thread (or memory has been acquired if the local queue handle
312+
// moved).
313+
blk.consumed.store(committed, Relaxed);
314+
315+
return Some(BlockIter {
316+
buffer: &blk.entries,
317+
i: consumed.raw_index(),
318+
committed: committed.raw_index()
319+
});
360320
}
361-
/// Advance consumer to the next block, unless the producer has not reached the block yet.
362-
fn try_advance_consumer_block(
363-
&mut self,
321+
322+
/// Advance consumer to the next block, unless the producer has not reached the block yet.
323+
fn can_advance_consumer_block(
324+
&self,
364325
next_block: &Block<E, ENTRIES_PER_BLOCK>,
365326
curr_consumed: IndexAndVersion<ENTRIES_PER_BLOCK>,
366-
) -> Result<(), ()> {
327+
) -> bool {
367328
let next_cons_vsn = curr_consumed
368329
.version()
369330
.wrapping_add(next_block.is_head() as usize);
@@ -373,7 +334,7 @@ impl<E, const NUM_BLOCKS: usize, const ENTRIES_PER_BLOCK: usize>
373334
// and we must wait.
374335
let next_reserved_vsn = next_block.reserved.load(Relaxed).version();
375336
if next_reserved_vsn != next_cons_vsn {
376-
return Err(());
337+
return false;
377338
}
378339

379340
/* stop stealers */
@@ -394,10 +355,24 @@ impl<E, const NUM_BLOCKS: usize, const ENTRIES_PER_BLOCK: usize>
394355
/* advance the block and try again */
395356
// The consumer must skip already reserved entries.
396357
next_block.consumed.store(reserved_old, Relaxed);
397-
*self.ccache = next_block;
398-
Ok(())
358+
true
399359
}
400360

361+
// /// Advance consumer to the next block, unless the producer has not reached the block yet.
362+
// fn try_advance_consumer_block(
363+
// &mut self,
364+
// next_block: &Block<E, ENTRIES_PER_BLOCK>,
365+
// curr_consumed: IndexAndVersion<ENTRIES_PER_BLOCK>,
366+
// ) -> Result<(), ()> {
367+
// if self.can_advance_consumer_block(next_block, curr_consumed) {
368+
// *self.ccache = next_block;
369+
// Ok(())
370+
// } else {
371+
// Err(())
372+
// }
373+
// }
374+
375+
/// Todo: Ideally we would not have this function.
401376
pub fn has_stealers(&self) -> bool {
402377
let curr_spos = self.spos.load(Relaxed);
403378
// spos increments beyond NUM_BLOCKS to prevent ABA problems.
@@ -419,6 +394,8 @@ impl<E, const NUM_BLOCKS: usize, const ENTRIES_PER_BLOCK: usize>
419394
/// Check if there is a block available for stealing in the queue.
420395
///
421396
/// Note that stealing may still fail for a number of reasons even if this function returned true
397+
/// Todo: the overhead could be reduced, if we allow this function to return false in some
398+
/// cases when the queue size is low.
422399
#[cfg(feature = "stats")]
423400
pub fn has_stealable_block(&self) -> bool {
424401
let n = self.queue.stats.curr_enqueued();
@@ -430,12 +407,57 @@ impl<E, const NUM_BLOCKS: usize, const ENTRIES_PER_BLOCK: usize>
430407
n > (committed_idx - consumed_idx)
431408
}
432409

433-
/// Check if there are items in the queue available for the consumer.
410+
/// `true` if there is at least one entry that can be dequeued.
434411
///
435-
/// This function may sporadically provide a wrong result.
436-
#[cfg(feature = "stats")]
412+
/// This function is accurate and not racy, since stealers can not steal from the same block
413+
/// as the consumer.
437414
pub fn can_consume(&self) -> bool {
438-
self.queue.stats.curr_enqueued() > 0
415+
self.next_dequeable_entry().1.is_some()
416+
}
417+
418+
/// Find the next dequeable entry without advancing the consumer cache.
419+
///
420+
/// To allow this function to be usable without mutable access to self, it is left to the caller
421+
/// to update `self.ccache` with the returned block.
422+
/// If a dequable value is found, the second tuple element will be some and contain the index
423+
/// of the element in the block.
424+
fn next_dequeable_entry(&self) -> (&Block<E, ENTRIES_PER_BLOCK>, Option<IndexAndVersion<ENTRIES_PER_BLOCK>>) {
425+
// SAFETY: `ccache` always points to a valid `Block` in the queue. We never create a mutable reference
426+
// to a Block, so it is safe to construct a shared reference here.
427+
let current_blk_cache = unsafe { &**self.ccache };
428+
let mut blk = current_blk_cache;
429+
for _ in 0..NUM_BLOCKS {
430+
// check if the block is fully consumed already
431+
let consumed = blk.consumed.load(Relaxed);
432+
let consumed_idx = consumed.raw_index();
433+
434+
// Fastpath (Block is not fully consumed yet)
435+
if consumed_idx < ENTRIES_PER_BLOCK {
436+
// we know the block is not full, but we must first check if there is an entry to
437+
// dequeue.
438+
let committed_idx = blk.committed.load(Relaxed).raw_index();
439+
if consumed_idx == committed_idx {
440+
return (blk, None);
441+
}
442+
443+
/* There is an entry to dequeue */
444+
return (blk, Some(consumed));
445+
}
446+
447+
/* Slow-path */
448+
449+
/* Consumer head may never pass the Producer head and Consumer/Stealer tail */
450+
let nblk = unsafe { &*blk.next() };
451+
if self.can_advance_consumer_block(nblk, consumed) {
452+
blk = nblk;
453+
} else {
454+
return (blk, None);
455+
}
456+
/* We advanced to the next block - loop around and try again */
457+
}
458+
// Since there is no concurrent enqueuing and the buffer is bounded, we should reach
459+
// one of the exit conditions in at most NUM_BLOCKS iterations.
460+
unreachable!()
439461
}
440462
}
441463

@@ -604,12 +626,16 @@ impl<'a, E, const ENTRIES_PER_BLOCK: usize> Iterator for BlockIter<'a, E, ENTRIE
604626
fn next(&mut self) -> Option<E> {
605627
let i = self.i;
606628
self.i += 1;
607-
self.buffer.get(i).map(|entry_cell| {
608-
entry_cell.with(|entry| {
609-
// SAFETY: we claimed the entries
610-
unsafe { entry.read().assume_init() }
629+
if i < self.committed {
630+
self.buffer.get(i).map(|entry_cell| {
631+
entry_cell.with(|entry| {
632+
// SAFETY: we claimed the entries
633+
unsafe { entry.read().assume_init() }
634+
})
611635
})
612-
})
636+
} else {
637+
None
638+
}
613639
}
614640
}
615641

0 commit comments

Comments
 (0)