@@ -19,6 +19,7 @@ use trust_dns_resolver::TokioAsyncResolver;
1919#[ derive( Debug , Clone ) ]
2020struct SpfChecker {
2121 resolver : Arc < TokioAsyncResolver > ,
22+ max_depth : usize ,
2223}
2324
2425#[ derive( Debug , Deserialize ) ]
@@ -46,11 +47,15 @@ struct ErrorResponse {
4647
4748impl SpfChecker {
4849 async fn new ( ) -> Result < Self > {
49- let resolver =
50- TokioAsyncResolver :: tokio ( ResolverConfig :: default ( ) , ResolverOpts :: default ( ) ) ;
50+ let mut opts = ResolverOpts :: default ( ) ;
51+ opts. timeout = std:: time:: Duration :: from_secs ( 2 ) ;
52+ opts. attempts = 2 ;
53+
54+ let resolver = TokioAsyncResolver :: tokio ( ResolverConfig :: default ( ) , opts) ;
5155
5256 Ok ( Self {
5357 resolver : Arc :: new ( resolver) ,
58+ max_depth : 10 ,
5459 } )
5560 }
5661
@@ -77,6 +82,15 @@ impl SpfChecker {
7782 target : String ,
7883 visited : & mut HashSet < String > ,
7984 ) -> Result < ( bool , Option < String > , Option < Vec < String > > ) > {
85+ if visited. len ( ) >= self . max_depth {
86+ log_message ( & format ! (
87+ "Maximum recursion depth of {} reached. Visited domains: {:?}" ,
88+ self . max_depth,
89+ visited. iter( ) . collect:: <Vec <_>>( )
90+ ) ) ;
91+ return Ok ( ( false , None , None ) ) ;
92+ }
93+
8094 if !visited. insert ( domain. clone ( ) ) {
8195 return Ok ( ( false , None , None ) ) ;
8296 }
@@ -151,8 +165,6 @@ async fn check_spf(
151165 Query ( params) : Query < SpfCheckParams > ,
152166 checker : axum:: extract:: State < Arc < SpfChecker > > ,
153167) -> impl IntoResponse {
154- log_message ( & format ! ( "Request to check \" {}\" for \" {}\" " , params. domain, params. target) ) ;
155-
156168 let start = std:: time:: Instant :: now ( ) ;
157169 let mut visited = HashSet :: new ( ) ;
158170
@@ -161,19 +173,36 @@ async fn check_spf(
161173 . await
162174 {
163175 Ok ( ( found, spf_record, included_domains) ) => {
176+ let elapsed_ms = start. elapsed ( ) . as_millis ( ) as u64 ;
177+ log_message ( & format ! (
178+ "Successfully checked \" {}\" for \" {}\" ({}ms)" ,
179+ params. domain,
180+ params. target,
181+ elapsed_ms
182+ ) ) ;
183+
164184 let response = SpfCheckResponse {
165185 found,
166186 checked_domains : visited. len ( ) ,
167187 domain : params. domain ,
168188 target : params. target ,
169- elapsed_ms : start . elapsed ( ) . as_millis ( ) as u64 ,
189+ elapsed_ms,
170190 has_spf_record : spf_record. is_some ( ) ,
171191 spf_record,
172192 included_domains,
173193 } ;
174194 ( StatusCode :: OK , Json ( response) ) . into_response ( )
175195 }
176196 Err ( err) => {
197+ let elapsed_ms = start. elapsed ( ) . as_millis ( ) as u64 ;
198+ log_message ( & format ! (
199+ "Failed to check \" {}\" for \" {}\" : {} ({}ms)" ,
200+ params. domain,
201+ params. target,
202+ err,
203+ elapsed_ms
204+ ) ) ;
205+
177206 let error = ErrorResponse {
178207 error : err. to_string ( ) ,
179208 } ;
0 commit comments