1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15+ use std:: future:: Future ;
1516use std:: future:: IntoFuture ;
17+ use std:: mem;
18+ use std:: pin:: Pin ;
19+ use std:: sync:: atomic:: AtomicU32 ;
1620use std:: sync:: atomic:: AtomicUsize ;
1721use std:: sync:: atomic:: Ordering ;
1822use std:: sync:: Arc ;
23+ use std:: task:: Context ;
24+ use std:: task:: Poll ;
25+ use std:: task:: RawWaker ;
26+ use std:: task:: RawWakerVTable ;
27+ use std:: task:: Waker ;
1928use std:: time:: Duration ;
2029
2130use crate :: oneshot;
@@ -111,17 +120,17 @@ async fn poll_receiver_then_drop_it() {
111120#[ tokio:: test]
112121async fn recv_within_select ( ) {
113122 let ( tx, rx) = oneshot:: channel :: < & ' static str > ( ) ;
114- let mut interval = tokio:: time:: interval ( Duration :: from_secs ( 100 ) ) ;
123+ let mut interval = tokio:: time:: interval ( Duration :: from_millis ( 10 ) ) ;
115124
116125 let handle = tokio:: spawn ( async move {
117- tokio:: time:: sleep ( Duration :: from_secs ( 1 ) ) . await ;
126+ tokio:: time:: sleep ( Duration :: from_millis ( 100 ) ) . await ;
118127 tx. send ( "shut down" ) . unwrap ( ) ;
119128 } ) ;
120129
121130 let mut recv = rx. into_future ( ) ;
122131 loop {
123132 tokio:: select! {
124- _ = interval. tick( ) => println!( "another 100ms " ) ,
133+ _ = interval. tick( ) => println!( "another 10ms " ) ,
125134 msg = & mut recv => {
126135 println!( "Got message: {}" , msg. unwrap( ) ) ;
127136 break ;
@@ -131,3 +140,147 @@ async fn recv_within_select() {
131140
132141 handle. await . unwrap ( ) ;
133142}
143+
144+ #[ derive( Default ) ]
145+ pub struct WakerHandle {
146+ clone_count : AtomicU32 ,
147+ drop_count : AtomicU32 ,
148+ wake_count : AtomicU32 ,
149+ }
150+
151+ impl WakerHandle {
152+ pub fn clone_count ( & self ) -> u32 {
153+ self . clone_count . load ( Ordering :: Relaxed )
154+ }
155+
156+ pub fn drop_count ( & self ) -> u32 {
157+ self . drop_count . load ( Ordering :: Relaxed )
158+ }
159+
160+ pub fn wake_count ( & self ) -> u32 {
161+ self . wake_count . load ( Ordering :: Relaxed )
162+ }
163+ }
164+
165+ fn waker ( ) -> ( Waker , Arc < WakerHandle > ) {
166+ let waker_handle = Arc :: new ( WakerHandle :: default ( ) ) ;
167+ let waker_handle_ptr = Arc :: into_raw ( waker_handle. clone ( ) ) ;
168+ let raw_waker = RawWaker :: new ( waker_handle_ptr as * const _ , waker_vtable ( ) ) ;
169+ ( unsafe { Waker :: from_raw ( raw_waker) } , waker_handle)
170+ }
171+
172+ fn waker_vtable ( ) -> & ' static RawWakerVTable {
173+ & RawWakerVTable :: new ( clone_raw, wake_raw, wake_by_ref_raw, drop_raw)
174+ }
175+
176+ unsafe fn clone_raw ( data : * const ( ) ) -> RawWaker {
177+ let handle: Arc < WakerHandle > = Arc :: from_raw ( data as * const _ ) ;
178+ handle. clone_count . fetch_add ( 1 , Ordering :: Relaxed ) ;
179+ mem:: forget ( handle. clone ( ) ) ;
180+ mem:: forget ( handle) ;
181+ RawWaker :: new ( data, waker_vtable ( ) )
182+ }
183+
184+ unsafe fn wake_raw ( data : * const ( ) ) {
185+ let handle: Arc < WakerHandle > = Arc :: from_raw ( data as * const _ ) ;
186+ handle. wake_count . fetch_add ( 1 , Ordering :: Relaxed ) ;
187+ handle. drop_count . fetch_add ( 1 , Ordering :: Relaxed ) ;
188+ }
189+
190+ unsafe fn wake_by_ref_raw ( data : * const ( ) ) {
191+ let handle: Arc < WakerHandle > = Arc :: from_raw ( data as * const _ ) ;
192+ handle. wake_count . fetch_add ( 1 , Ordering :: Relaxed ) ;
193+ mem:: forget ( handle)
194+ }
195+
196+ unsafe fn drop_raw ( data : * const ( ) ) {
197+ let handle: Arc < WakerHandle > = Arc :: from_raw ( data as * const _ ) ;
198+ handle. drop_count . fetch_add ( 1 , Ordering :: Relaxed ) ;
199+ drop ( handle)
200+ }
201+
202+ #[ test]
203+ fn poll_then_send ( ) {
204+ let ( sender, receiver) = oneshot:: channel :: < u128 > ( ) ;
205+ let mut receiver = receiver. into_future ( ) ;
206+
207+ let ( waker, waker_handle) = waker ( ) ;
208+ let mut context = Context :: from_waker ( & waker) ;
209+
210+ assert_eq ! ( Pin :: new( & mut receiver) . poll( & mut context) , Poll :: Pending ) ;
211+ assert_eq ! ( waker_handle. clone_count( ) , 1 ) ;
212+ assert_eq ! ( waker_handle. drop_count( ) , 0 ) ;
213+ assert_eq ! ( waker_handle. wake_count( ) , 0 ) ;
214+
215+ sender. send ( 1234 ) . unwrap ( ) ;
216+ assert_eq ! ( waker_handle. clone_count( ) , 1 ) ;
217+ assert_eq ! ( waker_handle. drop_count( ) , 1 ) ;
218+ assert_eq ! ( waker_handle. wake_count( ) , 1 ) ;
219+
220+ assert_eq ! (
221+ Pin :: new( & mut receiver) . poll( & mut context) ,
222+ Poll :: Ready ( Ok ( 1234 ) )
223+ ) ;
224+ assert_eq ! ( waker_handle. clone_count( ) , 1 ) ;
225+ assert_eq ! ( waker_handle. drop_count( ) , 1 ) ;
226+ assert_eq ! ( waker_handle. wake_count( ) , 1 ) ;
227+ }
228+
229+ #[ test]
230+ fn poll_with_different_wakers ( ) {
231+ let ( sender, receiver) = oneshot:: channel :: < u128 > ( ) ;
232+ let mut receiver = receiver. into_future ( ) ;
233+
234+ let ( waker1, waker_handle1) = waker ( ) ;
235+ let mut context1 = Context :: from_waker ( & waker1) ;
236+
237+ assert_eq ! ( Pin :: new( & mut receiver) . poll( & mut context1) , Poll :: Pending ) ;
238+ assert_eq ! ( waker_handle1. clone_count( ) , 1 ) ;
239+ assert_eq ! ( waker_handle1. drop_count( ) , 0 ) ;
240+ assert_eq ! ( waker_handle1. wake_count( ) , 0 ) ;
241+
242+ let ( waker2, waker_handle2) = waker ( ) ;
243+ let mut context2 = Context :: from_waker ( & waker2) ;
244+
245+ assert_eq ! ( Pin :: new( & mut receiver) . poll( & mut context2) , Poll :: Pending ) ;
246+ assert_eq ! ( waker_handle1. clone_count( ) , 1 ) ;
247+ assert_eq ! ( waker_handle1. drop_count( ) , 1 ) ;
248+ assert_eq ! ( waker_handle1. wake_count( ) , 0 ) ;
249+
250+ assert_eq ! ( waker_handle2. clone_count( ) , 1 ) ;
251+ assert_eq ! ( waker_handle2. drop_count( ) , 0 ) ;
252+ assert_eq ! ( waker_handle2. wake_count( ) , 0 ) ;
253+
254+ // Sending should cause the waker from the latest poll to be woken up
255+ sender. send ( 1234 ) . unwrap ( ) ;
256+ assert_eq ! ( waker_handle1. clone_count( ) , 1 ) ;
257+ assert_eq ! ( waker_handle1. drop_count( ) , 1 ) ;
258+ assert_eq ! ( waker_handle1. wake_count( ) , 0 ) ;
259+
260+ assert_eq ! ( waker_handle2. clone_count( ) , 1 ) ;
261+ assert_eq ! ( waker_handle2. drop_count( ) , 1 ) ;
262+ assert_eq ! ( waker_handle2. wake_count( ) , 1 ) ;
263+ }
264+
265+ #[ test]
266+ fn poll_then_drop_receiver_during_send ( ) {
267+ let ( sender, receiver) = oneshot:: channel :: < u128 > ( ) ;
268+ let mut receiver = receiver. into_future ( ) ;
269+
270+ let ( waker, _waker_handle) = waker ( ) ;
271+ let mut context = Context :: from_waker ( & waker) ;
272+
273+ // Put the channel into the receiving state
274+ assert_eq ! ( Pin :: new( & mut receiver) . poll( & mut context) , Poll :: Pending ) ;
275+
276+ // Spawn a separate thread that sends in parallel
277+ let t = std:: thread:: spawn ( move || {
278+ let _ = sender. send ( 1234 ) ;
279+ } ) ;
280+
281+ // Drop the receiver.
282+ drop ( receiver) ;
283+
284+ // The send operation should also not have panicked
285+ t. join ( ) . unwrap ( ) ;
286+ }
0 commit comments