Skip to content

Commit 838bd58

Browse files
committed
add timeout and retry logic to create_redis_pool
1 parent 92c65ef commit 838bd58

File tree

4 files changed

+41
-6
lines changed

4 files changed

+41
-6
lines changed

HISTORY.rst

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ v0.8.0 (2017-06-05)
99
* change logger name for control process log messages
1010
* use ``Semaphore`` rather than ``asyncio.wait(...return_when=asyncio.FIRST_COMPLETED)`` for improved performance
1111
* improve log display
12+
* add timeout and retry logic to ``RedisMixin.create_redis_pool``
1213

1314
v0.7.0 (2017-06-01)
1415
...................

arq/utils.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,33 @@
66
"""
77
import asyncio
88
import base64
9+
import logging
910
import os
1011
from datetime import datetime, timedelta, timezone
1112
from typing import Tuple, Union
1213

1314
import aioredis
1415
from aioredis.pool import RedisPool
16+
from async_timeout import timeout
1517

1618
__all__ = ['RedisSettings', 'RedisMixin']
19+
logger = logging.getLogger('arq.utils')
1720

1821

1922
class RedisSettings:
2023
"""
2124
No-Op class used to hold redis connection redis_settings.
2225
"""
26+
__slots__ = 'host', 'port', 'database', 'password', 'conn_retries', 'conn_timeout', 'conn_retry_delay'
27+
2328
def __init__(self,
2429
host='localhost',
2530
port=6379,
2631
database=0,
27-
password=None):
32+
password=None,
33+
conn_timeout=1,
34+
conn_retries=5,
35+
conn_retry_delay=1):
2836
"""
2937
:param host: redis host
3038
:param port: redis port
@@ -35,6 +43,9 @@ def __init__(self,
3543
self.port = port
3644
self.database = database
3745
self.password = password
46+
self.conn_timeout = conn_timeout
47+
self.conn_retries = conn_retries
48+
self.conn_retry_delay = conn_retry_delay
3849

3950

4051
class RedisMixin:
@@ -56,12 +67,23 @@ def __init__(self, *,
5667
self.redis_settings = redis_settings or getattr(self, 'redis_settings', None) or RedisSettings()
5768
self._redis_pool = existing_pool
5869

59-
async def create_redis_pool(self) -> RedisPool:
70+
async def create_redis_pool(self, *, _retry=0) -> RedisPool:
6071
"""
6172
Create a new redis pool.
6273
"""
63-
return await aioredis.create_pool((self.redis_settings.host, self.redis_settings.port), loop=self.loop,
64-
db=self.redis_settings.database, password=self.redis_settings.password)
74+
addr = self.redis_settings.host, self.redis_settings.port
75+
try:
76+
with timeout(self.redis_settings.conn_timeout):
77+
return await aioredis.create_pool(addr, loop=self.loop, db=self.redis_settings.database,
78+
password=self.redis_settings.password)
79+
except (ConnectionError, OSError, aioredis.RedisError, asyncio.TimeoutError) as e:
80+
if _retry < self.redis_settings.conn_retries:
81+
logger.warning('redis connection error %s %s, %d retries remaining...',
82+
e.__class__.__name__, e, self.redis_settings.conn_retries - _retry)
83+
await asyncio.sleep(self.redis_settings.conn_retry_delay)
84+
return await self.create_redis_pool(_retry=_retry + 1)
85+
else:
86+
raise
6587

6688
async def get_redis_pool(self) -> RedisPool:
6789
"""

arq/worker.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class BaseWorker(RedisMixin):
8181
repeat_health_check_logs = False
8282

8383
drain_class = Drain
84+
_shadow_factory_timeout = 10
8485

8586
def __init__(self, *,
8687
burst: bool=False,
@@ -178,7 +179,7 @@ async def run(self):
178179
self._stopped = False
179180
work_logger.info('Initialising work manager, burst mode: %s, creating shadows...', self._burst_mode)
180181

181-
with timeout(10):
182+
with timeout(self._shadow_factory_timeout):
182183
shadows = await self.shadow_factory()
183184
assert isinstance(shadows, list), 'shadow_factory should return a list not %s' % type(shadows)
184185
self.job_class = shadows[0].job_class

tests/test_utils.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
import os
33
from datetime import datetime
44

5-
from arq import RedisSettings
5+
import pytest
6+
7+
import arq.utils
8+
from arq import RedisMixin, RedisSettings
69
from arq.logs import ColourHandler
710
from arq.testing import MockRedis
811
from arq.utils import timestamp
@@ -49,3 +52,11 @@ async def test_mock_redis_flushdb(loop):
4952
assert 'bar' == await r.get('foo')
5053
await r.flushdb()
5154
assert None is await r.get('foo')
55+
56+
57+
async def test_redis_timeout(loop, mocker):
58+
mocker.spy(arq.utils.asyncio, 'sleep')
59+
r = RedisMixin(redis_settings=RedisSettings(port=0, conn_retry_delay=0), loop=loop)
60+
with pytest.raises(OSError):
61+
await r.get_redis_pool()
62+
assert arq.utils.asyncio.sleep.call_count == 5

0 commit comments

Comments
 (0)