Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ edition = '2021'
repository = 'https://github.com/narrowlink/udp-stream'
license = 'MIT'
name = 'udp-stream'
version = '0.0.12'
version = '0.1.0'
keywords = ["stream", "udp", "dtls", "tokio"]

[dependencies]
Expand All @@ -14,7 +14,7 @@ log = "0.4"
tokio = { version = "1", features = ["rt", "sync", "net", "macros", "io-util"] }

[dev-dependencies]
env_logger = "0.10"
env_logger = "0.11"
openssl = { version = "0.10", features = ["vendored"] }
tokio = { version = "1", features = ["time", "rt-multi-thread"] }
tokio-openssl = '0.6'
15 changes: 5 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,20 @@

To use `udp-stream` in your Rust project, simply add it as a dependency in your `Cargo.toml` file:

toml

```[dependencies]
udp-stream = "0.0.12"
```toml
[dependencies]
udp-stream = "0.1"
```

Then, you can import and use the library in your Rust code:

rust

```
```rust,no_run
use std::{net::SocketAddr, str::FromStr};

use tokio::io::{AsyncReadExt, AsyncWriteExt};

use udp_stream::UdpStream;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut stream = UdpStream::connect(SocketAddr::from_str("127.0.0.1:8080")?).await?;
println!("Ready to Connected to {}", &stream.peer_addr()?);
let mut buffer = String::new();
Expand Down
20 changes: 11 additions & 9 deletions examples/echo-dtls.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{error::Error, net::SocketAddr, pin::Pin, str::FromStr, time::Duration};
use std::{net::SocketAddr, pin::Pin, str::FromStr, time::Duration};
use udp_stream::UdpListener;

use openssl::{
Expand Down Expand Up @@ -29,28 +29,30 @@ fn ssl_acceptor(certificate: &[u8], private_key: &[u8]) -> std::io::Result<SslCo
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let listener = UdpListener::bind(SocketAddr::from_str("127.0.0.1:8080")?).await?;
let acceptor = ssl_acceptor(SERVER_CERT, SERVER_KEY)?;
loop {
let (socket, _) = listener.accept().await?;
let acceptor = acceptor.clone();
tokio::spawn(async move {
let ssl = Ssl::new(&acceptor).unwrap();
let mut stream = tokio_openssl::SslStream::new(ssl, socket).unwrap();
Pin::new(&mut stream).accept().await.unwrap();
let ssl = Ssl::new(&acceptor).map_err(std::io::Error::other)?;
let mut stream = tokio_openssl::SslStream::new(ssl, socket).map_err(std::io::Error::other)?;
Pin::new(&mut stream).accept().await.map_err(std::io::Error::other)?;
let mut buf = vec![0u8; UDP_BUFFER_SIZE];
loop {
let duration = Duration::from_millis(UDP_TIMEOUT);
let n = match timeout(duration, stream.read(&mut buf)).await.unwrap() {
Ok(len) => len,
let n = match timeout(duration, stream.read(&mut buf)).await {
Ok(len) => len?,
Err(_) => {
return;
stream.shutdown().await?;
break;
}
};

stream.write_all(&buf[0..n]).await.unwrap();
stream.write_all(&buf[0..n]).await?;
}
Ok::<(), std::io::Error>(())
});
}
}
14 changes: 10 additions & 4 deletions examples/echo-udp.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{error::Error, net::SocketAddr, str::FromStr, time::Duration};
use std::{net::SocketAddr, str::FromStr, time::Duration};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
time::timeout,
Expand All @@ -8,7 +8,7 @@ const UDP_BUFFER_SIZE: usize = 17480; // 17kb
const UDP_TIMEOUT: u64 = 10 * 1000; // 10sec

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let level = format!("{}={}", module_path!(), "trace");
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(level)).init();

Expand All @@ -21,11 +21,17 @@ async fn main() -> Result<(), Box<dyn Error>> {
let mut buf = vec![0u8; UDP_BUFFER_SIZE];
let duration = Duration::from_millis(UDP_TIMEOUT);
loop {
let n = timeout(duration, stream.read(&mut buf)).await??;
let n = match timeout(duration, stream.read(&mut buf)).await {
Err(err) => {
log::debug!("{id:?} {err}");
stream.shutdown().await?;
break;
}
Ok(val) => val?,
};
stream.write_all(&buf[0..n]).await?;
log::trace!("{:?} echoed {:?} for {} bytes", id, stream.peer_addr(), n);
}
#[allow(unreachable_code)]
Ok::<(), std::io::Error>(())
};
if let Err(e) = block.await {
Expand Down
1 change: 1 addition & 0 deletions rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
max_width = 140
67 changes: 36 additions & 31 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#![doc = include_str!("../README.md")]

use bytes::{Buf, Bytes, BytesMut};
use std::{
collections::HashMap,
future::Future,
io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
pin::Pin,
sync::Arc,
Expand Down Expand Up @@ -55,12 +56,13 @@ impl Drop for UdpListener {

impl UdpListener {
/// Binds the `UdpListener` to the given local address.
pub async fn bind(local_addr: SocketAddr) -> io::Result<Self> {
pub async fn bind(local_addr: SocketAddr) -> std::io::Result<Self> {
let udp_socket = UdpSocket::bind(local_addr).await?;
Self::from_tokio(udp_socket).await
}

/// Creates a `UdpListener` from an existing `tokio::net::UdpSocket`.
pub async fn from_tokio(udp_socket: UdpSocket) -> io::Result<Self> {
pub async fn from_tokio(udp_socket: UdpSocket) -> std::io::Result<Self> {
let (tx, rx) = mpsc::channel(CHANNEL_LEN);
let local_addr = udp_socket.local_addr()?;

Expand Down Expand Up @@ -101,6 +103,7 @@ impl UdpListener {
socket: socket.clone(),
handler: None,
drop: Some(drop_tx.clone()),
drop_sent: false,
remaining: None,
};
if let Err(err) = tx.send((udp_stream, peer_addr)).await {
Expand All @@ -122,18 +125,18 @@ impl UdpListener {
}

///Returns the local address that this socket is bound to.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
Ok(self.local_addr)
}

/// Accepts a new incoming UDP connection.
pub async fn accept(&self) -> io::Result<(UdpStream, SocketAddr)> {
pub async fn accept(&self) -> std::io::Result<(UdpStream, SocketAddr)> {
self.receiver
.lock()
.await
.recv()
.await
.ok_or(io::Error::from(io::ErrorKind::BrokenPipe))
.ok_or(std::io::Error::from(std::io::ErrorKind::BrokenPipe))
}
}

Expand All @@ -153,6 +156,7 @@ pub struct UdpStream {
socket: Arc<tokio::net::UdpSocket>,
handler: Option<tokio::task::JoinHandle<()>>,
drop: Option<mpsc::Sender<SocketAddr>>,
drop_sent: bool,
remaining: Option<Bytes>,
}

Expand All @@ -162,9 +166,9 @@ impl Drop for UdpStream {
handler.abort()
}

if let Some(drop) = &self.drop {
let _ = drop.try_send(self.peer_addr);
};
if let Err(e) = self.send_drop_helper() {
log::error!("drop send_drop_helper {:?}", e);
}
}
}

Expand All @@ -187,10 +191,7 @@ impl UdpStream {
/// Creates a new UdpStream from a tokio::net::UdpSocket.
/// This function is intended to be used to wrap a UDP socket from the tokio library.
/// Note: The UdpSocket must have the UdpSocket::connect method called before invoking this function.
pub async fn from_tokio(
socket: UdpSocket,
peer_addr: SocketAddr,
) -> Result<Self, tokio::io::Error> {
pub async fn from_tokio(socket: UdpSocket, peer_addr: SocketAddr) -> Result<Self, tokio::io::Error> {
let socket = Arc::new(socket);

let local_addr = socket.local_addr()?;
Expand All @@ -201,8 +202,7 @@ impl UdpStream {

let handler = tokio::spawn(async move {
let mut buf = BytesMut::with_capacity(UDP_BUFFER_SIZE);
while let Ok((len, received_addr)) = socket_inner.clone().recv_buf_from(&mut buf).await
{
while let Ok((len, received_addr)) = socket_inner.clone().recv_buf_from(&mut buf).await {
if received_addr != peer_addr {
continue;
}
Expand All @@ -224,6 +224,7 @@ impl UdpStream {
socket: socket.clone(),
handler: Some(handler),
drop: None,
drop_sent: false,
remaining: None,
})
}
Expand All @@ -234,19 +235,22 @@ impl UdpStream {
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
Ok(self.local_addr)
}
pub fn shutdown(&self) {
if let Some(drop) = &self.drop {
let _ = drop.try_send(self.peer_addr);
};

fn send_drop_helper(&mut self) -> std::io::Result<()> {
if !self.drop_sent {
if let Some(drop) = &self.drop {
match drop.try_send(self.peer_addr) {
Ok(_) => self.drop_sent = true,
Err(err) => return Err(std::io::Error::other(err)),
}
}
}
Ok(())
}
}

impl AsyncRead for UdpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll<std::io::Result<()>> {
if let Some(remaining) = self.remaining.as_mut() {
if buf.remaining() < remaining.len() {
buf.put_slice(&remaining.split_to(buf.remaining())[..]);
Expand All @@ -271,29 +275,30 @@ impl AsyncRead for UdpStream {
buf.put_slice(&inner_buf[..]);
Poll::Ready(Ok(()))
}
Poll::Ready(None) => Poll::Ready(Err(io::Error::from(io::ErrorKind::BrokenPipe))),
Poll::Ready(None) => Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe))),
Poll::Pending => Poll::Pending,
}
}
}

impl AsyncWrite for UdpStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<std::io::Result<usize>> {
match self.socket.poll_send_to(cx, buf, self.peer_addr) {
Poll::Ready(Ok(r)) => Poll::Ready(Ok(r)),
Poll::Ready(Err(e)) => {
if let Some(drop) = &self.drop {
let _ = drop.try_send(self.peer_addr);
};
if let Err(err) = self.send_drop_helper() {
log::error!("poll_write send_drop_helper {:?}", err);
}
Poll::Ready(Err(e))
}
Poll::Pending => Poll::Pending,
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context) -> Poll<std::io::Result<()>> {
self.send_drop_helper()?;
Poll::Ready(Ok(()))
}
}
Loading