diff --git a/CHANGES b/CHANGES index 8750128b05..5322319cd5 100644 --- a/CHANGES +++ b/CHANGES @@ -67,6 +67,7 @@ * Close Unix sockets if the connection attempt fails. This prevents `ResourceWarning`s. (#3314) * Close SSL sockets if the connection attempt fails, or if validations fail. (#3317) * Eliminate mutable default arguments in the `redis.commands.core.Script` class. (#3332) + * Turn `acquire_connection` into a context manager (#3435) * 4.1.3 (Feb 8, 2022) * Fix flushdb and flushall (#1926) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 4e82e5448f..b1ec8fb69c 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1,13 +1,12 @@ import asyncio -import collections import random import socket import ssl import warnings +from contextlib import asynccontextmanager from typing import ( Any, Callable, - Deque, Dict, Generator, List, @@ -155,6 +154,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand maximum number of connections are already created, a :class:`~.MaxConnectionsError` is raised. This error may be retried as defined by :attr:`connection_error_retry_attempts` + :param connection_queueing_timeout: + | The maximum time in seconds to wait for a free node connection (as + specified by ``max_connections``). If the timeout is reached, a + :class:`~.MaxConnectionsError` is raised. If set to <= 0, no + queueing is done and a :class:`~.MaxConnectionsError` is raised + if no free connection is available. :param address_remap: | An optional callable which, when provided with an internal network address of a node, e.g. a `(host, port)` tuple, will return the address @@ -237,6 +242,7 @@ def __init__( cluster_error_retry_attempts: int = 3, connection_error_retry_attempts: int = 3, max_connections: int = 2**31, + connection_queueing_timeout: float = 0, # Client related kwargs db: Union[str, int] = 0, path: Optional[str] = None, @@ -292,6 +298,7 @@ def __init__( kwargs: Dict[str, Any] = { "max_connections": max_connections, + "connection_queueing_timeout": connection_queueing_timeout, "connection_class": Connection, "parser_class": ClusterParser, # Client related kwargs @@ -933,6 +940,7 @@ class ClusterNode: "connection_kwargs", "host", "max_connections", + "connection_queueing_timeout", "name", "port", "response_callbacks", @@ -946,6 +954,7 @@ def __init__( server_type: Optional[str] = None, *, max_connections: int = 2**31, + connection_queueing_timeout: float = 0, connection_class: Type[Connection] = Connection, **connection_kwargs: Any, ) -> None: @@ -960,12 +969,13 @@ def __init__( self.server_type = server_type self.max_connections = max_connections + self.connection_queueing_timeout = connection_queueing_timeout self.connection_class = connection_class self.connection_kwargs = connection_kwargs self.response_callbacks = connection_kwargs.pop("response_callbacks", {}) self._connections: List[Connection] = [] - self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) + self._free: asyncio.Queue[Connection] = asyncio.Queue(maxsize=self.max_connections) def __repr__(self) -> str: return ( @@ -1006,16 +1016,32 @@ async def disconnect(self) -> None: if exc: raise exc - def acquire_connection(self) -> Connection: + @asynccontextmanager + async def acquire_connection(self) -> Connection: + """Context manager acquiring a connection on enter and automatically + freeing it on exit.""" + + # Try to get a free connection try: - return self._free.popleft() - except IndexError: + connection = self._free.get_nowait() + except asyncio.QueueEmpty: if len(self._connections) < self.max_connections: connection = self.connection_class(**self.connection_kwargs) self._connections.append(connection) - return connection + else: + if self.connection_queueing_timeout <= 0: + raise MaxConnectionsError() + try: + connection = await asyncio.wait_for( + self._free.get(), self.connection_queueing_timeout + ) + except asyncio.TimeoutError: + raise MaxConnectionsError() - raise MaxConnectionsError() + try: + yield connection + finally: + self._free.put_nowait(connection) async def parse_response( self, connection: Connection, command: str, **kwargs: Any @@ -1045,42 +1071,35 @@ async def parse_response( async def execute_command(self, *args: Any, **kwargs: Any) -> Any: # Acquire connection - connection = self.acquire_connection() + async with self.acquire_connection() as connection: - # Execute command - await connection.send_packed_command(connection.pack_command(*args), False) + # Execute command + await connection.send_packed_command(connection.pack_command(*args), False) - # Read response - try: + # Read response return await self.parse_response(connection, args[0], **kwargs) - finally: - # Release connection - self._free.append(connection) async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Acquire connection - connection = self.acquire_connection() + async with self.acquire_connection() as connection: - # Execute command - await connection.send_packed_command( - connection.pack_commands(cmd.args for cmd in commands), False - ) - - # Read responses - ret = False - for cmd in commands: - try: - cmd.result = await self.parse_response( - connection, cmd.args[0], **cmd.kwargs - ) - except Exception as e: - cmd.result = e - ret = True + # Execute command + await connection.send_packed_command( + connection.pack_commands(cmd.args for cmd in commands), False + ) - # Release connection - self._free.append(connection) + # Read responses + ret = False + for cmd in commands: + try: + cmd.result = await self.parse_response( + connection, cmd.args[0], **cmd.kwargs + ) + except Exception as e: + cmd.result = e + ret = True - return ret + return ret class NodesManager: diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index f3b76b80c9..823e71edaf 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -190,9 +190,12 @@ def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.return_value = response - while node._free: - node._free.pop() - node._free.append(connection) + while True: + try: + node._free.get_nowait() + except asyncio.QueueEmpty: + break + node._free.put_nowait(connection) return node @@ -200,9 +203,12 @@ def mock_node_resp_exc(node: ClusterNode, exc: Exception) -> ClusterNode: connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.side_effect = exc - while node._free: - node._free.pop() - node._free.append(connection) + while True: + try: + node._free.get_nowait() + except asyncio.QueueEmpty: + break + node._free.put_nowait(connection) return node @@ -360,20 +366,20 @@ async def test_cluster_set_get_retry_object(self, request: FixtureRequest): assert n_retry._retries == retry._retries assert isinstance(n_retry._backoff, NoBackoff) rand_cluster_node = r.get_random_node() - existing_conn = rand_cluster_node.acquire_connection() - # Change retry policy - new_retry = Retry(ExponentialBackoff(), 3) - r.set_retry(new_retry) - assert r.get_retry()._retries == new_retry._retries - assert isinstance(r.get_retry()._backoff, ExponentialBackoff) - for node in r.get_nodes(): - n_retry = node.connection_kwargs.get("retry") - assert n_retry is not None - assert n_retry._retries == new_retry._retries - assert isinstance(n_retry._backoff, ExponentialBackoff) - assert existing_conn.retry._retries == new_retry._retries - new_conn = rand_cluster_node.acquire_connection() - assert new_conn.retry._retries == new_retry._retries + async with rand_cluster_node.acquire_connection() as existing_conn: + # Change retry policy + new_retry = Retry(ExponentialBackoff(), 3) + r.set_retry(new_retry) + assert r.get_retry()._retries == new_retry._retries + assert isinstance(r.get_retry()._backoff, ExponentialBackoff) + for node in r.get_nodes(): + n_retry = node.connection_kwargs.get("retry") + assert n_retry is not None + assert n_retry._retries == new_retry._retries + assert isinstance(n_retry._backoff, ExponentialBackoff) + assert existing_conn.retry._retries == new_retry._retries + async with rand_cluster_node.acquire_connection() as new_conn: + assert new_conn.retry._retries == new_retry._retries async def test_cluster_retry_object(self, request: FixtureRequest) -> None: url = request.config.getoption("--redis-url") @@ -482,10 +488,10 @@ async def test_execute_command_node_flag_primaries(self, r: RedisCluster) -> Non mock_all_nodes_resp(r, "PONG") assert await r.ping(target_nodes=RedisCluster.PRIMARIES) is True for primary in primaries: - conn = primary._free.pop() + conn = primary._free.get_nowait() assert conn.read_response.called is True for replica in replicas: - conn = replica._free.pop() + conn = replica._free.get_nowait() assert conn.read_response.called is not True async def test_execute_command_node_flag_replicas(self, r: RedisCluster) -> None: @@ -499,10 +505,10 @@ async def test_execute_command_node_flag_replicas(self, r: RedisCluster) -> None mock_all_nodes_resp(r, "PONG") assert await r.ping(target_nodes=RedisCluster.REPLICAS) is True for replica in replicas: - conn = replica._free.pop() + conn = replica._free.get_nowait() assert conn.read_response.called is True for primary in primaries: - conn = primary._free.pop() + conn = primary._free.get_nowait() assert conn.read_response.called is not True await r.aclose() @@ -514,7 +520,7 @@ async def test_execute_command_node_flag_all_nodes(self, r: RedisCluster) -> Non mock_all_nodes_resp(r, "PONG") assert await r.ping(target_nodes=RedisCluster.ALL_NODES) is True for node in r.get_nodes(): - conn = node._free.pop() + conn = node._free.get_nowait() assert conn.read_response.called is True async def test_execute_command_node_flag_random(self, r: RedisCluster) -> None: @@ -525,7 +531,7 @@ async def test_execute_command_node_flag_random(self, r: RedisCluster) -> None: assert await r.ping(target_nodes=RedisCluster.RANDOM) is True called_count = 0 for node in r.get_nodes(): - conn = node._free.pop() + conn = node._free.get_nowait() if conn.read_response.called is True: called_count += 1 assert called_count == 1 @@ -538,7 +544,7 @@ async def test_execute_command_default_node(self, r: RedisCluster) -> None: def_node = r.get_default_node() mock_node_resp(def_node, "PONG") assert await r.ping() is True - conn = def_node._free.pop() + conn = def_node._free.get_nowait() assert conn.read_response.called async def test_ask_redirection(self, r: RedisCluster) -> None: @@ -1106,8 +1112,8 @@ async def test_cluster_delslots(self) -> None: node0 = r.get_node(default_host, 7000) node1 = r.get_node(default_host, 7001) assert await r.cluster_delslots(0, 8192) == [True, True] - assert node0._free.pop().read_response.called - assert node1._free.pop().read_response.called + assert node0._free.get_nowait().read_response.called + assert node1._free.get_nowait().read_response.called await r.aclose() @@ -1119,7 +1125,7 @@ async def test_cluster_delslotsrange(self): node = r.get_random_node() await r.cluster_addslots(node, 1, 2, 3, 4, 5) assert await r.cluster_delslotsrange(1, 5) - assert node._free.pop().read_response.called + assert node._free.get_nowait().read_response.called await r.aclose() @skip_if_redis_enterprise() @@ -1279,7 +1285,7 @@ async def test_cluster_setslot_stable(self, r: RedisCluster) -> None: node = r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, "OK") assert await r.cluster_setslot_stable(12182) is True - assert node._free.pop().read_response.called + assert node._free.get_nowait().read_response.called @skip_if_redis_enterprise() async def test_cluster_replicas(self, r: RedisCluster) -> None: @@ -1328,7 +1334,7 @@ async def test_readonly(self) -> None: for res in all_replicas_results.values(): assert res is True for replica in r.get_replicas(): - assert replica._free.pop().read_response.called + assert replica._free.get_nowait().read_response.called await r.aclose() @@ -1341,7 +1347,7 @@ async def test_readwrite(self) -> None: for res in all_replicas_results.values(): assert res is True for replica in r.get_replicas(): - assert replica._free.pop().read_response.called + assert replica._free.get_nowait().read_response.called await r.aclose() @@ -2800,8 +2806,8 @@ async def test_asking_error(self, r: RedisCluster) -> None: mock_node_resp_exc(first_node, AskError(ask_msg)) mock_node_resp(ask_node, "MOCK_OK") res = await pipe.get(key).execute() - assert first_node._free.pop().read_response.await_count - assert ask_node._free.pop().read_response.await_count + assert first_node._free.get_nowait().read_response.await_count + assert ask_node._free.get_nowait().read_response.await_count assert res == ["MOCK_OK"] @skip_if_server_version_gte("7.0.0") @@ -2857,7 +2863,7 @@ async def test_readonly_pipeline_from_readonly_client( executed_on_replica = False for node in slot_nodes: if node.server_type == REPLICA: - if node._free.pop().read_response.await_count: + if node._free.get_nowait().read_response.await_count: executed_on_replica = True break assert executed_on_replica