Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## [Unreleased]

### Changed
- Add a separate type `AsyncReceiver` that implements `Future` instead of implementing it
directly on the `Receiver` type. Now the `Receiver` implements `IntoFuture` instead.
This is a breaking change. This change removes the possible panics in many recv* methods,
and it simplifies some code a bit.


## [0.1.10] - 2025-02-04
### Added
Expand Down
121 changes: 89 additions & 32 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
//! The sender's send method is non-blocking, and potentially lock- and wait-free.
//! See documentation on [Sender::send] for situations where it might not be fully wait-free.
//! The receiver supports both lock- and wait-free `try_recv` as well as indefinite and time
//! limited thread blocking receive operations. The receiver also implements `Future` and
//! limited thread blocking receive operations. The receiver also implements `IntoFuture` and
//! supports asynchronously awaiting the message.
//!
//!
Expand Down Expand Up @@ -83,10 +83,10 @@
//! that should work smoothly between the sync and async parts of the program!
//!
//! This library achieves that by having a fast and cheap send operation that can
//! be used in both sync threads and async tasks. The receiver has both thread blocking
//! receive methods for synchronous usage, and implements `Future` for asynchronous usage.
//! be used in both regular threads and async tasks. The receiver has both thread blocking
//! receive methods for synchronous usage, and implements `IntoFuture` for asynchronous usage.
//!
//! The receiving endpoint of this channel implements Rust's `Future` trait and can be waited on
//! The receiving endpoint of this channel implements Rust's `IntoFuture` trait and can be waited on
//! in an asynchronous task. This implementation is completely executor/runtime agnostic. It should
//! be possible to use this library with any executor, or even pass messages between tasks running
//! in different executors.
Expand Down Expand Up @@ -203,6 +203,21 @@ pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
)
}

/// Ergonomic shorthand for creating a channel and immediately convert the [`Receiver`] into
/// a future.
///
/// This can be useful when you need to pass the receiver to a function that expects a
/// type implementing [`Future`](core::future::Future) directly. Using this function is not necessary when
/// you are going to use `.await` on the receiver, as that will automatically call
/// [`IntoFuture::into_future`](core::future::IntoFuture::into_future) in the background.
#[cfg(feature = "async")]
#[inline(always)]
pub fn async_channel<T>() -> (Sender<T>, AsyncReceiver<T>) {
let (sender, receiver) = channel();
let async_receiver = core::future::IntoFuture::into_future(receiver);
(sender, async_receiver)
}

/// Sending end of a oneshot channel.
///
/// Created and returned from the [`channel`] function.
Expand Down Expand Up @@ -246,6 +261,20 @@ pub struct Receiver<T> {
channel_ptr: NonNull<Channel<T>>,
}

/// A version of [`Receiver`] that implements [`Future`](core::future::Future), for awaiting the
/// message in an async context.
///
/// This type is automatically created and polled in the background when awaiting a [`Receiver`].
/// But it can also be created explicitly with the [`async_channel`] function or by calling
/// [`IntoFuture::into_future`](core::future::IntoFuture::into_future) on the [`Receiver`].
#[cfg(feature = "async")]
#[derive(Debug)]
pub struct AsyncReceiver<T> {
// Covariance is the right choice here. Consider the example presented in Sender, and you'll
// see that if we replaced `rx` instead then we would get the expected behavior
channel_ptr: NonNull<Channel<T>>,
}

unsafe impl<T: Send> Send for Sender<T> {}

// SAFETY: The only methods that assumes there is only a single reference to the sender
Expand All @@ -254,7 +283,11 @@ unsafe impl<T: Send> Send for Sender<T> {}
unsafe impl<T: Sync> Sync for Sender<T> {}

unsafe impl<T: Send> Send for Receiver<T> {}
impl<T> Unpin for Receiver<T> {}

#[cfg(feature = "async")]
unsafe impl<T: Send> Send for AsyncReceiver<T> {}
#[cfg(feature = "async")]
impl<T> Unpin for AsyncReceiver<T> {}

impl<T> Sender<T> {
/// Sends `message` over the channel to the corresponding [`Receiver`].
Expand Down Expand Up @@ -484,10 +517,6 @@ impl<T> Receiver<T> {
///
/// If a sent message has already been extracted from this channel this method will return an
/// error.
///
/// # Panics
///
/// Panics if called after this receiver has been polled asynchronously.
#[cfg(feature = "std")]
pub fn recv(self) -> Result<T, RecvError> {
// Note that we don't need to worry about changing the state to disconnected or setting the
Expand Down Expand Up @@ -617,9 +646,6 @@ impl<T> Receiver<T> {

Err(RecvError)
}
// The receiver must have been `Future::poll`ed prior to this call.
#[cfg(feature = "async")]
RECEIVING | UNPARKING => panic!("{}", RECEIVER_USED_SYNC_AND_ASYNC_ERROR),
_ => unreachable!(),
}
}
Expand All @@ -630,10 +656,6 @@ impl<T> Receiver<T> {
///
/// If a message is returned, the channel is disconnected and any subsequent receive operation
/// using this receiver will return an error.
///
/// # Panics
///
/// Panics if called after this receiver has been polled asynchronously.
#[cfg(feature = "std")]
pub fn recv_ref(&self) -> Result<T, RecvError> {
self.start_recv_ref(RecvError, |channel| {
Expand Down Expand Up @@ -673,10 +695,6 @@ impl<T> Receiver<T> {
///
/// If the supplied `timeout` is so large that Rust's `Instant` type can't represent this point
/// in the future this falls back to an indefinitely blocking receive operation.
///
/// # Panics
///
/// Panics if called after this receiver has been polled asynchronously.
#[cfg(feature = "std")]
pub fn recv_timeout(&self, timeout: Duration) -> Result<T, RecvTimeoutError> {
match Instant::now().checked_add(timeout) {
Expand All @@ -693,10 +711,6 @@ impl<T> Receiver<T> {
///
/// If a message is returned, the channel is disconnected and any subsequent receive operation
/// using this receiver will return an error.
///
/// # Panics
///
/// Panics if called after this receiver has been polled asynchronously.
#[cfg(feature = "std")]
pub fn recv_deadline(&self, deadline: Instant) -> Result<T, RecvTimeoutError> {
/// # Safety
Expand Down Expand Up @@ -912,9 +926,6 @@ impl<T> Receiver<T> {
}
// The sender was dropped before sending anything, or we already received the message.
DISCONNECTED => Err(disconnected_error),
// The receiver must have been `Future::poll`ed prior to this call.
#[cfg(feature = "async")]
RECEIVING | UNPARKING => panic!("{}", RECEIVER_USED_SYNC_AND_ASYNC_ERROR),
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -945,7 +956,23 @@ impl<T> Receiver<T> {
}

#[cfg(feature = "async")]
impl<T> core::future::Future for Receiver<T> {
impl<T> core::future::IntoFuture for Receiver<T> {
type Output = Result<T, RecvError>;
type IntoFuture = AsyncReceiver<T>;

#[inline(always)]
fn into_future(self) -> Self::IntoFuture {
let Receiver { channel_ptr } = self;

// Don't run our Drop implementation, since the receiver lives on as an async receiver.
mem::forget(self);

AsyncReceiver { channel_ptr }
}
}

#[cfg(feature = "async")]
impl<T> core::future::Future for AsyncReceiver<T> {
type Output = Result<T, RecvError>;

fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
Expand Down Expand Up @@ -1072,6 +1099,40 @@ impl<T> Drop for Receiver<T> {
}
}

#[cfg(feature = "async")]
impl<T> Drop for AsyncReceiver<T> {
fn drop(&mut self) {
// SAFETY: since the receiving side is still alive the sender would have observed that and
// left deallocating the channel allocation to us.
let channel = unsafe { self.channel_ptr.as_ref() };

// Set the channel state to disconnected and read what state the receiver was in
match channel.state.swap(DISCONNECTED, Acquire) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AsyncReceiver can observe UNPARKING state here?

// The sender has not sent anything, nor is it dropped.
EMPTY => (),
// The sender already sent something. We must drop it, and free the channel.
MESSAGE => {
// SAFETY: we are in the message state so the message is initialized
unsafe { channel.drop_message() };

// SAFETY: see safety comment at top of function
unsafe { dealloc(self.channel_ptr) };
}
// The receiver has been polled.
RECEIVING => {
// TODO: figure this out when async is fixed
unsafe { channel.drop_waker() };
}
// The sender was already dropped. We are responsible for freeing the channel.
DISCONNECTED => {
// SAFETY: see safety comment at top of function
unsafe { dealloc(self.channel_ptr) };
}
_ => unreachable!(),
}
}
}

/// All the values that the `Channel::state` field can have during the lifetime of a channel.
mod states {
// These values are very explicitly chosen so that we can replace some cmpxchg calls with
Expand Down Expand Up @@ -1310,10 +1371,6 @@ fn receiver_waker_size() {
assert_eq!(mem::size_of::<ReceiverWaker>(), expected);
}

#[cfg(all(feature = "std", feature = "async"))]
const RECEIVER_USED_SYNC_AND_ASYNC_ERROR: &str =
"Invalid to call a blocking receive method on oneshot::Receiver after it has been polled";

#[inline]
pub(crate) unsafe fn dealloc<T>(channel: NonNull<Channel<T>>) {
drop(Box::from_raw(channel.as_ptr()))
Expand Down
30 changes: 1 addition & 29 deletions tests/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,37 +88,9 @@ async fn await_before_send_then_drop_sender_async_std() {
t.await;
}

// Tests that the Receiver handles being used synchronously even after being polled
#[tokio::test]
async fn poll_future_and_then_try_recv() {
use core::future::Future;
use core::pin::Pin;
use core::task::{self, Poll};

struct StupidReceiverFuture(oneshot::Receiver<()>);

impl Future for StupidReceiverFuture {
type Output = Result<(), oneshot::RecvError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let poll_result = Future::poll(Pin::new(&mut self.0), cx);
self.0.try_recv().expect_err("Should never be a message");
poll_result
}
}

let (sender, receiver) = oneshot::channel();
let t = tokio::spawn(async {
tokio::time::sleep(Duration::from_millis(20)).await;
mem::drop(sender);
});
StupidReceiverFuture(receiver).await.unwrap_err();
t.await.unwrap();
}

#[tokio::test]
async fn poll_receiver_then_drop_it() {
let (sender, receiver) = oneshot::channel::<()>();
let (sender, receiver) = oneshot::async_channel::<()>();
// This will poll the receiver and then give up after 100 ms.
tokio::time::timeout(Duration::from_millis(100), receiver)
.await
Expand Down
2 changes: 1 addition & 1 deletion tests/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ fn multiple_receiver_polls_keeps_only_latest_waker() {
let waker1 = unsafe { task::Waker::from_raw(raw_waker1) };
let mut context1 = task::Context::from_waker(&waker1);

let (_sender, mut receiver) = oneshot::channel::<()>();
let (_sender, mut receiver) = oneshot::async_channel::<()>();

let poll_result = future::Future::poll(pin::Pin::new(&mut receiver), &mut context1);
assert_eq!(poll_result, task::Poll::Pending);
Expand Down
67 changes: 5 additions & 62 deletions tests/loom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use oneshot::TryRecvError;
use loom::hint;
use loom::thread;
#[cfg(feature = "async")]
use std::future::Future;
use std::future::{Future, IntoFuture};
#[cfg(feature = "async")]
use std::pin::Pin;
#[cfg(feature = "async")]
Expand Down Expand Up @@ -73,7 +73,7 @@ fn async_recv() {
let t1 = thread::spawn(move || {
sender.send(987).unwrap();
});
assert_eq!(loom::future::block_on(receiver), Ok(987));
assert_eq!(loom::future::block_on(receiver.into_future()), Ok(987));
t1.join().unwrap();
})
}
Expand All @@ -82,7 +82,7 @@ fn async_recv() {
#[test]
fn send_then_poll() {
loom::model(|| {
let (sender, mut receiver) = oneshot::channel::<u128>();
let (sender, mut receiver) = oneshot::async_channel::<u128>();
sender.send(1234).unwrap();

let (waker, waker_handle) = helpers::waker::waker();
Expand All @@ -102,7 +102,7 @@ fn send_then_poll() {
#[test]
fn poll_then_send() {
loom::model(|| {
let (sender, mut receiver) = oneshot::channel::<u128>();
let (sender, mut receiver) = oneshot::async_channel::<u128>();

let (waker, waker_handle) = helpers::waker::waker();
let mut context = task::Context::from_waker(&waker);
Expand Down Expand Up @@ -131,7 +131,7 @@ fn poll_then_send() {
#[test]
fn poll_with_different_wakers() {
loom::model(|| {
let (sender, mut receiver) = oneshot::channel::<u128>();
let (sender, mut receiver) = oneshot::async_channel::<u128>();

let (waker1, waker_handle1) = helpers::waker::waker();
let mut context1 = task::Context::from_waker(&waker1);
Expand Down Expand Up @@ -164,60 +164,3 @@ fn poll_with_different_wakers() {
assert_eq!(waker_handle2.wake_count(), 1);
})
}

#[cfg(feature = "async")]
#[test]
fn poll_then_try_recv() {
loom::model(|| {
let (_sender, mut receiver) = oneshot::channel::<u128>();

let (waker, waker_handle) = helpers::waker::waker();
let mut context = task::Context::from_waker(&waker);

assert_eq!(Pin::new(&mut receiver).poll(&mut context), Poll::Pending);
assert_eq!(waker_handle.clone_count(), 1);
assert_eq!(waker_handle.drop_count(), 0);
assert_eq!(waker_handle.wake_count(), 0);

assert_eq!(receiver.try_recv(), Err(TryRecvError::Empty));

assert_eq!(Pin::new(&mut receiver).poll(&mut context), Poll::Pending);
assert_eq!(waker_handle.clone_count(), 2);
assert_eq!(waker_handle.drop_count(), 1);
assert_eq!(waker_handle.wake_count(), 0);
})
}

#[cfg(feature = "async")]
#[test]
fn poll_then_try_recv_while_sending() {
loom::model(|| {
let (sender, mut receiver) = oneshot::channel::<u128>();

let (waker, waker_handle) = helpers::waker::waker();
let mut context = task::Context::from_waker(&waker);

assert_eq!(Pin::new(&mut receiver).poll(&mut context), Poll::Pending);
assert_eq!(waker_handle.clone_count(), 1);
assert_eq!(waker_handle.drop_count(), 0);
assert_eq!(waker_handle.wake_count(), 0);

let t = thread::spawn(move || {
sender.send(1234).unwrap();
});

let msg = loop {
match receiver.try_recv() {
Ok(msg) => break msg,
Err(TryRecvError::Empty) => hint::spin_loop(),
Err(TryRecvError::Disconnected) => panic!("Should not be disconnected"),
}
};
assert_eq!(msg, 1234);
assert_eq!(waker_handle.clone_count(), 1);
assert_eq!(waker_handle.drop_count(), 1);
assert_eq!(waker_handle.wake_count(), 1);

t.join().unwrap();
})
}
Loading