diff --git a/Cargo.lock b/Cargo.lock index fb2ab35a..8fc6df4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -357,6 +357,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a18ed336352031311f4e0b4dd2ff392d4fbb370777c9d18d7fc9d7359f73871" dependencies = [ "axum-core", + "axum-macros", "base64 0.22.1", "bytes", "form_urlencoded", @@ -405,6 +406,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "backtrace" version = "0.3.76" @@ -1852,6 +1864,7 @@ version = "3.0.0" dependencies = [ "ahash", "async-trait", + "axum", "bumpalo", "bytes", "criterion", diff --git a/docs/README.md b/docs/README.md index abf6278d..ab535892 100644 --- a/docs/README.md +++ b/docs/README.md @@ -11,7 +11,7 @@ |[**log**](#log)|`object`|The router logger configuration.
Default: `{"filter":null,"format":"json","level":"info"}`
|| |[**query\_planner**](#query_planner)|`object`|Query planning configuration.
Default: `{"allow_expose":false,"timeout":"10s"}`
|| |[**supergraph**](#supergraph)|`object`|Configuration for the Federation supergraph source. By default, the router will use a local file-based supergraph source (`./supergraph.graphql`).
Default: `{"path":"supergraph.graphql","source":"file"}`
|| -|[**traffic\_shaping**](#traffic_shaping)|`object`|Configuration for the traffic-shaper executor. Use these configurations to control how requests are being executed to subgraphs.
Default: `{"dedupe_enabled":true,"max_connections_per_host":100,"pool_idle_timeout_seconds":50}`
|| +|[**traffic\_shaping**](#traffic_shaping)|`object`|Configuration for the traffic-shaper executor. Use these configurations to control how requests are being executed to subgraphs.
Default: `{"all":{"dedupe_enabled":true,"pool_idle_timeout_seconds":50},"max_connections_per_host":100}`
|| **Additional Properties:** not allowed **Example** @@ -64,9 +64,10 @@ supergraph: path: supergraph.graphql source: file traffic_shaping: - dedupe_enabled: true + all: + dedupe_enabled: true + pool_idle_timeout_seconds: 50 max_connections_per_host: 100 - pool_idle_timeout_seconds: 50 ``` @@ -1366,15 +1367,69 @@ Configuration for the traffic-shaper executor. Use these configurations to contr |Name|Type|Description|Required| |----|----|-----------|--------| -|**dedupe\_enabled**|`boolean`|Enables/disables request deduplication to subgraphs.

When requests exactly matches the hashing mechanism (e.g., subgraph name, URL, headers, query, variables), and are executed at the same time, they will
be deduplicated by sharing the response of other in-flight requests.
Default: `true`
|| +|[**all**](#traffic_shapingall)|`object`|The default configuration that will be applied to all subgraphs, unless overridden by a specific subgraph configuration.
Default: `{"dedupe_enabled":true,"pool_idle_timeout_seconds":50}`
|| |**max\_connections\_per\_host**|`integer`|Limits the concurrent amount of requests/connections per host/subgraph.
Default: `100`
Format: `"uint"`
Minimum: `0`
|| +|[**subgraphs**](#traffic_shapingsubgraphs)|`object`|Optional per-subgraph configurations that will override the default configuration for specific subgraphs.
|| + +**Example** + +```yaml +all: + dedupe_enabled: true + pool_idle_timeout_seconds: 50 +max_connections_per_host: 100 + +``` + + +### traffic\_shaping\.all: object + +The default configuration that will be applied to all subgraphs, unless overridden by a specific subgraph configuration. + + +**Properties** + +|Name|Type|Description|Required| +|----|----|-----------|--------| +|**dedupe\_enabled**|`boolean`|Enables/disables request deduplication to subgraphs.

When requests exactly matches the hashing mechanism (e.g., subgraph name, URL, headers, query, variables), and are executed at the same time, they will
be deduplicated by sharing the response of other in-flight requests.
Default: `true`
|| |**pool\_idle\_timeout\_seconds**|`integer`|Timeout for idle sockets being kept-alive.
Default: `50`
Format: `"uint64"`
Minimum: `0`
|| +|**timeout**||Optional timeout configuration for requests to subgraphs.

Example with a fixed duration:
```yaml
timeout:
duration: 5s
```

Or with a VRL expression that can return a duration based on the operation kind:
```yaml
timeout:
expression: \|
if (.request.operation.type == "mutation") {
10000
} else {
5000
}
```
|| + +**Example** + +```yaml +dedupe_enabled: true +pool_idle_timeout_seconds: 50 + +``` + + +### traffic\_shaping\.subgraphs: object + +Optional per-subgraph configurations that will override the default configuration for specific subgraphs. + + +**Additional Properties** + +|Name|Type|Description|Required| +|----|----|-----------|--------| +|[**Additional Properties**](#traffic_shapingsubgraphsadditionalproperties)|`object`||| + + +#### traffic\_shaping\.subgraphs\.additionalProperties: object + +**Properties** + +|Name|Type|Description|Required| +|----|----|-----------|--------| +|**dedupe\_enabled**|`boolean`|Enables/disables request deduplication to subgraphs.

When requests exactly matches the hashing mechanism (e.g., subgraph name, URL, headers, query, variables), and are executed at the same time, they will
be deduplicated by sharing the response of other in-flight requests.
Default: `true`
|| +|**pool\_idle\_timeout\_seconds**|`integer`|Timeout for idle sockets being kept-alive.
Default: `50`
Format: `"uint64"`
Minimum: `0`
|| +|**timeout**||Optional timeout configuration for requests to subgraphs.

Example with a fixed duration:
```yaml
timeout:
duration: 5s
```

Or with a VRL expression that can return a duration based on the operation kind:
```yaml
timeout:
expression: \|
if (.request.operation.type == "mutation") {
10000
} else {
5000
}
```
|| **Example** ```yaml dedupe_enabled: true -max_connections_per_host: 100 pool_idle_timeout_seconds: 50 ``` diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 16d3ed33..03f41bbd 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -53,6 +53,7 @@ subgraphs = { path = "../../bench/subgraphs" } criterion = { workspace = true } tokio = { workspace = true } insta = { workspace = true } +axum = { version = "0.8.6", features = ["macros", "tokio", "json"]} [[bench]] name = "executor_benches" diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index cb1edde2..3e2e434c 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -726,6 +726,7 @@ impl<'exec> Executor<'exec> { variables: variable_refs, representations, headers: headers_map, + client_request: self.client_request, }, ) .await diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index 6a053bdd..f85fc49d 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -4,6 +4,8 @@ use async_trait::async_trait; use bytes::Bytes; use http::HeaderMap; +use crate::execution::plan::ClientRequestDetails; + #[async_trait] pub trait SubgraphExecutor { async fn execute<'a>( @@ -30,6 +32,7 @@ pub struct HttpExecutionRequest<'a> { pub variables: Option>, pub headers: HeaderMap, pub representations: Option>, + pub client_request: &'a ClientRequestDetails<'a>, } pub struct HttpExecutionResponse { diff --git a/lib/executor/src/executors/error.rs b/lib/executor/src/executors/error.rs index 50185522..68a53434 100644 --- a/lib/executor/src/executors/error.rs +++ b/lib/executor/src/executors/error.rs @@ -1,3 +1,9 @@ +use std::time::Duration; + +use bytes::{BufMut, Bytes, BytesMut}; + +use crate::response::graphql_error::GraphQLError; + #[derive(thiserror::Error, Debug, Clone)] pub enum SubgraphExecutorError { #[error("Failed to parse endpoint \"{0}\" as URI: {1}")] @@ -8,4 +14,20 @@ pub enum SubgraphExecutorError { RequestFailure(String, String), #[error("Failed to serialize variable \"{0}\": {1}")] VariablesSerializationFailure(String, String), + #[error("Failed to parse timeout duration from expression: {0}")] + TimeoutExpressionParseFailure(String), + #[error("Request timed out after {0:?}")] + RequestTimeout(Duration), +} +pub fn error_to_graphql_bytes(endpoint: &http::Uri, e: SubgraphExecutorError) -> Bytes { + let graphql_error: GraphQLError = + format!("Failed to execute request to subgraph {}: {}", endpoint, e).into(); + let errors = vec![graphql_error]; + // This unwrap is safe as GraphQLError serialization shouldn't fail. + let errors_bytes = sonic_rs::to_vec(&errors).unwrap(); + let mut buffer = BytesMut::new(); + buffer.put_slice(b"{\"errors\":"); + buffer.put_slice(&errors_bytes); + buffer.put_slice(b"}"); + buffer.freeze() } diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index e1a28543..b1297a9b 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -1,14 +1,14 @@ use std::sync::Arc; -use crate::executors::common::HttpExecutionResponse; use crate::executors::dedupe::{request_fingerprint, ABuildHasher, SharedResponse}; use dashmap::DashMap; +use futures::TryFutureExt; use hive_router_config::traffic_shaping::TrafficShapingExecutorConfig; use tokio::sync::OnceCell; use async_trait::async_trait; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{BufMut, Bytes}; use http::HeaderMap; use http::HeaderValue; use http_body_util::BodyExt; @@ -18,9 +18,8 @@ use hyper_tls::HttpsConnector; use hyper_util::client::legacy::{connect::HttpConnector, Client}; use tokio::sync::Semaphore; -use crate::executors::common::HttpExecutionRequest; -use crate::executors::error::SubgraphExecutorError; -use crate::response::graphql_error::GraphQLError; +use crate::executors::common::{HttpExecutionRequest, HttpExecutionResponse}; +use crate::executors::error::{error_to_graphql_bytes, SubgraphExecutorError}; use crate::utils::consts::CLOSE_BRACE; use crate::utils::consts::COLON; use crate::utils::consts::COMMA; @@ -132,9 +131,13 @@ impl HTTPSubgraphExecutor { *req.headers_mut() = headers; - let res = self.http_client.request(req).await.map_err(|e| { - SubgraphExecutorError::RequestFailure(self.endpoint.to_string(), e.to_string()) - })?; + let res = self + .http_client + .request(req) + .map_err(|e| { + SubgraphExecutorError::RequestFailure(self.endpoint.to_string(), e.to_string()) + }) + .await?; let (parts, body) = res.into_parts(); @@ -150,22 +153,6 @@ impl HTTPSubgraphExecutor { headers: parts.headers, }) } - - fn error_to_graphql_bytes(&self, e: SubgraphExecutorError) -> Bytes { - let graphql_error: GraphQLError = format!( - "Failed to execute request to subgraph {}: {}", - self.endpoint, e - ) - .into(); - let errors = vec![graphql_error]; - // This unwrap is safe as GraphQLError serialization shouldn't fail. - let errors_bytes = sonic_rs::to_vec(&errors).unwrap(); - let mut buffer = BytesMut::new(); - buffer.put_slice(b"{\"errors\":"); - buffer.put_slice(&errors_bytes); - buffer.put_slice(b"}"); - buffer.freeze() - } } #[async_trait] @@ -178,7 +165,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { Ok(body) => body, Err(e) => { return HttpExecutionResponse { - body: self.error_to_graphql_bytes(e), + body: error_to_graphql_bytes(&self.endpoint, e), headers: Default::default(), } } @@ -199,7 +186,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { headers: shared_response.headers, }, Err(e) => HttpExecutionResponse { - body: self.error_to_graphql_bytes(e), + body: error_to_graphql_bytes(&self.endpoint, e), headers: Default::default(), }, }; @@ -238,7 +225,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { headers: shared_response.headers.clone(), }, Err(e) => HttpExecutionResponse { - body: self.error_to_graphql_bytes(e.clone()), + body: error_to_graphql_bytes(&self.endpoint, e.clone()), headers: Default::default(), }, } diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index ff18d118..555fa4ba 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -1,12 +1,15 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; -use bytes::{BufMut, BytesMut}; +use bytes::{BufMut, Bytes, BytesMut}; use dashmap::DashMap; -use hive_router_config::traffic_shaping::TrafficShapingExecutorConfig; +use hive_router_config::{ + traffic_shaping::TrafficShapingConfig, traffic_shaping::TrafficShapingExecutorConfig, +}; use http::Uri; +use http_body_util::Full; use hyper_tls::HttpsConnector; use hyper_util::{ - client::legacy::Client, + client::legacy::{connect::HttpConnector, Client}, rt::{TokioExecutor, TokioTimer}, }; use tokio::sync::{OnceCell, Semaphore}; @@ -19,6 +22,7 @@ use crate::{ dedupe::{ABuildHasher, SharedResponse}, error::SubgraphExecutorError, http::HTTPSubgraphExecutor, + timeout::TimeoutExecutor, }, response::graphql_error::GraphQLError, }; @@ -74,21 +78,16 @@ impl SubgraphExecutorMap { pub fn from_http_endpoint_map( subgraph_endpoint_map: HashMap, - config: TrafficShapingExecutorConfig, + config: TrafficShapingConfig, ) -> Result { - let https = HttpsConnector::new(); - let client = Client::builder(TokioExecutor::new()) - .pool_timer(TokioTimer::new()) - .pool_idle_timeout(Duration::from_secs(config.pool_idle_timeout_seconds)) - .pool_max_idle_per_host(config.max_connections_per_host) - .build(https); - - let client_arc = Arc::new(client); - let semaphores_by_origin: DashMap> = DashMap::new(); let max_connections_per_host = config.max_connections_per_host; - let config_arc = Arc::new(config); - let in_flight_requests: Arc>, ABuildHasher>> = - Arc::new(DashMap::with_hasher(ABuildHasher::default())); + let global_client_arc = + from_traffic_shaping_config_to_client(&config.all, max_connections_per_host); + let semaphores_by_origin: DashMap> = DashMap::new(); + let global_config_arc = Arc::new(config.all); + let global_in_flight_requests: Arc< + DashMap>, ABuildHasher>, + > = Arc::new(DashMap::with_hasher(ABuildHasher::default())); let executor_map = subgraph_endpoint_map .into_iter() @@ -110,20 +109,44 @@ impl SubgraphExecutorMap { }) ); + let subgraph_config = config.subgraphs.get(&subgraph_name); + + let http_client = get_http_client_for_subgraph( + subgraph_config, + &global_config_arc, + &global_client_arc, + max_connections_per_host, + ); + + // TODO: Maybe reuse the in-flight requests map in some cases ??? + let inflight_requests = subgraph_config + .map(|_| Arc::new(DashMap::with_hasher(ABuildHasher::default()))) + .unwrap_or_else(|| global_in_flight_requests.clone()); + + let config_arc = subgraph_config + .map(|cfg| Arc::new(cfg.clone())) + .unwrap_or_else(|| global_config_arc.clone()); + let semaphore = semaphores_by_origin - .entry(origin) + .entry(origin.to_string()) .or_insert_with(|| Arc::new(Semaphore::new(max_connections_per_host))) .clone(); - let executor = HTTPSubgraphExecutor::new( - endpoint_uri, - client_arc.clone(), + let mut executor = HTTPSubgraphExecutor::new( + endpoint_uri.clone(), + http_client, semaphore, config_arc.clone(), - in_flight_requests.clone(), - ); + inflight_requests, + ) + .to_boxed_arc(); + + if let Some(timeout_config) = &config_arc.timeout { + executor = TimeoutExecutor::try_new(endpoint_uri, timeout_config, executor)? + .to_boxed_arc(); + } - Ok((subgraph_name, executor.to_boxed_arc())) + Ok((subgraph_name, executor)) }) .collect::, SubgraphExecutorError>>()?; @@ -132,3 +155,37 @@ impl SubgraphExecutorMap { }) } } + +// Create a new hyper client based on the traffic shaping config +pub fn from_traffic_shaping_config_to_client( + config: &TrafficShapingExecutorConfig, + max_connections_per_host: usize, +) -> Arc, Full>> { + Arc::new( + Client::builder(TokioExecutor::new()) + .pool_timer(TokioTimer::new()) + .pool_idle_timeout(Duration::from_secs(config.pool_idle_timeout_seconds)) + .pool_max_idle_per_host(max_connections_per_host) + .build(HttpsConnector::new()), + ) +} + +// Reuse the global client if the subgraph config is the same as the global config +// Otherwise, create a new client based on the subgraph config +fn get_http_client_for_subgraph( + subgraph_config: Option<&TrafficShapingExecutorConfig>, + global_config: &TrafficShapingExecutorConfig, + global_client: &Arc, Full>>, + max_connections_per_host: usize, +) -> Arc, Full>> { + match subgraph_config { + Some(cfg) => { + if global_config.pool_idle_timeout_seconds == cfg.pool_idle_timeout_seconds { + global_client.clone() + } else { + from_traffic_shaping_config_to_client(cfg, max_connections_per_host) + } + } + None => global_client.clone(), + } +} diff --git a/lib/executor/src/executors/mod.rs b/lib/executor/src/executors/mod.rs index 520ff5f9..4b64bda5 100644 --- a/lib/executor/src/executors/mod.rs +++ b/lib/executor/src/executors/mod.rs @@ -3,3 +3,4 @@ pub mod dedupe; pub mod error; pub mod http; pub mod map; +pub mod timeout; diff --git a/lib/executor/src/executors/timeout.rs b/lib/executor/src/executors/timeout.rs new file mode 100644 index 00000000..1f0d276b --- /dev/null +++ b/lib/executor/src/executors/timeout.rs @@ -0,0 +1,428 @@ +use std::collections::BTreeMap; +use std::time::Duration; + +use async_trait::async_trait; +use hive_router_config::traffic_shaping::SubgraphTimeoutConfig; +use tracing::warn; +use vrl::compiler::Program as VrlProgram; + +use crate::executors::common::{ + HttpExecutionRequest, HttpExecutionResponse, SubgraphExecutor, SubgraphExecutorBoxedArc, +}; +use crate::executors::error::error_to_graphql_bytes; +use crate::{execution::plan::ClientRequestDetails, executors::error::SubgraphExecutorError}; +use vrl::{ + compiler::TargetValue as VrlTargetValue, + core::Value as VrlValue, + prelude::{state::RuntimeState as VrlState, Context as VrlContext, TimeZone as VrlTimeZone}, + value::Secrets as VrlSecrets, +}; + +use vrl::{compiler::compile as vrl_compile, stdlib::all as vrl_build_functions}; + +#[derive(Debug)] +pub enum TimeoutSource { + Expression(Box), + Duration(Duration), +} + +pub struct ExpressionContext<'a> { + pub client_request: &'a ClientRequestDetails<'a>, +} + +impl From<&ExpressionContext<'_>> for VrlValue { + fn from(ctx: &ExpressionContext) -> Self { + // .request + let request_value: Self = ctx.client_request.into(); + + Self::Object(BTreeMap::from([("request".into(), request_value)])) + } +} + +fn warn_unsupported_conversion_option(type_name: &str) -> Option { + warn!( + "Cannot convert VRL {} value to a Duration value. Please convert it to a number first.", + type_name + ); + + None +} + +fn vrl_value_to_duration(value: VrlValue) -> Option { + match value { + VrlValue::Integer(i) => { + if i < 0 { + warn!("Cannot convert negative integer ({}) to Duration.", i); + None + } else { + Some(Duration::from_millis(i as u64)) + } + } + VrlValue::Bytes(_) => warn_unsupported_conversion_option("Bytes"), + VrlValue::Float(_) => warn_unsupported_conversion_option("Float"), + VrlValue::Boolean(_) => warn_unsupported_conversion_option("Boolean"), + VrlValue::Array(_) => warn_unsupported_conversion_option("Array"), + VrlValue::Regex(_) => warn_unsupported_conversion_option("Regex"), + VrlValue::Timestamp(_) => warn_unsupported_conversion_option("Timestamp"), + VrlValue::Object(_) => warn_unsupported_conversion_option("Object"), + VrlValue::Null => { + warn!("Cannot convert VRL Null value to a Duration value."); + None + } + } +} + +pub struct TimeoutExecutor { + pub endpoint: http::Uri, + pub timeout: TimeoutSource, + pub executor: SubgraphExecutorBoxedArc, +} + +impl TimeoutExecutor { + pub fn try_new( + endpoint: http::Uri, + timeout_config: &SubgraphTimeoutConfig, + executor: SubgraphExecutorBoxedArc, + ) -> Result { + let timeout = match timeout_config { + SubgraphTimeoutConfig::Duration(dur) => TimeoutSource::Duration(*dur), + SubgraphTimeoutConfig::Expression(expr) => { + // Compile the VRL expression into a Program + let functions = vrl_build_functions(); + let compilation_result = vrl_compile(expr, &functions).map_err(|diagnostics| { + SubgraphExecutorError::TimeoutExpressionParseFailure( + diagnostics + .errors() + .into_iter() + .map(|d| d.code.to_string() + ": " + &d.message) + .collect::>() + .join(", "), + ) + })?; + TimeoutSource::Expression(Box::new(compilation_result.program)) + } + }; + Ok(Self { + endpoint, + timeout, + executor, + }) + } + pub fn get_timeout_duration<'a>( + &self, + client_request: &'a ClientRequestDetails<'a>, + ) -> Option { + let expression_context = ExpressionContext { client_request }; + + match &self.timeout { + TimeoutSource::Duration(dur) => Some(*dur), + TimeoutSource::Expression(program) => { + let mut target = VrlTargetValue { + value: VrlValue::from(&expression_context), + metadata: VrlValue::Object(BTreeMap::new()), + secrets: VrlSecrets::default(), + }; + + let mut state = VrlState::default(); + let timezone = VrlTimeZone::default(); + let mut ctx = VrlContext::new(&mut target, &mut state, &timezone); + match program.resolve(&mut ctx) { + Ok(resolved) => vrl_value_to_duration(resolved), + Err(err) => { + warn!( + "Failed to evaluate timeout expression: {:#?}, falling back to no timeout.", + err + ); + None + } + } + } + } + } +} + +#[async_trait] +impl SubgraphExecutor for TimeoutExecutor { + async fn execute<'a>( + &self, + execution_request: HttpExecutionRequest<'a>, + ) -> HttpExecutionResponse { + let timeout = self.get_timeout_duration(execution_request.client_request); + let execution = self.executor.execute(execution_request); + if let Some(timeout) = timeout { + match tokio::time::timeout(timeout, execution).await { + Ok(response) => response, + Err(_) => HttpExecutionResponse { + body: error_to_graphql_bytes( + &self.endpoint, + SubgraphExecutorError::RequestTimeout(timeout), + ), + headers: Default::default(), + }, + } + } else { + execution.await + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use async_trait::async_trait; + use axum::{extract::State, http::Response, Router}; + use hive_router_config::parse_yaml_config; + use http::Method; + use ntex_http::HeaderMap; + + use crate::{ + execution::plan::{ClientRequestDetails, OperationDetails}, + executors::{ + common::{HttpExecutionRequest, HttpExecutionResponse, SubgraphExecutor}, + map::from_traffic_shaping_config_to_client, + timeout::TimeoutExecutor, + }, + }; + + struct MockExecutor {} + + #[async_trait] + impl SubgraphExecutor for MockExecutor { + async fn execute<'a>( + &self, + _execution_request: HttpExecutionRequest<'a>, + ) -> HttpExecutionResponse { + HttpExecutionResponse { + body: Default::default(), + headers: Default::default(), + } + } + } + + #[test] + fn get_timeout_duration_from_expression() { + use std::time::Duration; + + use hive_router_config::traffic_shaping::SubgraphTimeoutConfig; + + let timeout_config = SubgraphTimeoutConfig::Expression( + r#" + if .request.operation.type == "mutation" { + 10000 + } else { + 5000 + } + "# + .to_string(), + ); + + let mock_executor = MockExecutor {}.to_boxed_arc(); + + let timeout_executor = TimeoutExecutor::try_new( + "http://example.com/graphql".parse().unwrap(), + &timeout_config, + mock_executor, + ) + .unwrap(); + + let headers = HeaderMap::new(); + + let client_request_query = ClientRequestDetails { + operation: OperationDetails { + name: Some("TestQuery".to_string()), + kind: "query", + query: "query TestQuery { field }".into(), + }, + url: "http://example.com/graphql".parse().unwrap(), + headers: &headers, + method: Method::POST, + }; + let duration_query = timeout_executor.get_timeout_duration(&client_request_query); + assert_eq!( + duration_query, + Some(Duration::from_millis(5000)), + "Expected 5000ms for query" + ); + + let client_request_mutation = crate::execution::plan::ClientRequestDetails { + operation: OperationDetails { + name: Some("TestMutation".to_string()), + kind: "mutation", + query: "mutation TestMutation { doSomething }".into(), + }, + url: "http://example.com/graphql".parse().unwrap(), + headers: &headers, + method: Method::POST, + }; + + let duration_mutation = timeout_executor.get_timeout_duration(&client_request_mutation); + assert_eq!( + duration_mutation, + Some(Duration::from_millis(10000)), + "Expected 10000ms for mutation" + ); + } + + #[test] + fn get_timeout_duration_from_fixed_duration() { + let yaml_str = r#" + traffic_shaping: + all: + timeout: + duration: 7s + "#; + let config = parse_yaml_config(yaml_str.to_string()).unwrap(); + let mock_executor = MockExecutor {}.to_boxed_arc(); + let timeout_executor = TimeoutExecutor::try_new( + "http://example.com/graphql".parse().unwrap(), + &config.traffic_shaping.all.timeout.unwrap(), + mock_executor, + ) + .unwrap(); + + let headers = HeaderMap::new(); + let client_request = ClientRequestDetails { + operation: OperationDetails { + name: Some("TestQuery".to_string()), + kind: "query", + query: "query TestQuery { field }".into(), + }, + url: "http://example.com/graphql".parse().unwrap(), + headers: &headers, + method: Method::POST, + }; + let duration = timeout_executor.get_timeout_duration(&client_request); + assert_eq!(duration, Some(std::time::Duration::from_millis(7000))); + } + + #[tokio::test] + async fn cancels_http_request_when_timeout_expires() { + /** + * We will test here that when the timeout expires, the request is cancelled on the server-end as well. + * For that, we will create a server that sets a flag when the request is dropped/cancelled. + */ + use std::sync::Arc; + + use http::Method; + + let (tx, mut rx) = tokio::sync::broadcast::channel(16); + + struct AppState { + tx: Arc>, + } + + let app_state = AppState { tx: Arc::new(tx) }; + + let app_state_arc = Arc::new(app_state); + + struct CancelOnDrop { + start: std::time::Instant, + tx: Arc>, + } + + impl Drop for CancelOnDrop { + fn drop(&mut self) { + self.tx.send(self.start.elapsed()).unwrap(); + } + } + + #[axum::debug_handler] + async fn handler(State(state): State>) -> Response { + let _cancel_on_drop = CancelOnDrop { + start: std::time::Instant::now(), + tx: state.tx.clone(), + }; + // Never resolve the request, just wait until it's cancelled + let fut = futures::future::pending::>(); + fut.await + } + + println!("Starting server..."); + let app = Router::new() + .fallback(handler) + .with_state(app_state_arc.clone()); + println!("Router created, binding to port..."); + let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); + println!("Listener bound, starting server..."); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + if let Err(e) = axum::serve(listener, app).await { + eprintln!("Server error: {}", e); + } + }); + println!("Server started on {}", addr); + let graphql_path = "graphql"; + let endpoint: http::Uri = format!("http://{}/{}", addr, graphql_path).parse().unwrap(); + println!("Endpoint: {}", endpoint); + + let config = r#" + traffic_shaping: + all: + timeout: + duration: 5s + "#; + + let config = hive_router_config::parse_yaml_config(config.to_string()).unwrap(); + let http_client = from_traffic_shaping_config_to_client(&config.traffic_shaping.all, 10); + let http_executor = crate::executors::http::HTTPSubgraphExecutor::new( + endpoint.clone(), + http_client, + Arc::new(tokio::sync::Semaphore::new(10)), + Arc::new(config.traffic_shaping.all.clone()), + Default::default(), + ); + let timeout_executor = TimeoutExecutor::try_new( + endpoint, + &config.traffic_shaping.all.timeout.unwrap(), + http_executor.to_boxed_arc(), + ) + .unwrap(); + + let headers = HeaderMap::new(); + let client_request = ClientRequestDetails { + operation: OperationDetails { + name: Some("TestQuery".to_string()), + kind: "query", + query: "query TestQuery { field }".into(), + }, + url: "http://example.com/graphql".parse().unwrap(), + headers: &headers, + method: Method::POST, + }; + + let execution_request = HttpExecutionRequest { + operation_name: Some("TestQuery"), + query: r#"{ field }"#, + variables: None, + representations: None, + headers: http::HeaderMap::new(), + client_request: &client_request, + dedupe: true, + }; + + println!("Sending request to executor with 5s timeout..."); + let response = timeout_executor.execute(execution_request).await; + + println!("Received response from executor."); + assert!( + response + .body + .starts_with(b"{\"errors\":[{\"message\":\"Failed to execute request to subgraph"), + "Expected error response due to timeout" + ); + + println!("Waiting to see if server was notified of cancellation..."); + + // Wait for the server to be notified that the request was cancelled + let elapsed = rx.recv().await.unwrap(); + println!("Server was notified of cancellation after {:?}", elapsed); + assert!( + elapsed >= Duration::from_secs_f32(4.9), + "Expected server to be notified of cancellation after at least 5s, but was {:?}", + elapsed + ); + + println!("Test completed."); + } +} diff --git a/lib/router-config/src/lib.rs b/lib/router-config/src/lib.rs index 38eefdfd..1a738bbc 100644 --- a/lib/router-config/src/lib.rs +++ b/lib/router-config/src/lib.rs @@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize}; use crate::{ http_server::HttpServerConfig, log::LoggingConfig, query_planner::QueryPlannerConfig, - supergraph::SupergraphSource, traffic_shaping::TrafficShapingExecutorConfig, + supergraph::SupergraphSource, traffic_shaping::TrafficShapingConfig, }; #[derive(Deserialize, Serialize, JsonSchema)] @@ -42,7 +42,7 @@ pub struct HiveRouterConfig { /// Configuration for the traffic-shaper executor. Use these configurations to control how requests are being executed to subgraphs. #[serde(default)] - pub traffic_shaping: TrafficShapingExecutorConfig, + pub traffic_shaping: TrafficShapingConfig, /// Configuration for the headers. #[serde(default)] diff --git a/lib/router-config/src/traffic_shaping.rs b/lib/router-config/src/traffic_shaping.rs index 595112e3..d634cf83 100644 --- a/lib/router-config/src/traffic_shaping.rs +++ b/lib/router-config/src/traffic_shaping.rs @@ -1,12 +1,25 @@ +use std::time::Duration; + use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)] -pub struct TrafficShapingExecutorConfig { +use std::collections::HashMap; + +#[derive(Clone, Deserialize, Serialize, JsonSchema)] +pub struct TrafficShapingConfig { + /// The default configuration that will be applied to all subgraphs, unless overridden by a specific subgraph configuration. + #[serde(default)] + pub all: TrafficShapingExecutorConfig, + /// Optional per-subgraph configurations that will override the default configuration for specific subgraphs. + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub subgraphs: HashMap, /// Limits the concurrent amount of requests/connections per host/subgraph. #[serde(default = "default_max_connections_per_host")] pub max_connections_per_host: usize, +} +#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)] +pub struct TrafficShapingExecutorConfig { /// Timeout for idle sockets being kept-alive. #[serde(default = "default_pool_idle_timeout_seconds")] pub pool_idle_timeout_seconds: u64, @@ -17,14 +30,60 @@ pub struct TrafficShapingExecutorConfig { /// be deduplicated by sharing the response of other in-flight requests. #[serde(default = "default_dedupe_enabled")] pub dedupe_enabled: bool, + + /// Optional timeout configuration for requests to subgraphs. + /// + /// Example with a fixed duration: + /// ```yaml + /// timeout: + /// duration: 5s + /// ``` + /// + /// Or with a VRL expression that can return a duration based on the operation kind: + /// ```yaml + /// timeout: + /// expression: | + /// if (.request.operation.type == "mutation") { + /// 10000 + /// } else { + /// 5000 + /// } + /// ``` + #[serde(default, skip_serializing_if = "Option::is_none")] + pub timeout: Option, +} + +#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)] +#[serde(rename_all = "camelCase")] +pub enum SubgraphTimeoutConfig { + Expression(String), + #[serde(deserialize_with = "humantime_serde")] + Duration(Duration), +} + +fn humantime_serde<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + humantime_serde::deserialize(deserializer) } impl Default for TrafficShapingExecutorConfig { fn default() -> Self { Self { - max_connections_per_host: default_max_connections_per_host(), pool_idle_timeout_seconds: default_pool_idle_timeout_seconds(), dedupe_enabled: default_dedupe_enabled(), + timeout: None, + } + } +} + +impl Default for TrafficShapingConfig { + fn default() -> Self { + Self { + all: TrafficShapingExecutorConfig::default(), + subgraphs: HashMap::new(), + max_connections_per_host: default_max_connections_per_host(), } } }