diff --git a/.gitignore b/.gitignore index 7fb613a..656bd25 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ target/ .env .bacon-locations .claude/settings.local.json +.mcp.json diff --git a/.sqlx/query-0816823ba7784f105d62568db8496a39a8470041474201d1fde1ba9def30b6d1.json b/.sqlx/query-0816823ba7784f105d62568db8496a39a8470041474201d1fde1ba9def30b6d1.json new file mode 100644 index 0000000..b9b1328 --- /dev/null +++ b/.sqlx/query-0816823ba7784f105d62568db8496a39a8470041474201d1fde1ba9def30b6d1.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n DELETE FROM job_executions\n WHERE id = $1 AND poller_instance_id = $2\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "0816823ba7784f105d62568db8496a39a8470041474201d1fde1ba9def30b6d1" +} diff --git a/.sqlx/query-3dd7b8b5f00c51b3a39204d396c7c218894f26cf6aede055be3ae92a43bcb0ba.json b/.sqlx/query-3dd7b8b5f00c51b3a39204d396c7c218894f26cf6aede055be3ae92a43bcb0ba.json new file mode 100644 index 0000000..0909227 --- /dev/null +++ b/.sqlx/query-3dd7b8b5f00c51b3a39204d396c7c218894f26cf6aede055be3ae92a43bcb0ba.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "UPDATE jobs SET cancelled_at = $2 WHERE id = $1 AND cancelled_at IS NULL", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "3dd7b8b5f00c51b3a39204d396c7c218894f26cf6aede055be3ae92a43bcb0ba" +} diff --git a/Cargo.lock b/Cargo.lock index 9f99069..e55ff81 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -352,6 +352,20 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "der" version = "0.7.10" @@ -706,6 +720,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.5" @@ -960,6 +980,7 @@ dependencies = [ "anyhow", "async-trait", "chrono", + "dashmap", "derive_builder", "es-entity", "futures", @@ -971,6 +992,7 @@ dependencies = [ "sqlx", "thiserror", "tokio", + "tokio-util", "tracing", "uuid", ] @@ -2199,6 +2221,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "tracing" version = "0.1.44" diff --git a/Cargo.toml b/Cargo.toml index 9b4cd1c..5c2823d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,8 @@ tokio = { workspace = true } uuid = { workspace = true } rand = { workspace = true } +dashmap = { workspace = true } +tokio-util = { workspace = true } schemars = { workspace = true, optional = true } [dev-dependencies] @@ -58,4 +60,6 @@ tokio = { version = "1.50", features = ["rt-multi-thread", "macros"] } uuid = { version = "1.22", features = ["serde", "v7"] } futures = "0.3" rand = "0.10" +dashmap = "6.1" +tokio-util = "0.7" schemars = { version = "1.0", features = ["derive", "chrono04", "rust_decimal1"] } diff --git a/migrations/20250904065521_job_setup.sql b/migrations/20250904065521_job_setup.sql index f856f18..a4e6279 100644 --- a/migrations/20250904065521_job_setup.sql +++ b/migrations/20250904065521_job_setup.sql @@ -2,6 +2,7 @@ CREATE TABLE jobs ( id UUID PRIMARY KEY, unique_per_type BOOLEAN NOT NULL, job_type VARCHAR NOT NULL, + cancelled_at TIMESTAMPTZ, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); CREATE UNIQUE INDEX idx_unique_job_type ON jobs (job_type) WHERE unique_per_type = TRUE; diff --git a/src/cancellation_tokens.rs b/src/cancellation_tokens.rs new file mode 100644 index 0000000..063106c --- /dev/null +++ b/src/cancellation_tokens.rs @@ -0,0 +1,42 @@ +//! Shared store of per-job cancellation tokens. + +use dashmap::DashMap; +use tokio_util::sync::CancellationToken; + +use crate::JobId; + +/// Thread-safe store mapping running job IDs to their cancellation tokens. +/// +/// Tokens are inserted when a job is dispatched and removed when it completes +/// (or is cancelled). The notification router calls [`cancel`] when a +/// `job_cancel` event arrives; the poller sweep also calls it as a safety net. +pub(crate) struct CancellationTokens { + tokens: DashMap, +} + +impl CancellationTokens { + pub fn new() -> Self { + Self { + tokens: DashMap::new(), + } + } + + /// Insert a new token for `job_id` and return a clone the runner can observe. + pub fn insert(&self, job_id: JobId) -> CancellationToken { + let token = CancellationToken::new(); + self.tokens.insert(job_id, token.clone()); + token + } + + /// Remove the token without cancelling it (used on normal completion). + pub fn remove(&self, job_id: &JobId) { + self.tokens.remove(job_id); + } + + /// Cancel the token for `job_id`, signalling the running job to stop. + pub fn cancel(&self, job_id: &JobId) { + if let Some((_, token)) = self.tokens.remove(job_id) { + token.cancel(); + } + } +} diff --git a/src/config.rs b/src/config.rs index a20c0e5..877c567 100644 --- a/src/config.rs +++ b/src/config.rs @@ -24,6 +24,10 @@ pub struct JobPollerConfig { #[serde(default = "default_shutdown_timeout")] /// How long to wait for jobs to complete gracefully during shutdown before rescheduling them. pub shutdown_timeout: Duration, + #[serde_as(as = "serde_with::DurationSeconds")] + #[serde(default = "default_cancel_timeout")] + /// Grace period after cancellation is requested before force-aborting a running job. + pub cancel_timeout: Duration, } impl Default for JobPollerConfig { @@ -33,6 +37,7 @@ impl Default for JobPollerConfig { max_jobs_per_process: default_max_jobs_per_process(), min_jobs_per_process: default_min_jobs_per_process(), shutdown_timeout: default_shutdown_timeout(), + cancel_timeout: default_cancel_timeout(), } } } @@ -161,3 +166,7 @@ fn default_min_jobs_per_process() -> usize { fn default_shutdown_timeout() -> Duration { Duration::from_secs(5) } + +fn default_cancel_timeout() -> Duration { + Duration::from_secs(30) +} diff --git a/src/current.rs b/src/current.rs index 422c32b..827a5cf 100644 --- a/src/current.rs +++ b/src/current.rs @@ -3,6 +3,9 @@ use es_entity::clock::ClockHandle; use serde::{Serialize, de::DeserializeOwned}; use sqlx::PgPool; +use tokio_util::sync::CancellationToken; + +use std::sync::{Arc, Mutex}; use super::{JobId, error::JobError}; @@ -16,9 +19,12 @@ pub struct CurrentJob { tokio::sync::mpsc::Sender>, >, clock: ClockHandle, + result: Arc>>, + cancel_token: CancellationToken, } impl CurrentJob { + #[allow(clippy::too_many_arguments)] pub(super) fn new( id: JobId, attempt: u32, @@ -28,6 +34,8 @@ impl CurrentJob { tokio::sync::mpsc::Sender>, >, clock: ClockHandle, + result: Arc>>, + cancel_token: CancellationToken, ) -> Self { Self { id, @@ -36,6 +44,8 @@ impl CurrentJob { execution_state_json: execution_state, shutdown_rx, clock, + result, + cancel_token, } } @@ -122,6 +132,22 @@ impl CurrentJob { Ok(ret) } + /// Attach or update the result value for this job execution. + /// + /// The result is serialized to JSON and will be available to callers via + /// [`Jobs::await_completion`](crate::Jobs::await_completion). Each call + /// overwrites the previous value — the **last** value set before the job + /// completes (or errors) is what gets persisted. This allows incremental + /// progress updates; for example, a batch job can call `set_result` after + /// each chunk so that partial progress is preserved even on failure. + pub fn set_result(&self, result: &T) -> Result<(), JobError> { + let json = + serde_json::to_value(result).map_err(JobError::CouldNotSerializeExecutionState)?; + let mut guard = self.result.lock().expect("result mutex poisoned"); + *guard = Some(json); + Ok(()) + } + /// Wait for a shutdown signal. Returns `true` if shutdown was requested. /// /// Job runners can use this to detect when the application is shutting down @@ -155,4 +181,20 @@ impl CurrentJob { pub fn is_shutdown_requested(&mut self) -> bool { self.shutdown_rx.try_recv().is_ok() } + + /// Non-blocking check if cancellation has been requested for this job. + /// + /// Returns `true` once the job has been cancelled via [`Jobs::cancel_job`](crate::Jobs::cancel_job). + /// Job runners should check this periodically and return + /// [`JobCompletion::Cancelled`](crate::JobCompletion::Cancelled) when `true`. + pub fn cancellation_requested(&self) -> bool { + self.cancel_token.is_cancelled() + } + + /// Returns a future that resolves when cancellation is requested. + /// + /// Useful in `tokio::select!` branches for cooperative cancellation. + pub async fn cancellation_notified(&self) { + self.cancel_token.cancelled().await; + } } diff --git a/src/dispatcher.rs b/src/dispatcher.rs index ced1fcf..f5abd0c 100644 --- a/src/dispatcher.rs +++ b/src/dispatcher.rs @@ -5,11 +5,14 @@ use futures::FutureExt; use serde_json::Value as JsonValue; use tracing::{Span, instrument}; -use std::{panic::AssertUnwindSafe, sync::Arc}; +use std::{ + panic::AssertUnwindSafe, + sync::{Arc, Mutex}, +}; use super::{ - JobId, current::CurrentJob, entity::RetryPolicy, error::JobError, repo::JobRepo, runner::*, - tracker::JobTracker, + JobId, cancellation_tokens::CancellationTokens, current::CurrentJob, entity::RetryPolicy, + error::JobError, repo::JobRepo, runner::*, tracker::JobTracker, }; #[derive(Debug)] @@ -24,11 +27,14 @@ pub(crate) struct JobDispatcher { retry_settings: RetrySettings, runner: Option>, tracker: Arc, + cancellation_tokens: Arc, + cancel_token: tokio_util::sync::CancellationToken, rescheduled: bool, instance_id: uuid::Uuid, clock: ClockHandle, } impl JobDispatcher { + #[allow(clippy::too_many_arguments)] pub fn new( repo: Arc, tracker: Arc, @@ -37,12 +43,16 @@ impl JobDispatcher { runner: Box, instance_id: uuid::Uuid, clock: ClockHandle, + cancellation_tokens: Arc, + cancel_token: tokio_util::sync::CancellationToken, ) -> Self { Self { repo, retry_settings, runner: Some(runner), tracker, + cancellation_tokens, + cancel_token, rescheduled: false, instance_id, clock, @@ -82,6 +92,7 @@ impl JobDispatcher { ) .expect("EventContext insert job data"); } + let result_holder = Arc::new(Mutex::new(None::)); let current_job = CurrentJob::new( polled_job.id, polled_job.attempt, @@ -89,28 +100,44 @@ impl JobDispatcher { polled_job.data_json, shutdown_rx, self.clock.clone(), + Arc::clone(&result_holder), + self.cancel_token.clone(), ); self.tracker.dispatch_job(); + let extract_result = + |holder: Arc>>| -> Option { + holder.lock().expect("result mutex poisoned").take() + }; match Self::dispatch_job(self.runner.take().expect("runner"), current_job).await { Err(e) => { span.record("conclusion", "Error"); - self.fail_job(job.id, e, polled_job.attempt).await? + let result = extract_result(result_holder); + self.fail_job(job.id, e, polled_job.attempt, result).await? } Ok(JobCompletion::Complete) => { span.record("conclusion", "Complete"); + let result = extract_result(result_holder); let mut op = self.repo.begin_op_with_clock(&self.clock).await?; - self.complete_job(&mut op, job.id).await?; + self.complete_job(&mut op, job.id, result).await?; + op.commit().await?; + } + Ok(JobCompletion::Cancelled) => { + span.record("conclusion", "Cancelled"); + let mut op = self.repo.begin_op_with_clock(&self.clock).await?; + self.cancel_running_job(&mut op, job.id).await?; op.commit().await?; } #[cfg(feature = "es-entity")] Ok(JobCompletion::CompleteWithOp(mut op)) => { span.record("conclusion", "CompleteWithOp"); - self.complete_job(&mut op, job.id).await?; + let result = extract_result(result_holder); + self.complete_job(&mut op, job.id, result).await?; op.commit().await?; } Ok(JobCompletion::CompleteWithTx(mut tx)) => { span.record("conclusion", "CompleteWithTx"); - self.complete_job(&mut tx, job.id).await?; + let result = extract_result(result_holder); + self.complete_job(&mut tx, job.id, result).await?; tx.commit().await?; } Ok(JobCompletion::RescheduleNow) => { @@ -241,7 +268,13 @@ impl JobDispatcher { error.message = tracing::field::Empty ) )] - async fn fail_job(&mut self, id: JobId, error: JobError, attempt: u32) -> Result<(), JobError> { + async fn fail_job( + &mut self, + id: JobId, + error: JobError, + attempt: u32, + result: Option, + ) -> Result<(), JobError> { let mut op = self.repo.begin_op_with_clock(&self.clock).await?; let mut job = self.repo.find_by_id(id).await?; @@ -256,7 +289,7 @@ impl JobDispatcher { let retry_policy = RetryPolicy::from(&self.retry_settings); if let Some((reschedule_at, next_attempt)) = - job.maybe_schedule_retry(self.clock.now(), attempt, &retry_policy, error_str) + job.maybe_schedule_retry(self.clock.now(), attempt, &retry_policy, error_str, result) { let exceeded_warn_attempts = self .retry_settings @@ -309,11 +342,35 @@ impl JobDispatcher { Ok(()) } - #[instrument(name = "job.complete_job", skip(self, op), fields(id = %id))] + #[instrument(name = "job.cancel_running_job", skip(self, op), fields(id = %id))] + async fn cancel_running_job( + &mut self, + op: &mut impl es_entity::AtomicOperation, + id: JobId, + ) -> Result<(), JobError> { + let mut job = self.repo.find_by_id(&id).await?; + sqlx::query!( + r#" + DELETE FROM job_executions + WHERE id = $1 AND poller_instance_id = $2 + "#, + id as JobId, + self.instance_id + ) + .execute(op.as_executor()) + .await?; + job.cancel_execution(); + self.repo.update_in_op(op, &mut job).await?; + self.cancellation_tokens.remove(&id); + Ok(()) + } + + #[instrument(name = "job.complete_job", skip(self, op, result), fields(id = %id))] async fn complete_job( &mut self, op: &mut impl es_entity::AtomicOperation, id: JobId, + result: Option, ) -> Result<(), JobError> { let mut job = self.repo.find_by_id(&id).await?; sqlx::query!( @@ -326,8 +383,9 @@ impl JobDispatcher { ) .execute(op.as_executor()) .await?; - job.complete_job(); + job.complete_job(result); self.repo.update_in_op(op, &mut job).await?; + self.cancellation_tokens.remove(&id); Ok(()) } diff --git a/src/entity.rs b/src/entity.rs index 72ef566..5050ccc 100644 --- a/src/entity.rs +++ b/src/entity.rs @@ -11,6 +11,40 @@ use es_entity::{context::TracingContext, *}; use crate::{JobId, error::JobError}; +/// Outcome returned by [`Jobs::await_completion`](crate::Jobs::await_completion), +/// carrying both the terminal state and an optional result value. +#[derive(Debug, Clone)] +pub struct JobCompletionResult { + state: JobTerminalState, + result: Option, +} + +impl JobCompletionResult { + pub(crate) fn new(state: JobTerminalState, result: Option) -> Self { + Self { state, result } + } + + /// The terminal state the job reached. + pub fn state(&self) -> JobTerminalState { + self.state + } + + /// Returns the raw JSON result value, if any. + pub fn result(&self) -> Option<&serde_json::Value> { + self.result.as_ref() + } + + /// Deserialize the result value into a typed struct. + pub fn typed_result( + &self, + ) -> Result, serde_json::Error> { + match &self.result { + Some(v) => serde_json::from_value(v.clone()).map(Some), + None => Ok(None), + } + } +} + /// Terminal outcome of a job lifecycle. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum JobTerminalState { @@ -79,7 +113,11 @@ pub enum JobEvent { ExecutionErrored { error: String, }, - JobCompleted, + ExecutionCancelled, + JobCompleted { + #[serde(default)] + result: Option, + }, Cancelled, AttemptCounterReset, } @@ -193,7 +231,7 @@ impl Job { self.events .iter_all() .rev() - .any(|event| matches!(event, JobEvent::JobCompleted | JobEvent::Cancelled)) + .any(|event| matches!(event, JobEvent::JobCompleted { .. } | JobEvent::Cancelled)) } /// Returns `true` if the job was cancelled. @@ -215,7 +253,7 @@ impl Job { let mut rev = self.events.iter_all().rev(); match rev.next()? { JobEvent::Cancelled => Some(JobTerminalState::Cancelled), - JobEvent::JobCompleted => match rev.next() { + JobEvent::JobCompleted { .. } => match rev.next() { Some(JobEvent::ExecutionErrored { .. }) => Some(JobTerminalState::Errored), _ => Some(JobTerminalState::Completed), }, @@ -223,6 +261,27 @@ impl Job { } } + /// Returns the result value attached to this job, if any. + pub fn result(&self) -> Option<&serde_json::Value> { + self.events.iter_all().rev().find_map(|event| { + if let JobEvent::JobCompleted { result } = event { + result.as_ref() + } else { + None + } + }) + } + + /// Deserialize the result value into a typed struct. + pub fn typed_result( + &self, + ) -> Result, serde_json::Error> { + match self.result() { + Some(v) => serde_json::from_value(v.clone()).map(Some), + None => Ok(None), + } + } + pub(crate) fn inject_tracing_parent(&self) { if let JobEvent::Initialized { tracing_context: Some(tracing_context), @@ -261,9 +320,9 @@ impl Job { }); } - pub(super) fn complete_job(&mut self) { + pub(super) fn complete_job(&mut self, result: Option) { self.events.push(JobEvent::ExecutionCompleted); - self.events.push(JobEvent::JobCompleted); + self.events.push(JobEvent::JobCompleted { result }); } pub(crate) fn cancel(&mut self) -> es_entity::Idempotent<()> { @@ -274,6 +333,15 @@ impl Job { es_entity::Idempotent::Executed(()) } + /// Record cancellation of a running execution and mark the job as cancelled. + /// + /// Used by the dispatcher when a running job is cooperatively or forcibly + /// cancelled. Pushes both `ExecutionCancelled` and `Cancelled` events. + pub(super) fn cancel_execution(&mut self) { + self.events.push(JobEvent::ExecutionCancelled); + self.events.push(JobEvent::Cancelled); + } + pub(super) fn schedule_retry( &mut self, error: String, @@ -287,9 +355,9 @@ impl Job { }); } - pub(super) fn error_job(&mut self, error: String) { + pub(super) fn error_job(&mut self, error: String, result: Option) { self.events.push(JobEvent::ExecutionErrored { error }); - self.events.push(JobEvent::JobCompleted); + self.events.push(JobEvent::JobCompleted { result }); } pub(super) fn maybe_schedule_retry( @@ -298,6 +366,7 @@ impl Job { attempt: u32, retry_policy: &RetryPolicy, error: String, + result: Option, ) -> Option<(DateTime, u32)> { let mut current_attempt = attempt.max(1); if self @@ -312,7 +381,7 @@ impl Job { let next_attempt = current_attempt.saturating_add(1); let max_attempts = retry_policy.max_attempts.unwrap_or(u32::MAX); if next_attempt > max_attempts { - self.error_job(error); + self.error_job(error, result); return None; } @@ -362,7 +431,8 @@ impl TryFromEvents for Job { JobEvent::ExecutionCompleted => {} JobEvent::ExecutionAborted { .. } => {} JobEvent::ExecutionErrored { .. } => {} - JobEvent::JobCompleted => {} + JobEvent::ExecutionCancelled => {} + JobEvent::JobCompleted { .. } => {} JobEvent::Cancelled => {} JobEvent::AttemptCounterReset => {} } @@ -532,7 +602,7 @@ mod tests { let retry_policy = build_retry_policy(Some(3)); let (_, next_attempt) = job - .maybe_schedule_retry(Clock::now(), 1, &retry_policy, "boom".to_string()) + .maybe_schedule_retry(Clock::now(), 1, &retry_policy, "boom".to_string(), None) .expect("retry expected"); assert_eq!(next_attempt, 2); @@ -565,7 +635,7 @@ mod tests { let retry_policy = build_retry_policy(Some(3)); let (_, next_attempt) = job - .maybe_schedule_retry(Clock::now(), 0, &retry_policy, "boom".to_string()) + .maybe_schedule_retry(Clock::now(), 0, &retry_policy, "boom".to_string(), None) .expect("retry expected when attempt starts at zero"); assert_eq!(next_attempt, 2); @@ -602,7 +672,7 @@ mod tests { let retry_policy = build_retry_policy(Some(2)); assert!( - job.maybe_schedule_retry(Clock::now(), 2, &retry_policy, "boom".to_string()) + job.maybe_schedule_retry(Clock::now(), 2, &retry_policy, "boom".to_string(), None) .is_none(), "should stop retrying when attempts exhausted" ); @@ -612,7 +682,7 @@ mod tests { events[events.len() - 2], JobEvent::ExecutionErrored { .. } )); - assert!(matches!(events.last(), Some(JobEvent::JobCompleted))); + assert!(matches!(events.last(), Some(JobEvent::JobCompleted { .. }))); } #[test] @@ -637,7 +707,7 @@ mod tests { let retry_policy = build_retry_policy(Some(5)); let (_, next_attempt) = job - .maybe_schedule_retry(Clock::now(), 2, &retry_policy, "boom".to_string()) + .maybe_schedule_retry(Clock::now(), 2, &retry_policy, "boom".to_string(), None) .expect("retry expected"); assert_eq!( @@ -681,7 +751,13 @@ mod tests { let retry_policy = build_retry_policy(Some(3)); let (_, next_attempt) = job - .maybe_schedule_retry(Clock::now(), 2, &retry_policy, "second failure".to_string()) + .maybe_schedule_retry( + Clock::now(), + 2, + &retry_policy, + "second failure".to_string(), + None, + ) .expect("final retry should still be scheduled"); assert_eq!(next_attempt, 3); @@ -720,7 +796,13 @@ mod tests { let retry_policy = build_retry_policy(Some(3)); let (_, next_attempt) = job - .maybe_schedule_retry(Clock::now(), 3, &retry_policy, "third failure".to_string()) + .maybe_schedule_retry( + Clock::now(), + 3, + &retry_policy, + "third failure".to_string(), + None, + ) .expect("a healthy gap should reset attempt even at limit"); assert_eq!(next_attempt, 2); @@ -760,7 +842,13 @@ mod tests { let retry_policy = build_retry_policy(None); let (_, next_attempt) = job - .maybe_schedule_retry(Clock::now(), attempt, &retry_policy, "overflow".to_string()) + .maybe_schedule_retry( + Clock::now(), + attempt, + &retry_policy, + "overflow".to_string(), + None, + ) .expect("unbounded retries should permit another schedule"); assert_eq!(next_attempt, u32::MAX); diff --git a/src/lib.rs b/src/lib.rs index 8f073c0..8e4dff8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -205,6 +205,7 @@ #![cfg_attr(feature = "fail-on-warnings", deny(clippy::all))] #![forbid(unsafe_code)] +mod cancellation_tokens; mod config; mod current; mod dispatcher; @@ -228,7 +229,7 @@ use std::sync::{Arc, Mutex}; pub use config::*; pub use current::*; -pub use entity::{Job, JobTerminalState, JobType}; +pub use entity::{Job, JobCompletionResult, JobTerminalState, JobType}; pub use es_entity::clock::{ ArtificialClockConfig, ArtificialMode, Clock, ClockController, ClockHandle, }; @@ -237,6 +238,7 @@ pub use registry::*; pub use runner::*; pub use spawner::*; +use cancellation_tokens::*; use error::*; use notification_router::*; use poller::*; @@ -245,6 +247,17 @@ use tracker::*; es_entity::entity_id! { JobId } +/// Outcome of a [`Jobs::cancel_job`] call. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CancelResult { + /// The job was successfully cancelled (pending or running). + Cancelled, + /// The job had already reached a terminal state. + AlreadyCompleted, + /// No job with the given id was found. + NotFound, +} + #[derive(Clone)] /// Primary entry point for interacting with the Job crate. Provides APIs to register job /// handlers, manage configuration, and control scheduling and execution. @@ -253,6 +266,7 @@ pub struct Jobs { repo: Arc, registry: Arc>>, router: Arc, + cancellation_tokens: Arc, poller_handle: Option>, clock: ClockHandle, } @@ -283,12 +297,14 @@ impl Jobs { let repo = Arc::new(JobRepo::new(&pool)); let registry = Arc::new(Mutex::new(Some(JobRegistry::new()))); let router = Arc::new(JobNotificationRouter::new(&pool, Arc::clone(&repo))); + let cancellation_tokens = Arc::new(CancellationTokens::new()); let clock = config.clock.clone(); Ok(Self { repo, config, registry, router, + cancellation_tokens, poller_handle: None, clock, }) @@ -456,13 +472,20 @@ impl Jobs { Arc::clone(&self.repo), registry, Arc::clone(&tracker), + Arc::clone(&self.cancellation_tokens), self.clock.clone(), ); let job_types = poller.registered_job_types(); - let (listener_handle, waiter_handle) = - self.router.start(Arc::clone(&tracker), job_types).await?; + let (listener_handle, waiter_handle) = self + .router + .start( + Arc::clone(&tracker), + job_types, + Arc::clone(&self.cancellation_tokens), + ) + .await?; let poller_handle = poller.start(listener_handle, waiter_handle); self.poller_handle = Some(Arc::new(poller_handle)); @@ -498,33 +521,72 @@ impl Jobs { Ok(self.repo.find_by_id(id).await?) } - /// Cancel a pending job, removing it from the execution queue. + /// Cancel a job, whether it is pending or currently running. + /// + /// - **Pending jobs** are cancelled atomically: the execution row is deleted + /// and cancel events are recorded in the same transaction. + /// - **Running jobs** have their `cancelled_at` column set, which triggers a + /// PG NOTIFY that routes through the unified router to signal the job's + /// cancellation token. The runner can observe this via + /// [`CurrentJob::cancellation_requested`] and return + /// [`JobCompletion::Cancelled`]. If the runner does not cooperate within + /// the configured `cancel_timeout`, the task is force-aborted. /// /// This operation is idempotent — calling it on an already cancelled or - /// completed job is a no-op. If the job exists but is currently running - /// (not pending), returns [`JobError::CannotCancelJob`]. + /// completed job returns [`CancelResult::AlreadyCompleted`]. #[instrument(name = "job.cancel_job", skip(self))] - pub async fn cancel_job(&self, id: JobId) -> Result<(), JobError> { - let mut op = self.repo.begin_op_with_clock(&self.clock).await?; - let mut job = self.repo.find_by_id(id).await?; + pub async fn cancel_job(&self, id: JobId) -> Result { + let job = match self.repo.find_by_id(id).await { + Ok(j) => j, + Err(_) => return Ok(CancelResult::NotFound), + }; - if job.cancel().did_execute() { - let result = sqlx::query!( - r#"DELETE FROM job_executions WHERE id = $1 AND state = 'pending'"#, - id as JobId, - ) - .execute(op.as_executor()) - .await?; + // Already in a terminal state — nothing to do. + if job.completed() { + return Ok(CancelResult::AlreadyCompleted); + } - if result.rows_affected() == 0 { - return Err(JobError::CannotCancelJob); - } + let mut op = self.repo.begin_op_with_clock(&self.clock).await?; + // Try to cancel a pending execution atomically. + let result = sqlx::query!( + r#"DELETE FROM job_executions WHERE id = $1 AND state = 'pending'"#, + id as JobId, + ) + .execute(op.as_executor()) + .await?; + + if result.rows_affected() > 0 { + // Pending job — record cancel events in the same transaction. + let mut job = job; + let _ = job.cancel(); self.repo.update_in_op(&mut op, &mut job).await?; op.commit().await?; + return Ok(CancelResult::Cancelled); } - Ok(()) + // Job is running — set cancelled_at and send application-level NOTIFY + // to signal the cancellation token via the unified router. + let now = self.clock.now(); + sqlx::query!( + r#"UPDATE jobs SET cancelled_at = $2 WHERE id = $1 AND cancelled_at IS NULL"#, + id as JobId, + now, + ) + .execute(op.as_executor()) + .await?; + + let payload = serde_json::json!({ + "type": "job_cancel", + "execution_id": id.to_string() + }); + sqlx::query("SELECT pg_notify('job_events', $1)") + .bind(payload.to_string()) + .execute(op.as_executor()) + .await?; + + op.commit().await?; + Ok(CancelResult::Cancelled) } /// Returns a reference to the clock used by this job service. @@ -533,7 +595,8 @@ impl Jobs { } /// Block until the given job reaches a terminal state (completed, errored, or - /// cancelled) and return the outcome. + /// cancelled) and return the outcome together with any result value the + /// runner attached via [`CurrentJob::set_result`]. /// /// # Errors /// @@ -541,12 +604,17 @@ impl Jobs { /// Returns [`JobError::AwaitCompletionShutdown`] if the notification channel is /// dropped (e.g., during shutdown) before delivering the terminal state. #[instrument(name = "job.await_completion", skip(self))] - pub async fn await_completion(&self, id: JobId) -> Result { + pub async fn await_completion(&self, id: JobId) -> Result { // Fail fast if the job doesn't exist — avoids a 5-minute silent hang // in the waiter manager for a JobId that will never resolve. self.find(id).await?; let rx = self.router.wait_for_terminal(id); - rx.await.map_err(|_| JobError::AwaitCompletionShutdown(id)) + let state = rx + .await + .map_err(|_| JobError::AwaitCompletionShutdown(id))?; + // Load job to retrieve any result value set by the runner + let job = self.find(id).await?; + Ok(JobCompletionResult::new(state, job.result().cloned())) } /// Gracefully shut down the job poller. diff --git a/src/notification_router.rs b/src/notification_router.rs index 2217cd5..c61df1d 100644 --- a/src/notification_router.rs +++ b/src/notification_router.rs @@ -3,6 +3,7 @@ use std::sync::{Arc, OnceLock}; use std::time::Duration; use crate::JobId; +use crate::cancellation_tokens::CancellationTokens; use crate::entity::{JobTerminalState, JobType}; use crate::handle::OwnedTaskHandle; use crate::repo::JobRepo; @@ -16,6 +17,7 @@ use tokio::sync::{broadcast, mpsc, oneshot}; enum JobNotification { ExecutionReady { job_type: String }, JobTerminal { job_id: JobId }, + JobCancel { execution_id: String }, } type WaiterRegistration = (JobId, oneshot::Sender); @@ -54,13 +56,16 @@ impl JobNotificationRouter { &self, tracker: Arc, job_types: Vec, + cancellation_tokens: Arc, ) -> Result<(OwnedTaskHandle, OwnedTaskHandle), sqlx::Error> { let (register_tx, register_rx) = mpsc::unbounded_channel(); self.register_tx .set(register_tx) .expect("router started more than once"); - let listener_handle = self.start_listener(tracker, job_types).await?; + let listener_handle = self + .start_listener(tracker, job_types, cancellation_tokens) + .await?; let waiter_handle = Self::start_waiter_manager( register_rx, self.terminal_tx.subscribe(), @@ -74,6 +79,7 @@ impl JobNotificationRouter { &self, tracker: Arc, job_types: Vec, + cancellation_tokens: Arc, ) -> Result { let mut listener = PgListener::connect_with(&self.pool).await?; listener.listen("job_events").await?; @@ -94,6 +100,11 @@ impl JobNotificationRouter { Ok(JobNotification::JobTerminal { job_id }) => { let _ = terminal_tx.send(job_id); } + Ok(JobNotification::JobCancel { execution_id }) => { + if let Ok(uuid) = execution_id.parse::() { + cancellation_tokens.cancel(&JobId::from(uuid)); + } + } Err(e) => { tracing::warn!( error = %e, diff --git a/src/poller.rs b/src/poller.rs index 1e6706c..4ebf1f7 100644 --- a/src/poller.rs +++ b/src/poller.rs @@ -13,8 +13,9 @@ use std::{ }; use super::{ - JobId, config::JobPollerConfig, dispatcher::*, entity::JobType, error::JobError, - handle::OwnedTaskHandle, registry::JobRegistry, repo::JobRepo, tracker::JobTracker, + JobId, cancellation_tokens::CancellationTokens, config::JobPollerConfig, dispatcher::*, + entity::JobType, error::JobError, handle::OwnedTaskHandle, registry::JobRegistry, + repo::JobRepo, tracker::JobTracker, }; /// Helper macro to spawn tasks with optional names based on the tokio-task-names feature @@ -41,6 +42,7 @@ pub(crate) struct JobPoller { repo: Arc, registry: JobRegistry, tracker: Arc, + cancellation_tokens: Arc, instance_id: uuid::Uuid, shutdown_tx: tokio::sync::broadcast::Sender< tokio::sync::mpsc::Sender>, @@ -76,6 +78,7 @@ impl JobPoller { repo: Arc, registry: JobRegistry, tracker: Arc, + cancellation_tokens: Arc, clock: ClockHandle, ) -> Self { let (shutdown_tx, _) = tokio::sync::broadcast::channel::< @@ -86,6 +89,7 @@ impl JobPoller { repo, config, registry, + cancellation_tokens, instance_id: uuid::Uuid::now_v7(), shutdown_tx, clock, @@ -254,6 +258,7 @@ impl JobPoller { let pool = self.repo.pool().clone(); let instance_id = self.instance_id; let clock = self.clock.clone(); + let cancellation_tokens = Arc::clone(&self.cancellation_tokens); OwnedTaskHandle::new(spawn_named_task!( "job-poller-keep-alive-handler", async move { @@ -291,6 +296,25 @@ impl JobPoller { Duration::from_millis(50 << failures) } }; + + // Safety net: cancel tokens for jobs whose cancelled_at is set + // but the NOTIFY may have been missed. + if let Ok(rows) = sqlx::query_scalar::<_, uuid::Uuid>( + "SELECT je.id FROM job_executions je \ + JOIN jobs j ON je.id = j.id \ + WHERE je.poller_instance_id = $1 \ + AND je.state = 'running' \ + AND j.cancelled_at IS NOT NULL", + ) + .bind(instance_id) + .fetch_all(&pool) + .await + { + for uuid in rows { + cancellation_tokens.cancel(&JobId::from(uuid)); + } + } + drop(_guard); clock.sleep(timeout).await; } @@ -321,6 +345,11 @@ impl JobPoller { span.record("now", tracing::field::display(clock.now())); span.record("poller_id", tracing::field::display(instance_id)); + // Create a cancellation token for this job + let cancel_token = self.cancellation_tokens.insert(polled_job.id); + let cancel_monitor_token = cancel_token.clone(); + let cancellation_tokens = Arc::clone(&self.cancellation_tokens); + let shutdown_rx = self.shutdown_tx.subscribe(); let job_id = job.id; let job_type = job.job_type.clone(); @@ -341,6 +370,8 @@ impl JobPoller { runner, instance_id, clock, + cancellation_tokens, + cancel_token, ) .execute_job(polled_job, shutdown_rx) .await @@ -351,6 +382,10 @@ impl JobPoller { let mut shutdown_rx = self.shutdown_tx.subscribe(); let shutdown_timeout = self.config.shutdown_timeout; + let cancel_timeout = self.config.cancel_timeout; + let cancel_monitor_repo = Arc::clone(&self.repo); + let cancel_monitor_clock = self.clock.clone(); + let cancel_monitor_tokens = Arc::clone(&self.cancellation_tokens); #[cfg_attr( not(all(feature = "tokio-task-names", tokio_unstable)), allow(unused_variables) @@ -364,7 +399,7 @@ impl JobPoller { tokio::select! { _ = &mut job_handle => { - // Job completed - no need for shutdown coordination + // Job completed - no need for shutdown or cancel coordination } Ok(shutdown_notifier) = shutdown_rx.recv() => { let (send, recv) = tokio::sync::oneshot::channel(); @@ -404,6 +439,30 @@ impl JobPoller { ) ).await; } + _ = cancel_monitor_token.cancelled() => { + // Cancellation requested — give the runner a grace period + async { + if tokio::time::timeout(cancel_timeout, &mut job_handle).await.is_err() { + tracing::warn!("Job exceeded cancel timeout, force-aborting"); + job_handle.abort(); + let _ = job_handle.await; + // Clean up: delete execution, record cancel events + let _ = force_cancel_cleanup( + &cancel_monitor_repo, + job_id, + instance_id, + &cancel_monitor_clock, + ).await; + } + cancel_monitor_tokens.remove(&job_id); + }.instrument(tracing::info_span!( + parent: None, + "job.cancel_coordination", + job_id = %job_id, + job_type = %job_type, + ) + ).await; + } } }); @@ -411,6 +470,39 @@ impl JobPoller { } } +/// Clean up a force-cancelled job: delete its execution row and record cancel events. +/// +/// Called by the monitor task after aborting a job that didn't finish within +/// the cancel timeout. +#[instrument(name = "job.force_cancel_cleanup", skip(repo, clock), fields(job_id = %job_id))] +async fn force_cancel_cleanup( + repo: &JobRepo, + job_id: JobId, + instance_id: uuid::Uuid, + clock: &ClockHandle, +) -> Result<(), JobError> { + let mut op = repo.begin_op_with_clock(clock).await?; + let mut job = repo.find_by_id(job_id).await?; + + sqlx::query!( + r#" + DELETE FROM job_executions + WHERE id = $1 AND poller_instance_id = $2 + "#, + job_id as JobId, + instance_id + ) + .execute(op.as_executor()) + .await?; + + job.cancel_execution(); + repo.update_in_op(&mut op, &mut job).await?; + op.commit().await?; + + tracing::info!("Force-cancelled job cleaned up"); + Ok(()) +} + #[instrument(name = "job.poll_jobs", level = "debug", skip(pool, supported_job_types, clock), fields(n_jobs_to_poll, instance_id = %instance_id, n_jobs_found = tracing::field::Empty), err)] async fn poll_jobs( pool: &PgPool, diff --git a/src/runner.rs b/src/runner.rs index 4938579..f2b677c 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -52,6 +52,8 @@ pub trait JobInitializer: Send + Sync + 'static { pub enum JobCompletion { /// Job finished successfully; mark the record as completed. Complete, + /// Job was cancelled cooperatively; mark the record as cancelled. + Cancelled, #[cfg(feature = "es-entity")] /// Job finished and returns an `EsEntity` operation that the job service will commit. CompleteWithOp(es_entity::DbOp<'static>), diff --git a/tests/job.rs b/tests/job.rs index 149bb77..cf35619 100644 --- a/tests/job.rs +++ b/tests/job.rs @@ -3,9 +3,9 @@ mod helpers; use async_trait::async_trait; use chrono::{DateTime, Utc}; use job::{ - ArtificialClockConfig, ClockHandle, CurrentJob, Job, JobCompletion, JobId, JobInitializer, - JobRunner, JobSpawner, JobSpec, JobSvcConfig, JobTerminalState, JobType, Jobs, RetrySettings, - error::JobError, + ArtificialClockConfig, CancelResult, ClockHandle, CurrentJob, Job, JobCompletion, + JobCompletionResult, JobId, JobInitializer, JobRunner, JobSpawner, JobSpec, JobSvcConfig, + JobTerminalState, JobType, Jobs, RetrySettings, error::JobError, }; use serde::{Deserialize, Serialize}; use std::sync::Arc; @@ -671,7 +671,8 @@ async fn test_cancel_pending_job() -> anyhow::Result<()> { .await?; // Cancel the pending job - jobs.cancel_job(job_id).await?; + let result = jobs.cancel_job(job_id).await?; + assert_eq!(result, CancelResult::Cancelled); // Verify it's findable as a cancelled/completed entity let found = jobs.find(job_id).await?; @@ -682,7 +683,7 @@ async fn test_cancel_pending_job() -> anyhow::Result<()> { } #[tokio::test] -async fn test_cancel_running_job_fails() -> anyhow::Result<()> { +async fn test_cancel_running_job_succeeds() -> anyhow::Result<()> { let pool = helpers::init_pool().await?; let config = JobSvcConfig::builder() .pool(pool) @@ -722,15 +723,11 @@ async fn test_cancel_running_job_fails() -> anyhow::Result<()> { assert!(attempts < 100, "Job never started"); } - // Cancel on a running job should fail - let result = jobs.cancel_job(job_id).await; - assert!( - matches!(result, Err(JobError::CannotCancelJob)), - "Cancelling a running job should return JobNotPending, got err: {:?}", - result.err(), - ); + // Cancel on a running job should succeed (sets cancelled_at) + let result = jobs.cancel_job(job_id).await?; + assert_eq!(result, CancelResult::Cancelled); - // Release the job so it completes normally + // Release the job so the force-cancel monitor can clean up release.notify_one(); // Wait for completion @@ -782,8 +779,9 @@ async fn test_cancel_already_completed_job_is_idempotent() -> anyhow::Result<()> assert!(attempts < 100, "Job never completed"); } - // Cancel on an already completed job is a no-op - jobs.cancel_job(job_id).await?; + // Cancel on an already completed job returns AlreadyCompleted + let result = jobs.cancel_job(job_id).await?; + assert_eq!(result, CancelResult::AlreadyCompleted); let job = jobs.find(job_id).await?; assert!( @@ -860,8 +858,8 @@ async fn test_await_completion_on_success() -> anyhow::Result<()> { let jobs_clone = jobs.clone(); let handle = tokio::spawn(async move { jobs_clone.await_completion(job_id).await }); - let state = handle.await??; - assert_eq!(state, JobTerminalState::Completed); + let outcome = handle.await??; + assert_eq!(outcome.state(), JobTerminalState::Completed); Ok(()) } @@ -884,8 +882,8 @@ async fn test_await_completion_on_error() -> anyhow::Result<()> { let jobs_clone = jobs.clone(); let handle = tokio::spawn(async move { jobs_clone.await_completion(job_id).await }); - let state = handle.await??; - assert_eq!(state, JobTerminalState::Errored); + let outcome = handle.await??; + assert_eq!(outcome.state(), JobTerminalState::Errored); Ok(()) } @@ -918,8 +916,8 @@ async fn test_await_completion_on_cancel() -> anyhow::Result<()> { tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; jobs.cancel_job(job_id).await?; - let state = handle.await??; - assert_eq!(state, JobTerminalState::Cancelled); + let outcome = handle.await??; + assert_eq!(outcome.state(), JobTerminalState::Cancelled); Ok(()) } @@ -956,8 +954,349 @@ async fn test_await_completion_already_completed() -> anyhow::Result<()> { } // Now call await_completion — should return immediately - let state = jobs.await_completion(job_id).await?; - assert_eq!(state, JobTerminalState::Completed); + let outcome = jobs.await_completion(job_id).await?; + assert_eq!(outcome.state(), JobTerminalState::Completed); + + Ok(()) +} + +// -- Result passing tests -- + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +struct MyResult { + value: i32, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ResultJobConfig; + +struct ResultJobInitializer; + +impl JobInitializer for ResultJobInitializer { + type Config = ResultJobConfig; + + fn job_type(&self) -> JobType { + JobType::new("result-job") + } + + fn init( + &self, + _job: &Job, + _: JobSpawner, + ) -> Result, Box> { + Ok(Box::new(ResultJobRunner)) + } +} + +struct ResultJobRunner; + +#[async_trait] +impl JobRunner for ResultJobRunner { + async fn run( + &self, + current_job: CurrentJob, + ) -> Result> { + // Incremental updates — only the last value is persisted + current_job.set_result(&MyResult { value: 1 })?; + current_job.set_result(&MyResult { value: 42 })?; + Ok(JobCompletion::Complete) + } +} + +#[tokio::test] +async fn test_await_completion_returns_result() -> anyhow::Result<()> { + let pool = helpers::init_pool().await?; + let config = JobSvcConfig::builder() + .pool(pool) + .build() + .expect("Failed to build JobsConfig"); + + let mut jobs = Jobs::init(config).await?; + let spawner = jobs.add_initializer(ResultJobInitializer); + jobs.start_poll().await?; + + let job_id = JobId::new(); + spawner.spawn(job_id, ResultJobConfig).await?; + + let outcome = jobs.await_completion(job_id).await?; + assert_eq!(outcome.state(), JobTerminalState::Completed); + let result: MyResult = outcome + .typed_result() + .expect("deserialize result") + .expect("result should be Some"); + assert_eq!(result, MyResult { value: 42 }); + + Ok(()) +} + +struct PartialResultThenErrorInitializer; + +impl JobInitializer for PartialResultThenErrorInitializer { + type Config = ResultJobConfig; + + fn job_type(&self) -> JobType { + JobType::new("partial-result-error-job") + } + + fn retry_on_error_settings(&self) -> RetrySettings { + RetrySettings { + n_attempts: Some(1), + ..Default::default() + } + } + + fn init( + &self, + _job: &Job, + _: JobSpawner, + ) -> Result, Box> { + Ok(Box::new(PartialResultThenErrorRunner)) + } +} + +struct PartialResultThenErrorRunner; + +#[async_trait] +impl JobRunner for PartialResultThenErrorRunner { + async fn run( + &self, + current_job: CurrentJob, + ) -> Result> { + // Simulate processing 50 items then failing — partial progress preserved + current_job.set_result(&MyResult { value: 50 })?; + current_job.set_result(&MyResult { value: 99 })?; + Err("intentional failure after setting result".into()) + } +} + +#[tokio::test] +async fn test_await_completion_returns_partial_result_on_error() -> anyhow::Result<()> { + let pool = helpers::init_pool().await?; + let config = JobSvcConfig::builder() + .pool(pool) + .build() + .expect("Failed to build JobsConfig"); + + let mut jobs = Jobs::init(config).await?; + let spawner = jobs.add_initializer(PartialResultThenErrorInitializer); + jobs.start_poll().await?; + + let job_id = JobId::new(); + spawner.spawn(job_id, ResultJobConfig).await?; + + let outcome = jobs.await_completion(job_id).await?; + assert_eq!(outcome.state(), JobTerminalState::Errored); + let result: MyResult = outcome + .typed_result() + .expect("deserialize result") + .expect("partial result should be Some"); + assert_eq!(result, MyResult { value: 99 }); + + Ok(()) +} + +struct NoResultJobInitializer; + +impl JobInitializer for NoResultJobInitializer { + type Config = ResultJobConfig; + + fn job_type(&self) -> JobType { + JobType::new("no-result-job") + } + + fn init( + &self, + _job: &Job, + _: JobSpawner, + ) -> Result, Box> { + Ok(Box::new(NoResultJobRunner)) + } +} + +struct NoResultJobRunner; + +#[async_trait] +impl JobRunner for NoResultJobRunner { + async fn run( + &self, + _current_job: CurrentJob, + ) -> Result> { + Ok(JobCompletion::Complete) + } +} + +#[tokio::test] +async fn test_await_completion_no_result() -> anyhow::Result<()> { + let pool = helpers::init_pool().await?; + let config = JobSvcConfig::builder() + .pool(pool) + .build() + .expect("Failed to build JobsConfig"); + + let mut jobs = Jobs::init(config).await?; + let spawner = jobs.add_initializer(NoResultJobInitializer); + jobs.start_poll().await?; + + let job_id = JobId::new(); + spawner.spawn(job_id, ResultJobConfig).await?; + + let outcome = jobs.await_completion(job_id).await?; + assert_eq!(outcome.state(), JobTerminalState::Completed); + assert!(outcome.result().is_none()); + + Ok(()) +} + +// -- Incremental set_result tests -- + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +struct BatchProgress { + processed: u32, + total: u32, +} + +struct IncrementalResultInitializer; + +impl JobInitializer for IncrementalResultInitializer { + type Config = ResultJobConfig; + + fn job_type(&self) -> JobType { + JobType::new("incremental-result-job") + } + + fn init( + &self, + _job: &Job, + _: JobSpawner, + ) -> Result, Box> { + Ok(Box::new(IncrementalResultRunner)) + } +} + +struct IncrementalResultRunner; + +#[async_trait] +impl JobRunner for IncrementalResultRunner { + async fn run( + &self, + current_job: CurrentJob, + ) -> Result> { + let total = 5; + for i in 1..=total { + current_job.set_result(&BatchProgress { + processed: i, + total, + })?; + } + Ok(JobCompletion::Complete) + } +} + +#[tokio::test] +async fn test_set_result_multiple_calls_keeps_last() -> anyhow::Result<()> { + let pool = helpers::init_pool().await?; + let config = JobSvcConfig::builder() + .pool(pool) + .build() + .expect("Failed to build JobsConfig"); + + let mut jobs = Jobs::init(config).await?; + let spawner = jobs.add_initializer(IncrementalResultInitializer); + jobs.start_poll().await?; + + let job_id = JobId::new(); + spawner.spawn(job_id, ResultJobConfig).await?; + + let outcome = jobs.await_completion(job_id).await?; + assert_eq!(outcome.state(), JobTerminalState::Completed); + let progress: BatchProgress = outcome + .typed_result() + .expect("deserialize result") + .expect("result should be Some"); + assert_eq!( + progress, + BatchProgress { + processed: 5, + total: 5 + } + ); + + Ok(()) +} + +struct IncrementalResultThenErrorInitializer; + +impl JobInitializer for IncrementalResultThenErrorInitializer { + type Config = ResultJobConfig; + + fn job_type(&self) -> JobType { + JobType::new("incremental-error-result-job") + } + + fn retry_on_error_settings(&self) -> RetrySettings { + RetrySettings { + n_attempts: Some(1), + ..Default::default() + } + } + + fn init( + &self, + _job: &Job, + _: JobSpawner, + ) -> Result, Box> { + Ok(Box::new(IncrementalResultThenErrorRunner)) + } +} + +struct IncrementalResultThenErrorRunner; + +#[async_trait] +impl JobRunner for IncrementalResultThenErrorRunner { + async fn run( + &self, + current_job: CurrentJob, + ) -> Result> { + let total = 100; + for i in 1..=50 { + current_job.set_result(&BatchProgress { + processed: i, + total, + })?; + } + Err("failed at item 51".into()) + } +} + +#[tokio::test] +async fn test_set_result_partial_progress_preserved_on_error() -> anyhow::Result<()> { + let pool = helpers::init_pool().await?; + let config = JobSvcConfig::builder() + .pool(pool) + .build() + .expect("Failed to build JobsConfig"); + + let mut jobs = Jobs::init(config).await?; + let spawner = jobs.add_initializer(IncrementalResultThenErrorInitializer); + jobs.start_poll().await?; + + let job_id = JobId::new(); + spawner.spawn(job_id, ResultJobConfig).await?; + + let outcome = jobs.await_completion(job_id).await?; + assert_eq!(outcome.state(), JobTerminalState::Errored); + let progress: BatchProgress = outcome + .typed_result() + .expect("deserialize result") + .expect("partial result should be Some"); + assert_eq!( + progress, + BatchProgress { + processed: 50, + total: 100 + }, + "partial progress from before the error should be preserved" + ); Ok(()) }