Skip to content

Commit e3c25ea

Browse files
committed
Merge branch 'feat/socket-api-debug' into feat/socket-api
2 parents 4ec838f + 13681b1 commit e3c25ea

File tree

7 files changed

+65
-74
lines changed

7 files changed

+65
-74
lines changed

crates/tls/client/src/conn.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ impl ConnectionCommon {
549549
if let Ok(0) = res {
550550
self.common_state.has_seen_eof = true;
551551
}
552+
let res = res.inspect(|v| println!("read tls: {v} bytes"));
552553
res
553554
}
554555

crates/tlsn/src/prover.rs

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ pub mod state;
1010
pub use conn::TlsConnection;
1111
pub use control::ProverControl;
1212
pub use error::ProverError;
13-
use futures_plex::DuplexStream;
1413
pub use tlsn_core::ProverOutput;
1514

1615
use crate::{
@@ -25,7 +24,7 @@ use crate::{
2524
utils::{await_with_copy_io, build_mt_context},
2625
};
2726

28-
use futures::{AsyncRead, AsyncWrite, FutureExt, TryFutureExt};
27+
use futures::{AsyncRead, AsyncWrite, FutureExt, TryFutureExt, ready};
2928
use rustls_pki_types::CertificateDer;
3029
use serio::{SinkExt, stream::IoStreamExt};
3130
use std::{
@@ -270,6 +269,8 @@ impl Prover<state::Setup> {
270269
S: AsyncRead + AsyncWrite + Send + Unpin,
271270
T: AsyncRead + AsyncWrite + Send + Unpin,
272271
{
272+
let (client_to_server, server_to_client) = futures_plex::duplex(BUF_CAP);
273+
273274
Prover {
274275
config: self.config,
275276
span: self.span,
@@ -281,6 +282,8 @@ impl Prover<state::Setup> {
281282
output: None,
282283
client_io: self.state.client_io,
283284
verifier_io: self.state.verifier_io,
285+
tls_client_duplex: client_to_server,
286+
server_duplex: server_to_client,
284287
server_socket: server_io,
285288
verifier_socket: verifier_io,
286289
client_closed: false,
@@ -323,13 +326,13 @@ where
323326
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
324327
let mut state = Pin::new(&mut self.state).project();
325328

326-
let (mut duplex_1, mut duplex_2) = futures_plex::duplex(BUF_CAP);
327329
loop {
328330
let mut progress = false;
329331

330332
progress |= Self::io_client_conn(&mut state, cx)?;
331-
progress |= Self::io_client_server(&mut state, cx, &mut duplex_1, &mut duplex_2)?;
333+
progress |= Self::io_client_server(&mut state, cx)?;
332334
progress |= Self::io_client_verifier(&mut state, cx)?;
335+
333336
_ = state.mux_fut.poll_unpin(cx)?;
334337

335338
if state.output.is_none()
@@ -339,8 +342,14 @@ where
339342
}
340343

341344
if !progress {
345+
cx.waker().wake_by_ref();
342346
return Poll::Pending;
343-
} else if *state.server_closed && *state.client_closed && state.output.is_some() {
347+
}
348+
349+
if *state.server_closed && state.output.is_some() {
350+
ready!(state.client_io.poll_close(cx))?;
351+
ready!(state.server_socket.poll_close(cx))?;
352+
344353
return Poll::Ready(Ok(()));
345354
}
346355
}
@@ -359,10 +368,11 @@ where
359368
let mut progress = false;
360369

361370
// tls_conn -> tls_client
362-
if let Poll::Ready(mut simplex) = state.client_io.as_mut().poll_lock_read(cx)
371+
if state.tls_client.wants_write()
372+
&& let Poll::Ready(mut simplex) = state.client_io.as_mut().poll_lock_read(cx)
363373
&& let Poll::Ready(buf) = simplex.poll_get(cx)?
364374
{
365-
if buf.len() > 0 {
375+
if !buf.is_empty() {
366376
let write = state.tls_client.write(buf)?;
367377
if write > 0 {
368378
progress = true;
@@ -376,72 +386,64 @@ where
376386
}
377387

378388
// tls_client -> tls_conn
379-
if let Poll::Ready(mut simplex) = state.client_io.as_mut().poll_lock_write(cx)
389+
if state.tls_client.wants_read()
390+
&& let Poll::Ready(mut simplex) = state.client_io.as_mut().poll_lock_write(cx)
380391
&& let Poll::Ready(buf) = simplex.poll_mut(cx)?
392+
&& let read = state.tls_client.read(buf)?
393+
&& read > 0
381394
{
382-
if buf.len() > 0
383-
&& let read = state.tls_client.read(buf)?
384-
&& read > 0
385-
{
386-
progress = true;
387-
simplex.advance_mut(read);
388-
}
395+
progress = true;
396+
simplex.advance_mut(read);
389397
}
390-
391398
Ok(progress)
392399
}
393400

394401
fn io_client_server(
395402
state: &mut ConnectedProj<S, T>,
396403
cx: &mut Context,
397-
duplex_1: &mut DuplexStream,
398-
duplex_2: &mut DuplexStream,
399404
) -> Result<bool, ProverError> {
400405
let mut progress = false;
401-
let mut duplex_1 = Pin::new(duplex_1);
402-
let mut duplex_2 = Pin::new(duplex_2);
403406

404407
// server_socket -> duplex
405-
if let Poll::Ready(write) = duplex_1.poll_write_from(cx, state.server_socket.as_mut())? {
408+
if let Poll::Ready(write) = state
409+
.server_duplex
410+
.poll_write_from(cx, state.server_socket.as_mut())?
411+
{
406412
if write > 0 {
407413
progress = true;
408-
} else if let Poll::Ready(()) = duplex_1.as_mut().poll_close(cx)? {
414+
} else if !*state.server_closed {
409415
progress = true;
416+
*state.server_closed = true;
417+
state.tls_client.server_close()?;
410418
}
411419
}
412420

413421
// duplex -> tls_client
414-
if let Poll::Ready(mut simplex) = duplex_2.as_mut().poll_lock_read(cx)
422+
if state.tls_client.wants_read_tls()
423+
&& let Poll::Ready(mut simplex) = state.tls_client_duplex.as_mut().poll_lock_read(cx)
415424
&& let Poll::Ready(buf) = simplex.poll_get(cx)?
425+
&& let read = state.tls_client.read_tls(buf)?
426+
&& read > 0
416427
{
417-
if buf.len() > 0
418-
&& let read = state.tls_client.read_tls(buf)?
419-
&& read > 0
420-
{
421-
progress = true;
422-
simplex.advance(read);
423-
} else if !*state.server_closed {
424-
progress = true;
425-
*state.server_closed = true;
426-
state.tls_client.server_close()?;
427-
}
428+
progress = true;
429+
simplex.advance(read);
428430
}
429431

430432
// tls_client -> duplex
431-
if let Poll::Ready(mut simplex) = duplex_2.as_mut().poll_lock_write(cx)
433+
if state.tls_client.wants_write_tls()
434+
&& let Poll::Ready(mut simplex) = state.tls_client_duplex.as_mut().poll_lock_write(cx)
432435
&& let Poll::Ready(buf) = simplex.poll_mut(cx)?
436+
&& let write = state.tls_client.write_tls(buf)?
437+
&& write > 0
433438
{
434-
if buf.len() > 0
435-
&& let write = state.tls_client.write_tls(buf)?
436-
&& write > 0
437-
{
438-
progress = true;
439-
simplex.advance_mut(write);
440-
}
439+
progress = true;
440+
simplex.advance_mut(write);
441441
}
442442

443443
// duplex -> server_socket
444-
if let Poll::Ready(read) = duplex_1.poll_read_to(cx, state.server_socket.as_mut())?
444+
if let Poll::Ready(read) = state
445+
.server_duplex
446+
.poll_read_to(cx, state.server_socket.as_mut())?
445447
&& read > 0
446448
{
447449
progress = true;

crates/tlsn/src/prover/client/mpc.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ impl TlsClient for MpcTlsClient {
237237
receiver,
238238
} => {
239239
trace!("inner client is active");
240+
//println!("client is handshaking: {}", inner.tls.is_handshaking());
240241

241242
if !inner.tls.is_handshaking()
242243
&& let Ok(cmd) = receiver.try_recv()
@@ -442,8 +443,7 @@ impl InnerState {
442443
}
443444
self.client_closed = true;
444445
}
445-
self.tls.process_new_packets().await?;
446-
Ok(self)
446+
self.run().await
447447
}
448448

449449
#[instrument(parent = &self.span, level = "debug", skip_all, err)]

crates/tlsn/src/prover/conn.rs

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,18 @@ use std::{
1818
/// connection.
1919
pub struct TlsConnection {
2020
duplex: DuplexStream,
21-
closed: bool,
2221
}
2322

2423
impl TlsConnection {
2524
pub(crate) fn new(duplex: DuplexStream) -> Self {
26-
Self {
27-
duplex,
28-
closed: false,
29-
}
25+
Self { duplex }
3026
}
3127
}
3228

3329
impl Drop for TlsConnection {
3430
fn drop(&mut self) {
35-
if !self.closed {
36-
if let Err(err) = futures::executor::block_on(self.duplex.close()) {
37-
tracing::error!("error closing connection: {}", err);
38-
}
39-
self.closed = true;
31+
if let Err(err) = futures::executor::block_on(self.duplex.close()) {
32+
tracing::error!("error closing connection: {}", err);
4033
}
4134
}
4235
}
@@ -47,12 +40,8 @@ impl AsyncRead for TlsConnection {
4740
cx: &mut Context<'_>,
4841
buf: &mut [u8],
4942
) -> Poll<std::io::Result<usize>> {
50-
if !self.closed {
51-
let duplex = Pin::new(&mut self.duplex);
52-
duplex.poll_read(cx, buf)
53-
} else {
54-
Poll::Ready(Ok(0))
55-
}
43+
let duplex = Pin::new(&mut self.duplex);
44+
duplex.poll_read(cx, buf)
5645
}
5746
}
5847

@@ -71,14 +60,7 @@ impl AsyncWrite for TlsConnection {
7160
}
7261

7362
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
74-
if !self.closed {
75-
let duplex = Pin::new(&mut self.duplex);
76-
if let Poll::Ready(()) = duplex.poll_close(cx)? {
77-
self.closed = true;
78-
} else {
79-
return Poll::Pending;
80-
};
81-
}
82-
Poll::Ready(Ok(()))
63+
let duplex = Pin::new(&mut self.duplex);
64+
duplex.poll_close(cx)
8365
}
8466
}

crates/tlsn/src/prover/state.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ pin_project_lite::pin_project! {
6666
pub(crate) server_socket: S,
6767
#[pin]
6868
pub(crate) verifier_socket: T,
69+
#[pin]
70+
pub(crate) tls_client_duplex: DuplexStream,
71+
#[pin]
72+
pub(crate) server_duplex: DuplexStream,
6973
pub(crate) client_closed: bool,
7074
pub(crate) server_closed: bool
7175
}

crates/tlsn/src/utils.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use futures::{AsyncRead, AsyncWrite, future::FusedFuture, ready};
99
use futures_plex::DuplexStream;
1010
use mpz_common::context::Multithread;
1111

12-
use crate::{BUF_CAP, mux::MuxControl};
12+
use crate::mux::MuxControl;
1313

1414
/// Maximum concurrency for multi-threaded context.
1515
pub(crate) const MAX_CONCURRENCY: usize = 8;
@@ -53,7 +53,6 @@ pin_project_lite::pin_project! {
5353
io: S,
5454
#[pin]
5555
duplex: &'a mut DuplexStream,
56-
buf: Vec<u8>,
5756
closed: bool
5857
}
5958
}
@@ -63,7 +62,6 @@ impl<'a, S> CopyFlush<'a, S> {
6362
Self {
6463
io,
6564
duplex,
66-
buf: Vec::with_capacity(BUF_CAP),
6765
closed: false,
6866
}
6967
}
@@ -85,19 +83,23 @@ where
8583
if let Poll::Ready(read) = read
8684
&& read == 0
8785
{
86+
println!("copy flush: read 0 from duplex: closing connection");
8887
ready!(this.io.as_mut().poll_close(cx))?;
8988
*this.closed = true;
9089
}
9190

9291
if let Poll::Ready(write) = write
9392
&& write == 0
9493
{
94+
println!("copy flush: read 0 from io: closing connection");
9595
ready!(this.duplex.as_mut().poll_close(cx))?;
9696
*this.closed = true;
9797
}
9898

9999
if matches!(read, Poll::Pending) && matches!(write, Poll::Pending) {
100100
return Poll::Pending;
101+
} else if *this.closed {
102+
return Poll::Ready(Ok(()));
101103
}
102104
}
103105
}

crates/tlsn/tests/test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,11 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
153153
.write_all(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
154154
.await
155155
.unwrap();
156-
tls_connection.close().await.unwrap();
157156

158157
let mut response = vec![0u8; 1024];
159158
tls_connection.read_to_end(&mut response).await.unwrap();
160159

160+
tls_connection.close().await.unwrap();
161161
let _ = server_task.await.unwrap();
162162

163163
let (mut prover, _, mut verifier_socket) = prover_task.await.unwrap().unwrap();

0 commit comments

Comments
 (0)