Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
name = "bandwhich"
version = "0.23.1"
authors = [
"Aram Drevekenin <[email protected]>",
"Eduardo Toledo <[email protected]>",
"Eduardo Broto <[email protected]>",
"Kelvin Zhang <[email protected]>",
"Brooks Rady <[email protected]>",
"cyqsimon <[email protected]>",
"Aram Drevekenin <[email protected]>",
"Eduardo Toledo <[email protected]>",
"Eduardo Broto <[email protected]>",
"Kelvin Zhang <[email protected]>",
"Brooks Rady <[email protected]>",
"cyqsimon <[email protected]>",
]
categories = ["network-programming", "command-line-utilities"]
edition = "2021"
Expand Down Expand Up @@ -42,7 +42,7 @@ ratatui = "0.29.0"
resolv-conf = "0.7.4"
simplelog = "0.12.2"
thiserror = "2.0.12"
tokio = { version = "1.46", features = ["rt", "sync"] }
tokio = { version = "1.46", features = ["rt", "sync", "macros"] }
trust-dns-resolver = "0.23.2"
unicode-width = "0.2.0"
strum = { version = "0.27.1", features = ["derive"] }
Expand Down
82 changes: 69 additions & 13 deletions src/network/dns/resolver.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4},
time::Duration,
};

use async_trait::async_trait;
use log::warn;
use tokio::time::sleep;
use trust_dns_resolver::{
config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts},
error::ResolveErrorKind,
TokioAsyncResolver,
};

Expand Down Expand Up @@ -40,18 +44,70 @@ impl Resolver {
#[async_trait]
impl Lookup for Resolver {
async fn lookup(&self, ip: IpAddr) -> Option<String> {
let lookup_future = self.0.reverse_lookup(ip);
match lookup_future.await {
Ok(names) => {
// Take the first result and convert it to a string
names.into_iter().next().map(|name| name.to_string())
}
Err(e) => match e.kind() {
// If the IP is not associated with a hostname, store the IP
// so that we don't retry indefinitely
ResolveErrorKind::NoRecordsFound { .. } => Some(ip.to_string()),
_ => None,
let retry_config = RetryPolicy {
max_retries: 3,
..Default::default()
};

retry_with_backoff(
|| {
let resolver = &self.0;
async move {
resolver
.reverse_lookup(ip)
.await
.ok()
.and_then(|names| names.iter().next().map(|n| n.to_string()))
.or_else(|| Some(ip.to_string()))
}
},
retry_config.max_retries,
retry_config.base_delay,
)
.await
.or_else(|| Some("DNS lookup timeout.".into()))
}
}

struct RetryPolicy {
max_retries: u8,
base_delay: tokio::time::Duration,
}

impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_retries: 2,
base_delay: Duration::from_millis(1000),
}
}
}

pub async fn retry_with_backoff<F, Fut, T>(
mut operation: F,
max_reties: u8,
inittial_delay: tokio::time::Duration,
) -> Option<T>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Option<T>>,
{
let mut delay = inittial_delay;
for attemp in 0..=max_reties {
match operation().await {
Some(value) => return Some(value),
None if attemp < max_reties => {
warn!(
"Retrying.. attemp: {}/{} (waiting {:?})",
attemp + 1,
max_reties,
delay
);
sleep(delay).await;
delay *= 2;
}
None => return None,
}
}
None
}
1 change: 1 addition & 0 deletions src/tests/cases/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod network;
pub mod raw_mode;
pub mod test_utils;
#[cfg(feature = "ui_test")]
Expand Down
63 changes: 63 additions & 0 deletions src/tests/cases/network.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#[cfg(test)]
mod tests {
use super::*;
use crate::network::dns;
use std::sync::{
atomic::{AtomicU8, Ordering},
Arc,
};
use tokio::time::Instant;

#[tokio::test]
async fn retry_should_succeed_after_3_attempts() {
let counter = Arc::new(AtomicU8::new(0));
let counter_clone = counter.clone();

let start = Instant::now();

let result = dns::retry_with_backoff(
move || {
let counter = counter_clone.clone();
async move {
let attempt = counter.fetch_add(1, Ordering::SeqCst);
if attempt >= 2 {
Some("Success".to_string())
} else {
None
}
}
},
5,
std::time::Duration::from_millis(50),
)
.await;

let duration = start.elapsed();

assert_eq!(result, Some("Success".to_string()));
assert!(duration >= std::time::Duration::from_millis(50 + 100)); // 2 delays
assert!(counter.load(Ordering::SeqCst) == 3); // called 3 times
}

#[tokio::test]
async fn retry_should_fail_after_max_retries() {
let counter = Arc::new(AtomicU8::new(0));
let counter_clone = counter.clone();

let result: Option<()> = dns::retry_with_backoff(
move || {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
None
}
},
3,
std::time::Duration::from_millis(10),
)
.await;

assert_eq!(result, None);
assert_eq!(counter.load(Ordering::SeqCst), 4); // initial try + 3 retries
}
}