From f8d4b45db181668e8d7c77ed02bc40cf23429029 Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Sat, 14 Dec 2024 18:04:46 +0100 Subject: [PATCH 01/12] Use redis stream for tasks without delay --- arq/connections.py | 69 ++++++++++++++-- arq/constants.py | 3 + arq/jobs.py | 47 +++++++++-- arq/lua.py | 48 +++++++++++ arq/worker.py | 191 +++++++++++++++++++++++++++++++++++++------ tests/test_worker.py | 50 +++++++---- 6 files changed, 355 insertions(+), 53 deletions(-) create mode 100644 arq/lua.py diff --git a/arq/connections.py b/arq/connections.py index c1058890..c599f411 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -13,8 +13,16 @@ from redis.asyncio.sentinel import Sentinel from redis.exceptions import RedisError, WatchError -from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix +from .constants import ( + default_queue_name, + expires_extra_ms, + job_key_prefix, + job_message_id_prefix, + result_key_prefix, + stream_key_suffix, +) from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job +from .lua import publish_job_lua from .utils import timestamp_ms, to_ms, to_unix_ms logger = logging.getLogger('arq.connections') @@ -115,6 +123,7 @@ def __init__( kwargs['connection_pool'] = pool_or_conn self.expires_extra_ms = expires_extra_ms super().__init__(**kwargs) + self.publish_to_stream_script = self.register_script(publish_job_lua) async def enqueue_job( self, @@ -165,20 +174,57 @@ async def enqueue_job( elif defer_by_ms: score = enqueue_time_ms + defer_by_ms else: - score = enqueue_time_ms + score = None - expires_ms = expires_ms or score - enqueue_time_ms + self.expires_extra_ms + expires_ms = expires_ms or (score or enqueue_time_ms) - enqueue_time_ms + self.expires_extra_ms - job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer) + job = serialize_job( + function, + args, + kwargs, + _job_try, + enqueue_time_ms, + serializer=self.job_serializer, + ) pipe.multi() pipe.psetex(job_key, expires_ms, job) - pipe.zadd(_queue_name, {job_id: score}) + + if score is not None: + pipe.zadd(_queue_name, {job_id: score}) + else: + stream_key = _queue_name + stream_key_suffix + job_message_id_key = job_message_id_prefix + job_id + await self.publish_to_stream_script( + keys=[stream_key, job_message_id_key], + args=[job_id, str(enqueue_time_ms), str(expires_ms)], + client=pipe, + ) + try: await pipe.execute() except WatchError: # job got enqueued since we checked 'job_exists' return None - return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer) + return Job( + job_id, + redis=self, + _queue_name=_queue_name, + _deserializer=self.job_deserializer, + ) + + async def get_queue_size(self, queue_name: str | None = None, include_delayed_tasks: bool = True) -> int: + if queue_name is None: + queue_name = self.default_queue_name + + async with self.pipeline(transaction=True) as pipe: + pipe.xlen(queue_name + stream_key_suffix) + pipe.zcount(queue_name, '-inf', '+inf') + stream_size, delayed_queue_size = await pipe.execute() + + if not include_delayed_tasks: + return stream_size + + return stream_size + delayed_queue_size async def _get_job_result(self, key: bytes) -> JobResult: job_id = key[len(result_key_prefix) :].decode() @@ -213,7 +259,16 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef] """ if queue_name is None: queue_name = self.default_queue_name - jobs = await self.zrange(queue_name, withscores=True, start=0, end=-1) + + async with self.pipeline(transaction=True) as pipe: + pipe.zrange(queue_name, withscores=True, start=0, end=-1) + pipe.xrange(queue_name + stream_key_suffix, '-', '+') + delayed_jobs, stream_jobs = await pipe.execute() + + jobs = [ + *delayed_jobs, + *[(j[b'job_id'], int(j[b'score'])) for _, j in stream_jobs], + ] return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs]) diff --git a/arq/constants.py b/arq/constants.py index 84c009aa..98d6ff99 100644 --- a/arq/constants.py +++ b/arq/constants.py @@ -1,9 +1,12 @@ default_queue_name = 'arq:queue' job_key_prefix = 'arq:job:' in_progress_key_prefix = 'arq:in-progress:' +job_message_id_prefix = 'arq:message-id:' result_key_prefix = 'arq:result:' retry_key_prefix = 'arq:retry:' abort_jobs_ss = 'arq:abort' +stream_key_suffix = ':stream' +default_consumer_group = 'arq:workers' # age of items in the abort_key sorted set after which they're deleted abort_job_max_age = 60 health_check_key_suffix = ':health-check' diff --git a/arq/jobs.py b/arq/jobs.py index 15b7231e..fbd2bc1f 100644 --- a/arq/jobs.py +++ b/arq/jobs.py @@ -5,11 +5,21 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum +from itertools import batched from typing import Any, Callable, Dict, Optional, Tuple from redis.asyncio import Redis -from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix +from .constants import ( + abort_jobs_ss, + default_queue_name, + in_progress_key_prefix, + job_key_prefix, + job_message_id_prefix, + result_key_prefix, + stream_key_suffix, +) +from .lua import get_job_from_stream_lua from .utils import ms_to_datetime, poll, timestamp_ms logger = logging.getLogger('arq.jobs') @@ -63,12 +73,16 @@ class JobResult(JobDef): queue_name: str +def _list_to_dict(input_list: list[Any]) -> dict[Any, Any]: + return {key: value for key, value in batched(input_list, 2)} + + class Job: """ Holds data a reference to a job. """ - __slots__ = 'job_id', '_redis', '_queue_name', '_deserializer' + __slots__ = 'job_id', '_redis', '_queue_name', '_deserializer', '_get_job_from_stream_script' def __init__( self, @@ -81,6 +95,7 @@ def __init__( self._redis = redis self._queue_name = _queue_name self._deserializer = _deserializer + self._get_job_from_stream_script = redis.register_script(get_job_from_stream_lua) async def result( self, timeout: Optional[float] = None, *, poll_delay: float = 0.5, pole_delay: Optional[float] = None @@ -105,7 +120,8 @@ async def result( async with self._redis.pipeline(transaction=True) as tr: tr.get(result_key_prefix + self.job_id) tr.zscore(self._queue_name, self.job_id) - v, s = await tr.execute() + tr.get(job_message_id_prefix + self.job_id) + v, s, m = await tr.execute() if v: info = deserialize_result(v, deserializer=self._deserializer) @@ -115,7 +131,7 @@ async def result( raise info.result else: raise SerializationError(info.result) - elif s is None: + elif s is None and m is None: raise ResultNotFound( 'Not waiting for job result because the job is not in queue. ' 'Is the worker function configured to keep result?' @@ -134,8 +150,22 @@ async def info(self) -> Optional[JobDef]: if v: info = deserialize_job(v, deserializer=self._deserializer) if info: - s = await self._redis.zscore(self._queue_name, self.job_id) - info.score = None if s is None else int(s) + async with self._redis.pipeline(transaction=True) as tr: + tr.zscore(self._queue_name, self.job_id) + await self._get_job_from_stream_script( + keys=[self._queue_name + stream_key_suffix, job_message_id_prefix + self.job_id], + client=tr, + ) + delayed_score, job_info = await tr.execute() + + if delayed_score: + info.score = int(delayed_score) + elif job_info: + _, job_info_payload = job_info + info.score = int(_list_to_dict(job_info_payload)[b'score']) + else: + info.score = None + return info async def result_info(self) -> Optional[JobResult]: @@ -157,12 +187,15 @@ async def status(self) -> JobStatus: tr.exists(result_key_prefix + self.job_id) tr.exists(in_progress_key_prefix + self.job_id) tr.zscore(self._queue_name, self.job_id) - is_complete, is_in_progress, score = await tr.execute() + tr.exists(job_message_id_prefix + self.job_id) + is_complete, is_in_progress, score, queued = await tr.execute() if is_complete: return JobStatus.complete elif is_in_progress: return JobStatus.in_progress + elif queued: + return JobStatus.queued elif score: return JobStatus.deferred if score > timestamp_ms() else JobStatus.queued else: diff --git a/arq/lua.py b/arq/lua.py new file mode 100644 index 00000000..e7bd5230 --- /dev/null +++ b/arq/lua.py @@ -0,0 +1,48 @@ +publish_delayed_job_lua = """ +local delayed_queue_key = KEYS[1] +local stream_key = KEYS[2] +local job_message_id_key = KEYS[3] + +local job_id = ARGV[1] +local job_message_id_expire_ms = ARGV[2] + +local score = redis.call('zscore', delayed_queue_key, job_id) +if score == nil or score == false then + return 0 +end + +local message_id = redis.call('xadd', stream_key, '*', 'job_id', job_id, 'score', score) +redis.call('set', job_message_id_key, message_id, 'px', job_message_id_expire_ms) +redis.call('zrem', delayed_queue_key, job_id) +return 1 +""" + +publish_job_lua = """ +local stream_key = KEYS[1] +local job_message_id_key = KEYS[2] + +local job_id = ARGV[1] +local score = ARGV[2] +local job_message_id_expire_ms = ARGV[3] + +local message_id = redis.call('xadd', stream_key, '*', 'job_id', job_id, 'score', score) +redis.call('set', job_message_id_key, message_id, 'px', job_message_id_expire_ms) +return message_id +""" + +get_job_from_stream_lua = """ +local stream_key = KEYS[1] +local job_message_id_key = KEYS[2] + +local message_id = redis.call('get', job_message_id_key) +if message_id == false then + return nil +end + +local job = redis.call('xrange', stream_key, message_id, message_id) +if job == nil then + return nil +end + +return job[1] +""" diff --git a/arq/worker.py b/arq/worker.py index 8fcd5fc8..f6a1c708 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -3,12 +3,14 @@ import inspect import logging import signal +from contextlib import suppress from dataclasses import dataclass from datetime import datetime, timedelta, timezone from functools import partial from signal import Signals from time import time from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, cast +from uuid import uuid4 from redis.exceptions import ResponseError, WatchError @@ -19,15 +21,19 @@ from .constants import ( abort_job_max_age, abort_jobs_ss, + default_consumer_group, default_queue_name, expires_extra_ms, health_check_key_suffix, in_progress_key_prefix, job_key_prefix, + job_message_id_prefix, keep_cronjob_progress, result_key_prefix, retry_key_prefix, + stream_key_suffix, ) +from .lua import publish_delayed_job_lua from .utils import ( args_to_string, import_string, @@ -57,6 +63,13 @@ class Function: max_tries: Optional[int] +@dataclass +class StreamMessage: + message_id: str + job_id: str + score: int + + def func( coroutine: Union[str, Function, 'WorkerCoroutine'], *, @@ -188,6 +201,8 @@ def __init__( functions: Sequence[Union[Function, 'WorkerCoroutine']] = (), *, queue_name: Optional[str] = default_queue_name, + consumer_group_name: str = default_consumer_group, + worker_id: Optional[str] = None, cron_jobs: Optional[Sequence[CronJob]] = None, redis_settings: Optional[RedisSettings] = None, redis_pool: Optional[ArqRedis] = None, @@ -225,6 +240,8 @@ def __init__( else: raise ValueError('If queue_name is absent, redis_pool must be present.') self.queue_name = queue_name + self.consumer_group_name = consumer_group_name + self.worker_id = worker_id or str(uuid4().hex) self.cron_jobs: List[CronJob] = [] if cron_jobs is not None: if not all(isinstance(cj, CronJob) for cj in cron_jobs): @@ -357,18 +374,85 @@ async def main(self) -> None: if self.on_startup: await self.on_startup(self.ctx) - async for _ in poll(self.poll_delay_s): + await self.create_consumer_group() + + self.poller_task = asyncio.create_task(self.run_delayed_queue_poller()) + self.autoclaimer_task = asyncio.create_task(self.run_dead_task_autoclaimer()) + + try: + await self.run_stream_reader() + finally: + self.poller_task.cancel() + self.autoclaimer_task.cancel() + await asyncio.gather( + self.poller_task, + self.autoclaimer_task, + return_exceptions=True, + ) + + async def run_stream_reader(self) -> None: + while True: await self._poll_iteration() if self.burst: if 0 <= self.max_burst_jobs <= self._jobs_started(): await asyncio.gather(*self.tasks.values()) return None - queued_jobs = await self.pool.zcard(self.queue_name) + queued_jobs = await self.pool.get_queue_size(self.queue_name) if queued_jobs == 0: await asyncio.gather(*self.tasks.values()) return None + async def run_dead_task_autoclaimer(self) -> None: + async for _ in poll(self.poll_delay_s): + consumers_info = await self.pool.xinfo_consumers( + self.queue_name + stream_key_suffix, + groupname=self.consumer_group_name, + ) + for consumer_info in consumers_info: + await self.pool.xautoclaim( + self.queue_name + stream_key_suffix, + groupname=self.consumer_group_name, + consumername=consumer_info['name'], + min_idle_time=int(self.in_progress_timeout_s * 1000), + ) + + async def run_delayed_queue_poller(self) -> None: + publish_delayed_job = self.pool.register_script(publish_delayed_job_lua) + + async for _ in poll(self.poll_delay_s): + job_ids = await self.pool.zrange( + self.queue_name, + start=float('-inf'), + end=timestamp_ms(), + num=self.queue_read_limit, + offset=self._queue_read_offset, + withscores=True, + byscore=True, + ) + for job_id, score in job_ids: + expire_ms = int(score - timestamp_ms() + self.expires_extra_ms) + if expire_ms <= 0: + expire_ms = self.expires_extra_ms + + await publish_delayed_job( + keys=[ + self.queue_name, + self.queue_name + stream_key_suffix, + job_message_id_prefix + job_id.decode(), + ], + args=[job_id.decode(), expire_ms], + ) + + async def create_consumer_group(self) -> None: + with suppress(ResponseError): + await self.pool.xgroup_create( + name=self.queue_name + stream_key_suffix, + groupname=self.consumer_group_name, + id='0', + mkstream=True, + ) + async def _poll_iteration(self) -> None: """ Get ids of pending jobs from the main queue sorted-set data structure and start those jobs, remove @@ -382,12 +466,24 @@ async def _poll_iteration(self) -> None: count = min(burst_jobs_remaining, count) if self.allow_pick_jobs: if self.job_counter < self.max_jobs: - now = timestamp_ms() - job_ids = await self.pool.zrangebyscore( - self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now + stream_msgs = await self.pool.xreadgroup( + groupname=self.consumer_group_name, + consumername=self.worker_id, + streams={self.queue_name + stream_key_suffix: '>'}, + count=count, + block=int(max(self.poll_delay_s * 1000, 1)), ) + jobs = [] + + for _, msgs in stream_msgs: + for msg_id, job in msgs: + jobs.append( + StreamMessage( + message_id=msg_id.decode(), job_id=job[b'job_id'].decode(), score=int(job[b'score']) + ) + ) - await self.start_jobs(job_ids) + await self.start_jobs(jobs) if self.allow_abort_jobs: await self._cancel_aborted_jobs() @@ -428,11 +524,14 @@ def _release_sem_dec_counter_on_complete(self) -> None: self.job_counter = self.job_counter - 1 self.sem.release() - async def start_jobs(self, job_ids: List[bytes]) -> None: + async def start_jobs(self, jobs: list[StreamMessage]) -> None: """ For each job id, get the job definition, check it's not running and start it in a task """ - for job_id_b in job_ids: + for job in jobs: + job_id = job.job_id + score = job.score + await self.sem.acquire() if self.job_counter >= self.max_jobs: @@ -441,16 +540,13 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: self.job_counter = self.job_counter + 1 - job_id = job_id_b.decode() in_progress_key = in_progress_key_prefix + job_id async with self.pool.pipeline(transaction=True) as pipe: await pipe.watch(in_progress_key) ongoing_exists = await pipe.exists(in_progress_key) - score = await pipe.zscore(self.queue_name, job_id) - if ongoing_exists or not score or score > timestamp_ms(): - # job already started elsewhere, or already finished and removed from queue - # if score > ts_now, - # it means probably the job was re-enqueued with a delay in another worker + + if ongoing_exists: + await self._redeliver_job(job) self.job_counter = self.job_counter - 1 self.sem.release() logger.debug('job %s already running elsewhere', job_id) @@ -466,11 +562,30 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: self.sem.release() logger.debug('multi-exec error, job %s already started elsewhere', job_id) else: - t = self.loop.create_task(self.run_job(job_id, int(score))) + t = self.loop.create_task(self.run_job(job_id, job.message_id, score)) t.add_done_callback(lambda _: self._release_sem_dec_counter_on_complete()) self.tasks[job_id] = t - async def run_job(self, job_id: str, score: int) -> None: # noqa: C901 + async def _redeliver_job(self, job: StreamMessage) -> None: + async with self.pool.pipeline(transaction=True) as pipe: + stream_key = self.queue_name + stream_key_suffix + job_message_id_key = job_message_id_prefix + job.job_id + + pipe.xack(stream_key, self.consumer_group_name, job.message_id) + pipe.xdel(stream_key, job.message_id) + job_message_id_expire = job.score - timestamp_ms() + self.expires_extra_ms + if job_message_id_expire <= 0: + job_message_id_expire = self.expires_extra_ms + + await self.pool.publish_to_stream_script( + keys=[stream_key, job_message_id_key], + args=[job.job_id, str(job.score), str(job_message_id_expire)], + client=pipe, + ) + + await pipe.execute() + + async def run_job(self, job_id: str, message_id: str, score: int) -> None: # noqa: C901 start_ms = timestamp_ms() async with self.pool.pipeline(transaction=True) as pipe: pipe.get(job_key_prefix + job_id) @@ -504,7 +619,7 @@ async def job_failed(exc: BaseException) -> None: queue_name=self.queue_name, job_id=job_id, ) - await asyncio.shield(self.finish_failed_job(job_id, result_data_)) + await asyncio.shield(self.finish_failed_job(job_id, message_id, result_data_)) if not v: logger.warning('job %s expired', job_id) @@ -561,7 +676,7 @@ async def job_failed(exc: BaseException) -> None: job_id=job_id, serializer=self.job_serializer, ) - return await asyncio.shield(self.finish_failed_job(job_id, result_data)) + return await asyncio.shield(self.finish_failed_job(job_id, message_id, result_data)) result = no_result exc_extra = None @@ -662,6 +777,8 @@ async def job_failed(exc: BaseException) -> None: await asyncio.shield( self.finish_job( job_id, + message_id, + score, finish, result_data, result_timeout_s, @@ -677,6 +794,8 @@ async def job_failed(exc: BaseException) -> None: async def finish_job( self, job_id: str, + message_id: str, + score: int, finish: bool, result_data: Optional[bytes], result_timeout_s: Optional[float], @@ -687,33 +806,56 @@ async def finish_job( async with self.pool.pipeline(transaction=True) as tr: delete_keys = [] in_progress_key = in_progress_key_prefix + job_id + stream_key = self.queue_name + stream_key_suffix + job_message_id_key = job_message_id_prefix + job_id if keep_in_progress is None: delete_keys += [in_progress_key] else: tr.pexpire(in_progress_key, to_ms(keep_in_progress)) + tr.xack( + stream_key, + self.consumer_group_name, + message_id, + ) + tr.xdel(stream_key, message_id) + if finish: if result_data: expire = None if keep_result_forever else result_timeout_s tr.set(result_key_prefix + job_id, result_data, px=to_ms(expire)) - delete_keys += [retry_key_prefix + job_id, job_key_prefix + job_id] + delete_keys += [retry_key_prefix + job_id, job_key_prefix + job_id, job_message_id_key] tr.zrem(abort_jobs_ss, job_id) - tr.zrem(self.queue_name, job_id) elif incr_score: - tr.zincrby(self.queue_name, incr_score, job_id) + delete_keys += [job_message_id_key] + tr.zadd(self.queue_name, {job_id: score + incr_score}) + else: + job_message_id_expire = score - timestamp_ms() + self.expires_extra_ms + await self.pool.publish_to_stream_script( + keys=[stream_key, job_message_id_key], + args=[job_id, str(score), str(job_message_id_expire)], + client=tr, + ) if delete_keys: tr.delete(*delete_keys) await tr.execute() - async def finish_failed_job(self, job_id: str, result_data: Optional[bytes]) -> None: + async def finish_failed_job(self, job_id: str, message_id: str, result_data: Optional[bytes]) -> None: + stream_key = self.queue_name + stream_key_suffix async with self.pool.pipeline(transaction=True) as tr: tr.delete( retry_key_prefix + job_id, in_progress_key_prefix + job_id, job_key_prefix + job_id, + job_message_id_prefix + job_id, ) tr.zrem(abort_jobs_ss, job_id) - tr.zrem(self.queue_name, job_id) + tr.xack( + stream_key, + self.consumer_group_name, + message_id, + ) + tr.xdel(stream_key, message_id) # result_data would only be None if serializing the result fails keep_result = self.keep_result_forever or self.keep_result_s > 0 if result_data is not None and keep_result: # pragma: no branch @@ -722,6 +864,9 @@ async def finish_failed_job(self, job_id: str, result_data: Optional[bytes]) -> await tr.execute() async def heart_beat(self) -> None: + if self.poller_task.done(): + raise self.poller_task.exception() + now = datetime.now(tz=self.timezone) await self.record_health() diff --git a/tests/test_worker.py b/tests/test_worker.py index 93fbc7f0..9c23fde6 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -19,6 +19,7 @@ JobExecutionFailed, Retry, RetryJob, + StreamMessage, Worker, async_check_health, check_health, @@ -71,7 +72,7 @@ async def test_set_health_check_key(arq_redis: ArqRedis, worker): await arq_redis.enqueue_job('foobar', _job_id='testing') worker: Worker = worker(functions=[func(foobar, keep_result=0)], health_check_key='arq:test:health-check') await worker.main() - assert sorted(await arq_redis.keys('*')) == [b'arq:test:health-check'] + assert sorted(await arq_redis.keys('*')) == [b'arq:queue:stream', b'arq:test:health-check'] async def test_handle_sig(caplog, arq_redis: ArqRedis): @@ -214,6 +215,8 @@ async def retry(ctx): await arq_redis.enqueue_job('retry', _job_id='testing') worker: Worker = worker(functions=[func(retry, name='retry')]) await worker.main() + + assert await worker.pool.get_queue_size(worker.queue_name) == 0 assert worker.jobs_complete == 1 assert worker.jobs_failed == 0 assert worker.jobs_retried == 2 @@ -517,38 +520,38 @@ async def test_log_health_check(arq_redis: ArqRedis, worker, caplog): async def test_remain_keys(test_redis_settings: RedisSettings, arq_redis: ArqRedis, worker, create_pool): redis2 = await create_pool(test_redis_settings) await arq_redis.enqueue_job('foobar', _job_id='testing') - assert sorted(await redis2.keys('*')) == [b'arq:job:testing', b'arq:queue'] + assert sorted(await redis2.keys('*')) == [b'arq:job:testing', b'arq:message-id:testing', b'arq:queue:stream'] worker: Worker = worker(functions=[foobar]) await worker.main() - assert sorted(await redis2.keys('*')) == [b'arq:queue:health-check', b'arq:result:testing'] + assert sorted(await redis2.keys('*')) == [b'arq:queue:health-check', b'arq:queue:stream', b'arq:result:testing'] await worker.close() - assert sorted(await redis2.keys('*')) == [b'arq:result:testing'] + assert sorted(await redis2.keys('*')) == [b'arq:queue:stream', b'arq:result:testing'] async def test_remain_keys_no_results(arq_redis: ArqRedis, worker): await arq_redis.enqueue_job('foobar', _job_id='testing') - assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:queue'] + assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:message-id:testing', b'arq:queue:stream'] worker: Worker = worker(functions=[func(foobar, keep_result=0)]) await worker.main() - assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check'] + assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check', b'arq:queue:stream'] async def test_remain_keys_keep_results_forever_in_function(arq_redis: ArqRedis, worker): await arq_redis.enqueue_job('foobar', _job_id='testing') - assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:queue'] + assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:message-id:testing', b'arq:queue:stream'] worker: Worker = worker(functions=[func(foobar, keep_result_forever=True)]) await worker.main() - assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check', b'arq:result:testing'] + assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check', b'arq:queue:stream', b'arq:result:testing'] ttl_result = await arq_redis.ttl('arq:result:testing') assert ttl_result == -1 async def test_remain_keys_keep_results_forever(arq_redis: ArqRedis, worker): await arq_redis.enqueue_job('foobar', _job_id='testing') - assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:queue'] + assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:message-id:testing', b'arq:queue:stream'] worker: Worker = worker(functions=[func(foobar)], keep_result_forever=True) await worker.main() - assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check', b'arq:result:testing'] + assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check', b'arq:queue:stream', b'arq:result:testing'] ttl_result = await arq_redis.ttl('arq:result:testing') assert ttl_result == -1 @@ -644,23 +647,24 @@ async def test_queue_read_limit_equals_max_jobs(arq_redis: ArqRedis, worker): for _ in range(4): await arq_redis.enqueue_job('foobar') - assert await arq_redis.zcard(default_queue_name) == 4 + assert await arq_redis.get_queue_size(default_queue_name) == 4 worker: Worker = worker(functions=[foobar], queue_read_limit=2) assert worker.queue_read_limit == 2 assert worker.jobs_complete == 0 assert worker.jobs_failed == 0 assert worker.jobs_retried == 0 + await worker.create_consumer_group() await worker._poll_iteration() await asyncio.sleep(0.1) - assert await arq_redis.zcard(default_queue_name) == 2 + assert await arq_redis.get_queue_size(default_queue_name) == 2 assert worker.jobs_complete == 2 assert worker.jobs_failed == 0 assert worker.jobs_retried == 0 await worker._poll_iteration() await asyncio.sleep(0.1) - assert await arq_redis.zcard(default_queue_name) == 0 + assert await arq_redis.get_queue_size(default_queue_name) == 0 assert worker.jobs_complete == 4 assert worker.jobs_failed == 0 assert worker.jobs_retried == 0 @@ -677,15 +681,16 @@ async def test_custom_queue_read_limit(arq_redis: ArqRedis, worker): for _ in range(4): await arq_redis.enqueue_job('foobar') - assert await arq_redis.zcard(default_queue_name) == 4 + assert await arq_redis.get_queue_size(default_queue_name) == 4 worker: Worker = worker(functions=[foobar], max_jobs=4, queue_read_limit=2) assert worker.jobs_complete == 0 assert worker.jobs_failed == 0 assert worker.jobs_retried == 0 + await worker.create_consumer_group() await worker._poll_iteration() await asyncio.sleep(0.1) - assert await arq_redis.zcard(default_queue_name) == 2 + assert await arq_redis.get_queue_size(default_queue_name) == 2 assert worker.jobs_complete == 2 assert worker.jobs_failed == 0 assert worker.jobs_retried == 0 @@ -846,7 +851,20 @@ async def foo(ctx, v): caplog.set_level(logging.DEBUG, logger='arq.worker') await arq_redis.enqueue_job('foo', 1, _job_id='testing') worker: Worker = worker(functions=[func(foo, name='foo')]) - await asyncio.gather(*[worker.start_jobs([b'testing']) for _ in range(5)]) + await asyncio.gather( + *[ + worker.start_jobs( + [ + StreamMessage( + job_id='testing', + message_id='1', + score=1, + ) + ] + ) + for _ in range(5) + ] + ) # debug(caplog.text) await worker.main() assert c == 1 From 922ed2c3b2666b9980d139b7ccfda12b9992d73e Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Sat, 14 Dec 2024 21:59:05 +0100 Subject: [PATCH 02/12] fix autoclaim --- arq/worker.py | 57 +++++++++++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/arq/worker.py b/arq/worker.py index f6a1c708..28963f00 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -232,6 +232,7 @@ def __init__( expires_extra_ms: int = expires_extra_ms, timezone: Optional[timezone] = None, log_results: bool = True, + max_consumer_inactivity: 'SecondsTimedelta' = 86400, ): self.functions: Dict[str, Union[Function, CronJob]] = {f.name: f for f in map(func, functions)} if queue_name is None: @@ -312,6 +313,7 @@ def __init__( self.job_deserializer = job_deserializer self.expires_extra_ms = expires_extra_ms self.log_results = log_results + self.max_consumer_inactivity = max_consumer_inactivity # default to system timezone self.timezone = datetime.now().astimezone().tzinfo if timezone is None else timezone @@ -377,16 +379,13 @@ async def main(self) -> None: await self.create_consumer_group() self.poller_task = asyncio.create_task(self.run_delayed_queue_poller()) - self.autoclaimer_task = asyncio.create_task(self.run_dead_task_autoclaimer()) try: await self.run_stream_reader() finally: self.poller_task.cancel() - self.autoclaimer_task.cancel() await asyncio.gather( self.poller_task, - self.autoclaimer_task, return_exceptions=True, ) @@ -403,20 +402,6 @@ async def run_stream_reader(self) -> None: await asyncio.gather(*self.tasks.values()) return None - async def run_dead_task_autoclaimer(self) -> None: - async for _ in poll(self.poll_delay_s): - consumers_info = await self.pool.xinfo_consumers( - self.queue_name + stream_key_suffix, - groupname=self.consumer_group_name, - ) - for consumer_info in consumers_info: - await self.pool.xautoclaim( - self.queue_name + stream_key_suffix, - groupname=self.consumer_group_name, - consumername=consumer_info['name'], - min_idle_time=int(self.in_progress_timeout_s * 1000), - ) - async def run_delayed_queue_poller(self) -> None: publish_delayed_job = self.pool.register_script(publish_delayed_job_lua) @@ -466,13 +451,17 @@ async def _poll_iteration(self) -> None: count = min(burst_jobs_remaining, count) if self.allow_pick_jobs: if self.job_counter < self.max_jobs: - stream_msgs = await self.pool.xreadgroup( - groupname=self.consumer_group_name, - consumername=self.worker_id, - streams={self.queue_name + stream_key_suffix: '>'}, - count=count, - block=int(max(self.poll_delay_s * 1000, 1)), - ) + stream_msgs = await self._get_idle_tasks(count) + + if not stream_msgs: + stream_msgs = await self.pool.xreadgroup( + groupname=self.consumer_group_name, + consumername=self.worker_id, + streams={self.queue_name + stream_key_suffix: '>'}, + count=count, + block=int(max(self.poll_delay_s * 1000, 1)), + ) + jobs = [] for _, msgs in stream_msgs: @@ -496,6 +485,26 @@ async def _poll_iteration(self) -> None: await self.heart_beat() + async def _get_idle_tasks(self, count: int) -> list[tuple[bytes, list]]: + resp = await self.pool.xautoclaim( + self.queue_name + stream_key_suffix, + groupname=self.consumer_group_name, + consumername=self.worker_id, + min_idle_time=self.in_progress_timeout_s * 1000, + count=count, + ) + + if not resp: + return [] + + _, msgs, __ = resp + if not msgs: + return [] + + # cast to the same format as the xreadgroup response + return [((self.queue_name + stream_key_suffix).encode(), msgs)] + + async def _cancel_aborted_jobs(self) -> None: """ Go through job_ids in the abort_jobs_ss sorted set and cancel those tasks. From 24659e851a9b5499a0a3cb37eb393aad8ad82841 Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Sat, 14 Dec 2024 22:01:09 +0100 Subject: [PATCH 03/12] fix --- arq/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arq/worker.py b/arq/worker.py index 28963f00..6bdff0c4 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -490,7 +490,7 @@ async def _get_idle_tasks(self, count: int) -> list[tuple[bytes, list]]: self.queue_name + stream_key_suffix, groupname=self.consumer_group_name, consumername=self.worker_id, - min_idle_time=self.in_progress_timeout_s * 1000, + min_idle_time=int(self.in_progress_timeout_s * 1000), count=count, ) From 5fd75f1370dcfc9f08570710df07c5f96674b6e2 Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Sat, 14 Dec 2024 22:22:54 +0100 Subject: [PATCH 04/12] fix tests --- arq/worker.py | 39 ++++++++++++++++++++------------------- tests/test_worker.py | 38 ++------------------------------------ 2 files changed, 22 insertions(+), 55 deletions(-) diff --git a/arq/worker.py b/arq/worker.py index 6bdff0c4..42e423eb 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -64,7 +64,7 @@ class Function: @dataclass -class StreamMessage: +class JobMetaInfo: message_id: str job_id: str score: int @@ -314,6 +314,7 @@ def __init__( self.expires_extra_ms = expires_extra_ms self.log_results = log_results self.max_consumer_inactivity = max_consumer_inactivity + self.dlq_job_poller_task: asyncio.Task[None] = None # default to system timezone self.timezone = datetime.now().astimezone().tzinfo if timezone is None else timezone @@ -378,16 +379,18 @@ async def main(self) -> None: await self.create_consumer_group() - self.poller_task = asyncio.create_task(self.run_delayed_queue_poller()) + _, pending = await asyncio.wait( + [ + asyncio.ensure_future(self.run_delayed_queue_poller()), + asyncio.ensure_future(self.run_stream_reader()), + ], + return_when=asyncio.FIRST_COMPLETED, + ) - try: - await self.run_stream_reader() - finally: - self.poller_task.cancel() - await asyncio.gather( - self.poller_task, - return_exceptions=True, - ) + for task in pending: + task.cancel() + + await asyncio.gather(*pending, return_exceptions=True) async def run_stream_reader(self) -> None: while True: @@ -467,8 +470,10 @@ async def _poll_iteration(self) -> None: for _, msgs in stream_msgs: for msg_id, job in msgs: jobs.append( - StreamMessage( - message_id=msg_id.decode(), job_id=job[b'job_id'].decode(), score=int(job[b'score']) + JobMetaInfo( + message_id=msg_id.decode(), + job_id=job[b'job_id'].decode(), + score=int(job[b'score']), ) ) @@ -504,7 +509,6 @@ async def _get_idle_tasks(self, count: int) -> list[tuple[bytes, list]]: # cast to the same format as the xreadgroup response return [((self.queue_name + stream_key_suffix).encode(), msgs)] - async def _cancel_aborted_jobs(self) -> None: """ Go through job_ids in the abort_jobs_ss sorted set and cancel those tasks. @@ -533,7 +537,7 @@ def _release_sem_dec_counter_on_complete(self) -> None: self.job_counter = self.job_counter - 1 self.sem.release() - async def start_jobs(self, jobs: list[StreamMessage]) -> None: + async def start_jobs(self, jobs: list[JobMetaInfo]) -> None: """ For each job id, get the job definition, check it's not running and start it in a task """ @@ -555,7 +559,7 @@ async def start_jobs(self, jobs: list[StreamMessage]) -> None: ongoing_exists = await pipe.exists(in_progress_key) if ongoing_exists: - await self._redeliver_job(job) + await self._unclaim_job(job) self.job_counter = self.job_counter - 1 self.sem.release() logger.debug('job %s already running elsewhere', job_id) @@ -575,7 +579,7 @@ async def start_jobs(self, jobs: list[StreamMessage]) -> None: t.add_done_callback(lambda _: self._release_sem_dec_counter_on_complete()) self.tasks[job_id] = t - async def _redeliver_job(self, job: StreamMessage) -> None: + async def _unclaim_job(self, job: JobMetaInfo) -> None: async with self.pool.pipeline(transaction=True) as pipe: stream_key = self.queue_name + stream_key_suffix job_message_id_key = job_message_id_prefix + job.job_id @@ -873,9 +877,6 @@ async def finish_failed_job(self, job_id: str, message_id: str, result_data: Opt await tr.execute() async def heart_beat(self) -> None: - if self.poller_task.done(): - raise self.poller_task.exception() - now = datetime.now(tz=self.timezone) await self.record_health() diff --git a/tests/test_worker.py b/tests/test_worker.py index 9c23fde6..85bec5f4 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -17,9 +17,9 @@ from arq.worker import ( FailedJobs, JobExecutionFailed, + JobMetaInfo, Retry, RetryJob, - StreamMessage, Worker, async_check_health, check_health, @@ -161,40 +161,6 @@ async def test_job_successful(arq_redis: ArqRedis, worker, caplog): assert 'X.XXs → testing:foobar()\n X.XXs ← testing:foobar ● 42' in log -async def test_job_retry_race_condition(arq_redis: ArqRedis, worker): - async def retry_job(ctx): - if ctx['job_try'] == 1: - raise Retry(defer=10) - - job_id = 'testing' - await arq_redis.enqueue_job('retry_job', _job_id=job_id) - - worker_one: Worker = worker(functions=[func(retry_job, name='retry_job')]) - worker_two: Worker = worker(functions=[func(retry_job, name='retry_job')]) - - assert worker_one.jobs_complete == 0 - assert worker_one.jobs_failed == 0 - assert worker_one.jobs_retried == 0 - - assert worker_two.jobs_complete == 0 - assert worker_two.jobs_failed == 0 - assert worker_two.jobs_retried == 0 - - await worker_one.start_jobs([job_id.encode()]) - await asyncio.gather(*worker_one.tasks.values()) - - await worker_two.start_jobs([job_id.encode()]) - await asyncio.gather(*worker_two.tasks.values()) - - assert worker_one.jobs_complete == 0 - assert worker_one.jobs_failed == 0 - assert worker_one.jobs_retried == 1 - - assert worker_two.jobs_complete == 0 - assert worker_two.jobs_failed == 0 - assert worker_two.jobs_retried == 0 - - async def test_job_successful_no_result_logging(arq_redis: ArqRedis, worker, caplog): caplog.set_level(logging.INFO) await arq_redis.enqueue_job('foobar', _job_id='testing') @@ -855,7 +821,7 @@ async def foo(ctx, v): *[ worker.start_jobs( [ - StreamMessage( + JobMetaInfo( job_id='testing', message_id='1', score=1, From 186bd5e1da74d124046c67a2eaf8ed102d593f57 Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Sat, 14 Dec 2024 22:44:26 +0100 Subject: [PATCH 05/12] add lua tests --- tests/test_lua.py | 122 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 tests/test_lua.py diff --git a/tests/test_lua.py b/tests/test_lua.py new file mode 100644 index 00000000..53e90e14 --- /dev/null +++ b/tests/test_lua.py @@ -0,0 +1,122 @@ +import pytest +from redis.commands.core import AsyncScript + +from arq import ArqRedis +from arq.lua import get_job_from_stream_lua, publish_delayed_job_lua, publish_job_lua + + +@pytest.fixture() +def publish_delayed_job(arq_redis: ArqRedis) -> AsyncScript: + return arq_redis.register_script(publish_delayed_job_lua) + + +@pytest.fixture() +def publish_job(arq_redis: ArqRedis) -> AsyncScript: + return arq_redis.register_script(publish_job_lua) + + +@pytest.fixture() +def get_job_from_stream(arq_redis: ArqRedis) -> AsyncScript: + return arq_redis.register_script(get_job_from_stream_lua) + + +async def test_publish_delayed_job(arq_redis: ArqRedis, publish_delayed_job: AsyncScript) -> None: + await arq_redis.zadd('delayed_queue_key', {'job_id': 1000}) + await publish_delayed_job( + keys=[ + 'delayed_queue_key', + 'stream_key', + 'job_message_id_key', + ], + args=[ + 'job_id', + '1000', + ], + ) + + stream_msgs = await arq_redis.xrange('stream_key', '-', '+') + assert len(stream_msgs) == 1 + + saved_msg_id = await arq_redis.get('job_message_id_key') + + msg_id, msg = stream_msgs[0] + assert msg == {b'job_id': b'job_id', b'score': b'1000'} + assert saved_msg_id == msg_id + + assert await arq_redis.zrange('delayed_queue_key', '-inf', '+inf', byscore=True) == [] + + await publish_delayed_job( + keys=[ + 'delayed_queue_key', + 'stream_key', + 'job_message_id_key', + ], + args=[ + 'job_id', + '1000', + ], + ) + + stream_msgs = await arq_redis.xrange('stream_key', '-', '+') + assert len(stream_msgs) == 1 + + saved_msg_id = await arq_redis.get('job_message_id_key') + assert saved_msg_id == msg_id + + +async def test_publish_job(arq_redis: ArqRedis, publish_job: AsyncScript) -> None: + msg_id = await publish_job( + keys=[ + 'stream_key', + 'job_message_id_key', + ], + args=[ + 'job_id', + '1000', + '1000', + ], + ) + + stream_msgs = await arq_redis.xrange('stream_key', '-', '+') + assert len(stream_msgs) == 1 + + saved_msg_id = await arq_redis.get('job_message_id_key') + assert saved_msg_id == msg_id + + msg_id, msg = stream_msgs[0] + assert msg == {b'job_id': b'job_id', b'score': b'1000'} + assert saved_msg_id == msg_id + + +async def test_get_job_from_stream( + arq_redis: ArqRedis, publish_job: AsyncScript, get_job_from_stream: AsyncScript +) -> None: + msg_id = await publish_job( + keys=[ + 'stream_key', + 'job_message_id_key', + ], + args=[ + 'job_id', + '1000', + '1000', + ], + ) + + job = await get_job_from_stream( + keys=[ + 'stream_key', + 'job_message_id_key', + ], + ) + + assert job == [msg_id, [b'job_id', b'job_id', b'score', b'1000']] + + await arq_redis.delete('job_message_id_key') + job = await get_job_from_stream( + keys=[ + 'stream_key', + 'job_message_id_key', + ], + ) + assert job is None From 4030fed560b1e052db0da172d771e25df260a3c3 Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Sat, 14 Dec 2024 22:55:29 +0100 Subject: [PATCH 06/12] add idle_consumer_cleanup --- arq/worker.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/arq/worker.py b/arq/worker.py index 42e423eb..235af43a 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -233,6 +233,7 @@ def __init__( timezone: Optional[timezone] = None, log_results: bool = True, max_consumer_inactivity: 'SecondsTimedelta' = 86400, + idle_consumer_poll_interval: 'SecondsTimedelta' = 60, ): self.functions: Dict[str, Union[Function, CronJob]] = {f.name: f for f in map(func, functions)} if queue_name is None: @@ -314,7 +315,7 @@ def __init__( self.expires_extra_ms = expires_extra_ms self.log_results = log_results self.max_consumer_inactivity = max_consumer_inactivity - self.dlq_job_poller_task: asyncio.Task[None] = None + self.idle_consumer_poll_interval = idle_consumer_poll_interval # default to system timezone self.timezone = datetime.now().astimezone().tzinfo if timezone is None else timezone @@ -379,10 +380,11 @@ async def main(self) -> None: await self.create_consumer_group() - _, pending = await asyncio.wait( + done, pending = await asyncio.wait( [ asyncio.ensure_future(self.run_delayed_queue_poller()), asyncio.ensure_future(self.run_stream_reader()), + asyncio.ensure_future(self.run_idle_consumer_cleanup()), ], return_when=asyncio.FIRST_COMPLETED, ) @@ -392,6 +394,9 @@ async def main(self) -> None: await asyncio.gather(*pending, return_exceptions=True) + for task in done: + task.result() + async def run_stream_reader(self) -> None: while True: await self._poll_iteration() @@ -432,6 +437,27 @@ async def run_delayed_queue_poller(self) -> None: args=[job_id.decode(), expire_ms], ) + async def run_idle_consumer_cleanup(self) -> None: + async for _ in poll(self.idle_consumer_poll_interval): + consumers_info = await self.pool.xinfo_consumers( + self.queue_name + stream_key_suffix, + groupname=self.consumer_group_name, + ) + + for consumer_info in consumers_info: + if self.worker_id == consumer_info['name'].decode(): + continue + + idle = timedelta(milliseconds=consumer_info['idle']).seconds + pending = consumer_info['pending'] + + if pending == 0 and idle > self.max_consumer_inactivity: + await self.pool.xgroup_delconsumer( + name=self.queue_name + stream_key_suffix, + groupname=self.consumer_group_name, + consumername=consumer_info['name'], + ) + async def create_consumer_group(self) -> None: with suppress(ResponseError): await self.pool.xgroup_create( From 98d880d4f26d36f7f1b1c0e08f4ac6f09616d5da Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Sun, 15 Dec 2024 12:00:12 +0100 Subject: [PATCH 07/12] small refactoring --- arq/worker.py | 19 +++++++++++-------- tests/conftest.py | 18 ++++++++++++++++-- tests/test_worker.py | 18 +++++++++--------- 3 files changed, 36 insertions(+), 19 deletions(-) diff --git a/arq/worker.py b/arq/worker.py index 235af43a..2f592781 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -219,6 +219,7 @@ def __init__( keep_result: 'SecondsTimedelta' = 3600, keep_result_forever: bool = False, poll_delay: 'SecondsTimedelta' = 0.5, + stream_block: 'SecondsTimedelta' = 0.5, queue_read_limit: Optional[int] = None, max_tries: int = 5, health_check_interval: 'SecondsTimedelta' = 3600, @@ -267,6 +268,9 @@ def __init__( self.keep_result_s = to_seconds(keep_result) self.keep_result_forever = keep_result_forever self.poll_delay_s = to_seconds(poll_delay) + self.stream_block_s = to_seconds(stream_block) + self.max_consumer_inactivity_s = to_seconds(max_consumer_inactivity) + self.idle_consumer_poll_interval_s = to_seconds(idle_consumer_poll_interval) self.queue_read_limit = queue_read_limit or max(max_jobs * 5, 100) self._queue_read_offset = 0 self.max_tries = max_tries @@ -314,8 +318,6 @@ def __init__( self.job_deserializer = job_deserializer self.expires_extra_ms = expires_extra_ms self.log_results = log_results - self.max_consumer_inactivity = max_consumer_inactivity - self.idle_consumer_poll_interval = idle_consumer_poll_interval # default to system timezone self.timezone = datetime.now().astimezone().tzinfo if timezone is None else timezone @@ -399,7 +401,7 @@ async def main(self) -> None: async def run_stream_reader(self) -> None: while True: - await self._poll_iteration() + await self._read_stream_iteration() if self.burst: if 0 <= self.max_burst_jobs <= self._jobs_started(): @@ -438,7 +440,7 @@ async def run_delayed_queue_poller(self) -> None: ) async def run_idle_consumer_cleanup(self) -> None: - async for _ in poll(self.idle_consumer_poll_interval): + async for _ in poll(self.idle_consumer_poll_interval_s): consumers_info = await self.pool.xinfo_consumers( self.queue_name + stream_key_suffix, groupname=self.consumer_group_name, @@ -451,7 +453,7 @@ async def run_idle_consumer_cleanup(self) -> None: idle = timedelta(milliseconds=consumer_info['idle']).seconds pending = consumer_info['pending'] - if pending == 0 and idle > self.max_consumer_inactivity: + if pending == 0 and idle > self.max_consumer_inactivity_s: await self.pool.xgroup_delconsumer( name=self.queue_name + stream_key_suffix, groupname=self.consumer_group_name, @@ -467,7 +469,7 @@ async def create_consumer_group(self) -> None: mkstream=True, ) - async def _poll_iteration(self) -> None: + async def _read_stream_iteration(self) -> None: """ Get ids of pending jobs from the main queue sorted-set data structure and start those jobs, remove any finished tasks from self.tasks. @@ -481,14 +483,15 @@ async def _poll_iteration(self) -> None: if self.allow_pick_jobs: if self.job_counter < self.max_jobs: stream_msgs = await self._get_idle_tasks(count) + count = count - len(stream_msgs) - if not stream_msgs: + if count > 0: stream_msgs = await self.pool.xreadgroup( groupname=self.consumer_group_name, consumername=self.worker_id, streams={self.queue_name + stream_key_suffix: '>'}, count=count, - block=int(max(self.poll_delay_s * 1000, 1)), + block=int(max(self.stream_block_s * 1000, 1)), ) jobs = [] diff --git a/tests/conftest.py b/tests/conftest.py index 9b6b7f5b..7362bdea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,10 +94,24 @@ async def arq_redis_retry(test_redis_host: str, test_redis_port: int): async def worker(arq_redis): worker_: Worker = None - def create(functions=[], burst=True, poll_delay=0, max_jobs=10, arq_redis=arq_redis, **kwargs): + def create( + functions=[], + burst=True, + poll_delay=0, + stream_block=0, + max_jobs=10, + arq_redis=arq_redis, + **kwargs, + ): nonlocal worker_ worker_ = Worker( - functions=functions, redis_pool=arq_redis, burst=burst, poll_delay=poll_delay, max_jobs=max_jobs, **kwargs + functions=functions, + redis_pool=arq_redis, + burst=burst, + poll_delay=poll_delay, + max_jobs=max_jobs, + stream_block=stream_block, + **kwargs ) return worker_ diff --git a/tests/test_worker.py b/tests/test_worker.py index 85bec5f4..cbc3102e 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -621,14 +621,14 @@ async def test_queue_read_limit_equals_max_jobs(arq_redis: ArqRedis, worker): assert worker.jobs_retried == 0 await worker.create_consumer_group() - await worker._poll_iteration() + await worker._read_stream_iteration() await asyncio.sleep(0.1) assert await arq_redis.get_queue_size(default_queue_name) == 2 assert worker.jobs_complete == 2 assert worker.jobs_failed == 0 assert worker.jobs_retried == 0 - await worker._poll_iteration() + await worker._read_stream_iteration() await asyncio.sleep(0.1) assert await arq_redis.get_queue_size(default_queue_name) == 0 assert worker.jobs_complete == 4 @@ -654,14 +654,14 @@ async def test_custom_queue_read_limit(arq_redis: ArqRedis, worker): assert worker.jobs_retried == 0 await worker.create_consumer_group() - await worker._poll_iteration() + await worker._read_stream_iteration() await asyncio.sleep(0.1) assert await arq_redis.get_queue_size(default_queue_name) == 2 assert worker.jobs_complete == 2 assert worker.jobs_failed == 0 assert worker.jobs_retried == 0 - await worker._poll_iteration() + await worker._read_stream_iteration() await asyncio.sleep(0.1) assert await arq_redis.zcard(default_queue_name) == 0 assert worker.jobs_complete == 4 @@ -785,7 +785,7 @@ async def foo(ctx, v): worker.max_burst_jobs = 0 assert len(worker.tasks) == 0 - await worker._poll_iteration() + await worker._read_stream_iteration() assert len(worker.tasks) == 0 @@ -1062,7 +1062,7 @@ async def test_worker_retry(mocker, worker_retry, exception_thrown): # baseline await worker.main() - await worker._poll_iteration() + await worker._read_stream_iteration() # spy method handling call_with_retry failure spy = mocker.spy(worker.pool, '_disconnect_raise') @@ -1073,7 +1073,7 @@ async def test_worker_retry(mocker, worker_retry, exception_thrown): # assert exception thrown with pytest.raises(type(exception_thrown)): - await worker._poll_iteration() + await worker._read_stream_iteration() # assert retry counts and no exception thrown during '_disconnect_raise' assert spy.call_count == 4 # retries setting + 1 @@ -1100,7 +1100,7 @@ async def test_worker_crash(mocker, worker, exception_thrown): # baseline await worker.main() - await worker._poll_iteration() + await worker._read_stream_iteration() # spy method handling call_with_retry failure spy = mocker.spy(worker.pool, '_disconnect_raise') @@ -1111,7 +1111,7 @@ async def test_worker_crash(mocker, worker, exception_thrown): # assert exception thrown with pytest.raises(type(exception_thrown)): - await worker._poll_iteration() + await worker._read_stream_iteration() # assert no retry counts and exception thrown during '_disconnect_raise' assert spy.call_count == 1 From 5b3f256d47ad4414e5ac46c358df2af58adfd441 Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Sun, 15 Dec 2024 12:15:14 +0100 Subject: [PATCH 08/12] use pipe in delayed_queue_poller --- arq/worker.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/arq/worker.py b/arq/worker.py index 2f592781..9a8b622e 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -425,19 +425,23 @@ async def run_delayed_queue_poller(self) -> None: withscores=True, byscore=True, ) - for job_id, score in job_ids: - expire_ms = int(score - timestamp_ms() + self.expires_extra_ms) - if expire_ms <= 0: - expire_ms = self.expires_extra_ms - - await publish_delayed_job( - keys=[ - self.queue_name, - self.queue_name + stream_key_suffix, - job_message_id_prefix + job_id.decode(), - ], - args=[job_id.decode(), expire_ms], - ) + async with self.pool.pipeline(transaction=False) as pipe: + for job_id, score in job_ids: + expire_ms = int(score - timestamp_ms() + self.expires_extra_ms) + if expire_ms <= 0: + expire_ms = self.expires_extra_ms + + await publish_delayed_job( + keys=[ + self.queue_name, + self.queue_name + stream_key_suffix, + job_message_id_prefix + job_id.decode(), + ], + args=[job_id.decode(), expire_ms], + client=pipe, + ) + + await pipe.execute() async def run_idle_consumer_cleanup(self) -> None: async for _ in poll(self.idle_consumer_poll_interval_s): From e8ea12ceda851f0db0fcf7cacbf345fc073b877b Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Sun, 15 Dec 2024 12:18:22 +0100 Subject: [PATCH 09/12] fmt --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 7362bdea..b2123ed5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -111,7 +111,7 @@ def create( poll_delay=poll_delay, max_jobs=max_jobs, stream_block=stream_block, - **kwargs + **kwargs, ) return worker_ From d712e3d2d6013e4b12a649deb8101e5eb6d6309e Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Sun, 15 Dec 2024 14:23:31 +0100 Subject: [PATCH 10/12] optimize stream write --- arq/connections.py | 15 +++++++---- arq/jobs.py | 11 ++++---- arq/worker.py | 63 ++++++++++++++++++++++++++++------------------ 3 files changed, 55 insertions(+), 34 deletions(-) diff --git a/arq/connections.py b/arq/connections.py index c599f411..e2f2746d 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -123,7 +123,6 @@ def __init__( kwargs['connection_pool'] = pool_or_conn self.expires_extra_ms = expires_extra_ms super().__init__(**kwargs) - self.publish_to_stream_script = self.register_script(publish_job_lua) async def enqueue_job( self, @@ -194,10 +193,16 @@ async def enqueue_job( else: stream_key = _queue_name + stream_key_suffix job_message_id_key = job_message_id_prefix + job_id - await self.publish_to_stream_script( - keys=[stream_key, job_message_id_key], - args=[job_id, str(enqueue_time_ms), str(expires_ms)], - client=pipe, + pipe.eval( + publish_job_lua, + 2, + # keys + stream_key, + job_message_id_key, + # args + job_id, + str(enqueue_time_ms), + str(expires_ms), ) try: diff --git a/arq/jobs.py b/arq/jobs.py index fbd2bc1f..63dd3976 100644 --- a/arq/jobs.py +++ b/arq/jobs.py @@ -82,7 +82,7 @@ class Job: Holds data a reference to a job. """ - __slots__ = 'job_id', '_redis', '_queue_name', '_deserializer', '_get_job_from_stream_script' + __slots__ = 'job_id', '_redis', '_queue_name', '_deserializer' def __init__( self, @@ -95,7 +95,6 @@ def __init__( self._redis = redis self._queue_name = _queue_name self._deserializer = _deserializer - self._get_job_from_stream_script = redis.register_script(get_job_from_stream_lua) async def result( self, timeout: Optional[float] = None, *, poll_delay: float = 0.5, pole_delay: Optional[float] = None @@ -152,9 +151,11 @@ async def info(self) -> Optional[JobDef]: if info: async with self._redis.pipeline(transaction=True) as tr: tr.zscore(self._queue_name, self.job_id) - await self._get_job_from_stream_script( - keys=[self._queue_name + stream_key_suffix, job_message_id_prefix + self.job_id], - client=tr, + tr.eval( + get_job_from_stream_lua, + 2, + self._queue_name + stream_key_suffix, + job_message_id_prefix + self.job_id, ) delayed_score, job_info = await tr.execute() diff --git a/arq/worker.py b/arq/worker.py index 9a8b622e..542401f3 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, cast from uuid import uuid4 +from redis.asyncio.client import Pipeline from redis.exceptions import ResponseError, WatchError from arq.cron import CronJob @@ -33,7 +34,7 @@ retry_key_prefix, stream_key_suffix, ) -from .lua import publish_delayed_job_lua +from .lua import publish_delayed_job_lua, publish_job_lua from .utils import ( args_to_string, import_string, @@ -592,7 +593,9 @@ async def start_jobs(self, jobs: list[JobMetaInfo]) -> None: ongoing_exists = await pipe.exists(in_progress_key) if ongoing_exists: - await self._unclaim_job(job) + await pipe.unwatch() + await self._unclaim_job(job, pipe) + await pipe.execute() self.job_counter = self.job_counter - 1 self.sem.release() logger.debug('job %s already running elsewhere', job_id) @@ -604,6 +607,9 @@ async def start_jobs(self, jobs: list[JobMetaInfo]) -> None: await pipe.execute() except (ResponseError, WatchError): # job already started elsewhere since we got 'existing' + pipe.multi() + await self._unclaim_job(job, pipe) + await pipe.execute() self.job_counter = self.job_counter - 1 self.sem.release() logger.debug('multi-exec error, job %s already started elsewhere', job_id) @@ -612,24 +618,27 @@ async def start_jobs(self, jobs: list[JobMetaInfo]) -> None: t.add_done_callback(lambda _: self._release_sem_dec_counter_on_complete()) self.tasks[job_id] = t - async def _unclaim_job(self, job: JobMetaInfo) -> None: - async with self.pool.pipeline(transaction=True) as pipe: - stream_key = self.queue_name + stream_key_suffix - job_message_id_key = job_message_id_prefix + job.job_id - - pipe.xack(stream_key, self.consumer_group_name, job.message_id) - pipe.xdel(stream_key, job.message_id) - job_message_id_expire = job.score - timestamp_ms() + self.expires_extra_ms - if job_message_id_expire <= 0: - job_message_id_expire = self.expires_extra_ms - - await self.pool.publish_to_stream_script( - keys=[stream_key, job_message_id_key], - args=[job.job_id, str(job.score), str(job_message_id_expire)], - client=pipe, - ) - - await pipe.execute() + async def _unclaim_job(self, job: JobMetaInfo, pipe: Pipeline) -> None: + stream_key = self.queue_name + stream_key_suffix + job_message_id_key = job_message_id_prefix + job.job_id + + pipe.xack(stream_key, self.consumer_group_name, job.message_id) + pipe.xdel(stream_key, job.message_id) + job_message_id_expire = job.score - timestamp_ms() + self.expires_extra_ms + if job_message_id_expire <= 0: + job_message_id_expire = self.expires_extra_ms + + pipe.eval( + publish_job_lua, + 2, + # keys + stream_key, + job_message_id_key, + # args + job.job_id, + str(job.score), + str(job_message_id_expire), + ) async def run_job(self, job_id: str, message_id: str, score: int) -> None: # noqa: C901 start_ms = timestamp_ms() @@ -877,10 +886,16 @@ async def finish_job( tr.zadd(self.queue_name, {job_id: score + incr_score}) else: job_message_id_expire = score - timestamp_ms() + self.expires_extra_ms - await self.pool.publish_to_stream_script( - keys=[stream_key, job_message_id_key], - args=[job_id, str(score), str(job_message_id_expire)], - client=tr, + tr.eval( + publish_job_lua, + 2, + # keys + stream_key, + job_message_id_key, + # args + job_id, + str(score), + str(job_message_id_expire), ) if delete_keys: tr.delete(*delete_keys) From 6a61863c26432214b98d15c4ec4180bab1400368 Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Sun, 15 Dec 2024 16:41:33 +0100 Subject: [PATCH 11/12] bug fix --- arq/worker.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/arq/worker.py b/arq/worker.py index 542401f3..880dd48d 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -488,15 +488,19 @@ async def _read_stream_iteration(self) -> None: if self.allow_pick_jobs: if self.job_counter < self.max_jobs: stream_msgs = await self._get_idle_tasks(count) - count = count - len(stream_msgs) + msgs_count = sum([len(msgs) for _, msgs in stream_msgs]) + + count -= msgs_count if count > 0: - stream_msgs = await self.pool.xreadgroup( - groupname=self.consumer_group_name, - consumername=self.worker_id, - streams={self.queue_name + stream_key_suffix: '>'}, - count=count, - block=int(max(self.stream_block_s * 1000, 1)), + stream_msgs.extend( + await self.pool.xreadgroup( + groupname=self.consumer_group_name, + consumername=self.worker_id, + streams={self.queue_name + stream_key_suffix: '>'}, + count=count, + block=int(max(self.stream_block_s * 1000, 1)), + ) ) jobs = [] From 8e76b586387fd7d559bb81c4fd490b87dc6947f2 Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Tue, 17 Dec 2024 16:52:28 +0100 Subject: [PATCH 12/12] replace batched from itertools --- arq/jobs.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/arq/jobs.py b/arq/jobs.py index 63dd3976..71e8eef3 100644 --- a/arq/jobs.py +++ b/arq/jobs.py @@ -5,7 +5,6 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from itertools import batched from typing import Any, Callable, Dict, Optional, Tuple from redis.asyncio import Redis @@ -74,7 +73,7 @@ class JobResult(JobDef): def _list_to_dict(input_list: list[Any]) -> dict[Any, Any]: - return {key: value for key, value in batched(input_list, 2)} + return dict(zip(input_list[::2], input_list[1::2], strict=True)) class Job: