Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ mod unreachable;
#[allow(deprecated)]
pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal};

use std::cell::UnsafeCell;
use std::{cell::UnsafeCell, thread::AccessError};
use std::fmt;
use std::iter::FusedIterator;
use std::mem;
Expand Down Expand Up @@ -219,6 +219,50 @@ impl<T: Send> ThreadLocal<T> {
}
}

/// Returns the element for the current thread, if it exists.
///
/// If the thread key has been destroyed (which may happen if this is called
/// in a destructor), this function will return an [`AccessError`].
pub fn try_get(&self) -> Result<Option<&T>, AccessError> {
let thread = thread_id::try_get()?;
Ok(self.get_inner(thread))
}

/// Returns the element for the current thread, or creates it if it doesn't
/// exist.
///
/// If the thread key has been destroyed (which may happen if this is called
/// in a destructor), this function will return an [`AccessError`].
pub fn try_get_or<F>(&self, create: F) -> Result<&T, AccessError>
where
F: FnOnce() -> T,
{
unsafe {
self.try_get_or_try(|| Ok::<T, ()>(create()))
.map(|r| r.unchecked_unwrap_ok())
}
}

/// Returns the element for the current thread, or creates it if it doesn't
/// exist. If `create` fails, that error is returned and no element is
/// added.
///
/// If the thread key has been destroyed (which may happen if this is called
/// in a destructor), this function will return an [`AccessError`].
pub fn try_get_or_try<F, E>(&self, create: F) -> Result<Result<&T, E>, AccessError>
where
F: FnOnce() -> Result<T, E>,
{
let thread = thread_id::try_get()?;
match self.get_inner(thread) {
Some(x) => Ok(Ok(x)),
None => match create() {
Ok(v) => Ok(Ok(self.insert(thread, v))),
Err(e) => Ok(Err(e)),
},
}
}

fn get_inner(&self, thread: Thread) -> Option<&T> {
let bucket_ptr =
unsafe { self.buckets.get_unchecked(thread.bucket) }.load(Ordering::Acquire);
Expand Down Expand Up @@ -349,6 +393,15 @@ impl<T: Send + Default> ThreadLocal<T> {
pub fn get_or_default(&self) -> &T {
self.get_or(Default::default)
}

/// Returns the element for the current thread, or creates a default one if
/// it doesn't exist.
///
/// If the thread key has been destroyed (which may happen if this is called
/// in a destructor), this function will return an [`AccessError`].
pub fn try_get_or_default(&self) -> Result<&T, AccessError> {
self.try_get_or(Default::default)
}
}

impl<T: Send + fmt::Debug> fmt::Debug for ThreadLocal<T> {
Expand Down Expand Up @@ -584,6 +637,45 @@ mod tests {
assert_eq!(0, *tls.get_or(|| create()));
}

#[test]
fn same_thread_try() {
let create = make_create();
let mut tls = ThreadLocal::new();
assert_eq!(Ok(None), tls.try_get());
assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls));
assert_eq!(Ok(&0), tls.try_get_or(|| create()));
assert_eq!(Ok(Some(&0)), tls.try_get());
assert_eq!(Ok(&0), tls.try_get_or(|| create()));
assert_eq!(Ok(Some(&0)), tls.try_get());
assert_eq!(Ok(&0), tls.try_get_or(|| create()));
assert_eq!(Ok(Some(&0)), tls.try_get());
assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls));
tls.clear();
assert_eq!(Ok(None), tls.try_get());
}

#[test]
fn different_thread_try() {
let create = make_create();
let tls = Arc::new(ThreadLocal::new());
assert_eq!(Ok(None), tls.try_get());
assert_eq!(Ok(&0), tls.try_get_or(|| create()));
assert_eq!(Ok(Some(&0)), tls.try_get());

let tls2 = tls.clone();
let create2 = create.clone();
thread::spawn(move || {
assert_eq!(Ok(None), tls2.try_get());
assert_eq!(Ok(&1), tls2.try_get_or(|| create2()));
assert_eq!(Ok(Some(&1)), tls2.try_get());
})
.join()
.unwrap();

assert_eq!(Ok(Some(&0)), tls.try_get());
assert_eq!(Ok(&0), tls.try_get_or(|| create()));
}

#[test]
fn iter() {
let tls = Arc::new(ThreadLocal::new());
Expand Down
12 changes: 10 additions & 2 deletions src/thread_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

use crate::POINTER_WIDTH;
use once_cell::sync::Lazy;
use std::cmp::Reverse;
use std::{cmp::Reverse, thread::AccessError};
use std::collections::BinaryHeap;
use std::sync::Mutex;
use std::usize;
Expand Down Expand Up @@ -90,7 +90,15 @@ thread_local!(static THREAD_HOLDER: ThreadHolder = ThreadHolder::new());

/// Get the current thread.
pub(crate) fn get() -> Thread {
THREAD_HOLDER.with(|holder| holder.0)
try_get().unwrap()
}

/// Get the current thread.
///
/// If the key has been destroyed (which may happen if this is called
/// in a destructor), this function will return an [`AccessError`].
pub(crate) fn try_get() -> Result<Thread, AccessError> {
THREAD_HOLDER.try_with(|holder| holder.0)
}

#[test]
Expand Down