diff --git a/async-stream/src/async_stream.rs b/async-stream/src/async_stream.rs index ff408ab..090a1bb 100644 --- a/async-stream/src/async_stream.rs +++ b/async-stream/src/async_stream.rs @@ -1,3 +1,4 @@ +use crate::sync_wrapper::SyncWrapper; use crate::yielder::Receiver; use futures_core::{FusedStream, Stream}; @@ -13,7 +14,7 @@ pin_project! { rx: Receiver<T>, done: bool, #[pin] - generator: U, + generator: SyncWrapper<U>, } } @@ -23,7 +24,7 @@ impl<T, U> AsyncStream<T, U> { AsyncStream { rx, done: false, - generator, + generator: SyncWrapper::new(generator), } } } diff --git a/async-stream/src/lib.rs b/async-stream/src/lib.rs index 318e404..c1401ef 100644 --- a/async-stream/src/lib.rs +++ b/async-stream/src/lib.rs @@ -158,6 +158,7 @@ mod async_stream; mod next; +mod sync_wrapper; mod yielder; /// Asynchronous stream diff --git a/async-stream/src/sync_wrapper.rs b/async-stream/src/sync_wrapper.rs new file mode 100644 index 0000000..ebef215 --- /dev/null +++ b/async-stream/src/sync_wrapper.rs @@ -0,0 +1,51 @@ +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// A wrapper around `T` that only allows mutable access. +/// +/// This allows it to unconditionally implement `Sync`, since there is nothing +/// you can do with an `&SyncWrapper<T>`. +pub(crate) struct SyncWrapper<T> { + inner: T, +} + +impl<T> SyncWrapper<T> { + pub(crate) fn new(value: T) -> Self { + Self { inner: value } + } + + pub(crate) fn get_pinned_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + // We can't use pin_project! for this because it generates a project_ref + // method which would allow accessing the inner element + // + // SAFETY: this.inner is guaranteed not to move as long as this lives. + unsafe { self.map_unchecked_mut(|this| &mut this.inner) } + } +} + +// SAFETY: It is not possible to do anything with an &SyncWrapper<T> so it is +// safe for it to be shared between threads. +// +// See [0] for more details. +// +// [0]: https://internals.rust-lang.org/t/what-shall-sync-mean-across-an-await/12020/2 +unsafe impl<T> Sync for SyncWrapper<T> {} + +impl<T: Future> Future for SyncWrapper<T> { + type Output = T::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.get_pinned_mut().poll(cx) + } +} + +impl<T> fmt::Debug for SyncWrapper<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // We can't format the inner value (since that would create an &T reference) + // so we just print a placeholder string. + + f.write_str("<opaque future>") + } +} diff --git a/async-stream/tests/stream.rs b/async-stream/tests/stream.rs index abfd1fc..23195e5 100644 --- a/async-stream/tests/stream.rs +++ b/async-stream/tests/stream.rs @@ -1,3 +1,5 @@ +use std::cell::Cell; + use async_stream::stream; use futures_core::stream::{FusedStream, Stream}; @@ -229,6 +231,18 @@ fn inner_try_stream() { }; } +#[test] +fn stream_is_sync() { + fn assert_sync<T: Sync>(_: T) {} + + // The stream should be sync even if it contains a non-sync value. + assert_sync(stream! { + let cell = Cell::new(true); + yield 5; + drop(cell); + }); +} + #[rustversion::attr(not(stable), ignore)] #[test] fn test() {