diff --git a/CHANGES b/CHANGES index 8750128b05..a5ce4fbe0e 100644 --- a/CHANGES +++ b/CHANGES @@ -67,6 +67,7 @@ * Close Unix sockets if the connection attempt fails. This prevents `ResourceWarning`s. (#3314) * Close SSL sockets if the connection attempt fails, or if validations fail. (#3317) * Eliminate mutable default arguments in the `redis.commands.core.Script` class. (#3332) + * Add command_timeout to async client. (#3436) * 4.1.3 (Feb 8, 2022) * Fix flushdb and flushall (#1926) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 9508849703..9a114b17f1 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -233,6 +233,7 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, + command_timeout: Optional[float] = None, ): """ Initialize a new Redis client. @@ -282,6 +283,7 @@ def __init__( "lib_version": lib_version, "redis_connect_func": redis_connect_func, "protocol": protocol, + "command_timeout": command_timeout, } # based on input, setup appropriate connection args if unix_socket_path is not None: diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 4e82e5448f..08baadd90b 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -270,6 +270,7 @@ def __init__( ssl_ciphers: Optional[str] = None, protocol: Optional[int] = 2, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, + command_timeout: Optional[float] = None, ) -> None: if db: raise RedisClusterException( @@ -311,6 +312,7 @@ def __init__( "socket_keepalive": socket_keepalive, "socket_keepalive_options": socket_keepalive_options, "socket_timeout": socket_timeout, + "command_timeout": command_timeout, "retry": retry, "protocol": protocol, } diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index ddbd22c95d..c1d8545a62 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -104,6 +104,7 @@ class AbstractConnection: "credential_provider", "password", "socket_timeout", + "command_timeout", "socket_connect_timeout", "redis_connect_func", "retry_on_timeout", @@ -148,6 +149,7 @@ def __init__( encoder_class: Type[Encoder] = Encoder, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, + command_timeout: Optional[float] = None, ): if (username or password) and credential_provider is not None: raise DataError( @@ -167,6 +169,7 @@ def __init__( if socket_connect_timeout is None: socket_connect_timeout = socket_timeout self.socket_connect_timeout = socket_connect_timeout + self.command_timeout = command_timeout self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: retry_on_error = [] @@ -206,6 +209,13 @@ def __init__( raise ConnectionError("protocol must be either 2 or 3") self.protocol = protocol + def _get_command_timeout(self, timeout: Optional[float] = None): + if timeout is not None: + return timeout + if self.command_timeout is not None: + return self.command_timeout + return self.socket_timeout + def __del__(self, _warnings: Any = warnings): # For some reason, the individual streams don't get properly garbage # collected and therefore produce no resource warnings. We add one @@ -466,10 +476,9 @@ async def send_packed_command( command = command.encode() if isinstance(command, bytes): command = [command] - if self.socket_timeout: - await asyncio.wait_for( - self._send_packed_command(command), self.socket_timeout - ) + timeout = self._get_command_timeout() + if timeout: + await asyncio.wait_for(self._send_packed_command(command), timeout) else: self._writer.writelines(command) await self._writer.drain() @@ -518,7 +527,7 @@ async def read_response( push_request: Optional[bool] = False, ): """Read the response from a previously sent command""" - read_timeout = timeout if timeout is not None else self.socket_timeout + read_timeout = self._get_command_timeout(timeout) host_error = self._host_error() try: if ( diff --git a/redis/cluster.py b/redis/cluster.py index 9dcbad7fc1..bbb778b581 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -140,6 +140,7 @@ def parse_cluster_myshardid(resp, **options): "credential_provider", "db", "decode_responses", + "command_timeout", "encoding", "encoding_errors", "errors",