diff --git a/src/client.rs b/src/client.rs index 7d22472..ad7a23d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -11,7 +11,8 @@ use uuid::Uuid; use crate::error::{DurableError, DurableResult}; use crate::task::{Task, TaskRegistry}; use crate::types::{ - CancellationPolicy, RetryStrategy, SpawnOptions, SpawnResult, SpawnResultRow, WorkerOptions, + CancellationPolicy, RetryStrategy, SpawnDefaults, SpawnOptions, SpawnResult, SpawnResultRow, + WorkerOptions, }; /// Internal struct for serializing spawn options to the database. @@ -110,7 +111,7 @@ where pool: PgPool, owns_pool: bool, queue_name: String, - default_max_attempts: u32, + spawn_defaults: SpawnDefaults, registry: Arc>>, state: State, } @@ -120,11 +121,23 @@ where /// # Example /// /// ```ignore +/// use std::time::Duration; +/// use durable::{Durable, RetryStrategy, CancellationPolicy}; +/// /// // Without state /// let client = Durable::builder() /// .database_url("postgres://localhost/myapp") /// .queue_name("orders") /// .default_max_attempts(3) +/// .default_retry_strategy(RetryStrategy::Exponential { +/// base_delay: Duration::from_secs(5), +/// factor: 2.0, +/// max_backoff: Duration::from_secs(300), +/// }) +/// .default_cancellation(CancellationPolicy { +/// max_pending_time: Some(Duration::from_secs(3600)), +/// max_running_time: None, +/// }) /// .build() /// .await?; /// @@ -138,7 +151,7 @@ pub struct DurableBuilder { database_url: Option, pool: Option, queue_name: String, - default_max_attempts: u32, + spawn_defaults: SpawnDefaults, } impl DurableBuilder { @@ -147,7 +160,11 @@ impl DurableBuilder { database_url: None, pool: None, queue_name: "default".to_string(), - default_max_attempts: 5, + spawn_defaults: SpawnDefaults { + max_attempts: 5, + retry_strategy: None, + cancellation: None, + }, } } @@ -171,7 +188,19 @@ impl DurableBuilder { /// Set default max attempts for spawned tasks (default: 5) pub fn default_max_attempts(mut self, attempts: u32) -> Self { - self.default_max_attempts = attempts; + self.spawn_defaults.max_attempts = attempts; + self + } + + /// Set default retry strategy for spawned tasks (default: Fixed with 5s delay) + pub fn default_retry_strategy(mut self, strategy: RetryStrategy) -> Self { + self.spawn_defaults.retry_strategy = Some(strategy); + self + } + + /// Set default cancellation policy for spawned tasks (default: no auto-cancellation) + pub fn default_cancellation(mut self, policy: CancellationPolicy) -> Self { + self.spawn_defaults.cancellation = Some(policy); self } @@ -226,7 +255,7 @@ impl DurableBuilder { pool, owns_pool, queue_name: self.queue_name, - default_max_attempts: self.default_max_attempts, + spawn_defaults: self.spawn_defaults, registry: Arc::new(RwLock::new(HashMap::new())), state, }) @@ -471,7 +500,19 @@ where #[cfg(feature = "telemetry")] tracing::Span::current().record("queue", &self.queue_name); - let max_attempts = options.max_attempts.unwrap_or(self.default_max_attempts); + // Apply defaults if not set + let max_attempts = options + .max_attempts + .unwrap_or(self.spawn_defaults.max_attempts); + let options = SpawnOptions { + retry_strategy: options + .retry_strategy + .or_else(|| self.spawn_defaults.retry_strategy.clone()), + cancellation: options + .cancellation + .or_else(|| self.spawn_defaults.cancellation.clone()), + ..options + }; let db_options = Self::serialize_spawn_options(&options, max_attempts)?; @@ -649,6 +690,7 @@ where self.registry.clone(), options, self.state.clone(), + self.spawn_defaults.clone(), ) .await) } diff --git a/src/context.rs b/src/context.rs index 3bb942d..825944b 100644 --- a/src/context.rs +++ b/src/context.rs @@ -11,8 +11,8 @@ use uuid::Uuid; use crate::error::{ControlFlow, TaskError, TaskResult}; use crate::task::{Task, TaskRegistry}; use crate::types::{ - AwaitEventResult, CheckpointRow, ChildCompletePayload, ChildStatus, ClaimedTask, SpawnOptions, - SpawnResultRow, TaskHandle, + AwaitEventResult, CheckpointRow, ChildCompletePayload, ChildStatus, ClaimedTask, SpawnDefaults, + SpawnOptions, SpawnResultRow, TaskHandle, }; use crate::worker::LeaseExtender; @@ -72,6 +72,9 @@ where /// Task registry for validating spawn_by_name calls. registry: Arc>>, + + /// Default settings for subtasks spawned via spawn/spawn_by_name. + spawn_defaults: SpawnDefaults, } /// Validate that a user-provided step name doesn't use reserved prefix. @@ -90,6 +93,7 @@ where { /// Create a new TaskContext. Called by the worker before executing a task. /// Loads all existing checkpoints into the cache. + #[allow(clippy::too_many_arguments)] pub(crate) async fn create( pool: PgPool, queue_name: String, @@ -98,6 +102,7 @@ where lease_extender: LeaseExtender, registry: Arc>>, state: State, + spawn_defaults: SpawnDefaults, ) -> Result { // Load all checkpoints for this task into cache let checkpoints: Vec = sqlx::query_as( @@ -127,6 +132,7 @@ where lease_extender, registry, state, + spawn_defaults, }) } @@ -668,6 +674,22 @@ where } } + // Apply defaults if not set + let options = SpawnOptions { + max_attempts: Some( + options + .max_attempts + .unwrap_or(self.spawn_defaults.max_attempts), + ), + retry_strategy: options + .retry_strategy + .or_else(|| self.spawn_defaults.retry_strategy.clone()), + cancellation: options + .cancellation + .or_else(|| self.spawn_defaults.cancellation.clone()), + ..options + }; + // Build options JSON, merging user options with parent_task_id #[derive(Serialize)] struct SubtaskOptions<'a> { @@ -844,6 +866,11 @@ mod tests { LeaseExtender::dummy_for_tests(), Arc::new(RwLock::new(TaskRegistry::new())), (), + SpawnDefaults { + max_attempts: 5, + retry_strategy: None, + cancellation: None, + }, ) .await .unwrap(); diff --git a/src/lib.rs b/src/lib.rs index b56ad69..1c5d404 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -107,8 +107,8 @@ pub use context::TaskContext; pub use error::{ControlFlow, DurableError, DurableResult, TaskError, TaskResult}; pub use task::Task; pub use types::{ - CancellationPolicy, ClaimedTask, RetryStrategy, SpawnOptions, SpawnResult, TaskHandle, - WorkerOptions, + CancellationPolicy, ClaimedTask, RetryStrategy, SpawnDefaults, SpawnOptions, SpawnResult, + TaskHandle, WorkerOptions, }; pub use worker::Worker; diff --git a/src/types.rs b/src/types.rs index 5861750..5319dcf 100644 --- a/src/types.rs +++ b/src/types.rs @@ -315,6 +315,21 @@ impl TaskHandle { } } +/// Default settings for spawned tasks. +/// +/// Groups the default `max_attempts`, `retry_strategy`, and `cancellation` +/// settings that are applied when spawning tasks (either from the client +/// or from within a task context). +#[derive(Debug, Clone, Default)] +pub struct SpawnDefaults { + /// Default max attempts for spawned tasks (default: 5) + pub max_attempts: u32, + /// Default retry strategy for spawned tasks + pub retry_strategy: Option, + /// Default cancellation policy for spawned tasks + pub cancellation: Option, +} + /// Terminal status of a child task. /// /// This enum represents the possible terminal states a subtask can be in diff --git a/src/worker.rs b/src/worker.rs index 8714715..e04bd9b 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -11,7 +11,7 @@ use uuid::Uuid; use crate::context::TaskContext; use crate::error::{ControlFlow, TaskError, serialize_task_error}; use crate::task::TaskRegistry; -use crate::types::{ClaimedTask, ClaimedTaskRow, WorkerOptions}; +use crate::types::{ClaimedTask, ClaimedTaskRow, SpawnDefaults, WorkerOptions}; /// Notifies the worker that the lease has been extended. /// Used by TaskContext to reset warning/fatal timers. @@ -67,6 +67,7 @@ impl Worker { registry: Arc>>, options: WorkerOptions, state: State, + spawn_defaults: SpawnDefaults, ) -> Self where State: Clone + Send + Sync + 'static, @@ -92,6 +93,7 @@ impl Worker { worker_id, shutdown_rx, state, + spawn_defaults, )); Self { @@ -109,6 +111,7 @@ impl Worker { let _ = self.handle.await; } + #[allow(clippy::too_many_arguments)] async fn run_loop( pool: PgPool, queue_name: String, @@ -117,6 +120,7 @@ impl Worker { worker_id: String, mut shutdown_rx: broadcast::Receiver<()>, state: State, + spawn_defaults: SpawnDefaults, ) where State: Clone + Send + Sync + 'static, { @@ -190,6 +194,7 @@ impl Worker { let registry = registry.clone(); let done_tx = done_tx.clone(); let state = state.clone(); + let spawn_defaults = spawn_defaults.clone(); tokio::spawn(async move { Self::execute_task( @@ -200,6 +205,7 @@ impl Worker { claim_timeout, fatal_on_lease_timeout, state, + spawn_defaults, ).await; drop(permit); @@ -258,6 +264,7 @@ impl Worker { Ok(tasks) } + #[allow(clippy::too_many_arguments)] async fn execute_task( pool: PgPool, queue_name: String, @@ -266,6 +273,7 @@ impl Worker { claim_timeout: Duration, fatal_on_lease_timeout: bool, state: State, + spawn_defaults: SpawnDefaults, ) where State: Clone + Send + Sync + 'static, { @@ -295,11 +303,13 @@ impl Worker { claim_timeout, fatal_on_lease_timeout, state, + spawn_defaults, ) .instrument(span) .await } + #[allow(clippy::too_many_arguments)] async fn execute_task_inner( pool: PgPool, queue_name: String, @@ -308,6 +318,7 @@ impl Worker { claim_timeout: Duration, fatal_on_lease_timeout: bool, state: State, + spawn_defaults: SpawnDefaults, ) where State: Clone + Send + Sync + 'static, { @@ -333,6 +344,7 @@ impl Worker { lease_extender, registry.clone(), state.clone(), + spawn_defaults, ) .await { diff --git a/tests/crash_test.rs b/tests/crash_test.rs index bedca5d..b6eef1c 100644 --- a/tests/crash_test.rs +++ b/tests/crash_test.rs @@ -1,7 +1,7 @@ #![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)] use sqlx::{AssertSqlSafe, PgPool}; -use std::time::Duration; +use std::time::{Duration, Instant}; mod common; @@ -538,14 +538,33 @@ async fn test_slow_task_outlives_lease(pool: PgPool) -> sqlx::Result<()> { // Wait for real time to pass the lease timeout tokio::time::sleep(claim_timeout + Duration::from_secs(2)).await; - // Verify a new run was created (reclaim happened) - let run_count = count_runs_for_task(&pool, "crash_slow", spawn_result.task_id).await?; + // Second worker polls to reclaim the expired lease. + let worker2 = client + .start_worker(WorkerOptions { + poll_interval: Duration::from_millis(50), + claim_timeout: Duration::from_secs(10), + ..Default::default() + }) + .await + .unwrap(); + + // Verify a new run was created (reclaim happened), with bounded polling. + let deadline = Instant::now() + Duration::from_secs(5); + let mut run_count = 0; + while Instant::now() < deadline { + run_count = count_runs_for_task(&pool, "crash_slow", spawn_result.task_id).await?; + if run_count >= 2 { + break; + } + tokio::time::sleep(Duration::from_millis(50)).await; + } assert!( run_count >= 2, "Should have at least 2 runs after lease expiration, got {}", run_count ); + worker2.shutdown().await; worker.shutdown().await; Ok(()) diff --git a/tests/fanout_test.rs b/tests/fanout_test.rs index 4a41267..9247516 100644 --- a/tests/fanout_test.rs +++ b/tests/fanout_test.rs @@ -8,7 +8,8 @@ use common::tasks::{ SpawnByNameParams, SpawnByNameTask, SpawnFailingChildTask, SpawnSlowChildParams, SpawnSlowChildTask, }; -use durable::{Durable, MIGRATOR, WorkerOptions}; +use durable::{CancellationPolicy, Durable, MIGRATOR, RetryStrategy, WorkerOptions}; +use serde_json::Value as JsonValue; use sqlx::{AssertSqlSafe, PgPool}; use std::time::Duration; @@ -453,9 +454,68 @@ async fn test_cascade_cancel_when_parent_auto_cancelled_by_max_duration( // spawn_by_name Tests // ============================================================================ +/// Helper to query max_attempts from the database. +async fn get_task_max_attempts( + pool: &PgPool, + queue_name: &str, + task_id: uuid::Uuid, +) -> Option { + #[derive(sqlx::FromRow)] + struct TaskMaxAttempts { + max_attempts: Option, + } + let query = AssertSqlSafe(format!( + "SELECT max_attempts FROM durable.t_{queue_name} WHERE task_id = $1" + )); + let result: TaskMaxAttempts = sqlx::query_as(query) + .bind(task_id) + .fetch_one(pool) + .await + .expect("Failed to query task max_attempts"); + result.max_attempts +} + +/// Helper to query retry_strategy from the database. +async fn get_task_retry_strategy( + pool: &PgPool, + queue_name: &str, + task_id: uuid::Uuid, +) -> Option { + #[derive(sqlx::FromRow)] + struct TaskRetryStrategy { + retry_strategy: Option, + } + let query = AssertSqlSafe(format!( + "SELECT retry_strategy FROM durable.t_{queue_name} WHERE task_id = $1" + )); + let result: TaskRetryStrategy = sqlx::query_as(query) + .bind(task_id) + .fetch_one(pool) + .await + .expect("Failed to query task retry_strategy"); + result.retry_strategy +} + #[sqlx::test(migrator = "MIGRATOR")] async fn test_spawn_by_name_from_task_context(pool: PgPool) -> sqlx::Result<()> { - let client = create_client(pool.clone(), "fanout_by_name").await; + // Use custom defaults to verify subtasks inherit them + let client = Durable::builder() + .pool(pool.clone()) + .queue_name("fanout_by_name") + .default_max_attempts(7) + .default_retry_strategy(RetryStrategy::Exponential { + base_delay: Duration::from_secs(10), + factor: 3.0, + max_backoff: Duration::from_secs(600), + }) + .default_cancellation(CancellationPolicy { + max_pending_time: Some(Duration::from_secs(3600)), + max_running_time: None, + }) + .build() + .await + .expect("Failed to create client"); + client.create_queue(None).await.unwrap(); client.register::().await.unwrap(); client.register::().await.unwrap(); @@ -497,6 +557,48 @@ async fn test_spawn_by_name_from_task_context(pool: PgPool) -> sqlx::Result<()> "Child should have doubled 21 to 42 (spawned via spawn_by_name)" ); + // Find the child task and verify it inherited the default_max_attempts + let child_query = "SELECT task_id FROM durable.t_fanout_by_name WHERE parent_task_id = $1"; + let child_ids: Vec<(uuid::Uuid,)> = sqlx::query_as(child_query) + .bind(spawn_result.task_id) + .fetch_all(&pool) + .await?; + + assert_eq!(child_ids.len(), 1, "Should have exactly one child task"); + let child_task_id = child_ids[0].0; + + // Verify child task has the default max_attempts from the client config + let child_max_attempts = get_task_max_attempts(&pool, "fanout_by_name", child_task_id).await; + assert_eq!( + child_max_attempts, + Some(7), + "Child task spawned via spawn_by_name should inherit default_max_attempts=7" + ); + + // Verify child task has the default retry_strategy from the client config + let child_retry_strategy = + get_task_retry_strategy(&pool, "fanout_by_name", child_task_id).await; + assert!( + child_retry_strategy.is_some(), + "Child task should have a retry_strategy" + ); + let strategy = child_retry_strategy.unwrap(); + assert_eq!( + strategy.get("kind").and_then(|v| v.as_str()), + Some("exponential"), + "Child task should inherit exponential retry strategy" + ); + assert_eq!( + strategy.get("base_seconds").and_then(|v| v.as_u64()), + Some(10), + "Child task should inherit base_delay=10s" + ); + assert_eq!( + strategy.get("factor").and_then(|v| v.as_f64()), + Some(3.0), + "Child task should inherit factor=3.0" + ); + Ok(()) } @@ -728,9 +830,11 @@ async fn test_join_timeout_when_parent_claim_expires(pool: PgPool) -> sqlx::Resu let error_name = failed_payload.get("name").and_then(|v| v.as_str()); // Could be Timeout or other error depending on how the timeout manifests + // ChildFailed is also valid when child tasks have bounded max_attempts assert!( error_name == Some("Timeout") || error_name == Some("ChildCancelled") + || error_name == Some("ChildFailed") || error_name == Some("TaskInternal"), "Expected timeout-related error, got: {:?}", error_name