Skip to content

Commit eda36f8

Browse files
RB387rossmacarthur
authored andcommitted
Redis Streams for immediate task delivery (python-arq#492)
1 parent 7a911f3 commit eda36f8

8 files changed

+571
-100
lines changed

arq/connections.py

+67-7
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,16 @@
1313
from redis.asyncio.sentinel import Sentinel
1414
from redis.exceptions import RedisError, WatchError
1515

16-
from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix
16+
from .constants import (
17+
default_queue_name,
18+
expires_extra_ms,
19+
job_key_prefix,
20+
job_message_id_prefix,
21+
result_key_prefix,
22+
stream_key_suffix,
23+
)
1724
from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job
25+
from .lua import publish_job_lua
1826
from .utils import timestamp_ms, to_ms, to_unix_ms
1927

2028
logger = logging.getLogger('arq.connections')
@@ -165,20 +173,63 @@ async def enqueue_job(
165173
elif defer_by_ms:
166174
score = enqueue_time_ms + defer_by_ms
167175
else:
168-
score = enqueue_time_ms
176+
score = None
169177

170-
expires_ms = expires_ms or score - enqueue_time_ms + self.expires_extra_ms
178+
expires_ms = expires_ms or (score or enqueue_time_ms) - enqueue_time_ms + self.expires_extra_ms
171179

172-
job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer)
180+
job = serialize_job(
181+
function,
182+
args,
183+
kwargs,
184+
_job_try,
185+
enqueue_time_ms,
186+
serializer=self.job_serializer,
187+
)
173188
pipe.multi()
174189
pipe.psetex(job_key, expires_ms, job)
175-
pipe.zadd(_queue_name, {job_id: score})
190+
191+
if score is not None:
192+
pipe.zadd(_queue_name, {job_id: score})
193+
else:
194+
stream_key = _queue_name + stream_key_suffix
195+
job_message_id_key = job_message_id_prefix + job_id
196+
pipe.eval(
197+
publish_job_lua,
198+
2,
199+
# keys
200+
stream_key,
201+
job_message_id_key,
202+
# args
203+
job_id,
204+
str(enqueue_time_ms),
205+
str(expires_ms),
206+
)
207+
176208
try:
177209
await pipe.execute()
178210
except WatchError:
179211
# job got enqueued since we checked 'job_exists'
180212
return None
181-
return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer)
213+
return Job(
214+
job_id,
215+
redis=self,
216+
_queue_name=_queue_name,
217+
_deserializer=self.job_deserializer,
218+
)
219+
220+
async def get_queue_size(self, queue_name: str | None = None, include_delayed_tasks: bool = True) -> int:
221+
if queue_name is None:
222+
queue_name = self.default_queue_name
223+
224+
async with self.pipeline(transaction=True) as pipe:
225+
pipe.xlen(queue_name + stream_key_suffix)
226+
pipe.zcount(queue_name, '-inf', '+inf')
227+
stream_size, delayed_queue_size = await pipe.execute()
228+
229+
if not include_delayed_tasks:
230+
return stream_size
231+
232+
return stream_size + delayed_queue_size
182233

183234
async def _get_job_result(self, key: bytes) -> JobResult:
184235
job_id = key[len(result_key_prefix) :].decode()
@@ -213,7 +264,16 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef]
213264
"""
214265
if queue_name is None:
215266
queue_name = self.default_queue_name
216-
jobs = await self.zrange(queue_name, withscores=True, start=0, end=-1)
267+
268+
async with self.pipeline(transaction=True) as pipe:
269+
pipe.zrange(queue_name, withscores=True, start=0, end=-1)
270+
pipe.xrange(queue_name + stream_key_suffix, '-', '+')
271+
delayed_jobs, stream_jobs = await pipe.execute()
272+
273+
jobs = [
274+
*delayed_jobs,
275+
*[(j[b'job_id'], int(j[b'score'])) for _, j in stream_jobs],
276+
]
217277
return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs])
218278

219279

arq/constants.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
default_queue_name = 'arq:queue'
22
job_key_prefix = 'arq:job:'
33
in_progress_key_prefix = 'arq:in-progress:'
4+
job_message_id_prefix = 'arq:message-id:'
45
result_key_prefix = 'arq:result:'
56
retry_key_prefix = 'arq:retry:'
67
abort_jobs_ss = 'arq:abort'
8+
stream_key_suffix = ':stream'
9+
default_consumer_group = 'arq:workers'
710
# age of items in the abort_key sorted set after which they're deleted
811
abort_job_max_age = 60
912
health_check_key_suffix = ':health-check'

arq/jobs.py

+39-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,16 @@
99

1010
from redis.asyncio import Redis
1111

12-
from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix
12+
from .constants import (
13+
abort_jobs_ss,
14+
default_queue_name,
15+
in_progress_key_prefix,
16+
job_key_prefix,
17+
job_message_id_prefix,
18+
result_key_prefix,
19+
stream_key_suffix,
20+
)
21+
from .lua import get_job_from_stream_lua
1322
from .utils import ms_to_datetime, poll, timestamp_ms
1423

1524
logger = logging.getLogger('arq.jobs')
@@ -63,6 +72,10 @@ class JobResult(JobDef):
6372
queue_name: str
6473

6574

75+
def _list_to_dict(input_list: list[Any]) -> dict[Any, Any]:
76+
return dict(zip(input_list[::2], input_list[1::2], strict=True))
77+
78+
6679
class Job:
6780
"""
6881
Holds data a reference to a job.
@@ -105,7 +118,8 @@ async def result(
105118
async with self._redis.pipeline(transaction=True) as tr:
106119
tr.get(result_key_prefix + self.job_id)
107120
tr.zscore(self._queue_name, self.job_id)
108-
v, s = await tr.execute()
121+
tr.get(job_message_id_prefix + self.job_id)
122+
v, s, m = await tr.execute()
109123

110124
if v:
111125
info = deserialize_result(v, deserializer=self._deserializer)
@@ -115,7 +129,7 @@ async def result(
115129
raise info.result
116130
else:
117131
raise SerializationError(info.result)
118-
elif s is None:
132+
elif s is None and m is None:
119133
raise ResultNotFound(
120134
'Not waiting for job result because the job is not in queue. '
121135
'Is the worker function configured to keep result?'
@@ -134,8 +148,24 @@ async def info(self) -> Optional[JobDef]:
134148
if v:
135149
info = deserialize_job(v, deserializer=self._deserializer)
136150
if info:
137-
s = await self._redis.zscore(self._queue_name, self.job_id)
138-
info.score = None if s is None else int(s)
151+
async with self._redis.pipeline(transaction=True) as tr:
152+
tr.zscore(self._queue_name, self.job_id)
153+
tr.eval(
154+
get_job_from_stream_lua,
155+
2,
156+
self._queue_name + stream_key_suffix,
157+
job_message_id_prefix + self.job_id,
158+
)
159+
delayed_score, job_info = await tr.execute()
160+
161+
if delayed_score:
162+
info.score = int(delayed_score)
163+
elif job_info:
164+
_, job_info_payload = job_info
165+
info.score = int(_list_to_dict(job_info_payload)[b'score'])
166+
else:
167+
info.score = None
168+
139169
return info
140170

141171
async def result_info(self) -> Optional[JobResult]:
@@ -157,12 +187,15 @@ async def status(self) -> JobStatus:
157187
tr.exists(result_key_prefix + self.job_id)
158188
tr.exists(in_progress_key_prefix + self.job_id)
159189
tr.zscore(self._queue_name, self.job_id)
160-
is_complete, is_in_progress, score = await tr.execute()
190+
tr.exists(job_message_id_prefix + self.job_id)
191+
is_complete, is_in_progress, score, queued = await tr.execute()
161192

162193
if is_complete:
163194
return JobStatus.complete
164195
elif is_in_progress:
165196
return JobStatus.in_progress
197+
elif queued:
198+
return JobStatus.queued
166199
elif score:
167200
return JobStatus.deferred if score > timestamp_ms() else JobStatus.queued
168201
else:

arq/lua.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
publish_delayed_job_lua = """
2+
local delayed_queue_key = KEYS[1]
3+
local stream_key = KEYS[2]
4+
local job_message_id_key = KEYS[3]
5+
6+
local job_id = ARGV[1]
7+
local job_message_id_expire_ms = ARGV[2]
8+
9+
local score = redis.call('zscore', delayed_queue_key, job_id)
10+
if score == nil or score == false then
11+
return 0
12+
end
13+
14+
local message_id = redis.call('xadd', stream_key, '*', 'job_id', job_id, 'score', score)
15+
redis.call('set', job_message_id_key, message_id, 'px', job_message_id_expire_ms)
16+
redis.call('zrem', delayed_queue_key, job_id)
17+
return 1
18+
"""
19+
20+
publish_job_lua = """
21+
local stream_key = KEYS[1]
22+
local job_message_id_key = KEYS[2]
23+
24+
local job_id = ARGV[1]
25+
local score = ARGV[2]
26+
local job_message_id_expire_ms = ARGV[3]
27+
28+
local message_id = redis.call('xadd', stream_key, '*', 'job_id', job_id, 'score', score)
29+
redis.call('set', job_message_id_key, message_id, 'px', job_message_id_expire_ms)
30+
return message_id
31+
"""
32+
33+
get_job_from_stream_lua = """
34+
local stream_key = KEYS[1]
35+
local job_message_id_key = KEYS[2]
36+
37+
local message_id = redis.call('get', job_message_id_key)
38+
if message_id == false then
39+
return nil
40+
end
41+
42+
local job = redis.call('xrange', stream_key, message_id, message_id)
43+
if job == nil then
44+
return nil
45+
end
46+
47+
return job[1]
48+
"""

0 commit comments

Comments
 (0)