Skip to content

Commit 77031fd

Browse files
Yolleysamuelcolvin
andauthored
support aioredis v2 (#259)
* support aioredis 2 * use psetex for float ttl * health_check_interval + 1 sec * fix health_check_interval psetex * encoder context to encode/decode on per-request basis * fix hiredis parser * remove excessive comments * update docstrings, pipe commands in run_job, gather in enqueue_job, expire in ms * fix mypy, tweak tests * using latest aioredis * switch how latest aioredis is installed * fix install * fixing worker tests removing encoder_context usage * removing usage of encoder_context * removing parser and encoder_context * tweak test_abort_job * tests at last passing! * switch to redis from aioredis * cleanup Co-authored-by: Samuel Colvin <[email protected]>
1 parent 763826a commit 77031fd

13 files changed

+180
-180
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ __pycache__/
1717
.vscode/
1818
.venv/
1919
/.auto-format
20+
/scratch/

arq/connections.py

+54-58
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
from urllib.parse import urlparse
1010
from uuid import uuid4
1111

12-
import aioredis
13-
from aioredis import MultiExecError, Redis
1412
from pydantic.validators import make_arbitrary_type_validator
13+
from redis.asyncio import ConnectionPool, Redis
14+
from redis.asyncio.sentinel import Sentinel
15+
from redis.exceptions import RedisError, WatchError
1516

1617
from .constants import default_queue_name, job_key_prefix, result_key_prefix
1718
from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job
@@ -70,20 +71,20 @@ def __repr__(self) -> str:
7071
expires_extra_ms = 86_400_000
7172

7273

73-
class ArqRedis(Redis): # type: ignore
74+
class ArqRedis(Redis): # type: ignore[misc]
7475
"""
75-
Thin subclass of ``aioredis.Redis`` which adds :func:`arq.connections.enqueue_job`.
76+
Thin subclass of ``redis.asyncio.Redis`` which adds :func:`arq.connections.enqueue_job`.
7677
7778
:param redis_settings: an instance of ``arq.connections.RedisSettings``.
7879
:param job_serializer: a function that serializes Python objects to bytes, defaults to pickle.dumps
7980
:param job_deserializer: a function that deserializes bytes into Python objects, defaults to pickle.loads
8081
:param default_queue_name: the default queue name to use, defaults to ``arq.queue``.
81-
:param kwargs: keyword arguments directly passed to ``aioredis.Redis``.
82+
:param kwargs: keyword arguments directly passed to ``redis.asyncio.Redis``.
8283
"""
8384

8485
def __init__(
8586
self,
86-
pool_or_conn: Any,
87+
pool_or_conn: Optional[ConnectionPool] = None,
8788
job_serializer: Optional[Serializer] = None,
8889
job_deserializer: Optional[Deserializer] = None,
8990
default_queue_name: str = default_queue_name,
@@ -92,7 +93,9 @@ def __init__(
9293
self.job_serializer = job_serializer
9394
self.job_deserializer = job_deserializer
9495
self.default_queue_name = default_queue_name
95-
super().__init__(pool_or_conn, **kwargs)
96+
if pool_or_conn:
97+
kwargs['connection_pool'] = pool_or_conn
98+
super().__init__(**kwargs)
9699

97100
async def enqueue_job(
98101
self,
@@ -129,14 +132,10 @@ async def enqueue_job(
129132
defer_by_ms = to_ms(_defer_by)
130133
expires_ms = to_ms(_expires)
131134

132-
with await self as conn:
133-
pipe = conn.pipeline()
134-
pipe.unwatch()
135-
pipe.watch(job_key)
136-
job_exists = pipe.exists(job_key)
137-
job_result_exists = pipe.exists(result_key_prefix + job_id)
138-
await pipe.execute()
139-
if await job_exists or await job_result_exists:
135+
async with self.pipeline(transaction=True) as pipe:
136+
await pipe.watch(job_key)
137+
if any(await asyncio.gather(pipe.exists(job_key), pipe.exists(result_key_prefix + job_id))):
138+
await pipe.reset()
140139
return None
141140

142141
enqueue_time_ms = timestamp_ms()
@@ -150,24 +149,22 @@ async def enqueue_job(
150149
expires_ms = expires_ms or score - enqueue_time_ms + expires_extra_ms
151150

152151
job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer)
153-
tr = conn.multi_exec()
154-
tr.psetex(job_key, expires_ms, job)
155-
tr.zadd(_queue_name, score, job_id)
152+
pipe.multi()
153+
pipe.psetex(job_key, expires_ms, job)
154+
pipe.zadd(_queue_name, {job_id: score})
156155
try:
157-
await tr.execute()
158-
except MultiExecError:
156+
await pipe.execute()
157+
except WatchError:
159158
# job got enqueued since we checked 'job_exists'
160-
# https://github.com/samuelcolvin/arq/issues/131, avoid warnings in log
161-
await asyncio.gather(*tr._results, return_exceptions=True)
162159
return None
163160
return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer)
164161

165-
async def _get_job_result(self, key: str) -> JobResult:
166-
job_id = key[len(result_key_prefix) :]
162+
async def _get_job_result(self, key: bytes) -> JobResult:
163+
job_id = key[len(result_key_prefix) :].decode()
167164
job = Job(job_id, self, _deserializer=self.job_deserializer)
168165
r = await job.result_info()
169166
if r is None:
170-
raise KeyError(f'job "{key}" not found')
167+
raise KeyError(f'job "{key.decode()}" not found')
171168
r.job_id = job_id
172169
return r
173170

@@ -179,8 +176,8 @@ async def all_job_results(self) -> List[JobResult]:
179176
results = await asyncio.gather(*[self._get_job_result(k) for k in keys])
180177
return sorted(results, key=attrgetter('enqueue_time'))
181178

182-
async def _get_job_def(self, job_id: str, score: int) -> JobDef:
183-
v = await self.get(job_key_prefix + job_id, encoding=None)
179+
async def _get_job_def(self, job_id: bytes, score: int) -> JobDef:
180+
v = await self.get(job_key_prefix + job_id.decode())
184181
jd = deserialize_job(v, deserializer=self.job_deserializer)
185182
jd.score = score
186183
return jd
@@ -189,8 +186,8 @@ async def queued_jobs(self, *, queue_name: str = default_queue_name) -> List[Job
189186
"""
190187
Get information about queued, mostly useful when testing.
191188
"""
192-
jobs = await self.zrange(queue_name, withscores=True)
193-
return await asyncio.gather(*[self._get_job_def(job_id, score) for job_id, score in jobs])
189+
jobs = await self.zrange(queue_name, withscores=True, start=0, end=-1)
190+
return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs])
194191

195192

196193
async def create_pool(
@@ -204,8 +201,7 @@ async def create_pool(
204201
"""
205202
Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails.
206203
207-
Similar to ``aioredis.create_redis_pool`` except it returns a :class:`arq.connections.ArqRedis` instance,
208-
thus allowing job enqueuing.
204+
Returns a :class:`arq.connections.ArqRedis` instance, thus allowing job enqueuing.
209205
"""
210206
settings: RedisSettings = RedisSettings() if settings_ is None else settings_
211207

@@ -214,32 +210,33 @@ async def create_pool(
214210
), "str provided for 'host' but 'sentinel' is true; list of sentinels expected"
215211

216212
if settings.sentinel:
217-
addr: Any = settings.host
218213

219-
async def pool_factory(*args: Any, **kwargs: Any) -> Redis:
220-
client = await aioredis.sentinel.create_sentinel_pool(*args, ssl=settings.ssl, **kwargs)
221-
return client.master_for(settings.sentinel_master)
214+
def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis:
215+
client = Sentinel(*args, sentinels=settings.host, ssl=settings.ssl, **kwargs)
216+
return client.master_for(settings.sentinel_master, redis_class=ArqRedis)
222217

223218
else:
224219
pool_factory = functools.partial(
225-
aioredis.create_pool, create_connection_timeout=settings.conn_timeout, ssl=settings.ssl
220+
ArqRedis,
221+
host=settings.host,
222+
port=settings.port,
223+
socket_connect_timeout=settings.conn_timeout,
224+
ssl=settings.ssl,
226225
)
227-
addr = settings.host, settings.port
228226

229227
try:
230-
pool = await pool_factory(addr, db=settings.database, password=settings.password, encoding='utf8')
231-
pool = ArqRedis(
232-
pool,
233-
job_serializer=job_serializer,
234-
job_deserializer=job_deserializer,
235-
default_queue_name=default_queue_name,
236-
)
228+
pool = pool_factory(db=settings.database, password=settings.password, encoding='utf8')
229+
pool.job_serializer = job_serializer
230+
pool.job_deserializer = job_deserializer
231+
pool.default_queue_name = default_queue_name
232+
await pool.ping()
237233

238-
except (ConnectionError, OSError, aioredis.RedisError, asyncio.TimeoutError) as e:
234+
except (ConnectionError, OSError, RedisError, asyncio.TimeoutError) as e:
239235
if retry < settings.conn_retries:
240236
logger.warning(
241-
'redis connection error %s %s %s, %d retries remaining...',
242-
addr,
237+
'redis connection error %s:%s %s %s, %d retries remaining...',
238+
settings.host,
239+
settings.port,
243240
e.__class__.__name__,
244241
e,
245242
settings.conn_retries - retry,
@@ -264,17 +261,16 @@ async def pool_factory(*args: Any, **kwargs: Any) -> Redis:
264261

265262

266263
async def log_redis_info(redis: Redis, log_func: Callable[[str], Any]) -> None:
267-
with await redis as r:
268-
info_server, info_memory, info_clients, key_count = await asyncio.gather(
269-
r.info(section='Server'),
270-
r.info(section='Memory'),
271-
r.info(section='Clients'),
272-
r.dbsize(),
273-
)
274-
275-
redis_version = info_server.get('server', {}).get('redis_version', '?')
276-
mem_usage = info_memory.get('memory', {}).get('used_memory_human', '?')
277-
clients_connected = info_clients.get('clients', {}).get('connected_clients', '?')
264+
async with redis.pipeline(transaction=True) as pipe:
265+
pipe.info(section='Server')
266+
pipe.info(section='Memory')
267+
pipe.info(section='Clients')
268+
pipe.dbsize()
269+
info_server, info_memory, info_clients, key_count = await pipe.execute()
270+
271+
redis_version = info_server.get('redis_version', '?')
272+
mem_usage = info_memory.get('used_memory_human', '?')
273+
clients_connected = info_clients.get('connected_clients', '?')
278274

279275
log_func(
280276
f'redis_version={redis_version} '

arq/jobs.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from enum import Enum
88
from typing import Any, Callable, Dict, Optional, Tuple
99

10-
from aioredis import Redis
10+
from redis.asyncio import Redis
1111

1212
from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix
1313
from .utils import ms_to_datetime, poll, timestamp_ms
@@ -44,6 +44,10 @@ class JobDef:
4444
enqueue_time: datetime
4545
score: Optional[int]
4646

47+
def __post_init__(self) -> None:
48+
if isinstance(self.score, float):
49+
self.score = int(self.score)
50+
4751

4852
@dataclass
4953
class JobResult(JobDef):
@@ -110,7 +114,7 @@ async def info(self) -> Optional[JobDef]:
110114
"""
111115
info: Optional[JobDef] = await self.result_info()
112116
if not info:
113-
v = await self._redis.get(job_key_prefix + self.job_id, encoding=None)
117+
v = await self._redis.get(job_key_prefix + self.job_id)
114118
if v:
115119
info = deserialize_job(v, deserializer=self._deserializer)
116120
if info:
@@ -122,7 +126,7 @@ async def result_info(self) -> Optional[JobResult]:
122126
Information about the job result if available, does not wait for the result. Does not raise an exception
123127
even if the job raised one.
124128
"""
125-
v = await self._redis.get(result_key_prefix + self.job_id, encoding=None)
129+
v = await self._redis.get(result_key_prefix + self.job_id)
126130
if v:
127131
return deserialize_result(v, deserializer=self._deserializer)
128132
else:
@@ -151,7 +155,7 @@ async def abort(self, *, timeout: Optional[float] = None, poll_delay: float = 0.
151155
:param poll_delay: how often to poll redis for the job result
152156
:return: True if the job aborted properly, False otherwise
153157
"""
154-
await self._redis.zadd(abort_jobs_ss, timestamp_ms(), self.job_id)
158+
await self._redis.zadd(abort_jobs_ss, {self.job_id: timestamp_ms()})
155159
try:
156160
await self.result(timeout=timeout, poll_delay=poll_delay)
157161
except asyncio.CancelledError:
@@ -179,7 +183,7 @@ def serialize_job(
179183
enqueue_time_ms: int,
180184
*,
181185
serializer: Optional[Serializer] = None,
182-
) -> Optional[bytes]:
186+
) -> bytes:
183187
data = {'t': job_try, 'f': function_name, 'a': args, 'k': kwargs, 'et': enqueue_time_ms}
184188
if serializer is None:
185189
serializer = pickle.dumps

0 commit comments

Comments
 (0)