Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use uuid::Uuid;
use crate::error::{DurableError, DurableResult};
use crate::task::{Task, TaskRegistry};
use crate::types::{
CancellationPolicy, RetryStrategy, SpawnOptions, SpawnResult, SpawnResultRow, WorkerOptions,
CancellationPolicy, RetryStrategy, SpawnDefaults, SpawnOptions, SpawnResult, SpawnResultRow,
WorkerOptions,
};

/// Internal struct for serializing spawn options to the database.
Expand Down Expand Up @@ -110,7 +111,7 @@ where
pool: PgPool,
owns_pool: bool,
queue_name: String,
default_max_attempts: u32,
spawn_defaults: SpawnDefaults,
registry: Arc<RwLock<TaskRegistry<State>>>,
state: State,
}
Expand All @@ -120,11 +121,23 @@ where
/// # Example
///
/// ```ignore
/// use std::time::Duration;
/// use durable::{Durable, RetryStrategy, CancellationPolicy};
///
/// // Without state
/// let client = Durable::builder()
/// .database_url("postgres://localhost/myapp")
/// .queue_name("orders")
/// .default_max_attempts(3)
/// .default_retry_strategy(RetryStrategy::Exponential {
/// base_delay: Duration::from_secs(5),
/// factor: 2.0,
/// max_backoff: Duration::from_secs(300),
/// })
/// .default_cancellation(CancellationPolicy {
/// max_pending_time: Some(Duration::from_secs(3600)),
/// max_running_time: None,
/// })
/// .build()
/// .await?;
///
Expand All @@ -138,7 +151,7 @@ pub struct DurableBuilder {
database_url: Option<String>,
pool: Option<PgPool>,
queue_name: String,
default_max_attempts: u32,
spawn_defaults: SpawnDefaults,
}

impl DurableBuilder {
Expand All @@ -147,7 +160,11 @@ impl DurableBuilder {
database_url: None,
pool: None,
queue_name: "default".to_string(),
default_max_attempts: 5,
spawn_defaults: SpawnDefaults {
max_attempts: 5,
retry_strategy: None,
cancellation: None,
},
}
}

Expand All @@ -171,7 +188,19 @@ impl DurableBuilder {

/// Set default max attempts for spawned tasks (default: 5)
pub fn default_max_attempts(mut self, attempts: u32) -> Self {
self.default_max_attempts = attempts;
self.spawn_defaults.max_attempts = attempts;
self
}

/// Set default retry strategy for spawned tasks (default: Fixed with 5s delay)
pub fn default_retry_strategy(mut self, strategy: RetryStrategy) -> Self {
self.spawn_defaults.retry_strategy = Some(strategy);
self
}

/// Set default cancellation policy for spawned tasks (default: no auto-cancellation)
pub fn default_cancellation(mut self, policy: CancellationPolicy) -> Self {
self.spawn_defaults.cancellation = Some(policy);
self
}

Expand Down Expand Up @@ -226,7 +255,7 @@ impl DurableBuilder {
pool,
owns_pool,
queue_name: self.queue_name,
default_max_attempts: self.default_max_attempts,
spawn_defaults: self.spawn_defaults,
registry: Arc::new(RwLock::new(HashMap::new())),
state,
})
Expand Down Expand Up @@ -471,7 +500,19 @@ where
#[cfg(feature = "telemetry")]
tracing::Span::current().record("queue", &self.queue_name);

let max_attempts = options.max_attempts.unwrap_or(self.default_max_attempts);
// Apply defaults if not set
let max_attempts = options
.max_attempts
.unwrap_or(self.spawn_defaults.max_attempts);
let options = SpawnOptions {
retry_strategy: options
.retry_strategy
.or_else(|| self.spawn_defaults.retry_strategy.clone()),
cancellation: options
.cancellation
.or_else(|| self.spawn_defaults.cancellation.clone()),
..options
};

let db_options = Self::serialize_spawn_options(&options, max_attempts)?;

Expand Down Expand Up @@ -649,6 +690,7 @@ where
self.registry.clone(),
options,
self.state.clone(),
self.spawn_defaults.clone(),
)
.await)
}
Expand Down
31 changes: 29 additions & 2 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use uuid::Uuid;
use crate::error::{ControlFlow, TaskError, TaskResult};
use crate::task::{Task, TaskRegistry};
use crate::types::{
AwaitEventResult, CheckpointRow, ChildCompletePayload, ChildStatus, ClaimedTask, SpawnOptions,
SpawnResultRow, TaskHandle,
AwaitEventResult, CheckpointRow, ChildCompletePayload, ChildStatus, ClaimedTask, SpawnDefaults,
SpawnOptions, SpawnResultRow, TaskHandle,
};
use crate::worker::LeaseExtender;

Expand Down Expand Up @@ -72,6 +72,9 @@ where

/// Task registry for validating spawn_by_name calls.
registry: Arc<RwLock<TaskRegistry<State>>>,

/// Default settings for subtasks spawned via spawn/spawn_by_name.
spawn_defaults: SpawnDefaults,
}

/// Validate that a user-provided step name doesn't use reserved prefix.
Expand All @@ -90,6 +93,7 @@ where
{
/// Create a new TaskContext. Called by the worker before executing a task.
/// Loads all existing checkpoints into the cache.
#[allow(clippy::too_many_arguments)]
pub(crate) async fn create(
pool: PgPool,
queue_name: String,
Expand All @@ -98,6 +102,7 @@ where
lease_extender: LeaseExtender,
registry: Arc<RwLock<TaskRegistry<State>>>,
state: State,
spawn_defaults: SpawnDefaults,
) -> Result<Self, sqlx::Error> {
// Load all checkpoints for this task into cache
let checkpoints: Vec<CheckpointRow> = sqlx::query_as(
Expand Down Expand Up @@ -127,6 +132,7 @@ where
lease_extender,
registry,
state,
spawn_defaults,
})
}

Expand Down Expand Up @@ -668,6 +674,22 @@ where
}
}

// Apply defaults if not set
let options = SpawnOptions {
max_attempts: Some(
options
.max_attempts
.unwrap_or(self.spawn_defaults.max_attempts),
),
retry_strategy: options
.retry_strategy
.or_else(|| self.spawn_defaults.retry_strategy.clone()),
cancellation: options
.cancellation
.or_else(|| self.spawn_defaults.cancellation.clone()),
..options
};

// Build options JSON, merging user options with parent_task_id
#[derive(Serialize)]
struct SubtaskOptions<'a> {
Expand Down Expand Up @@ -844,6 +866,11 @@ mod tests {
LeaseExtender::dummy_for_tests(),
Arc::new(RwLock::new(TaskRegistry::new())),
(),
SpawnDefaults {
max_attempts: 5,
retry_strategy: None,
cancellation: None,
},
)
.await
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ pub use context::TaskContext;
pub use error::{ControlFlow, DurableError, DurableResult, TaskError, TaskResult};
pub use task::Task;
pub use types::{
CancellationPolicy, ClaimedTask, RetryStrategy, SpawnOptions, SpawnResult, TaskHandle,
WorkerOptions,
CancellationPolicy, ClaimedTask, RetryStrategy, SpawnDefaults, SpawnOptions, SpawnResult,
TaskHandle, WorkerOptions,
};
pub use worker::Worker;

Expand Down
15 changes: 15 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,21 @@ impl<T> TaskHandle<T> {
}
}

/// Default settings for spawned tasks.
///
/// Groups the default `max_attempts`, `retry_strategy`, and `cancellation`
/// settings that are applied when spawning tasks (either from the client
/// or from within a task context).
#[derive(Debug, Clone, Default)]
pub struct SpawnDefaults {
/// Default max attempts for spawned tasks (default: 5)
pub max_attempts: u32,
/// Default retry strategy for spawned tasks
pub retry_strategy: Option<RetryStrategy>,
/// Default cancellation policy for spawned tasks
pub cancellation: Option<CancellationPolicy>,
}

/// Terminal status of a child task.
///
/// This enum represents the possible terminal states a subtask can be in
Expand Down
14 changes: 13 additions & 1 deletion src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use uuid::Uuid;
use crate::context::TaskContext;
use crate::error::{ControlFlow, TaskError, serialize_task_error};
use crate::task::TaskRegistry;
use crate::types::{ClaimedTask, ClaimedTaskRow, WorkerOptions};
use crate::types::{ClaimedTask, ClaimedTaskRow, SpawnDefaults, WorkerOptions};

/// Notifies the worker that the lease has been extended.
/// Used by TaskContext to reset warning/fatal timers.
Expand Down Expand Up @@ -67,6 +67,7 @@ impl Worker {
registry: Arc<RwLock<TaskRegistry<State>>>,
options: WorkerOptions,
state: State,
spawn_defaults: SpawnDefaults,
) -> Self
where
State: Clone + Send + Sync + 'static,
Expand All @@ -92,6 +93,7 @@ impl Worker {
worker_id,
shutdown_rx,
state,
spawn_defaults,
));

Self {
Expand All @@ -109,6 +111,7 @@ impl Worker {
let _ = self.handle.await;
}

#[allow(clippy::too_many_arguments)]
async fn run_loop<State>(
pool: PgPool,
queue_name: String,
Expand All @@ -117,6 +120,7 @@ impl Worker {
worker_id: String,
mut shutdown_rx: broadcast::Receiver<()>,
state: State,
spawn_defaults: SpawnDefaults,
) where
State: Clone + Send + Sync + 'static,
{
Expand Down Expand Up @@ -190,6 +194,7 @@ impl Worker {
let registry = registry.clone();
let done_tx = done_tx.clone();
let state = state.clone();
let spawn_defaults = spawn_defaults.clone();

tokio::spawn(async move {
Self::execute_task(
Expand All @@ -200,6 +205,7 @@ impl Worker {
claim_timeout,
fatal_on_lease_timeout,
state,
spawn_defaults,
).await;

drop(permit);
Expand Down Expand Up @@ -258,6 +264,7 @@ impl Worker {
Ok(tasks)
}

#[allow(clippy::too_many_arguments)]
async fn execute_task<State>(
pool: PgPool,
queue_name: String,
Expand All @@ -266,6 +273,7 @@ impl Worker {
claim_timeout: Duration,
fatal_on_lease_timeout: bool,
state: State,
spawn_defaults: SpawnDefaults,
) where
State: Clone + Send + Sync + 'static,
{
Expand Down Expand Up @@ -295,11 +303,13 @@ impl Worker {
claim_timeout,
fatal_on_lease_timeout,
state,
spawn_defaults,
)
.instrument(span)
.await
}

#[allow(clippy::too_many_arguments)]
async fn execute_task_inner<State>(
pool: PgPool,
queue_name: String,
Expand All @@ -308,6 +318,7 @@ impl Worker {
claim_timeout: Duration,
fatal_on_lease_timeout: bool,
state: State,
spawn_defaults: SpawnDefaults,
) where
State: Clone + Send + Sync + 'static,
{
Expand All @@ -333,6 +344,7 @@ impl Worker {
lease_extender,
registry.clone(),
state.clone(),
spawn_defaults,
)
.await
{
Expand Down
25 changes: 22 additions & 3 deletions tests/crash_test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]

use sqlx::{AssertSqlSafe, PgPool};
use std::time::Duration;
use std::time::{Duration, Instant};

mod common;

Expand Down Expand Up @@ -538,14 +538,33 @@ async fn test_slow_task_outlives_lease(pool: PgPool) -> sqlx::Result<()> {
// Wait for real time to pass the lease timeout
tokio::time::sleep(claim_timeout + Duration::from_secs(2)).await;

// Verify a new run was created (reclaim happened)
let run_count = count_runs_for_task(&pool, "crash_slow", spawn_result.task_id).await?;
// Second worker polls to reclaim the expired lease.
let worker2 = client
.start_worker(WorkerOptions {
poll_interval: Duration::from_millis(50),
claim_timeout: Duration::from_secs(10),
..Default::default()
})
.await
.unwrap();

// Verify a new run was created (reclaim happened), with bounded polling.
let deadline = Instant::now() + Duration::from_secs(5);
let mut run_count = 0;
while Instant::now() < deadline {
run_count = count_runs_for_task(&pool, "crash_slow", spawn_result.task_id).await?;
if run_count >= 2 {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
assert!(
run_count >= 2,
"Should have at least 2 runs after lease expiration, got {}",
run_count
);

worker2.shutdown().await;
worker.shutdown().await;

Ok(())
Expand Down
Loading