Skip to content

Commit eae8c69

Browse files
committed
🐳 Add limits and timeouts. Fixes #1
1 parent e7f3ca7 commit eae8c69

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

src/main.rs

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use trust_dns_resolver::TokioAsyncResolver;
1919
#[derive(Debug, Clone)]
2020
struct SpfChecker {
2121
resolver: Arc<TokioAsyncResolver>,
22+
max_depth: usize,
2223
}
2324

2425
#[derive(Debug, Deserialize)]
@@ -46,11 +47,15 @@ struct ErrorResponse {
4647

4748
impl SpfChecker {
4849
async fn new() -> Result<Self> {
49-
let resolver =
50-
TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
50+
let mut opts = ResolverOpts::default();
51+
opts.timeout = std::time::Duration::from_secs(2);
52+
opts.attempts = 2;
53+
54+
let resolver = TokioAsyncResolver::tokio(ResolverConfig::default(), opts);
5155

5256
Ok(Self {
5357
resolver: Arc::new(resolver),
58+
max_depth: 10,
5459
})
5560
}
5661

@@ -77,6 +82,15 @@ impl SpfChecker {
7782
target: String,
7883
visited: &mut HashSet<String>,
7984
) -> Result<(bool, Option<String>, Option<Vec<String>>)> {
85+
if visited.len() >= self.max_depth {
86+
log_message(&format!(
87+
"Maximum recursion depth of {} reached. Visited domains: {:?}",
88+
self.max_depth,
89+
visited.iter().collect::<Vec<_>>()
90+
));
91+
return Ok((false, None, None));
92+
}
93+
8094
if !visited.insert(domain.clone()) {
8195
return Ok((false, None, None));
8296
}
@@ -151,8 +165,6 @@ async fn check_spf(
151165
Query(params): Query<SpfCheckParams>,
152166
checker: axum::extract::State<Arc<SpfChecker>>,
153167
) -> impl IntoResponse {
154-
log_message(&format!("Request to check \"{}\" for \"{}\"", params.domain, params.target));
155-
156168
let start = std::time::Instant::now();
157169
let mut visited = HashSet::new();
158170

@@ -161,19 +173,36 @@ async fn check_spf(
161173
.await
162174
{
163175
Ok((found, spf_record, included_domains)) => {
176+
let elapsed_ms = start.elapsed().as_millis() as u64;
177+
log_message(&format!(
178+
"Successfully checked \"{}\" for \"{}\" ({}ms)",
179+
params.domain,
180+
params.target,
181+
elapsed_ms
182+
));
183+
164184
let response = SpfCheckResponse {
165185
found,
166186
checked_domains: visited.len(),
167187
domain: params.domain,
168188
target: params.target,
169-
elapsed_ms: start.elapsed().as_millis() as u64,
189+
elapsed_ms,
170190
has_spf_record: spf_record.is_some(),
171191
spf_record,
172192
included_domains,
173193
};
174194
(StatusCode::OK, Json(response)).into_response()
175195
}
176196
Err(err) => {
197+
let elapsed_ms = start.elapsed().as_millis() as u64;
198+
log_message(&format!(
199+
"Failed to check \"{}\" for \"{}\": {} ({}ms)",
200+
params.domain,
201+
params.target,
202+
err,
203+
elapsed_ms
204+
));
205+
177206
let error = ErrorResponse {
178207
error: err.to_string(),
179208
};

0 commit comments

Comments
 (0)