Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 83 additions & 47 deletions crates/engineioxide/src/transport/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! Other functions are used internally to handle the websocket connection through tasks and channels
//! and to handle upgrade from polling to ws

use std::sync::Arc;
use std::{ops::ControlFlow, sync::Arc};

use futures_util::{
SinkExt, StreamExt, TryStreamExt,
Expand All @@ -15,7 +15,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tungstenite::{
WebSocketStream,
tungstenite::{
Message,
Message, Utf8Bytes,
handshake::derive_accept_key,
protocol::{Role, WebSocketConfig},
},
Expand All @@ -33,6 +33,7 @@ use crate::{
packet::{OpenPacket, Packet},
service::ProtocolVersion,
service::TransportType,
socket::PacketBuf,
};

/// Create a response for websocket upgrade
Expand Down Expand Up @@ -250,57 +251,92 @@ async fn forward_to_socket<H: EngineIoHandler, S>(
{
let mut internal_rx = socket.internal_rx.try_lock().unwrap();

// map a packet to a websocket message
// It is declared as a macro rather than a closure to avoid ownership issues
macro_rules! map_fn {
($item:ident) => {
let res = match $item {
Packet::Binary(bin) | Packet::BinaryV3(bin) => {
if socket.protocol == ProtocolVersion::V3 {
// v3 protocol requires packet type as the first byte.
// This requires a new buffer. This is OK as it is only for the V3 protocol.
let mut buff = Vec::with_capacity(bin.len() + 1);
buff.push(0x04);
buff.extend(bin);
tx.feed(Message::Binary(buff.into())).await
} else {
tx.feed(Message::Binary(bin)).await
}
}
Packet::Close => {
tx.send(Message::Close(None)).await.ok();
internal_rx.close();
break;
},
// A Noop Packet maybe sent by the server to upgrade from a polling connection
// In the case that the packet was not poll in time it will remain in the buffer and therefore
// it should be discarded here
Packet::Noop => Ok(()),
_ => {
let packet: String = $item.try_into().unwrap();
tx.feed(Message::Text(packet.into())).await
}
};
if let Err(_e) = res {
#[cfg(feature = "tracing")]
tracing::debug!("[sid={}] error sending packet: {}", socket.id, _e);
}
};
}
loop {
let Some(packets) = internal_rx.recv().await else {
break;
};

while let Some(items) = internal_rx.recv().await {
for item in items {
map_fn!(item);
}
let mut should_close = feed_all(&mut tx, packets, &socket).await.is_break();
// For every available packet we continue to send until the channel is drained
while let Ok(items) = internal_rx.try_recv() {
for item in items {
map_fn!(item);
while !should_close {
match internal_rx.try_recv() {
Ok(packets) => should_close = feed_all(&mut tx, packets, &socket).await.is_break(),
Err(_) => break,
}
}

tx.flush().await.ok();

// A `Packet::Close` was sent: close the channel so that pending senders
// are notified and stop forwarding.
if should_close {
internal_rx.close();
break;
}
}
}

/// Feeds a batch of packets to the sink, stopping early and returning
/// [`ControlFlow::Break`] as soon as a [`Packet::Close`] is encountered.
async fn feed_all<S, D>(
tx: &mut SplitSink<WebSocketStream<S>, Message>,
packets: PacketBuf,
socket: &Socket<D>,
) -> ControlFlow<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
D: Default + Send + Sync + 'static,
{
for packet in packets {
feed_tx(tx, packet, socket).await?;
}
ControlFlow::Continue(())
}

/// Helper that will feed the sink with the current packet.
///
/// Return [`ControlFlow::Break`] if we return a [`Packet::Close`] and
/// that we should stop everything.
async fn feed_tx<S, D>(
tx: &mut SplitSink<WebSocketStream<S>, Message>,
packet: Packet,
socket: &Socket<D>,
) -> ControlFlow<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
D: Default + Send + Sync + 'static,
{
let res = match packet {
Packet::Binary(bin) | Packet::BinaryV3(bin) => {
if socket.protocol == ProtocolVersion::V3 {
// v3 protocol requires packet type as the first byte.
// This requires a new buffer. This is OK as it is only for the V3 protocol.
let mut buff = Vec::with_capacity(bin.len() + 1);
buff.push(0x04);
buff.extend(bin);
tx.feed(Message::Binary(buff.into())).await
} else {
tx.feed(Message::Binary(bin)).await
}
}
Packet::Close => {
tx.send(Message::Close(None)).await.ok();
return ControlFlow::Break(());
}
// A Noop Packet maybe sent by the server to upgrade from a polling connection
// In the case that the packet was not poll in time it will remain in the buffer and therefore
// it should be discarded here
Packet::Noop => Ok(()),
_ => {
tx.feed(Message::Text(Utf8Bytes::from(String::from(packet))))
.await
}
};
if let Err(_e) = res {
#[cfg(feature = "tracing")]
tracing::debug!(sid = %socket.id, "failed to send packet to websocket: {}", _e);
}

ControlFlow::Continue(())
}
/// Send a Engine.IO [`OpenPacket`] to initiate a websocket connection
async fn init_handshake<S>(
Expand Down