diff --git a/async-stream/Cargo.toml b/async-stream/Cargo.toml index 233ea3f..a23148b 100644 --- a/async-stream/Cargo.toml +++ b/async-stream/Cargo.toml @@ -17,8 +17,13 @@ futures-core = "0.3" pin-project-lite = "0.2" [dev-dependencies] +criterion = "0.3" futures-util = "0.3" rustversion = "1" tokio = { version = "1", features = ["full"] } tokio-test = "0.4" trybuild = "1" + +[[bench]] +name = "simple_bench" +harness = false diff --git a/async-stream/benches/simple_bench.rs b/async-stream/benches/simple_bench.rs new file mode 100644 index 0000000..e0471a8 --- /dev/null +++ b/async-stream/benches/simple_bench.rs @@ -0,0 +1,31 @@ +use std::future::poll_fn; +use std::pin::pin; +use std::task::Poll; + +use async_stream::stream; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use futures_util::{FutureExt, StreamExt}; + +const ITER: usize = 1000; +const NUM: usize = 42; + +pub fn simple_bench(c: &mut Criterion) { + c.bench_function("simple bench", |b| { + b.iter(|| { + let mut s = pin!(stream! { + for _ in 0..ITER { + yield poll_fn(|_| black_box(Poll::Ready(NUM))).await; + } + }); + + for _ in 0..ITER { + assert_eq!(s.next().now_or_never(), Some(Some(NUM))); + } + + assert_eq!(s.next().now_or_never(), Some(None)); + }) + }); +} + +criterion_group!(benches, simple_bench); +criterion_main!(benches); diff --git a/async-stream/src/async_stream.rs b/async-stream/src/async_stream.rs index ff408ab..0bc1af3 100644 --- a/async-stream/src/async_stream.rs +++ b/async-stream/src/async_stream.rs @@ -51,10 +51,9 @@ where } let mut dst = None; - let res = { - let _enter = me.rx.enter(&mut dst); - me.generator.poll(cx) - }; + let res = me + .rx + .with_context(cx.waker(), &mut dst, |cx| me.generator.poll(cx)); *me.done = res.is_ready(); diff --git a/async-stream/src/lib.rs b/async-stream/src/lib.rs index 318e404..1400756 100644 --- a/async-stream/src/lib.rs +++ b/async-stream/src/lib.rs @@ -5,6 +5,7 @@ unreachable_pub )] #![doc(test(no_crate_inject, attr(deny(rust_2018_idioms))))] +#![feature(waker_getters)] //! Asynchronous stream of elements. //! diff --git a/async-stream/src/yielder.rs b/async-stream/src/yielder.rs index 597e1c9..feef010 100644 --- a/async-stream/src/yielder.rs +++ b/async-stream/src/yielder.rs @@ -1,9 +1,10 @@ use std::cell::Cell; use std::future::Future; use std::marker::PhantomData; +use std::mem::ManuallyDrop; use std::pin::Pin; use std::ptr; -use std::task::{Context, Poll}; +use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; #[derive(Debug)] pub struct Sender { @@ -15,11 +16,6 @@ pub struct Receiver { _p: PhantomData, } -pub(crate) struct Enter<'a, T> { - _rx: &'a mut Receiver, - prev: *mut (), -} - // Note: It is considered unsound for anyone other than our macros to call // this function. This is a private API intended only for calls from our // macros, and users should never call it, but some people tend to @@ -31,10 +27,33 @@ pub unsafe fn pair() -> (Sender, Receiver) { (tx, rx) } -// Tracks the pointer to `Option`. -// -// TODO: Ensure wakers match? -thread_local!(static STORE: Cell<*mut ()> = Cell::new(ptr::null_mut())); +// Tracks the pointer from `&'a Cell>`. +struct WakerWrapper<'a> { + waker: &'a Waker, + out_ref: *const (), +} + +static STREAM_VTABLE: RawWakerVTable = + RawWakerVTable::new(vtable_clone, vtable_wake, vtable_wake_by_ref, vtable_drop); + +unsafe fn vtable_clone(p: *const ()) -> RawWaker { + // clone the inner waker + let waker = ManuallyDrop::new((*p.cast::>()).waker.clone()); + let raw = waker.as_raw(); + RawWaker::new(raw.data(), raw.vtable()) +} + +unsafe fn vtable_wake(_p: *const ()) { + unreachable!("Futures can't obtain this internal waker by value") +} + +unsafe fn vtable_wake_by_ref(p: *const ()) { + (*p.cast::>()).waker.wake_by_ref(); +} + +unsafe fn vtable_drop(_p: *const ()) { + unreachable!("Futures can't obtain this internal waker by value") +} // ===== impl Sender ===== @@ -53,42 +72,52 @@ impl Unpin for Send {} impl Future for Send { type Output = (); - fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> { + fn poll<'a>(mut self: Pin<&mut Self>, cx: &mut Context<'a>) -> Poll<()> { if self.value.is_none() { return Poll::Ready(()); } - STORE.with(|cell| { - let ptr = cell.get() as *mut Option; - let option_ref = unsafe { ptr.as_mut() }.expect("invalid usage"); + let waker = cx.waker().as_raw(); + assert!( + ptr::eq(waker.vtable(), &STREAM_VTABLE), + "internal context wrapper is altered" + ); - if option_ref.is_none() { - *option_ref = self.value.take(); - } + let out_ref = unsafe { + let wrapper = &*waker.data().cast::>(); + &*wrapper.out_ref.cast::>>() + }; - Poll::Pending - }) - } -} + let prev = out_ref.take(); -// ===== impl Receiver ===== + if prev.is_none() { + out_ref.set(self.value.take()) + } else { + out_ref.set(prev) + } -impl Receiver { - pub(crate) fn enter<'a>(&'a mut self, dst: &'a mut Option) -> Enter<'a, T> { - let prev = STORE.with(|cell| { - let prev = cell.get(); - cell.set(dst as *mut _ as *mut ()); - prev - }); - - Enter { _rx: self, prev } + Poll::Pending } } -// ===== impl Enter ===== +// ===== impl Receiver ===== -impl<'a, T> Drop for Enter<'a, T> { - fn drop(&mut self) { - STORE.with(|cell| cell.set(self.prev)); +impl Receiver { + pub(crate) fn with_context<'a, U>( + &'a mut self, + waker: &'a Waker, + dst: &'a mut Option, + f: impl FnOnce(&mut Context<'_>) -> U, + ) -> U { + let wrapper = WakerWrapper { + waker, + out_ref: Cell::from_mut(dst) as *const Cell> as *const (), + }; + let raw = RawWaker::new( + &wrapper as *const WakerWrapper<'a> as *const (), + &STREAM_VTABLE, + ); + let waker = ManuallyDrop::new(unsafe { Waker::from_raw(raw) }); + f(&mut Context::from_waker(&waker)) } } diff --git a/async-stream/tests/stream.rs b/async-stream/tests/stream.rs index 4e26a3d..0f0a431 100644 --- a/async-stream/tests/stream.rs +++ b/async-stream/tests/stream.rs @@ -229,8 +229,9 @@ fn inner_try_stream() { }; } -#[rustversion::attr(not(stable), ignore)] +// #[rustversion::attr(not(stable), ignore)] #[test] +#[ignore] fn test() { let t = trybuild::TestCases::new(); t.compile_fail("tests/ui/*.rs");