Skip to content

Commit b197719

Browse files
SwatinemAmanieu
authored andcommitted
Do not early return on null bucket_ptr
Buckets are allocated on demand based on `Thread::bucket`. This means that when only threads with a high `id` (and thus high `bucket`) are writing entries into the `ThreadLocal`, only higher `buckets` will be allocated, and lower buckets will be `null`. Thus we must not early-return when encounting a `null` bucket.
1 parent b285630 commit b197719

File tree

2 files changed

+48
-26
lines changed

2 files changed

+48
-26
lines changed

src/lib.rs

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ impl<T: Send> Drop for ThreadLocal<T> {
143143
let this_bucket_size = 1 << i;
144144

145145
if bucket_ptr.is_null() {
146-
break;
146+
continue;
147147
}
148148

149149
unsafe { deallocate_bucket(bucket_ptr, this_bucket_size) };
@@ -205,7 +205,7 @@ impl<T: Send> ThreadLocal<T> {
205205
return Ok(val);
206206
}
207207

208-
Ok(self.insert(create()?))
208+
Ok(self.insert(thread, create()?))
209209
}
210210

211211
fn get_inner(&self, thread: Thread) -> Option<&T> {
@@ -226,8 +226,7 @@ impl<T: Send> ThreadLocal<T> {
226226
}
227227

228228
#[cold]
229-
fn insert(&self, data: T) -> &T {
230-
let thread = thread_id::get();
229+
fn insert(&self, thread: Thread, data: T) -> &T {
231230
let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
232231
let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire);
233232

@@ -372,16 +371,14 @@ impl RawIter {
372371
let bucket = unsafe { thread_local.buckets.get_unchecked(self.bucket) };
373372
let bucket = bucket.load(Ordering::Acquire);
374373

375-
if bucket.is_null() {
376-
return None;
377-
}
378-
379-
while self.index < self.bucket_size {
380-
let entry = unsafe { &*bucket.add(self.index) };
381-
self.index += 1;
382-
if entry.present.load(Ordering::Acquire) {
383-
self.yielded += 1;
384-
return Some(unsafe { &*(&*entry.value.get()).as_ptr() });
374+
if !bucket.is_null() {
375+
while self.index < self.bucket_size {
376+
let entry = unsafe { &*bucket.add(self.index) };
377+
self.index += 1;
378+
if entry.present.load(Ordering::Acquire) {
379+
self.yielded += 1;
380+
return Some(unsafe { &*(&*entry.value.get()).as_ptr() });
381+
}
385382
}
386383
}
387384

@@ -401,16 +398,14 @@ impl RawIter {
401398
let bucket = unsafe { thread_local.buckets.get_unchecked_mut(self.bucket) };
402399
let bucket = *bucket.get_mut();
403400

404-
if bucket.is_null() {
405-
return None;
406-
}
407-
408-
while self.index < self.bucket_size {
409-
let entry = unsafe { &mut *bucket.add(self.index) };
410-
self.index += 1;
411-
if *entry.present.get_mut() {
412-
self.yielded += 1;
413-
return Some(entry);
401+
if !bucket.is_null() {
402+
while self.index < self.bucket_size {
403+
let entry = unsafe { &mut *bucket.add(self.index) };
404+
self.index += 1;
405+
if *entry.present.get_mut() {
406+
self.yielded += 1;
407+
return Some(entry);
408+
}
414409
}
415410
}
416411

@@ -525,7 +520,8 @@ unsafe fn deallocate_bucket<T>(bucket: *mut Entry<T>, size: usize) {
525520

526521
#[cfg(test)]
527522
mod tests {
528-
use super::ThreadLocal;
523+
use super::*;
524+
529525
use std::cell::RefCell;
530526
use std::sync::atomic::AtomicUsize;
531527
use std::sync::atomic::Ordering::Relaxed;
@@ -627,6 +623,32 @@ mod tests {
627623
assert_eq!(dropped.load(Relaxed), 1);
628624
}
629625

626+
#[test]
627+
fn test_earlyreturn_buckets() {
628+
struct Dropped(Arc<AtomicUsize>);
629+
impl Drop for Dropped {
630+
fn drop(&mut self) {
631+
self.0.fetch_add(1, Relaxed);
632+
}
633+
}
634+
let dropped = Arc::new(AtomicUsize::new(0));
635+
636+
// We use a high `id` here to guarantee that a lazily allocated bucket somewhere in the middle is used.
637+
// Neither iteration nor `Drop` must early-return on `null` buckets that are used for lower `buckets`.
638+
let thread = Thread::new(1234);
639+
assert!(thread.bucket > 1);
640+
641+
let mut local = ThreadLocal::new();
642+
local.insert(thread, Dropped(dropped.clone()));
643+
644+
let item = local.iter().next().unwrap();
645+
assert_eq!(item.0.load(Relaxed), 0);
646+
let item = local.iter_mut().next().unwrap();
647+
assert_eq!(item.0.load(Relaxed), 0);
648+
drop(local);
649+
assert_eq!(dropped.load(Relaxed), 1);
650+
}
651+
630652
#[test]
631653
fn is_sync() {
632654
fn foo<T: Sync>() {}

src/thread_id.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ pub(crate) struct Thread {
5959
pub(crate) index: usize,
6060
}
6161
impl Thread {
62-
fn new(id: usize) -> Self {
62+
pub(crate) fn new(id: usize) -> Self {
6363
let bucket = usize::from(POINTER_WIDTH) - ((id + 1).leading_zeros() as usize) - 1;
6464
let bucket_size = 1 << bucket;
6565
let index = id - (bucket_size - 1);

0 commit comments

Comments
 (0)