diff --git a/hindsight-api-slim/hindsight_api/worker/poller.py b/hindsight-api-slim/hindsight_api/worker/poller.py index 9f556da9..5651192b 100644 --- a/hindsight-api-slim/hindsight_api/worker/poller.py +++ b/hindsight-api-slim/hindsight_api/worker/poller.py @@ -475,6 +475,11 @@ async def _execute_task_inner(self, task: ClaimedTask): traceback.print_exc() await self._mark_failed(task.operation_id, str(e), task.schema) + # Tasks claimed longer than this are considered abandoned by a dead worker. + # The longest observed legitimate task is ~7 minutes (large PDF extraction). + # 30 minutes provides a safe margin. + _STALE_TASK_THRESHOLD_MINUTES = 30 + async def recover_own_tasks(self) -> int: """ Recover tasks that were assigned to this worker but not completed. @@ -483,7 +488,9 @@ async def recover_own_tasks(self) -> int: On startup, we reset any tasks stuck in 'processing' for this worker_id back to 'pending' so they can be picked up again. - Also recovers batch API operations that were in-flight. + Also recovers batch API operations that were in-flight, and reclaims + stale tasks from dead workers (other worker_ids whose tasks have been + stuck in 'processing' beyond the stale threshold). If tenant_extension is configured, recovers across all tenant schemas. @@ -501,7 +508,7 @@ async def recover_own_tasks(self) -> int: batch_count = await self._recover_batch_operations(schema) total_count += batch_count - # Then reset normal worker tasks + # Then reset normal worker tasks (own worker_id) result = await self._pool.execute( f""" UPDATE {table} @@ -514,6 +521,32 @@ async def recover_own_tasks(self) -> int: # Parse "UPDATE N" to get count count = int(result.split()[-1]) if result else 0 total_count += count + + # Reclaim stale tasks from dead workers. + # When a worker pod is terminated (restart, deploy, OOM, node + # eviction), it may not release its claimed tasks. The new pod + # gets a different worker_id, so the above query won't match + # the old pod's tasks. Any task stuck in 'processing' with a + # claimed_at older than the threshold is assumed abandoned. + stale_result = await self._pool.execute( + f""" + UPDATE {table} + SET status = 'pending', worker_id = NULL, claimed_at = NULL, updated_at = now() + WHERE status = 'processing' + AND worker_id != $1 + AND claimed_at < now() - make_interval(mins => $2) + AND result_metadata->>'batch_id' IS NULL + """, + self._worker_id, + self._STALE_TASK_THRESHOLD_MINUTES, + ) + stale_count = int(stale_result.split()[-1]) if stale_result else 0 + if stale_count > 0: + logger.warning( + f"Worker {self._worker_id} reclaimed {stale_count} stale tasks " + f"from dead workers (claimed_at > {self._STALE_TASK_THRESHOLD_MINUTES}m ago)" + ) + total_count += stale_count except Exception as e: # Format schema for logging: custom schemas in quotes, None as-is schema_display = f'"{schema}"' if schema else str(schema) diff --git a/hindsight-api-slim/tests/test_worker.py b/hindsight-api-slim/tests/test_worker.py index 21926069..58141455 100644 --- a/hindsight-api-slim/tests/test_worker.py +++ b/hindsight-api-slim/tests/test_worker.py @@ -701,6 +701,73 @@ async def test_recover_own_tasks_returns_zero_when_no_stale_tasks(self, pool, cl recovered_count = await poller.recover_own_tasks() assert recovered_count == 0 + @pytest.mark.asyncio + async def test_recover_reclaims_stale_tasks_from_dead_workers(self, pool, clean_operations): + """Test that tasks stuck on dead workers for >30min are reclaimed on startup.""" + from hindsight_api.worker import WorkerPoller + + bank_id = f"test-worker-{uuid.uuid4().hex[:8]}" + await _ensure_bank(pool, bank_id) + + # Create a task claimed by a dead worker 60 minutes ago + stale_op_id = uuid.uuid4() + payload = json.dumps({"type": "consolidation", "bank_id": bank_id}) + await pool.execute( + """ + INSERT INTO async_operations + (operation_id, bank_id, operation_type, status, task_payload, + worker_id, claimed_at) + VALUES ($1, $2, 'consolidation', 'processing', $3::jsonb, + 'dead-worker-abc123', now() - interval '60 minutes') + """, + stale_op_id, + bank_id, + payload, + ) + + # Create a task claimed by a dead worker only 5 minutes ago (not stale yet) + recent_op_id = uuid.uuid4() + payload2 = json.dumps({"type": "retain", "bank_id": bank_id}) + await pool.execute( + """ + INSERT INTO async_operations + (operation_id, bank_id, operation_type, status, task_payload, + worker_id, claimed_at) + VALUES ($1, $2, 'retain', 'processing', $3::jsonb, + 'dead-worker-abc123', now() - interval '5 minutes') + """, + recent_op_id, + bank_id, + payload2, + ) + + # New worker starts up and recovers + poller = WorkerPoller( + pool=pool, + worker_id="new-worker", + executor=lambda x: None, + ) + + recovered_count = await poller.recover_own_tasks() + # Only the stale task (60 min old) should be recovered + assert recovered_count == 1 + + # Verify the stale task was reset to pending + stale_row = await pool.fetchrow( + "SELECT status, worker_id FROM async_operations WHERE operation_id = $1", + stale_op_id, + ) + assert stale_row["status"] == "pending" + assert stale_row["worker_id"] is None + + # Verify the recent task is still processing (not reclaimed) + recent_row = await pool.fetchrow( + "SELECT status, worker_id FROM async_operations WHERE operation_id = $1", + recent_op_id, + ) + assert recent_row["status"] == "processing" + assert recent_row["worker_id"] == "dead-worker-abc123" + class TestConcurrentWorkers: """Tests for concurrent worker task claiming (FOR UPDATE SKIP LOCKED)."""