Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion procrastinate/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,39 @@ def _start_side_tasks(self) -> list[asyncio.Task]:
side_tasks.append(asyncio.create_task(listener_coro, name="listener"))
return side_tasks

async def _monitor_side_tasks(self, side_tasks: list[asyncio.Task]):
"""Monitor side tasks and stop the worker if any task fails"""
try:
done, pending = await asyncio.wait(
side_tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in done:
if exc := task.exception():
self.logger.error(
f"Side task {task.get_name()} failed with exception: {exc}, stopping worker",
extra=self._log_extra(
action="side_task_failed",
context=None,
job_result=None,
task_name=task.get_name(),
exception=str(exc),
),
exc_info=exc,
)
self.stop()
return
except Exception as exc:
self.logger.exception(
f"Side task monitor failed: {exc}",
extra=self._log_extra(
action="side_task_monitor_failed",
context=None,
job_result=None,
exception=str(exc),
),
)
raise

async def _run_loop(self):
"""
Run all side coroutines, then start fetching/processing jobs in a loop
Expand All @@ -560,6 +593,9 @@ async def _run_loop(self):
self._running_jobs = {}
self._job_semaphore = asyncio.Semaphore(self.concurrency)
side_tasks = self._start_side_tasks()
side_tasks_monitor = asyncio.create_task(
self._monitor_side_tasks(side_tasks), name="side_tasks_monitor"
)

context = (
signals.on_stop(self.stop)
Expand All @@ -583,12 +619,14 @@ async def _run_loop(self):
self._stop_event.set()

while not self._stop_event.is_set():
# wait for a new job notification, a stop even or the next polling interval
# wait for a new job notification, a stop event or the next polling interval
await utils.wait_any(
self._new_job_event.wait(),
asyncio.sleep(self.fetch_job_polling_interval),
self._stop_event.wait(),
)
await self._fetch_and_process_jobs()
finally:
if not side_tasks_monitor.done():
side_tasks_monitor.cancel()
await self._shutdown(side_tasks=side_tasks)
27 changes: 27 additions & 0 deletions tests/unit/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,3 +834,30 @@ async def test_worker_prunes_stalled_workers(app: App):

assert worker1_id in connector.workers
assert worker2_id not in connector.workers


async def test_worker_stops_when_side_task_fails(
app: App, caplog, mocker: MockerFixture
):
caplog.set_level("INFO")

async def failing_update_heartbeat(self):
raise ValueError("Simulated heartbeat failure")

mocker.patch.object(Worker, "_update_heartbeat", failing_update_heartbeat)

worker = Worker(app)
await worker.run()

side_task_failed_records = [
record
for record in caplog.records
if hasattr(record, "action") and record.action == "side_task_failed"
]

assert len(side_task_failed_records) == 1
error_record = side_task_failed_records[0]
assert "update_heartbeats failed with exception" in error_record.message
assert "Simulated heartbeat failure" in error_record.message
assert "stopping worker" in error_record.message
assert error_record.task_name == "update_heartbeats"
Loading