Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ edition = '2021'
repository = 'https://github.com/narrowlink/udp-stream'
license = 'MIT'
name = 'udp-stream'
version = '0.0.12'
version = '0.1.0'
keywords = ["stream", "udp", "dtls", "tokio"]

[dependencies]
Expand All @@ -14,7 +14,7 @@ log = "0.4"
tokio = { version = "1", features = ["rt", "sync", "net", "macros", "io-util"] }

[dev-dependencies]
env_logger = "0.10"
env_logger = "0.11"
openssl = { version = "0.10", features = ["vendored"] }
tokio = { version = "1", features = ["time", "rt-multi-thread"] }
tokio-openssl = '0.6'
15 changes: 5 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,20 @@

To use `udp-stream` in your Rust project, simply add it as a dependency in your `Cargo.toml` file:

toml

```[dependencies]
udp-stream = "0.0.12"
```toml
[dependencies]
udp-stream = "0.1"
```

Then, you can import and use the library in your Rust code:

rust

```
```rust,no_run
use std::{net::SocketAddr, str::FromStr};

use tokio::io::{AsyncReadExt, AsyncWriteExt};

use udp_stream::UdpStream;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut stream = UdpStream::connect(SocketAddr::from_str("127.0.0.1:8080")?).await?;
println!("Ready to Connected to {}", &stream.peer_addr()?);
let mut buffer = String::new();
Expand Down
20 changes: 11 additions & 9 deletions examples/echo-dtls.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{error::Error, net::SocketAddr, pin::Pin, str::FromStr, time::Duration};
use std::{net::SocketAddr, pin::Pin, str::FromStr, time::Duration};
use udp_stream::UdpListener;

use openssl::{
Expand Down Expand Up @@ -29,28 +29,30 @@ fn ssl_acceptor(certificate: &[u8], private_key: &[u8]) -> std::io::Result<SslCo
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let listener = UdpListener::bind(SocketAddr::from_str("127.0.0.1:8080")?).await?;
let acceptor = ssl_acceptor(SERVER_CERT, SERVER_KEY)?;
loop {
let (socket, _) = listener.accept().await?;
let acceptor = acceptor.clone();
tokio::spawn(async move {
let ssl = Ssl::new(&acceptor).unwrap();
let mut stream = tokio_openssl::SslStream::new(ssl, socket).unwrap();
Pin::new(&mut stream).accept().await.unwrap();
let ssl = Ssl::new(&acceptor).map_err(std::io::Error::other)?;
let mut stream = tokio_openssl::SslStream::new(ssl, socket).map_err(std::io::Error::other)?;
Pin::new(&mut stream).accept().await.map_err(std::io::Error::other)?;
let mut buf = vec![0u8; UDP_BUFFER_SIZE];
loop {
let duration = Duration::from_millis(UDP_TIMEOUT);
let n = match timeout(duration, stream.read(&mut buf)).await.unwrap() {
Ok(len) => len,
let n = match timeout(duration, stream.read(&mut buf)).await {
Ok(len) => len?,
Err(_) => {
return;
stream.shutdown().await?;
break;
}
};

stream.write_all(&buf[0..n]).await.unwrap();
stream.write_all(&buf[0..n]).await?;
}
Ok::<(), std::io::Error>(())
});
}
}
14 changes: 10 additions & 4 deletions examples/echo-udp.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{error::Error, net::SocketAddr, str::FromStr, time::Duration};
use std::{net::SocketAddr, str::FromStr, time::Duration};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
time::timeout,
Expand All @@ -8,7 +8,7 @@ const UDP_BUFFER_SIZE: usize = 17480; // 17kb
const UDP_TIMEOUT: u64 = 10 * 1000; // 10sec

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let level = format!("{}={}", module_path!(), "trace");
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(level)).init();

Expand All @@ -21,11 +21,17 @@ async fn main() -> Result<(), Box<dyn Error>> {
let mut buf = vec![0u8; UDP_BUFFER_SIZE];
let duration = Duration::from_millis(UDP_TIMEOUT);
loop {
let n = timeout(duration, stream.read(&mut buf)).await??;
let n = match timeout(duration, stream.read(&mut buf)).await {
Err(err) => {
log::debug!("{id:?} {err}");
stream.shutdown().await?;
break;
}
Ok(val) => val?,
};
stream.write_all(&buf[0..n]).await?;
log::trace!("{:?} echoed {:?} for {} bytes", id, stream.peer_addr(), n);
}
#[allow(unreachable_code)]
Ok::<(), std::io::Error>(())
};
if let Err(e) = block.await {
Expand Down
1 change: 1 addition & 0 deletions rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
max_width = 140
71 changes: 39 additions & 32 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#![doc = include_str!("../README.md")]

use bytes::{Buf, Bytes, BytesMut};
use std::{
collections::HashMap,
future::Future,
io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
pin::Pin,
sync::Arc,
Expand Down Expand Up @@ -55,12 +56,13 @@ impl Drop for UdpListener {

impl UdpListener {
/// Binds the `UdpListener` to the given local address.
pub async fn bind(local_addr: SocketAddr) -> io::Result<Self> {
pub async fn bind(local_addr: SocketAddr) -> std::io::Result<Self> {
let udp_socket = UdpSocket::bind(local_addr).await?;
Self::from_tokio(udp_socket).await
}

/// Creates a `UdpListener` from an existing `tokio::net::UdpSocket`.
pub async fn from_tokio(udp_socket: UdpSocket) -> io::Result<Self> {
pub async fn from_tokio(udp_socket: UdpSocket) -> std::io::Result<Self> {
let (tx, rx) = mpsc::channel(CHANNEL_LEN);
let local_addr = udp_socket.local_addr()?;

Expand Down Expand Up @@ -101,6 +103,7 @@ impl UdpListener {
socket: socket.clone(),
handler: None,
drop: Some(drop_tx.clone()),
drop_sent: false,
remaining: None,
};
if let Err(err) = tx.send((udp_stream, peer_addr)).await {
Expand All @@ -122,18 +125,18 @@ impl UdpListener {
}

///Returns the local address that this socket is bound to.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
Ok(self.local_addr)
}

/// Accepts a new incoming UDP connection.
pub async fn accept(&self) -> io::Result<(UdpStream, SocketAddr)> {
pub async fn accept(&self) -> std::io::Result<(UdpStream, SocketAddr)> {
self.receiver
.lock()
.await
.recv()
.await
.ok_or(io::Error::from(io::ErrorKind::BrokenPipe))
.ok_or(std::io::Error::from(std::io::ErrorKind::BrokenPipe))
}
}

Expand All @@ -153,6 +156,7 @@ pub struct UdpStream {
socket: Arc<tokio::net::UdpSocket>,
handler: Option<tokio::task::JoinHandle<()>>,
drop: Option<mpsc::Sender<SocketAddr>>,
drop_sent: bool,
remaining: Option<Bytes>,
}

Expand All @@ -162,9 +166,13 @@ impl Drop for UdpStream {
handler.abort()
}

if let Some(drop) = &self.drop {
let _ = drop.try_send(self.peer_addr);
};
if !self.drop_sent {
if let Some(drop) = &self.drop {
if drop.try_send(self.peer_addr).is_ok() {
self.drop_sent = true;
}
}
}
}
}

Expand All @@ -187,10 +195,7 @@ impl UdpStream {
/// Creates a new UdpStream from a tokio::net::UdpSocket.
/// This function is intended to be used to wrap a UDP socket from the tokio library.
/// Note: The UdpSocket must have the UdpSocket::connect method called before invoking this function.
pub async fn from_tokio(
socket: UdpSocket,
peer_addr: SocketAddr,
) -> Result<Self, tokio::io::Error> {
pub async fn from_tokio(socket: UdpSocket, peer_addr: SocketAddr) -> Result<Self, tokio::io::Error> {
let socket = Arc::new(socket);

let local_addr = socket.local_addr()?;
Expand All @@ -201,8 +206,7 @@ impl UdpStream {

let handler = tokio::spawn(async move {
let mut buf = BytesMut::with_capacity(UDP_BUFFER_SIZE);
while let Ok((len, received_addr)) = socket_inner.clone().recv_buf_from(&mut buf).await
{
while let Ok((len, received_addr)) = socket_inner.clone().recv_buf_from(&mut buf).await {
if received_addr != peer_addr {
continue;
}
Expand All @@ -224,6 +228,7 @@ impl UdpStream {
socket: socket.clone(),
handler: Some(handler),
drop: None,
drop_sent: false,
remaining: None,
})
}
Expand All @@ -234,19 +239,10 @@ impl UdpStream {
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
Ok(self.local_addr)
}
pub fn shutdown(&self) {
if let Some(drop) = &self.drop {
let _ = drop.try_send(self.peer_addr);
};
}
}

impl AsyncRead for UdpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll<std::io::Result<()>> {
if let Some(remaining) = self.remaining.as_mut() {
if buf.remaining() < remaining.len() {
buf.put_slice(&remaining.split_to(buf.remaining())[..]);
Expand All @@ -271,29 +267,40 @@ impl AsyncRead for UdpStream {
buf.put_slice(&inner_buf[..]);
Poll::Ready(Ok(()))
}
Poll::Ready(None) => Poll::Ready(Err(io::Error::from(io::ErrorKind::BrokenPipe))),
Poll::Ready(None) => Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe))),
Poll::Pending => Poll::Pending,
}
}
}

impl AsyncWrite for UdpStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<std::io::Result<usize>> {
match self.socket.poll_send_to(cx, buf, self.peer_addr) {
Poll::Ready(Ok(r)) => Poll::Ready(Ok(r)),
Poll::Ready(Err(e)) => {
if let Some(drop) = &self.drop {
let _ = drop.try_send(self.peer_addr);
};
if !self.drop_sent {
if let Some(drop) = &self.drop {
if drop.try_send(self.peer_addr).is_ok() {
self.drop_sent = true;
}
}
}
Poll::Ready(Err(e))
}
Poll::Pending => Poll::Pending,
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context) -> Poll<std::io::Result<()>> {
if !self.drop_sent {
if let Some(drop) = &self.drop {
if drop.try_send(self.peer_addr).is_ok() {
self.drop_sent = true;
}
}
}
Poll::Ready(Ok(()))
}
}
Loading