@@ -10,7 +10,6 @@ pub mod state;
1010pub use conn:: TlsConnection ;
1111pub use control:: ProverControl ;
1212pub use error:: ProverError ;
13- use futures_plex:: DuplexStream ;
1413pub use tlsn_core:: ProverOutput ;
1514
1615use crate :: {
@@ -25,7 +24,7 @@ use crate::{
2524 utils:: { await_with_copy_io, build_mt_context} ,
2625} ;
2726
28- use futures:: { AsyncRead , AsyncWrite , FutureExt , TryFutureExt } ;
27+ use futures:: { AsyncRead , AsyncWrite , FutureExt , TryFutureExt , ready } ;
2928use rustls_pki_types:: CertificateDer ;
3029use serio:: { SinkExt , stream:: IoStreamExt } ;
3130use std:: {
@@ -270,6 +269,8 @@ impl Prover<state::Setup> {
270269 S : AsyncRead + AsyncWrite + Send + Unpin ,
271270 T : AsyncRead + AsyncWrite + Send + Unpin ,
272271 {
272+ let ( client_to_server, server_to_client) = futures_plex:: duplex ( BUF_CAP ) ;
273+
273274 Prover {
274275 config : self . config ,
275276 span : self . span ,
@@ -281,6 +282,8 @@ impl Prover<state::Setup> {
281282 output : None ,
282283 client_io : self . state . client_io ,
283284 verifier_io : self . state . verifier_io ,
285+ tls_client_duplex : client_to_server,
286+ server_duplex : server_to_client,
284287 server_socket : server_io,
285288 verifier_socket : verifier_io,
286289 client_closed : false ,
@@ -323,13 +326,13 @@ where
323326 fn poll ( mut self : std:: pin:: Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
324327 let mut state = Pin :: new ( & mut self . state ) . project ( ) ;
325328
326- let ( mut duplex_1, mut duplex_2) = futures_plex:: duplex ( BUF_CAP ) ;
327329 loop {
328330 let mut progress = false ;
329331
330332 progress |= Self :: io_client_conn ( & mut state, cx) ?;
331- progress |= Self :: io_client_server ( & mut state, cx, & mut duplex_1 , & mut duplex_2 ) ?;
333+ progress |= Self :: io_client_server ( & mut state, cx) ?;
332334 progress |= Self :: io_client_verifier ( & mut state, cx) ?;
335+
333336 _ = state. mux_fut . poll_unpin ( cx) ?;
334337
335338 if state. output . is_none ( )
@@ -339,8 +342,14 @@ where
339342 }
340343
341344 if !progress {
345+ cx. waker ( ) . wake_by_ref ( ) ;
342346 return Poll :: Pending ;
343- } else if * state. server_closed && * state. client_closed && state. output . is_some ( ) {
347+ }
348+
349+ if * state. server_closed && state. output . is_some ( ) {
350+ ready ! ( state. client_io. poll_close( cx) ) ?;
351+ ready ! ( state. server_socket. poll_close( cx) ) ?;
352+
344353 return Poll :: Ready ( Ok ( ( ) ) ) ;
345354 }
346355 }
@@ -359,10 +368,11 @@ where
359368 let mut progress = false ;
360369
361370 // tls_conn -> tls_client
362- if let Poll :: Ready ( mut simplex) = state. client_io . as_mut ( ) . poll_lock_read ( cx)
371+ if state. tls_client . wants_write ( )
372+ && let Poll :: Ready ( mut simplex) = state. client_io . as_mut ( ) . poll_lock_read ( cx)
363373 && let Poll :: Ready ( buf) = simplex. poll_get ( cx) ?
364374 {
365- if buf. len ( ) > 0 {
375+ if ! buf. is_empty ( ) {
366376 let write = state. tls_client . write ( buf) ?;
367377 if write > 0 {
368378 progress = true ;
@@ -376,72 +386,64 @@ where
376386 }
377387
378388 // tls_client -> tls_conn
379- if let Poll :: Ready ( mut simplex) = state. client_io . as_mut ( ) . poll_lock_write ( cx)
389+ if state. tls_client . wants_read ( )
390+ && let Poll :: Ready ( mut simplex) = state. client_io . as_mut ( ) . poll_lock_write ( cx)
380391 && let Poll :: Ready ( buf) = simplex. poll_mut ( cx) ?
392+ && let read = state. tls_client . read ( buf) ?
393+ && read > 0
381394 {
382- if buf. len ( ) > 0
383- && let read = state. tls_client . read ( buf) ?
384- && read > 0
385- {
386- progress = true ;
387- simplex. advance_mut ( read) ;
388- }
395+ progress = true ;
396+ simplex. advance_mut ( read) ;
389397 }
390-
391398 Ok ( progress)
392399 }
393400
394401 fn io_client_server (
395402 state : & mut ConnectedProj < S , T > ,
396403 cx : & mut Context ,
397- duplex_1 : & mut DuplexStream ,
398- duplex_2 : & mut DuplexStream ,
399404 ) -> Result < bool , ProverError > {
400405 let mut progress = false ;
401- let mut duplex_1 = Pin :: new ( duplex_1) ;
402- let mut duplex_2 = Pin :: new ( duplex_2) ;
403406
404407 // server_socket -> duplex
405- if let Poll :: Ready ( write) = duplex_1. poll_write_from ( cx, state. server_socket . as_mut ( ) ) ? {
408+ if let Poll :: Ready ( write) = state
409+ . server_duplex
410+ . poll_write_from ( cx, state. server_socket . as_mut ( ) ) ?
411+ {
406412 if write > 0 {
407413 progress = true ;
408- } else if let Poll :: Ready ( ( ) ) = duplex_1 . as_mut ( ) . poll_close ( cx ) ? {
414+ } else if ! * state . server_closed {
409415 progress = true ;
416+ * state. server_closed = true ;
417+ state. tls_client . server_close ( ) ?;
410418 }
411419 }
412420
413421 // duplex -> tls_client
414- if let Poll :: Ready ( mut simplex) = duplex_2. as_mut ( ) . poll_lock_read ( cx)
422+ if state. tls_client . wants_read_tls ( )
423+ && let Poll :: Ready ( mut simplex) = state. tls_client_duplex . as_mut ( ) . poll_lock_read ( cx)
415424 && let Poll :: Ready ( buf) = simplex. poll_get ( cx) ?
425+ && let read = state. tls_client . read_tls ( buf) ?
426+ && read > 0
416427 {
417- if buf. len ( ) > 0
418- && let read = state. tls_client . read_tls ( buf) ?
419- && read > 0
420- {
421- progress = true ;
422- simplex. advance ( read) ;
423- } else if !* state. server_closed {
424- progress = true ;
425- * state. server_closed = true ;
426- state. tls_client . server_close ( ) ?;
427- }
428+ progress = true ;
429+ simplex. advance ( read) ;
428430 }
429431
430432 // tls_client -> duplex
431- if let Poll :: Ready ( mut simplex) = duplex_2. as_mut ( ) . poll_lock_write ( cx)
433+ if state. tls_client . wants_write_tls ( )
434+ && let Poll :: Ready ( mut simplex) = state. tls_client_duplex . as_mut ( ) . poll_lock_write ( cx)
432435 && let Poll :: Ready ( buf) = simplex. poll_mut ( cx) ?
436+ && let write = state. tls_client . write_tls ( buf) ?
437+ && write > 0
433438 {
434- if buf. len ( ) > 0
435- && let write = state. tls_client . write_tls ( buf) ?
436- && write > 0
437- {
438- progress = true ;
439- simplex. advance_mut ( write) ;
440- }
439+ progress = true ;
440+ simplex. advance_mut ( write) ;
441441 }
442442
443443 // duplex -> server_socket
444- if let Poll :: Ready ( read) = duplex_1. poll_read_to ( cx, state. server_socket . as_mut ( ) ) ?
444+ if let Poll :: Ready ( read) = state
445+ . server_duplex
446+ . poll_read_to ( cx, state. server_socket . as_mut ( ) ) ?
445447 && read > 0
446448 {
447449 progress = true ;
0 commit comments