9
9
from urllib .parse import urlparse
10
10
from uuid import uuid4
11
11
12
- import aioredis
13
- from aioredis import MultiExecError , Redis
14
12
from pydantic .validators import make_arbitrary_type_validator
13
+ from redis .asyncio import ConnectionPool , Redis
14
+ from redis .asyncio .sentinel import Sentinel
15
+ from redis .exceptions import RedisError , WatchError
15
16
16
17
from .constants import default_queue_name , job_key_prefix , result_key_prefix
17
18
from .jobs import Deserializer , Job , JobDef , JobResult , Serializer , deserialize_job , serialize_job
@@ -70,20 +71,20 @@ def __repr__(self) -> str:
70
71
expires_extra_ms = 86_400_000
71
72
72
73
73
- class ArqRedis (Redis ): # type: ignore
74
+ class ArqRedis (Redis ): # type: ignore[misc]
74
75
"""
75
- Thin subclass of ``aioredis .Redis`` which adds :func:`arq.connections.enqueue_job`.
76
+ Thin subclass of ``redis.asyncio .Redis`` which adds :func:`arq.connections.enqueue_job`.
76
77
77
78
:param redis_settings: an instance of ``arq.connections.RedisSettings``.
78
79
:param job_serializer: a function that serializes Python objects to bytes, defaults to pickle.dumps
79
80
:param job_deserializer: a function that deserializes bytes into Python objects, defaults to pickle.loads
80
81
:param default_queue_name: the default queue name to use, defaults to ``arq.queue``.
81
- :param kwargs: keyword arguments directly passed to ``aioredis .Redis``.
82
+ :param kwargs: keyword arguments directly passed to ``redis.asyncio .Redis``.
82
83
"""
83
84
84
85
def __init__ (
85
86
self ,
86
- pool_or_conn : Any ,
87
+ pool_or_conn : Optional [ ConnectionPool ] = None ,
87
88
job_serializer : Optional [Serializer ] = None ,
88
89
job_deserializer : Optional [Deserializer ] = None ,
89
90
default_queue_name : str = default_queue_name ,
@@ -92,7 +93,9 @@ def __init__(
92
93
self .job_serializer = job_serializer
93
94
self .job_deserializer = job_deserializer
94
95
self .default_queue_name = default_queue_name
95
- super ().__init__ (pool_or_conn , ** kwargs )
96
+ if pool_or_conn :
97
+ kwargs ['connection_pool' ] = pool_or_conn
98
+ super ().__init__ (** kwargs )
96
99
97
100
async def enqueue_job (
98
101
self ,
@@ -129,14 +132,10 @@ async def enqueue_job(
129
132
defer_by_ms = to_ms (_defer_by )
130
133
expires_ms = to_ms (_expires )
131
134
132
- with await self as conn :
133
- pipe = conn .pipeline ()
134
- pipe .unwatch ()
135
- pipe .watch (job_key )
136
- job_exists = pipe .exists (job_key )
137
- job_result_exists = pipe .exists (result_key_prefix + job_id )
138
- await pipe .execute ()
139
- if await job_exists or await job_result_exists :
135
+ async with self .pipeline (transaction = True ) as pipe :
136
+ await pipe .watch (job_key )
137
+ if any (await asyncio .gather (pipe .exists (job_key ), pipe .exists (result_key_prefix + job_id ))):
138
+ await pipe .reset ()
140
139
return None
141
140
142
141
enqueue_time_ms = timestamp_ms ()
@@ -150,24 +149,22 @@ async def enqueue_job(
150
149
expires_ms = expires_ms or score - enqueue_time_ms + expires_extra_ms
151
150
152
151
job = serialize_job (function , args , kwargs , _job_try , enqueue_time_ms , serializer = self .job_serializer )
153
- tr = conn . multi_exec ()
154
- tr .psetex (job_key , expires_ms , job )
155
- tr .zadd (_queue_name , score , job_id )
152
+ pipe . multi ()
153
+ pipe .psetex (job_key , expires_ms , job )
154
+ pipe .zadd (_queue_name , { job_id : score } )
156
155
try :
157
- await tr .execute ()
158
- except MultiExecError :
156
+ await pipe .execute ()
157
+ except WatchError :
159
158
# job got enqueued since we checked 'job_exists'
160
- # https://github.com/samuelcolvin/arq/issues/131, avoid warnings in log
161
- await asyncio .gather (* tr ._results , return_exceptions = True )
162
159
return None
163
160
return Job (job_id , redis = self , _queue_name = _queue_name , _deserializer = self .job_deserializer )
164
161
165
- async def _get_job_result (self , key : str ) -> JobResult :
166
- job_id = key [len (result_key_prefix ) :]
162
+ async def _get_job_result (self , key : bytes ) -> JobResult :
163
+ job_id = key [len (result_key_prefix ) :]. decode ()
167
164
job = Job (job_id , self , _deserializer = self .job_deserializer )
168
165
r = await job .result_info ()
169
166
if r is None :
170
- raise KeyError (f'job "{ key } " not found' )
167
+ raise KeyError (f'job "{ key . decode () } " not found' )
171
168
r .job_id = job_id
172
169
return r
173
170
@@ -179,8 +176,8 @@ async def all_job_results(self) -> List[JobResult]:
179
176
results = await asyncio .gather (* [self ._get_job_result (k ) for k in keys ])
180
177
return sorted (results , key = attrgetter ('enqueue_time' ))
181
178
182
- async def _get_job_def (self , job_id : str , score : int ) -> JobDef :
183
- v = await self .get (job_key_prefix + job_id , encoding = None )
179
+ async def _get_job_def (self , job_id : bytes , score : int ) -> JobDef :
180
+ v = await self .get (job_key_prefix + job_id . decode () )
184
181
jd = deserialize_job (v , deserializer = self .job_deserializer )
185
182
jd .score = score
186
183
return jd
@@ -189,8 +186,8 @@ async def queued_jobs(self, *, queue_name: str = default_queue_name) -> List[Job
189
186
"""
190
187
Get information about queued, mostly useful when testing.
191
188
"""
192
- jobs = await self .zrange (queue_name , withscores = True )
193
- return await asyncio .gather (* [self ._get_job_def (job_id , score ) for job_id , score in jobs ])
189
+ jobs = await self .zrange (queue_name , withscores = True , start = 0 , end = - 1 )
190
+ return await asyncio .gather (* [self ._get_job_def (job_id , int ( score ) ) for job_id , score in jobs ])
194
191
195
192
196
193
async def create_pool (
@@ -204,8 +201,7 @@ async def create_pool(
204
201
"""
205
202
Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails.
206
203
207
- Similar to ``aioredis.create_redis_pool`` except it returns a :class:`arq.connections.ArqRedis` instance,
208
- thus allowing job enqueuing.
204
+ Returns a :class:`arq.connections.ArqRedis` instance, thus allowing job enqueuing.
209
205
"""
210
206
settings : RedisSettings = RedisSettings () if settings_ is None else settings_
211
207
@@ -214,32 +210,33 @@ async def create_pool(
214
210
), "str provided for 'host' but 'sentinel' is true; list of sentinels expected"
215
211
216
212
if settings .sentinel :
217
- addr : Any = settings .host
218
213
219
- async def pool_factory (* args : Any , ** kwargs : Any ) -> Redis :
220
- client = await aioredis . sentinel . create_sentinel_pool (* args , ssl = settings .ssl , ** kwargs )
221
- return client .master_for (settings .sentinel_master )
214
+ def pool_factory (* args : Any , ** kwargs : Any ) -> ArqRedis :
215
+ client = Sentinel (* args , sentinels = settings . host , ssl = settings .ssl , ** kwargs )
216
+ return client .master_for (settings .sentinel_master , redis_class = ArqRedis )
222
217
223
218
else :
224
219
pool_factory = functools .partial (
225
- aioredis .create_pool , create_connection_timeout = settings .conn_timeout , ssl = settings .ssl
220
+ ArqRedis ,
221
+ host = settings .host ,
222
+ port = settings .port ,
223
+ socket_connect_timeout = settings .conn_timeout ,
224
+ ssl = settings .ssl ,
226
225
)
227
- addr = settings .host , settings .port
228
226
229
227
try :
230
- pool = await pool_factory (addr , db = settings .database , password = settings .password , encoding = 'utf8' )
231
- pool = ArqRedis (
232
- pool ,
233
- job_serializer = job_serializer ,
234
- job_deserializer = job_deserializer ,
235
- default_queue_name = default_queue_name ,
236
- )
228
+ pool = pool_factory (db = settings .database , password = settings .password , encoding = 'utf8' )
229
+ pool .job_serializer = job_serializer
230
+ pool .job_deserializer = job_deserializer
231
+ pool .default_queue_name = default_queue_name
232
+ await pool .ping ()
237
233
238
- except (ConnectionError , OSError , aioredis . RedisError , asyncio .TimeoutError ) as e :
234
+ except (ConnectionError , OSError , RedisError , asyncio .TimeoutError ) as e :
239
235
if retry < settings .conn_retries :
240
236
logger .warning (
241
- 'redis connection error %s %s %s, %d retries remaining...' ,
242
- addr ,
237
+ 'redis connection error %s:%s %s %s, %d retries remaining...' ,
238
+ settings .host ,
239
+ settings .port ,
243
240
e .__class__ .__name__ ,
244
241
e ,
245
242
settings .conn_retries - retry ,
@@ -264,17 +261,16 @@ async def pool_factory(*args: Any, **kwargs: Any) -> Redis:
264
261
265
262
266
263
async def log_redis_info (redis : Redis , log_func : Callable [[str ], Any ]) -> None :
267
- with await redis as r :
268
- info_server , info_memory , info_clients , key_count = await asyncio .gather (
269
- r .info (section = 'Server' ),
270
- r .info (section = 'Memory' ),
271
- r .info (section = 'Clients' ),
272
- r .dbsize (),
273
- )
274
-
275
- redis_version = info_server .get ('server' , {}).get ('redis_version' , '?' )
276
- mem_usage = info_memory .get ('memory' , {}).get ('used_memory_human' , '?' )
277
- clients_connected = info_clients .get ('clients' , {}).get ('connected_clients' , '?' )
264
+ async with redis .pipeline (transaction = True ) as pipe :
265
+ pipe .info (section = 'Server' )
266
+ pipe .info (section = 'Memory' )
267
+ pipe .info (section = 'Clients' )
268
+ pipe .dbsize ()
269
+ info_server , info_memory , info_clients , key_count = await pipe .execute ()
270
+
271
+ redis_version = info_server .get ('redis_version' , '?' )
272
+ mem_usage = info_memory .get ('used_memory_human' , '?' )
273
+ clients_connected = info_clients .get ('connected_clients' , '?' )
278
274
279
275
log_func (
280
276
f'redis_version={ redis_version } '
0 commit comments