diff --git a/src/proxy.rs b/src/proxy.rs index 828d777b..373251eb 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -35,6 +35,11 @@ pub const HTTPJAIL_HEADER: &str = "HTTPJAIL"; pub const HTTPJAIL_HEADER_VALUE: &str = "true"; pub const BLOCKED_MESSAGE: &str = "Request blocked by httpjail"; +/// Header added to outgoing requests to detect loops (Issue #84) +/// Contains comma-separated nonces of all httpjail instances in the proxy chain. +/// If we see our own nonce in an incoming request, we're in a loop. +pub const HTTPJAIL_LOOP_DETECTION_HEADER: &str = "Httpjail-Loop-Prevention"; + /// Create a raw HTTP/1.1 403 Forbidden response for CONNECT tunnels pub fn create_connect_403_response() -> &'static [u8] { b"HTTP/1.1 403 Forbidden\r\nContent-Type: text/plain\r\nContent-Length: 27\r\n\r\nRequest blocked by httpjail" @@ -166,6 +171,7 @@ static HTTPS_CLIENT: OnceLock< pub fn prepare_upstream_request( req: Request, target_uri: Uri, + loop_nonce: &str, ) -> Request> { let (mut parts, incoming_body) = req.into_parts(); @@ -178,6 +184,16 @@ pub fn prepare_upstream_request( parts.headers.remove("proxy-authorization"); parts.headers.remove("proxy-authenticate"); + // SECURITY: Add our nonce to the loop detection header (Issue #84) + // HTTP natively supports multiple values for the same header name (via append). + // This allows chaining multiple httpjail instances while still detecting self-loops. + // Each instance appends its nonce; if we see our own nonce in an incoming request, it's a loop. + parts.headers.append( + HTTPJAIL_LOOP_DETECTION_HEADER, + hyper::header::HeaderValue::from_str(loop_nonce) + .unwrap_or_else(|_| hyper::header::HeaderValue::from_static("invalid")), + ); + // SECURITY: Ensure the Host header matches the URI to prevent routing bypasses (Issue #57) // This prevents attacks where an attacker sends a request to one domain but sets // the Host header to another domain, potentially bypassing security controls in @@ -364,11 +380,19 @@ async fn bind_listener(addr: std::net::SocketAddr) -> Result { TcpListener::bind(addr).await.map_err(Into::into) } +/// Context passed to all proxy handlers - reduces argument duplication +#[derive(Clone)] +pub struct ProxyContext { + pub rule_engine: Arc, + pub cert_manager: Arc, + /// Unique nonce for this proxy instance, used for loop detection (Issue #84) + pub loop_nonce: Arc, +} + pub struct ProxyServer { http_bind: Option, https_bind: Option, - rule_engine: Arc, - cert_manager: Arc, + context: ProxyContext, } impl ProxyServer { @@ -383,11 +407,23 @@ impl ProxyServer { let ca_cert_der = cert_manager.get_ca_cert_der(); init_client_with_ca(ca_cert_der); + // Generate a unique nonce for loop detection (Issue #84) + // Use 16 random hex characters for a reasonably short but collision-resistant ID + let loop_nonce = { + let random_u64: u64 = rand::random(); + format!("{:x}", random_u64) + }; + + let context = ProxyContext { + rule_engine: Arc::new(rule_engine), + cert_manager: Arc::new(cert_manager), + loop_nonce: Arc::new(loop_nonce), + }; + ProxyServer { http_bind, https_bind, - rule_engine: Arc::new(rule_engine), - cert_manager: Arc::new(cert_manager), + context, } } @@ -403,35 +439,13 @@ impl ProxyServer { let http_port = http_listener.local_addr()?.port(); info!("Starting HTTP proxy on port {}", http_port); - let rule_engine = Arc::clone(&self.rule_engine); - let cert_manager = Arc::clone(&self.cert_manager); - // Start HTTP proxy task - tokio::spawn(async move { - loop { - match http_listener.accept().await { - Ok((stream, addr)) => { - debug!("New HTTP connection from {}", addr); - let rule_engine = Arc::clone(&rule_engine); - let cert_manager = Arc::clone(&cert_manager); - - tokio::spawn(async move { - if let Err(e) = - handle_http_connection(stream, rule_engine, cert_manager, addr) - .await - { - error!("Error handling HTTP connection: {:?}", e); - } - }); - } - Err(e) => { - error!("Failed to accept HTTP connection: {}", e); - } - } - } - }); - - // IPv6-specific listener not required; IPv4 listener suffices for jail routing + spawn_listener_task( + http_listener, + self.context.clone(), + "HTTP", + handle_http_connection, + ); // Bind HTTPS listener let https_listener = if let Some(addr) = self.https_bind { @@ -444,35 +458,13 @@ impl ProxyServer { let https_port = https_listener.local_addr()?.port(); info!("Starting HTTPS proxy on port {}", https_port); - let rule_engine = Arc::clone(&self.rule_engine); - let cert_manager = Arc::clone(&self.cert_manager); - // Start HTTPS proxy task - tokio::spawn(async move { - loop { - match https_listener.accept().await { - Ok((stream, addr)) => { - debug!("New HTTPS connection from {}", addr); - let rule_engine = Arc::clone(&rule_engine); - let cert_manager = Arc::clone(&cert_manager); - - tokio::spawn(async move { - if let Err(e) = - handle_https_connection(stream, rule_engine, cert_manager, addr) - .await - { - error!("Error handling HTTPS connection: {:?}", e); - } - }); - } - Err(e) => { - error!("Failed to accept HTTPS connection: {}", e); - } - } - } - }); - - // IPv6-specific listener not required; IPv4 listener suffices for jail routing + spawn_listener_task( + https_listener, + self.context.clone(), + "HTTPS", + handle_https_connection, + ); Ok((http_port, https_port)) } @@ -480,25 +472,50 @@ impl ProxyServer { /// Get the CA certificate for client trust #[allow(dead_code)] pub fn get_ca_cert_pem(&self) -> String { - self.cert_manager.get_ca_cert_pem() + self.context.cert_manager.get_ca_cert_pem() } } +/// Generic listener task spawner to avoid code duplication between HTTP and HTTPS +fn spawn_listener_task( + listener: TcpListener, + context: ProxyContext, + protocol: &'static str, + handler: F, +) where + F: Fn(TcpStream, ProxyContext, SocketAddr) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, +{ + let handler = Arc::new(handler); + tokio::spawn(async move { + loop { + match listener.accept().await { + Ok((stream, addr)) => { + debug!("New {} connection from {}", protocol, addr); + let context = context.clone(); + let handler = Arc::clone(&handler); + + tokio::spawn(async move { + if let Err(e) = handler(stream, context, addr).await { + error!("Error handling {} connection: {:?}", protocol, e); + } + }); + } + Err(e) => { + error!("Failed to accept {} connection: {}", protocol, e); + } + } + } + }); +} + async fn handle_http_connection( stream: TcpStream, - rule_engine: Arc, - cert_manager: Arc, + context: ProxyContext, remote_addr: SocketAddr, ) -> Result<()> { let io = TokioIo::new(stream); - let service = service_fn(move |req| { - handle_http_request( - req, - Arc::clone(&rule_engine), - Arc::clone(&cert_manager), - remote_addr, - ) - }); + let service = service_fn(move |req| handle_http_request(req, context.clone(), remote_addr)); http1::Builder::new() .preserve_header_case(true) @@ -511,24 +528,41 @@ async fn handle_http_connection( async fn handle_https_connection( stream: TcpStream, - rule_engine: Arc, - cert_manager: Arc, + context: ProxyContext, remote_addr: SocketAddr, ) -> Result<()> { // Delegate to the TLS-specific module - crate::proxy_tls::handle_https_connection(stream, rule_engine, cert_manager, remote_addr).await + crate::proxy_tls::handle_https_connection(stream, context, remote_addr).await } pub async fn handle_http_request( req: Request, - rule_engine: Arc, - _cert_manager: Arc, + context: ProxyContext, remote_addr: SocketAddr, ) -> Result>, std::convert::Infallible> { let method = req.method().clone(); let uri = req.uri().clone(); let headers = req.headers().clone(); + // SECURITY: Check for loop detection header (Issue #84) + // HTTP supports multiple values for the same header name. + // Each httpjail instance adds its nonce; if we see our own, it's a loop. + let our_nonce = context.loop_nonce.as_str(); + for value in headers.get_all(HTTPJAIL_LOOP_DETECTION_HEADER).iter() { + if let Ok(nonce) = value.to_str() { + if nonce == our_nonce { + debug!( + "Loop detected: our nonce '{}' found in request to {}", + nonce, uri + ); + return create_forbidden_response(Some( + "Loop detected: request already processed by this httpjail instance" + .to_string(), + )); + } + } + } + // Check if the URI already contains the full URL (proxy request) let full_url = if uri.scheme().is_some() && uri.authority().is_some() { // This is a proxy request with absolute URL (e.g., GET http://example.com/ HTTP/1.1) @@ -551,7 +585,8 @@ pub async fn handle_http_request( // Evaluate rules with method and requester IP let requester_ip = remote_addr.ip().to_string(); - let evaluation = rule_engine + let evaluation = context + .rule_engine .evaluate_with_context_and_ip(method, &full_url, &requester_ip) .await; match evaluation.action { @@ -560,7 +595,8 @@ pub async fn handle_http_request( "Request allowed: {} (max_tx_bytes: {:?})", full_url, evaluation.max_tx_bytes ); - match proxy_request(req, &full_url, evaluation.max_tx_bytes).await { + match proxy_request(req, &full_url, evaluation.max_tx_bytes, &context.loop_nonce).await + { Ok(resp) => Ok(resp), Err(e) => { error!("Proxy error: {}", e); @@ -579,12 +615,13 @@ async fn proxy_request( req: Request, full_url: &str, max_tx_bytes: Option, + loop_nonce: &str, ) -> Result>> { // Parse the target URL let target_uri = full_url.parse::()?; // Prepare request for upstream - let prepared_req = prepare_upstream_request(req, target_uri.clone()); + let prepared_req = prepare_upstream_request(req, target_uri.clone(), loop_nonce); // Apply byte limit to outgoing request if specified, converting to BoxBody let new_req = if let Some(max_bytes) = max_tx_bytes { diff --git a/src/proxy_tls.rs b/src/proxy_tls.rs index 7043a47c..f7c1da56 100644 --- a/src/proxy_tls.rs +++ b/src/proxy_tls.rs @@ -1,8 +1,9 @@ use crate::proxy::{ - HTTPJAIL_HEADER, HTTPJAIL_HEADER_VALUE, apply_request_byte_limit, + HTTPJAIL_HEADER, HTTPJAIL_HEADER_VALUE, ProxyContext, apply_request_byte_limit, create_connect_403_response_with_context, create_forbidden_response, }; -use crate::rules::{Action, RuleEngine}; +use crate::rules::Action; +#[cfg(target_os = "macos")] use crate::tls::CertificateManager; use anyhow::Result; use bytes::Bytes; @@ -35,8 +36,7 @@ const CLIENT_HELLO_TIMEOUT: Duration = Duration::from_secs(5); /// Handle an HTTPS connection with potential CONNECT tunneling and TLS interception pub async fn handle_https_connection( stream: TcpStream, - rule_engine: Arc, - cert_manager: Arc, + context: ProxyContext, remote_addr: std::net::SocketAddr, ) -> Result<()> { debug!("Handling new HTTPS connection from {}", remote_addr); @@ -65,18 +65,18 @@ pub async fn handle_https_connection( if peek_buf[0] == 0x16 && n > 1 && (peek_buf[1] == 0x03 || peek_buf[1] == 0x02) { // This is a TLS ClientHello - we're in transparent proxy mode debug!("Detected TLS ClientHello - transparent proxy mode"); - handle_transparent_tls(stream, rule_engine, cert_manager, remote_addr).await + handle_transparent_tls(stream, context, remote_addr).await } else if peek_buf[0] >= 0x41 && peek_buf[0] <= 0x5A { // This looks like HTTP (starts with uppercase ASCII letter) // Check if it's a CONNECT request let request_str = String::from_utf8_lossy(&peek_buf); if request_str.starts_with("CONNEC") { debug!("Detected CONNECT request - explicit proxy mode"); - handle_connect_tunnel(stream, rule_engine, cert_manager, remote_addr).await + handle_connect_tunnel(stream, context, remote_addr).await } else { // Regular HTTP on HTTPS port debug!("Detected plain HTTP on HTTPS port"); - handle_plain_http(stream, rule_engine, cert_manager, remote_addr).await + handle_plain_http(stream, context, remote_addr).await } } else { warn!( @@ -158,8 +158,7 @@ async fn extract_sni_from_stream(stream: &mut TcpStream) -> Result, - cert_manager: Arc, + context: ProxyContext, remote_addr: std::net::SocketAddr, ) -> Result<()> { debug!("Handling transparent TLS connection"); @@ -182,7 +181,8 @@ async fn handle_transparent_tls( debug!("Processing transparent TLS for: {}", hostname); // Get certificate for the host - let (cert_chain, key) = cert_manager + let (cert_chain, key) = context + .cert_manager .get_cert_for_host(&hostname) .map_err(|e| anyhow::anyhow!("Failed to get certificate for {}: {}", hostname, e))?; @@ -214,7 +214,7 @@ async fn handle_transparent_tls( let io = TokioIo::new(tls_stream); let service = service_fn(move |req| { let host_clone = hostname.clone(); - handle_decrypted_https_request(req, Arc::clone(&rule_engine), host_clone, remote_addr) + handle_decrypted_https_request(req, context.clone(), host_clone, remote_addr) }); debug!("Starting HTTP/1.1 server for decrypted requests"); @@ -230,8 +230,7 @@ async fn handle_transparent_tls( /// Handle a CONNECT tunnel request with TLS interception async fn handle_connect_tunnel( stream: TcpStream, - rule_engine: Arc, - cert_manager: Arc, + context: ProxyContext, remote_addr: std::net::SocketAddr, ) -> Result<()> { debug!("Handling CONNECT tunnel"); @@ -309,7 +308,8 @@ async fn handle_connect_tunnel( // Check if this host is allowed let full_url = format!("https://{}", target); let requester_ip = remote_addr.ip().to_string(); - let evaluation = rule_engine + let evaluation = context + .rule_engine .evaluate_with_context_and_ip(Method::GET, &full_url, &requester_ip) .await; match evaluation.action { @@ -341,7 +341,7 @@ async fn handle_connect_tunnel( debug!("Sent 200 Connection Established, starting TLS handshake"); // Now perform TLS handshake with the client - perform_tls_interception(stream, rule_engine, cert_manager, host, remote_addr).await + perform_tls_interception(stream, context, host, remote_addr).await } Action::Deny => { warn!("CONNECT denied to: {}", host); @@ -373,8 +373,7 @@ async fn handle_connect_tunnel( /// Perform TLS interception on a stream async fn perform_tls_interception( stream: TcpStream, - rule_engine: Arc, - cert_manager: Arc, + context: ProxyContext, host: &str, remote_addr: std::net::SocketAddr, ) -> Result<()> { @@ -391,7 +390,8 @@ async fn perform_tls_interception( } // Get certificate for the host - let (cert_chain, key) = cert_manager + let (cert_chain, key) = context + .cert_manager .get_cert_for_host(host) .map_err(|e| anyhow::anyhow!("Failed to get certificate for {}: {}", host, e))?; @@ -422,10 +422,9 @@ async fn perform_tls_interception( // Now handle the decrypted HTTPS requests let io = TokioIo::new(tls_stream); let host_string = host.to_string(); - let remote_addr_copy = remote_addr; // Copy for the closure let service = service_fn(move |req| { let host_clone = host_string.clone(); - handle_decrypted_https_request(req, Arc::clone(&rule_engine), host_clone, remote_addr_copy) + handle_decrypted_https_request(req, context.clone(), host_clone, remote_addr) }); debug!("Starting HTTP/1.1 server for decrypted requests"); @@ -441,21 +440,14 @@ async fn perform_tls_interception( /// Handle a plain HTTP request on the HTTPS port async fn handle_plain_http( stream: TcpStream, - rule_engine: Arc, - cert_manager: Arc, + context: ProxyContext, remote_addr: std::net::SocketAddr, ) -> Result<()> { debug!("Handling plain HTTP on HTTPS port"); let io = TokioIo::new(stream); - let service = service_fn(move |req| { - crate::proxy::handle_http_request( - req, - Arc::clone(&rule_engine), - Arc::clone(&cert_manager), - remote_addr, - ) - }); + let service = + service_fn(move |req| crate::proxy::handle_http_request(req, context.clone(), remote_addr)); http1::Builder::new() .preserve_header_case(true) @@ -469,7 +461,7 @@ async fn handle_plain_http( /// Handle a decrypted HTTPS request after TLS interception async fn handle_decrypted_https_request( req: Request, - rule_engine: Arc, + context: ProxyContext, host: String, remote_addr: std::net::SocketAddr, ) -> Result>, std::convert::Infallible> { @@ -487,13 +479,16 @@ async fn handle_decrypted_https_request( // Evaluate rules with method and requester IP let requester_ip = remote_addr.ip().to_string(); - let evaluation = rule_engine + let evaluation = context + .rule_engine .evaluate_with_context_and_ip(method.clone(), &full_url, &requester_ip) .await; match evaluation.action { Action::Allow => { debug!("Request allowed: {}", full_url); - match proxy_https_request(req, &host, evaluation.max_tx_bytes).await { + match proxy_https_request(req, &host, evaluation.max_tx_bytes, &context.loop_nonce) + .await + { Ok(resp) => Ok(resp), Err(e) => { error!("Proxy error: {}", e); @@ -513,6 +508,7 @@ async fn proxy_https_request( req: Request, host: &str, max_tx_bytes: Option, + loop_nonce: &str, ) -> Result>> { // Build the target URL let path = req @@ -526,7 +522,7 @@ async fn proxy_https_request( debug!("Forwarding request to: {}", target_url); // Prepare request for upstream using common function - let prepared_req = crate::proxy::prepare_upstream_request(req, target_uri); + let prepared_req = crate::proxy::prepare_upstream_request(req, target_uri, loop_nonce); // Apply byte limit to outgoing request if specified, converting to BoxBody let new_req = if let Some(max_bytes) = max_tx_bytes { @@ -624,6 +620,8 @@ async fn proxy_https_request( #[cfg(test)] mod tests { use super::*; + use crate::rules::RuleEngine; + use crate::tls::CertificateManager; use rustls::ClientConfig; use std::sync::Arc; use tempfile::TempDir; @@ -726,7 +724,12 @@ mod tests { // Spawn proxy handler tokio::spawn(async move { let (stream, addr) = listener.accept().await.unwrap(); - let _ = handle_connect_tunnel(stream, rule_engine, cert_manager, addr).await; + let context = ProxyContext { + rule_engine, + cert_manager, + loop_nonce: Arc::new("test-nonce".to_string()), + }; + let _ = handle_connect_tunnel(stream, context, addr).await; }); // Connect to proxy @@ -761,7 +764,12 @@ mod tests { // Spawn proxy handler tokio::spawn(async move { let (stream, addr) = listener.accept().await.unwrap(); - let _ = handle_connect_tunnel(stream, rule_engine, cert_manager, addr).await; + let context = ProxyContext { + rule_engine: rule_engine.clone(), + cert_manager: Arc::clone(&cert_manager), + loop_nonce: Arc::new("test-nonce".to_string()), + }; + let _ = handle_connect_tunnel(stream, context, addr).await; }); // Connect to proxy @@ -798,7 +806,12 @@ mod tests { // Spawn proxy handler tokio::spawn(async move { let (stream, addr) = listener.accept().await.unwrap(); - let _ = handle_transparent_tls(stream, rule_engine, cert_manager, addr).await; + let context = ProxyContext { + rule_engine: rule_engine.clone(), + cert_manager: Arc::clone(&cert_manager), + loop_nonce: Arc::new("test-nonce".to_string()), + }; + let _ = handle_transparent_tls(stream, context, addr).await; }); // Connect to proxy with TLS directly (transparent mode) @@ -870,7 +883,12 @@ mod tests { let rule_engine = rule_engine.clone(); tokio::spawn(async move { let (stream, addr) = listener.accept().await.unwrap(); - let _ = handle_https_connection(stream, rule_engine, cert_manager, addr).await; + let context = ProxyContext { + rule_engine: rule_engine.clone(), + cert_manager: cert_manager.clone(), + loop_nonce: Arc::new("test-nonce".to_string()), + }; + let _ = handle_https_connection(stream, context, addr).await; }); let mut stream = TcpStream::connect(addr).await.unwrap(); @@ -904,7 +922,12 @@ mod tests { tokio::spawn(async move { let (stream, addr) = listener.accept().await.unwrap(); // Use the actual transparent TLS handler (which will extract SNI, etc.) - let _ = handle_transparent_tls(stream, rule_engine, cert_manager, addr).await; + let context = ProxyContext { + rule_engine, + cert_manager, + loop_nonce: Arc::new("test-nonce".to_string()), + }; + let _ = handle_transparent_tls(stream, context, addr).await; }); // Give the server time to start diff --git a/tests/common/mod.rs b/tests/common/mod.rs index a0c5716a..67a09ee9 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -251,3 +251,24 @@ pub fn test_https_allow(use_sudo: bool) { } } } + +// Wait until a TCP port on localhost is accepting connections +// Returns true if the port became ready before max_wait elapsed +pub async fn wait_for_server(port: u16, max_wait: std::time::Duration) -> bool { + let start = tokio::time::Instant::now(); + let poll_interval = tokio::time::Duration::from_millis(75); + let settle_time = tokio::time::Duration::from_millis(200); + + while start.elapsed() < max_wait { + if tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .is_ok() + { + // Give the server a brief moment to finish initialization + tokio::time::sleep(settle_time).await; + return true; + } + tokio::time::sleep(poll_interval).await; + } + false +} diff --git a/tests/self_request_loop.rs b/tests/self_request_loop.rs new file mode 100644 index 00000000..9ecaa5ad --- /dev/null +++ b/tests/self_request_loop.rs @@ -0,0 +1,66 @@ +mod common; + +use std::net::TcpListener; +use std::process::{Command, Stdio}; +use std::time::Duration; +use tracing::debug; + +/// Test that requests to the proxy itself are blocked to prevent infinite loops (issue #84) +#[tokio::test] +async fn test_server_mode_self_request_loop_prevention() { + let http_port = find_available_port(); + let https_port = find_available_port(); + + let httpjail_path: &str = env!("CARGO_BIN_EXE_httpjail"); + let mut proxy_process = Command::new(httpjail_path) + .env("HTTPJAIL_HTTP_BIND", format!("127.0.0.1:{}", http_port)) + .env("HTTPJAIL_HTTPS_BIND", format!("127.0.0.1:{}", https_port)) + .arg("--server") + .arg("--js") + .arg("true") + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("Failed to start httpjail"); + + assert!( + common::wait_for_server(http_port, Duration::from_secs(5)).await, + "Server failed to start on port {}", + http_port + ); + + let output = Command::new("curl") + .arg("--max-time") + .arg("3") + .arg("--proxy") + .arg(format!("http://127.0.0.1:{}", http_port)) + .arg(format!("http://127.0.0.1:{}/test", http_port)) + .output() + .expect("Failed to execute curl"); + + proxy_process.kill().ok(); + let _ = proxy_process.wait(); + + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + + debug!("curl exit code: {}", output.status.code().unwrap_or(-1)); + debug!("curl stdout: {}", stdout); + debug!("curl stderr: {}", stderr); + + assert!( + stdout.contains("Loop detected"), + "Expected loop detection message, got: {}", + stdout + ); +} + +/// Find an available port for testing +fn find_available_port() -> u16 { + TcpListener::bind("127.0.0.1:0") + .expect("Failed to bind to port") + .local_addr() + .expect("Failed to get local addr") + .port() +} diff --git a/tests/weak_integration.rs b/tests/weak_integration.rs index 1229eaf0..8de3ba82 100644 --- a/tests/weak_integration.rs +++ b/tests/weak_integration.rs @@ -173,7 +173,7 @@ fn test_weak_mode_appends_no_proxy() { } // Simple server start function - we know the ports we're setting -fn start_server(http_port: u16, https_port: u16) -> Result { +async fn start_server(http_port: u16, https_port: u16) -> Result { let httpjail_path: &str = env!("CARGO_BIN_EXE_httpjail"); let mut cmd = Command::new(httpjail_path); @@ -192,26 +192,13 @@ fn start_server(http_port: u16, https_port: u16) -> Result bool { - let start = std::time::Instant::now(); - while start.elapsed() < max_wait { - if std::net::TcpStream::connect(format!("127.0.0.1:{}", port)).is_ok() { - // Give the server a bit more time to fully initialize - thread::sleep(Duration::from_millis(500)); - return true; - } - thread::sleep(Duration::from_millis(100)); - } - false -} - fn test_curl_through_proxy(http_port: u16, _https_port: u16) -> Result { // First, verify the proxy port is actually listening if !verify_bind_address(http_port, "127.0.0.1") { @@ -263,13 +250,15 @@ fn verify_bind_address(port: u16, expected_ip: &str) -> bool { std::net::TcpStream::connect(format!("{}:{}", expected_ip, port)).is_ok() } -#[test] -fn test_server_mode() { +#[tokio::test] +async fn test_server_mode() { // Test server mode with specific ports let http_port = 19876; let https_port = 19877; - let mut server = start_server(http_port, https_port).expect("Failed to start server"); + let mut server = start_server(http_port, https_port) + .await + .expect("Failed to start server"); // Test HTTP proxy works match test_curl_through_proxy(http_port, https_port) { @@ -291,7 +280,7 @@ fn test_server_mode() { } // Helper to start server with custom bind config -fn start_server_with_bind(http_bind: &str, https_bind: &str) -> (std::process::Child, u16) { +async fn start_server_with_bind(http_bind: &str, https_bind: &str) -> (std::process::Child, u16) { let httpjail_path: &str = env!("CARGO_BIN_EXE_httpjail"); let mut child = Command::new(httpjail_path) @@ -319,7 +308,7 @@ fn start_server_with_bind(http_bind: &str, https_bind: &str) -> (std::process::C }; // Wait for server to bind - if !wait_for_server(expected_port, Duration::from_secs(3)) { + if !common::wait_for_server(expected_port, Duration::from_secs(3)).await { child.kill().ok(); panic!("Server failed to bind to port {}", expected_port); } @@ -327,19 +316,19 @@ fn start_server_with_bind(http_bind: &str, https_bind: &str) -> (std::process::C (child, expected_port) } -#[test] +#[tokio::test] #[serial] -fn test_server_bind_defaults() { - let (mut server, port) = start_server_with_bind("", ""); +async fn test_server_bind_defaults() { + let (mut server, port) = start_server_with_bind("", "").await; assert_eq!(port, 8080, "Server should default to port 8080"); server.kill().ok(); } -#[test] +#[tokio::test] #[serial] -fn test_server_bind_port_only() { +async fn test_server_bind_port_only() { // Port-only should bind to all interfaces (0.0.0.0) - let (mut server, port) = start_server_with_bind("19882", "19883"); + let (mut server, port) = start_server_with_bind("19882", "19883").await; assert_eq!( port, 19882, "Server should bind to specified port on all interfaces" @@ -347,11 +336,11 @@ fn test_server_bind_port_only() { server.kill().ok(); } -#[test] +#[tokio::test] #[serial] -fn test_server_bind_colon_prefix_port() { +async fn test_server_bind_colon_prefix_port() { // :port (Go-style) should bind to all interfaces (0.0.0.0) - let (mut server, port) = start_server_with_bind(":19892", ":19893"); + let (mut server, port) = start_server_with_bind(":19892", ":19893").await; assert_eq!( port, 19892, "Server should bind to specified port on all interfaces with :port format" @@ -359,10 +348,10 @@ fn test_server_bind_colon_prefix_port() { server.kill().ok(); } -#[test] +#[tokio::test] #[serial] -fn test_server_bind_all_interfaces() { - let (mut server, port) = start_server_with_bind("0.0.0.0:19884", "0.0.0.0:19885"); +async fn test_server_bind_all_interfaces() { + let (mut server, port) = start_server_with_bind("0.0.0.0:19884", "0.0.0.0:19885").await; assert_eq!( port, 19884, "Server should bind to specified port on 0.0.0.0" @@ -370,10 +359,10 @@ fn test_server_bind_all_interfaces() { server.kill().ok(); } -#[test] +#[tokio::test] #[serial] -fn test_server_bind_ip_without_port() { - let (mut server, port) = start_server_with_bind("127.0.0.1", "127.0.0.1"); +async fn test_server_bind_ip_without_port() { + let (mut server, port) = start_server_with_bind("127.0.0.1", "127.0.0.1").await; assert_eq!( port, 8080, "Server should use default port 8080 when only IP specified"