1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15- use std:: future:: IntoFuture ;
16- use std:: sync:: atomic:: AtomicUsize ;
15+ use std:: future:: { Future , IntoFuture } ;
16+ use std:: mem;
17+ use std:: pin:: Pin ;
1718use std:: sync:: atomic:: Ordering ;
19+ use std:: sync:: atomic:: { AtomicU32 , AtomicUsize } ;
1820use std:: sync:: Arc ;
21+ use std:: task:: { Context , Poll , RawWaker , RawWakerVTable , Waker } ;
1922use std:: time:: Duration ;
2023
2124use crate :: oneshot;
@@ -111,17 +114,17 @@ async fn poll_receiver_then_drop_it() {
111114#[ tokio:: test]
112115async fn recv_within_select ( ) {
113116 let ( tx, rx) = oneshot:: channel :: < & ' static str > ( ) ;
114- let mut interval = tokio:: time:: interval ( Duration :: from_secs ( 100 ) ) ;
117+ let mut interval = tokio:: time:: interval ( Duration :: from_millis ( 10 ) ) ;
115118
116119 let handle = tokio:: spawn ( async move {
117- tokio:: time:: sleep ( Duration :: from_secs ( 1 ) ) . await ;
120+ tokio:: time:: sleep ( Duration :: from_millis ( 100 ) ) . await ;
118121 tx. send ( "shut down" ) . unwrap ( ) ;
119122 } ) ;
120123
121124 let mut recv = rx. into_future ( ) ;
122125 loop {
123126 tokio:: select! {
124- _ = interval. tick( ) => println!( "another 100ms " ) ,
127+ _ = interval. tick( ) => println!( "another 10ms " ) ,
125128 msg = & mut recv => {
126129 println!( "Got message: {}" , msg. unwrap( ) ) ;
127130 break ;
@@ -131,3 +134,147 @@ async fn recv_within_select() {
131134
132135 handle. await . unwrap ( ) ;
133136}
137+
138+ #[ derive( Default ) ]
139+ pub struct WakerHandle {
140+ clone_count : AtomicU32 ,
141+ drop_count : AtomicU32 ,
142+ wake_count : AtomicU32 ,
143+ }
144+
145+ impl WakerHandle {
146+ pub fn clone_count ( & self ) -> u32 {
147+ self . clone_count . load ( Ordering :: Relaxed )
148+ }
149+
150+ pub fn drop_count ( & self ) -> u32 {
151+ self . drop_count . load ( Ordering :: Relaxed )
152+ }
153+
154+ pub fn wake_count ( & self ) -> u32 {
155+ self . wake_count . load ( Ordering :: Relaxed )
156+ }
157+ }
158+
159+ fn waker ( ) -> ( Waker , Arc < WakerHandle > ) {
160+ let waker_handle = Arc :: new ( WakerHandle :: default ( ) ) ;
161+ let waker_handle_ptr = Arc :: into_raw ( waker_handle. clone ( ) ) ;
162+ let raw_waker = RawWaker :: new ( waker_handle_ptr as * const _ , waker_vtable ( ) ) ;
163+ ( unsafe { Waker :: from_raw ( raw_waker) } , waker_handle)
164+ }
165+
166+ fn waker_vtable ( ) -> & ' static RawWakerVTable {
167+ & RawWakerVTable :: new ( clone_raw, wake_raw, wake_by_ref_raw, drop_raw)
168+ }
169+
170+ unsafe fn clone_raw ( data : * const ( ) ) -> RawWaker {
171+ let handle: Arc < WakerHandle > = Arc :: from_raw ( data as * const _ ) ;
172+ handle. clone_count . fetch_add ( 1 , Ordering :: Relaxed ) ;
173+ mem:: forget ( handle. clone ( ) ) ;
174+ mem:: forget ( handle) ;
175+ RawWaker :: new ( data, waker_vtable ( ) )
176+ }
177+
178+ unsafe fn wake_raw ( data : * const ( ) ) {
179+ let handle: Arc < WakerHandle > = Arc :: from_raw ( data as * const _ ) ;
180+ handle. wake_count . fetch_add ( 1 , Ordering :: Relaxed ) ;
181+ handle. drop_count . fetch_add ( 1 , Ordering :: Relaxed ) ;
182+ }
183+
184+ unsafe fn wake_by_ref_raw ( data : * const ( ) ) {
185+ let handle: Arc < WakerHandle > = Arc :: from_raw ( data as * const _ ) ;
186+ handle. wake_count . fetch_add ( 1 , Ordering :: Relaxed ) ;
187+ mem:: forget ( handle)
188+ }
189+
190+ unsafe fn drop_raw ( data : * const ( ) ) {
191+ let handle: Arc < WakerHandle > = Arc :: from_raw ( data as * const _ ) ;
192+ handle. drop_count . fetch_add ( 1 , Ordering :: Relaxed ) ;
193+ drop ( handle)
194+ }
195+
196+ #[ test]
197+ fn poll_then_send ( ) {
198+ let ( sender, receiver) = oneshot:: channel :: < u128 > ( ) ;
199+ let mut receiver = receiver. into_future ( ) ;
200+
201+ let ( waker, waker_handle) = waker ( ) ;
202+ let mut context = Context :: from_waker ( & waker) ;
203+
204+ assert_eq ! ( Pin :: new( & mut receiver) . poll( & mut context) , Poll :: Pending ) ;
205+ assert_eq ! ( waker_handle. clone_count( ) , 1 ) ;
206+ assert_eq ! ( waker_handle. drop_count( ) , 0 ) ;
207+ assert_eq ! ( waker_handle. wake_count( ) , 0 ) ;
208+
209+ sender. send ( 1234 ) . unwrap ( ) ;
210+ assert_eq ! ( waker_handle. clone_count( ) , 1 ) ;
211+ assert_eq ! ( waker_handle. drop_count( ) , 1 ) ;
212+ assert_eq ! ( waker_handle. wake_count( ) , 1 ) ;
213+
214+ assert_eq ! (
215+ Pin :: new( & mut receiver) . poll( & mut context) ,
216+ Poll :: Ready ( Ok ( 1234 ) )
217+ ) ;
218+ assert_eq ! ( waker_handle. clone_count( ) , 1 ) ;
219+ assert_eq ! ( waker_handle. drop_count( ) , 1 ) ;
220+ assert_eq ! ( waker_handle. wake_count( ) , 1 ) ;
221+ }
222+
223+ #[ test]
224+ fn poll_with_different_wakers ( ) {
225+ let ( sender, receiver) = oneshot:: channel :: < u128 > ( ) ;
226+ let mut receiver = receiver. into_future ( ) ;
227+
228+ let ( waker1, waker_handle1) = waker ( ) ;
229+ let mut context1 = Context :: from_waker ( & waker1) ;
230+
231+ assert_eq ! ( Pin :: new( & mut receiver) . poll( & mut context1) , Poll :: Pending ) ;
232+ assert_eq ! ( waker_handle1. clone_count( ) , 1 ) ;
233+ assert_eq ! ( waker_handle1. drop_count( ) , 0 ) ;
234+ assert_eq ! ( waker_handle1. wake_count( ) , 0 ) ;
235+
236+ let ( waker2, waker_handle2) = waker ( ) ;
237+ let mut context2 = Context :: from_waker ( & waker2) ;
238+
239+ assert_eq ! ( Pin :: new( & mut receiver) . poll( & mut context2) , Poll :: Pending ) ;
240+ assert_eq ! ( waker_handle1. clone_count( ) , 1 ) ;
241+ assert_eq ! ( waker_handle1. drop_count( ) , 1 ) ;
242+ assert_eq ! ( waker_handle1. wake_count( ) , 0 ) ;
243+
244+ assert_eq ! ( waker_handle2. clone_count( ) , 1 ) ;
245+ assert_eq ! ( waker_handle2. drop_count( ) , 0 ) ;
246+ assert_eq ! ( waker_handle2. wake_count( ) , 0 ) ;
247+
248+ // Sending should cause the waker from the latest poll to be woken up
249+ sender. send ( 1234 ) . unwrap ( ) ;
250+ assert_eq ! ( waker_handle1. clone_count( ) , 1 ) ;
251+ assert_eq ! ( waker_handle1. drop_count( ) , 1 ) ;
252+ assert_eq ! ( waker_handle1. wake_count( ) , 0 ) ;
253+
254+ assert_eq ! ( waker_handle2. clone_count( ) , 1 ) ;
255+ assert_eq ! ( waker_handle2. drop_count( ) , 1 ) ;
256+ assert_eq ! ( waker_handle2. wake_count( ) , 1 ) ;
257+ }
258+
259+ #[ test]
260+ fn poll_then_drop_receiver_during_send ( ) {
261+ let ( sender, receiver) = oneshot:: channel :: < u128 > ( ) ;
262+ let mut receiver = receiver. into_future ( ) ;
263+
264+ let ( waker, _waker_handle) = waker ( ) ;
265+ let mut context = Context :: from_waker ( & waker) ;
266+
267+ // Put the channel into the receiving state
268+ assert_eq ! ( Pin :: new( & mut receiver) . poll( & mut context) , Poll :: Pending ) ;
269+
270+ // Spawn a separate thread that sends in parallel
271+ let t = std:: thread:: spawn ( move || {
272+ let _ = sender. send ( 1234 ) ;
273+ } ) ;
274+
275+ // Drop the receiver.
276+ drop ( receiver) ;
277+
278+ // The send operation should also not have panicked
279+ t. join ( ) . unwrap ( ) ;
280+ }
0 commit comments