@@ -303,12 +303,6 @@ impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
303303 Mask ( !locked_set & connected_mask)
304304 }
305305
306- /// Choose the first index that is unlocked with bit `connected`
307- #[ inline]
308- fn next_unlocked ( & self , connected : bool ) -> Option < ConnectionIndex > {
309- self . unlocked_mask ( connected) . next ( )
310- }
311-
312306 async fn acquire ( self : & Arc < Self > , connected : bool ) -> SlotGuard < T > {
313307 // Attempt an unfair acquire first, before we modify the waitlist.
314308 if let Some ( locked) = self . try_acquire ( connected) {
@@ -323,18 +317,34 @@ impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
323317
324318 event_listener:: listener!( event_to_listen => listener) ;
325319
326- // We need to check again after creating the event listener,
327- // because in the meantime, a concurrent task may have seen that there were no listeners
328- // and just unlocked its connection.
329- if let Some ( locked) = self . try_acquire ( connected) {
330- return locked;
320+ let mut listener = pin ! ( listener) ;
321+
322+ loop {
323+ // We need to check again after creating the event listener,
324+ // because in the meantime, a concurrent task may have seen that there were no listeners
325+ // and just unlocked its connection.
326+ match rt:: timeout ( NON_LOCAL_ACQUIRE_DELAY , listener. as_mut ( ) ) . await {
327+ Ok ( slot) => return slot,
328+ Err ( _) => {
329+ if let Some ( slot) = self . try_acquire ( connected) {
330+ return slot;
331+ }
332+ }
333+ }
331334 }
332-
333- listener. await
334335 }
335336
336337 fn try_acquire ( self : & Arc < Self > , connected : bool ) -> Option < SlotGuard < T > > {
337- self . try_lock ( self . next_unlocked ( connected) ?)
338+ // If `locked_set` is constantly changing, don't loop forever.
339+ for index in self . unlocked_mask ( connected) {
340+ if let Some ( slot) = self . try_lock ( index) {
341+ return Some ( slot) ;
342+ }
343+
344+ std:: hint:: spin_loop ( ) ;
345+ }
346+
347+ None
338348 }
339349
340350 fn try_lock ( self : & Arc < Self > , index : ConnectionIndex ) -> Option < SlotGuard < T > > {
@@ -353,7 +363,7 @@ impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
353363 }
354364
355365 fn iter_min_connections ( self : & Arc < Self > ) -> impl Iterator < Item = DisconnectedSlot < T > > + ' _ {
356- ( 0 .. self . connections . len ( ) )
366+ self . unlocked_mask ( false )
357367 . filter_map ( |index| {
358368 let slot = self . try_lock ( index) ?;
359369
@@ -493,7 +503,7 @@ impl<T> DisconnectedSlot<T> {
493503 & self . 0 . shard . leaked_set ,
494504 self . 0 . index ,
495505 true ,
496- Ordering :: Release ,
506+ Ordering :: AcqRel ,
497507 ) ;
498508
499509 self . 0 . shard . leak_event . notify ( usize:: MAX . tag ( self . 0 . index ) ) ;
@@ -627,7 +637,7 @@ impl<T> Drop for SlotGuard<T> {
627637 // but then fail to lock the mutex for it.
628638 drop ( locked) ;
629639
630- atomic_set ( & self . shard . locked_set , self . index , false , Ordering :: Release ) ;
640+ atomic_set ( & self . shard . locked_set , self . index , false , Ordering :: AcqRel ) ;
631641 }
632642}
633643
@@ -737,7 +747,7 @@ impl Iterator for Mask {
737747
738748#[ cfg( test) ]
739749mod tests {
740- use super :: { Params , MAX_SHARD_SIZE } ;
750+ use super :: { Mask , Params , MAX_SHARD_SIZE } ;
741751
742752 #[ test]
743753 fn test_params ( ) {
@@ -762,4 +772,27 @@ mod tests {
762772 }
763773 }
764774 }
775+
776+ #[ test]
777+ fn test_mask ( ) {
778+ let inputs: & [ ( usize , & [ usize ] ) ] = & [
779+ ( 0b0 , & [ ] ) ,
780+ ( 0b1 , & [ 0 ] ) ,
781+ ( 0b11 , & [ 0 , 1 ] ) ,
782+ ( 0b111 , & [ 0 , 1 , 2 ] ) ,
783+ ( 0b1000 , & [ 3 ] ) ,
784+ ( 0b1001 , & [ 0 , 3 ] ) ,
785+ ( 0b1001001 , & [ 0 , 3 , 6 ] ) ,
786+ ] ;
787+
788+ for ( mask, expected_indices) in inputs {
789+ let actual_indices = Mask ( * mask) . collect :: < Vec < _ > > ( ) ;
790+
791+ assert_eq ! (
792+ actual_indices[ ..] ,
793+ expected_indices[ ..] ,
794+ "invalid mask: {mask:b}"
795+ ) ;
796+ }
797+ }
765798}
0 commit comments