diff --git a/src/postgres/migrations/20251202002136_initial_setup.sql b/src/postgres/migrations/20251202002136_initial_setup.sql index db2325d..b4d616d 100644 --- a/src/postgres/migrations/20251202002136_initial_setup.sql +++ b/src/postgres/migrations/20251202002136_initial_setup.sql @@ -1048,18 +1048,6 @@ begin end; $$; --- Advisory lock to serialize await_event and emit_event operations on the same event. --- This prevents lost wakeups when a waiter is being set up while an emit is happening. --- Called at the top of await_event and emit_event. -create function durable.lock_event ( - p_queue_name text, - p_event_name text -) - returns void - language sql -as $$ - select pg_advisory_xact_lock(hashtext(p_queue_name), hashtext(p_event_name)); -$$; -- awaits an event for a given task's run and step name. -- this will immediately return if it the event has already returned @@ -1095,9 +1083,6 @@ begin raise exception 'event_name must be provided'; end if; - -- Serialize with concurrent emit_event calls on the same event - perform durable.lock_event(p_queue_name, p_event_name); - if p_timeout is not null then if p_timeout < 0 then raise exception 'timeout must be non-negative'; @@ -1122,6 +1107,28 @@ begin return query select false, v_checkpoint_payload; return; end if; + -- Ensure a row exists for this event so we can take a row-level lock. + -- + -- We use payload IS NULL as the sentinel for "not emitted yet". emit_event + -- always writes a non-NULL payload (at minimum JSON null). + -- + -- Lock ordering is important to avoid deadlocks: await_event locks the event + -- row first (FOR SHARE) and then the run row (FOR UPDATE). emit_event + -- naturally locks the event row via its UPSERT before touching waits/runs. + execute format( + 'insert into durable.%I (event_name, payload, emitted_at) + values ($1, null, ''epoch''::timestamptz) + on conflict (event_name) do nothing', + 'e_' || p_queue_name + ) using p_event_name; + + execute format( + 'select 1 + from durable.%I + where event_name = $1 + for share', + 'e_' || p_queue_name + ) using p_event_name; -- let's get the run state, any existing event payload and wake event name execute format( @@ -1253,15 +1260,17 @@ begin raise exception 'event_name must be provided'; end if; - -- Serialize with concurrent await_event calls on the same event - perform durable.lock_event(p_queue_name, p_event_name); - -- Insert the event into the events table (first-writer-wins). -- Subsequent emits for the same event are no-ops. + -- We use DO UPDATE WHERE payload IS NULL to handle the case where await_event + -- created a placeholder row before emit_event ran. execute format( 'insert into durable.%I (event_name, payload, emitted_at) values ($1, $2, $3) - on conflict (event_name) do nothing', + on conflict (event_name) do update + set payload = excluded.payload, emitted_at = excluded.emitted_at + where durable.%I.payload is null', + 'e_' || p_queue_name, 'e_' || p_queue_name ) using p_event_name, v_payload, v_now; diff --git a/tests/event_test.rs b/tests/event_test.rs index 6eb9e20..0bc690f 100644 --- a/tests/event_test.rs +++ b/tests/event_test.rs @@ -6,8 +6,10 @@ use common::helpers::{get_task_state, wait_for_task_terminal}; use common::tasks::{EventEmitterParams, EventEmitterTask, EventWaitParams, EventWaitingTask}; use durable::{Durable, MIGRATOR, RetryStrategy, SpawnOptions, WorkerOptions}; use serde_json::json; -use sqlx::{AssertSqlSafe, PgPool}; -use std::time::Duration; +use sqlx::postgres::PgConnectOptions; +use sqlx::{AssertSqlSafe, Connection, PgConnection, PgPool}; +use std::time::{Duration, Instant}; +use uuid::Uuid; async fn create_client(pool: PgPool, queue_name: &str) -> Durable { Durable::builder() @@ -806,41 +808,10 @@ async fn test_emit_event_with_empty_name_fails(pool: PgPool) -> sqlx::Result<()> } // ============================================================================ -// Advisory Lock Tests +// Lock Tests // ============================================================================ -/// Test that both await_event and emit_event use advisory locks for synchronization. -/// This verifies the implementation calls lock_event() by inspecting function definitions. -#[sqlx::test(migrator = "MIGRATOR")] -async fn test_event_functions_use_advisory_locks(pool: PgPool) -> sqlx::Result<()> { - // Check that await_event calls lock_event - let await_def: (String,) = sqlx::query_as( - "SELECT pg_get_functiondef(oid) FROM pg_proc WHERE proname = 'await_event' AND pronamespace = (SELECT oid FROM pg_namespace WHERE nspname = 'durable')" - ) - .fetch_one(&pool) - .await?; - - assert!( - await_def.0.contains("lock_event"), - "await_event should call lock_event for advisory locking" - ); - - // Check that emit_event calls lock_event - let emit_def: (String,) = sqlx::query_as( - "SELECT pg_get_functiondef(oid) FROM pg_proc WHERE proname = 'emit_event' AND pronamespace = (SELECT oid FROM pg_namespace WHERE nspname = 'durable')" - ) - .fetch_one(&pool) - .await?; - - assert!( - emit_def.0.contains("lock_event"), - "emit_event should call lock_event for advisory locking" - ); - - Ok(()) -} - -/// Stress test to verify that advisory locks prevent lost wakeups. +/// Stress test to verify that locking prevent lost wakeups. /// This test spawns many tasks waiting on distinct events and emits all events /// with jittered timing to maximize race condition likelihood. #[sqlx::test(migrator = "MIGRATOR")] @@ -954,3 +925,299 @@ async fn test_event_race_stress(pool: PgPool) -> sqlx::Result<()> { worker.shutdown().await; Ok(()) } + +/// Regression test for the "lost wakeup" race between await_event() and emit_event(). +/// +/// We make the race deterministic by: +/// - pre-creating a dummy wait row for (run_id, step_name) +/// - holding a row lock on it so await_event blocks in the UPSERT path +/// - trying to emit the event while await_event is blocked (should block too) +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_await_emit_event_race_does_not_lose_wakeup(pool: PgPool) -> sqlx::Result<()> { + let queue = "event_race_gate"; + let event_name = "race-event"; + let payload = json!({"value": 42}); + + // Setup: Create queue, spawn task, claim it + sqlx::query("SELECT durable.create_queue($1)") + .bind(queue) + .execute(&pool) + .await?; + + let (task_id, run_id): (Uuid, Uuid) = + sqlx::query_as("SELECT task_id, run_id FROM durable.spawn_task($1, $2, $3, $4)") + .bind(queue) + .bind("waiter") + .bind(json!({"step": 1})) + .bind(json!({})) + .fetch_one(&pool) + .await?; + + let claim: (Uuid, Uuid) = + sqlx::query_as("SELECT run_id, task_id FROM durable.claim_task($1, $2, $3, $4)") + .bind(queue) + .bind("worker") + .bind(60) + .bind(1) + .fetch_one(&pool) + .await?; + assert_eq!(claim.0, run_id); + assert_eq!(claim.1, task_id); + + // Create a dummy wait row so await_event hits the UPDATE path and can block. + sqlx::query(AssertSqlSafe(format!( + "INSERT INTO durable.w_{} (task_id, run_id, step_name, event_name, timeout_at) + VALUES ($1, $2, $3, $4, NULL)", + queue + ))) + .bind(task_id) + .bind(run_id) + .bind("wait") + .bind("dummy") + .execute(&pool) + .await?; + + // Get connect options from pool for creating separate connections + let connect_opts: PgConnectOptions = (*pool.connect_options()).clone(); + + // Open lock connection and hold FOR UPDATE lock on the wait row + let lock_opts = connect_opts.clone().application_name("durable-locker"); + let mut lock_conn = PgConnection::connect_with(&lock_opts).await?; + + sqlx::query("BEGIN").execute(&mut lock_conn).await?; + sqlx::query(AssertSqlSafe(format!( + "SELECT 1 FROM durable.w_{} WHERE run_id = $1 AND step_name = $2 FOR UPDATE", + queue + ))) + .bind(run_id) + .bind("wait") + .execute(&mut lock_conn) + .await?; + + // Spawn async task to call await_event - it will block on the lock + let await_opts = connect_opts.clone().application_name("durable-await-race"); + let queue_clone = queue.to_string(); + let event_name_clone = event_name.to_string(); + let await_handle = tokio::spawn(async move { + let mut conn = PgConnection::connect_with(&await_opts).await?; + + let result: (bool, Option) = sqlx::query_as( + "SELECT should_suspend, payload FROM durable.await_event($1, $2, $3, $4, $5, $6)", + ) + .bind(&queue_clone) + .bind(task_id) + .bind(run_id) + .bind("wait") + .bind(&event_name_clone) + .bind(None::) + .fetch_one(&mut conn) + .await?; + + Ok::<_, sqlx::Error>(result) + }); + + // Wait until await_event is blocked on a lock (the w_ row lock) + let deadline = Instant::now() + Duration::from_secs(5); + loop { + let row: Option<(Option,)> = sqlx::query_as( + "SELECT wait_event_type FROM pg_stat_activity WHERE application_name = $1", + ) + .bind("durable-await-race") + .fetch_optional(&pool) + .await?; + + if let Some((Some(ref wait_type),)) = row + && wait_type == "Lock" + { + break; + } + assert!( + Instant::now() < deadline, + "await_event did not block as expected" + ); + tokio::time::sleep(Duration::from_millis(10)).await; + } + + // While await_event is blocked, emit_event should block on the event-row lock. + // We use a short statement_timeout to verify it blocks. + let emit_opts = connect_opts.clone().application_name("durable-emit"); + let mut emit_conn = PgConnection::connect_with(&emit_opts).await?; + sqlx::query("SET statement_timeout = '200ms'") + .execute(&mut emit_conn) + .await?; + + let emit_result = sqlx::query("SELECT durable.emit_event($1, $2, $3)") + .bind(queue) + .bind(event_name) + .bind(&payload) + .execute(&mut emit_conn) + .await; + + // Should timeout/be cancelled because it's blocked + assert!( + emit_result.is_err(), + "emit_event should have blocked and timed out" + ); + + // Reset statement_timeout for later use + sqlx::query("SET statement_timeout = 0") + .execute(&mut emit_conn) + .await?; + + // Let await_event proceed; it should suspend (no event delivered yet). + sqlx::query("ROLLBACK").execute(&mut lock_conn).await?; + drop(lock_conn); + + let await_result = await_handle + .await + .expect("await task panicked") + .expect("await_event failed"); + let (should_suspend, got_payload) = await_result; + assert!(should_suspend, "should_suspend should be true"); + assert!(got_payload.is_none(), "payload should be null on suspend"); + + // Now emit for real; it must wake the sleeping run and create the checkpoint. + sqlx::query("SELECT durable.emit_event($1, $2, $3)") + .bind(queue) + .bind(event_name) + .bind(&payload) + .execute(&pool) + .await?; + + // Run should now be pending + let (state,): (String,) = sqlx::query_as(AssertSqlSafe(format!( + "SELECT state FROM durable.r_{} WHERE run_id = $1", + queue + ))) + .bind(run_id) + .fetch_one(&pool) + .await?; + assert_eq!(state, "pending"); + + // Claim the task again + let claim2: (Uuid,) = sqlx::query_as("SELECT run_id FROM durable.claim_task($1, $2, $3, $4)") + .bind(queue) + .bind("worker") + .bind(60) + .bind(1) + .fetch_one(&pool) + .await?; + assert_eq!(claim2.0, run_id); + + // await_event should now return the payload (should_suspend = false) + let resume: (bool, Option) = sqlx::query_as( + "SELECT should_suspend, payload FROM durable.await_event($1, $2, $3, $4, $5, $6)", + ) + .bind(queue) + .bind(task_id) + .bind(run_id) + .bind("wait") + .bind(event_name) + .bind(None::) + .fetch_one(&pool) + .await?; + + assert!(!resume.0, "should_suspend should be false on resume"); + assert_eq!( + resume.1, + Some(payload), + "payload should match emitted value" + ); + + Ok(()) +} + +/// Regression test: a task that awaits an event AFTER another task already +/// registered a waiter and the event was emitted should still see the payload. +/// +/// This tests the bug where await_event creates a placeholder row with NULL payload, +/// and emit_event's ON CONFLICT DO NOTHING would fail to update it, causing late +/// joiners to see NULL and sleep forever. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_await_event_late_joiner_sees_payload(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "late_join").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + let event_name = "late-joiner-event"; + let payload = json!({"late": "joiner"}); + + // Spawn Task A that waits for the event + let task_a = client + .spawn::(EventWaitParams { + event_name: event_name.to_string(), + timeout_seconds: Some(30), + }) + .await + .expect("Failed to spawn task A"); + + let worker = client + .start_worker(WorkerOptions { + poll_interval: Duration::from_millis(50), + claim_timeout: Duration::from_secs(30), + ..Default::default() + }) + .await; + + // Wait for Task A to start waiting (creates placeholder row in e_ table) + tokio::time::sleep(Duration::from_millis(300)).await; + + let state = get_task_state(&pool, "late_join", task_a.task_id).await?; + assert!( + state == Some("sleeping".to_string()) || state == Some("running".to_string()), + "Task A should be sleeping or running, got {:?}", + state + ); + + // Emit the event - this should update the placeholder row with the real payload + client + .emit_event(event_name, &payload, None) + .await + .expect("Failed to emit event"); + + // Wait for Task A to complete + let terminal_a = + wait_for_task_terminal(&pool, "late_join", task_a.task_id, Duration::from_secs(5)).await?; + assert_eq!( + terminal_a, + Some("completed".to_string()), + "Task A should complete" + ); + + // Now spawn Task B - it should see the event was already emitted and complete immediately + let task_b = client + .spawn::(EventWaitParams { + event_name: event_name.to_string(), + timeout_seconds: Some(2), // Short timeout to fail fast if bug exists + }) + .await + .expect("Failed to spawn task B"); + + // Task B should complete quickly since event already exists with payload + let terminal_b = + wait_for_task_terminal(&pool, "late_join", task_b.task_id, Duration::from_secs(3)).await?; + + worker.expect("Failed to start worker").shutdown().await; + + assert_eq!( + terminal_b, + Some("completed".to_string()), + "Task B (late joiner) should complete immediately, not sleep forever" + ); + + // Verify Task B received the correct payload + let query = AssertSqlSafe( + "SELECT completed_payload FROM durable.t_late_join WHERE task_id = $1".to_string(), + ); + let result: (serde_json::Value,) = sqlx::query_as(query) + .bind(task_b.task_id) + .fetch_one(&pool) + .await?; + + assert_eq!( + result.0, payload, + "Task B should receive the emitted payload" + ); + + Ok(()) +}