diff --git a/redis/commands/core.py b/redis/commands/core.py index b356d101ee..6f1d72796d 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -1,12 +1,10 @@ -# from __future__ import annotations - import datetime import hashlib import warnings from typing import ( TYPE_CHECKING, + Any, AsyncIterator, - Awaitable, Callable, Dict, Iterable, @@ -16,16 +14,18 @@ Mapping, Optional, Sequence, - Set, Tuple, Union, ) from redis.exceptions import ConnectionError, DataError, NoScriptError, RedisError from redis.typing import ( + OKT, AbsExpiryT, AnyKeyT, + ArrayResponseT, BitfieldOffsetT, + BulkStringResponseT, ChannelT, CommandsProtocol, ConsumerT, @@ -33,6 +33,7 @@ ExpiryT, FieldT, GroupT, + IntegerResponseT, KeysT, KeyT, PatternT, @@ -56,7 +57,9 @@ class ACLCommands(CommandsProtocol): see: https://redis.io/topics/acl """ - def acl_cat(self, category: Union[str, None] = None, **kwargs) -> ResponseT: + def acl_cat( + self, category: Union[str, None] = None, **kwargs + ) -> ResponseT[ArrayResponseT]: """ Returns a list of categories or commands within a category. @@ -66,10 +69,12 @@ def acl_cat(self, category: Union[str, None] = None, **kwargs) -> ResponseT: For more information see https://redis.io/commands/acl-cat """ - pieces: list[EncodableT] = [category] if category else [] + pieces: List[EncodableT] = [category] if category else [] return self.execute_command("ACL CAT", *pieces, **kwargs) - def acl_dryrun(self, username, *args, **kwargs): + def acl_dryrun( + self, username, *args, **kwargs + ) -> ResponseT[Union[BulkStringResponseT, OKT]]: """ Simulate the execution of a given command by a given ``username``. @@ -77,7 +82,7 @@ def acl_dryrun(self, username, *args, **kwargs): """ return self.execute_command("ACL DRYRUN", username, *args, **kwargs) - def acl_deluser(self, *username: str, **kwargs) -> ResponseT: + def acl_deluser(self, *username: str, **kwargs) -> ResponseT[IntegerResponseT]: """ Delete the ACL for the specified ``username``\\s @@ -85,7 +90,9 @@ def acl_deluser(self, *username: str, **kwargs) -> ResponseT: """ return self.execute_command("ACL DELUSER", *username, **kwargs) - def acl_genpass(self, bits: Union[int, None] = None, **kwargs) -> ResponseT: + def acl_genpass( + self, bits: Union[int, None] = None, **kwargs + ) -> ResponseT[BulkStringResponseT]: """Generate a random password value. If ``bits`` is supplied then use this number of bits, rounded to the next multiple of 4. @@ -104,7 +111,9 @@ def acl_genpass(self, bits: Union[int, None] = None, **kwargs) -> ResponseT: ) return self.execute_command("ACL GENPASS", *pieces, **kwargs) - def acl_getuser(self, username: str, **kwargs) -> ResponseT: + def acl_getuser( + self, username: str, **kwargs + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Get the ACL details for the specified ``username``. @@ -114,7 +123,7 @@ def acl_getuser(self, username: str, **kwargs) -> ResponseT: """ return self.execute_command("ACL GETUSER", username, **kwargs) - def acl_help(self, **kwargs) -> ResponseT: + def acl_help(self, **kwargs) -> ResponseT[ArrayResponseT]: """The ACL HELP command returns helpful text describing the different subcommands. @@ -122,7 +131,7 @@ def acl_help(self, **kwargs) -> ResponseT: """ return self.execute_command("ACL HELP", **kwargs) - def acl_list(self, **kwargs) -> ResponseT: + def acl_list(self, **kwargs) -> ResponseT[ArrayResponseT]: """ Return a list of all ACLs on the server @@ -130,7 +139,9 @@ def acl_list(self, **kwargs) -> ResponseT: """ return self.execute_command("ACL LIST", **kwargs) - def acl_log(self, count: Union[int, None] = None, **kwargs) -> ResponseT: + def acl_log( + self, count: Union[int, None] = None, **kwargs + ) -> ResponseT[Union[ArrayResponseT, OKT]]: """ Get ACL logs as a list. :param int count: Get logs[0:count]. @@ -143,10 +154,9 @@ def acl_log(self, count: Union[int, None] = None, **kwargs) -> ResponseT: if not isinstance(count, int): raise DataError("ACL LOG count must be an integer") args.append(count) - return self.execute_command("ACL LOG", *args, **kwargs) - def acl_log_reset(self, **kwargs) -> ResponseT: + def acl_log_reset(self, **kwargs) -> ResponseT[Union[ArrayResponseT, OKT]]: """ Reset ACL logs. :rtype: Boolean. @@ -156,7 +166,7 @@ def acl_log_reset(self, **kwargs) -> ResponseT: args = [b"RESET"] return self.execute_command("ACL LOG", *args, **kwargs) - def acl_load(self, **kwargs) -> ResponseT: + def acl_load(self, **kwargs) -> ResponseT[OKT]: """ Load ACL rules from the configured ``aclfile``. @@ -167,7 +177,7 @@ def acl_load(self, **kwargs) -> ResponseT: """ return self.execute_command("ACL LOAD", **kwargs) - def acl_save(self, **kwargs) -> ResponseT: + def acl_save(self, **kwargs) -> ResponseT[OKT]: """ Save ACL rules to the configured ``aclfile``. @@ -195,7 +205,7 @@ def acl_setuser( reset_channels: bool = False, reset_passwords: bool = False, **kwargs, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Create or update an ACL user. @@ -255,29 +265,22 @@ def acl_setuser( """ encoder = self.get_encoder() pieces: List[EncodableT] = [username] - if reset: pieces.append(b"reset") - if reset_keys: pieces.append(b"resetkeys") - if reset_channels: pieces.append(b"resetchannels") - if reset_passwords: pieces.append(b"resetpass") - if enabled: pieces.append(b"on") else: pieces.append(b"off") - if (passwords or hashed_passwords) and nopass: raise DataError( "Cannot set 'nopass' and supply 'passwords' or 'hashed_passwords'" ) - if passwords: # as most users will have only one password, allow remove_passwords # to be specified as a simple string or a list @@ -293,7 +296,6 @@ def acl_setuser( f"Password {i} must be prefixed with a " f'"+" to add or a "-" to remove' ) - if hashed_passwords: # as most users will have only one password, allow remove_passwords # to be specified as a simple string or a list @@ -309,10 +311,8 @@ def acl_setuser( f"Hashed password {i} must be prefixed with a " f'"+" to add or a "-" to remove' ) - if nopass: pieces.append(b"nopass") - if categories: for category in categories: category = encoder.encode(category) @@ -339,19 +339,16 @@ def acl_setuser( 'must be prefixed with "+" or "-"' ) pieces.append(cmd) - if keys: for key in keys: key = encoder.encode(key) if not key.startswith(b"%") and not key.startswith(b"~"): key = b"~%s" % key pieces.append(key) - if channels: for channel in channels: channel = encoder.encode(channel) pieces.append(b"&%s" % channel) - if selectors: for cmd, key in selectors: cmd = encoder.encode(cmd) @@ -360,23 +357,20 @@ def acl_setuser( f'Command "{encoder.decode(cmd, force=True)}" ' 'must be prefixed with "+" or "-"' ) - key = encoder.encode(key) if not key.startswith(b"%") and not key.startswith(b"~"): key = b"~%s" % key - pieces.append(b"(%s %s)" % (cmd, key)) - return self.execute_command("ACL SETUSER", *pieces, **kwargs) - def acl_users(self, **kwargs) -> ResponseT: + def acl_users(self, **kwargs) -> ResponseT[ArrayResponseT]: """Returns a list of all registered users on the server. For more information see https://redis.io/commands/acl-users """ return self.execute_command("ACL USERS", **kwargs) - def acl_whoami(self, **kwargs) -> ResponseT: + def acl_whoami(self, **kwargs) -> ResponseT[BulkStringResponseT]: """Get the username for the current connection For more information see https://redis.io/commands/acl-whoami @@ -392,7 +386,9 @@ class ManagementCommands(CommandsProtocol): Redis management commands """ - def auth(self, password: str, username: Optional[str] = None, **kwargs): + def auth( + self, password: str, username: Optional[str] = None, **kwargs + ) -> ResponseT[OKT]: """ Authenticates the user. If you do not pass username, Redis will try to authenticate for the "default" user. If you do pass username, it will @@ -405,14 +401,14 @@ def auth(self, password: str, username: Optional[str] = None, **kwargs): pieces.append(password) return self.execute_command("AUTH", *pieces, **kwargs) - def bgrewriteaof(self, **kwargs): + def bgrewriteaof(self, **kwargs) -> ResponseT[str]: """Tell the Redis server to rewrite the AOF file from data in memory. For more information see https://redis.io/commands/bgrewriteaof """ return self.execute_command("BGREWRITEAOF", **kwargs) - def bgsave(self, schedule: bool = True, **kwargs) -> ResponseT: + def bgsave(self, schedule: bool = True, **kwargs) -> ResponseT[str]: """ Tell the Redis server to save its data to disk. Unlike save(), this method is asynchronous and returns immediately. @@ -424,7 +420,7 @@ def bgsave(self, schedule: bool = True, **kwargs) -> ResponseT: pieces.append("SCHEDULE") return self.execute_command("BGSAVE", *pieces, **kwargs) - def role(self) -> ResponseT: + def role(self) -> ResponseT[ArrayResponseT]: """ Provide information on the role of a Redis instance in the context of replication, by returning if the instance @@ -434,7 +430,9 @@ def role(self) -> ResponseT: """ return self.execute_command("ROLE") - def client_kill(self, address: str, **kwargs) -> ResponseT: + def client_kill( + self, address: str, **kwargs + ) -> ResponseT[Union[IntegerResponseT, OKT]]: """Disconnects the client at ``address`` (ip:port) For more information see https://redis.io/commands/client-kill @@ -448,10 +446,10 @@ def client_kill_filter( addr: Union[str, None] = None, skipme: Union[bool, None] = None, laddr: Union[bool, None] = None, - user: str = None, + user: Optional[str] = None, maxage: Union[int, None] = None, **kwargs, - ) -> ResponseT: + ) -> ResponseT[Union[IntegerResponseT, OKT]]: """ Disconnects client(s) using a variety of filter options :param _id: Kills a client by its unique ID field @@ -495,7 +493,7 @@ def client_kill_filter( ) return self.execute_command("CLIENT KILL", *args, **kwargs) - def client_info(self, **kwargs) -> ResponseT: + def client_info(self, **kwargs) -> ResponseT[BulkStringResponseT]: """ Returns information and statistics about the current client connection. @@ -506,7 +504,7 @@ def client_info(self, **kwargs) -> ResponseT: def client_list( self, _type: Union[str, None] = None, client_id: List[EncodableT] = [], **kwargs - ) -> ResponseT: + ) -> ResponseT[BulkStringResponseT]: """ Returns a list of currently connected clients. If type of client specified, only that type will be returned. @@ -531,7 +529,7 @@ def client_list( args.append(" ".join(client_id)) return self.execute_command("CLIENT LIST", *args, **kwargs) - def client_getname(self, **kwargs) -> ResponseT: + def client_getname(self, **kwargs) -> ResponseT[Union[BulkStringResponseT, None]]: """ Returns the current connection name @@ -539,7 +537,7 @@ def client_getname(self, **kwargs) -> ResponseT: """ return self.execute_command("CLIENT GETNAME", **kwargs) - def client_getredir(self, **kwargs) -> ResponseT: + def client_getredir(self, **kwargs) -> ResponseT[IntegerResponseT]: """ Returns the ID (an integer) of the client to whom we are redirecting tracking notifications. @@ -550,7 +548,7 @@ def client_getredir(self, **kwargs) -> ResponseT: def client_reply( self, reply: Union[Literal["ON"], Literal["OFF"], Literal["SKIP"]], **kwargs - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Enable and disable redis server replies. @@ -572,7 +570,7 @@ def client_reply( raise DataError(f"CLIENT REPLY must be one of {replies!r}") return self.execute_command("CLIENT REPLY", reply, **kwargs) - def client_id(self, **kwargs) -> ResponseT: + def client_id(self, **kwargs) -> ResponseT[IntegerResponseT]: """ Returns the current connection id @@ -588,7 +586,7 @@ def client_tracking_on( optin: bool = False, optout: bool = False, noloop: bool = False, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Turn on the tracking mode. For more information about the options look at client_tracking func. @@ -607,7 +605,7 @@ def client_tracking_off( optin: bool = False, optout: bool = False, noloop: bool = False, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Turn off the tracking mode. For more information about the options look at client_tracking func. @@ -628,7 +626,7 @@ def client_tracking( optout: bool = False, noloop: bool = False, **kwargs, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Enables the tracking feature of the Redis server, that is used for server assisted client side caching. @@ -658,10 +656,8 @@ def client_tracking( See https://redis.io/commands/client-tracking """ - if len(prefix) != 0 and bcast is False: raise DataError("Prefix can only be used with bcast") - pieces = ["ON"] if on else ["OFF"] if clientid is not None: pieces.extend(["REDIRECT", clientid]) @@ -675,10 +671,9 @@ def client_tracking( pieces.append("OPTOUT") if noloop: pieces.append("NOLOOP") - return self.execute_command("CLIENT TRACKING", *pieces) - def client_trackinginfo(self, **kwargs) -> ResponseT: + def client_trackinginfo(self, **kwargs) -> ResponseT[ArrayResponseT]: """ Returns the information about the current client connection's use of the server assisted client side cache. @@ -687,7 +682,7 @@ def client_trackinginfo(self, **kwargs) -> ResponseT: """ return self.execute_command("CLIENT TRACKINGINFO", **kwargs) - def client_setname(self, name: str, **kwargs) -> ResponseT: + def client_setname(self, name: str, **kwargs) -> ResponseT[OKT]: """ Sets the current connection name @@ -701,7 +696,7 @@ def client_setname(self, name: str, **kwargs) -> ResponseT: """ return self.execute_command("CLIENT SETNAME", name, **kwargs) - def client_setinfo(self, attr: str, value: str, **kwargs) -> ResponseT: + def client_setinfo(self, attr: str, value: str, **kwargs) -> ResponseT[OKT]: """ Sets the current connection library name or version For mor information see https://redis.io/commands/client-setinfo @@ -710,7 +705,7 @@ def client_setinfo(self, attr: str, value: str, **kwargs) -> ResponseT: def client_unblock( self, client_id: int, error: bool = False, **kwargs - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Unblocks a connection by its client id. If ``error`` is True, unblocks the client with a special error message. @@ -724,7 +719,7 @@ def client_unblock( args.append(b"ERROR") return self.execute_command(*args, **kwargs) - def client_pause(self, timeout: int, all: bool = True, **kwargs) -> ResponseT: + def client_pause(self, timeout: int, all: bool = True, **kwargs) -> ResponseT[OKT]: """ Suspend all the Redis clients for the specified amount of time. @@ -749,7 +744,7 @@ def client_pause(self, timeout: int, all: bool = True, **kwargs) -> ResponseT: args.append("WRITE") return self.execute_command(*args, **kwargs) - def client_unpause(self, **kwargs) -> ResponseT: + def client_unpause(self, **kwargs) -> ResponseT[OKT]: """ Unpause all redis clients @@ -757,7 +752,7 @@ def client_unpause(self, **kwargs) -> ResponseT: """ return self.execute_command("CLIENT UNPAUSE", **kwargs) - def client_no_evict(self, mode: str) -> Union[Awaitable[str], str]: + def client_no_evict(self, mode: str) -> ResponseT[OKT]: """ Sets the client eviction mode for the current connection. @@ -765,7 +760,7 @@ def client_no_evict(self, mode: str) -> Union[Awaitable[str], str]: """ return self.execute_command("CLIENT NO-EVICT", mode) - def client_no_touch(self, mode: str) -> Union[Awaitable[str], str]: + def client_no_touch(self, mode: str) -> ResponseT[OKT]: """ # The command controls whether commands sent by the client will alter # the LRU/LFU of the keys they access. @@ -776,7 +771,7 @@ def client_no_touch(self, mode: str) -> Union[Awaitable[str], str]: """ return self.execute_command("CLIENT NO-TOUCH", mode) - def command(self, **kwargs): + def command(self, **kwargs) -> ResponseT[ArrayResponseT]: """ Returns dict reply of details about all Redis commands. @@ -789,7 +784,7 @@ def command_info(self, **kwargs) -> None: "COMMAND INFO is intentionally not implemented in the client." ) - def command_count(self, **kwargs) -> ResponseT: + def command_count(self, **kwargs) -> ResponseT[IntegerResponseT]: return self.execute_command("COMMAND COUNT", **kwargs) def command_list( @@ -797,7 +792,7 @@ def command_list( module: Optional[str] = None, category: Optional[str] = None, pattern: Optional[str] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Return an array of the server's command names. You can use one of the following filters: @@ -814,13 +809,11 @@ def command_list( pieces.extend(["ACLCAT", category]) if pattern is not None: pieces.extend(["PATTERN", pattern]) - if pieces: pieces.insert(0, "FILTERBY") - return self.execute_command("COMMAND LIST", *pieces) - def command_getkeysandflags(self, *args: List[str]) -> List[Union[str, List[str]]]: + def command_getkeysandflags(self, *args: List[str]) -> ResponseT[ArrayResponseT]: """ Returns array of keys from a full Redis command and their usage flags. @@ -828,7 +821,7 @@ def command_getkeysandflags(self, *args: List[str]) -> List[Union[str, List[str] """ return self.execute_command("COMMAND GETKEYSANDFLAGS", *args) - def command_docs(self, *args): + def command_docs(self, *args) -> None: """ This function throws a NotImplementedError since it is intentionally not supported. @@ -839,7 +832,7 @@ def command_docs(self, *args): def config_get( self, pattern: PatternT = "*", *args: List[PatternT], **kwargs - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Return a dictionary of configuration based on the ``pattern`` @@ -853,14 +846,14 @@ def config_set( value: EncodableT, *args: List[Union[KeyT, EncodableT]], **kwargs, - ) -> ResponseT: + ) -> ResponseT[OKT]: """Set config item ``name`` with ``value`` For more information see https://redis.io/commands/config-set """ return self.execute_command("CONFIG SET", name, value, *args, **kwargs) - def config_resetstat(self, **kwargs) -> ResponseT: + def config_resetstat(self, **kwargs) -> ResponseT[OKT]: """ Reset runtime statistics @@ -868,7 +861,7 @@ def config_resetstat(self, **kwargs) -> ResponseT: """ return self.execute_command("CONFIG RESETSTAT", **kwargs) - def config_rewrite(self, **kwargs) -> ResponseT: + def config_rewrite(self, **kwargs) -> ResponseT[OKT]: """ Rewrite config file with the minimal change to reflect running config. @@ -876,7 +869,7 @@ def config_rewrite(self, **kwargs) -> ResponseT: """ return self.execute_command("CONFIG REWRITE", **kwargs) - def dbsize(self, **kwargs) -> ResponseT: + def dbsize(self, **kwargs) -> ResponseT[IntegerResponseT]: """ Returns the number of keys in the current database @@ -884,7 +877,7 @@ def dbsize(self, **kwargs) -> ResponseT: """ return self.execute_command("DBSIZE", **kwargs) - def debug_object(self, key: KeyT, **kwargs) -> ResponseT: + def debug_object(self, key: KeyT, **kwargs) -> ResponseT[Any]: """ Returns version specific meta information about a given key @@ -901,7 +894,7 @@ def debug_segfault(self, **kwargs) -> None: """ ) - def echo(self, value: EncodableT, **kwargs) -> ResponseT: + def echo(self, value: EncodableT, **kwargs) -> ResponseT[BulkStringResponseT]: """ Echo the string back from the server @@ -909,7 +902,7 @@ def echo(self, value: EncodableT, **kwargs) -> ResponseT: """ return self.execute_command("ECHO", value, **kwargs) - def flushall(self, asynchronous: bool = False, **kwargs) -> ResponseT: + def flushall(self, asynchronous: bool = False, **kwargs) -> ResponseT[OKT]: """ Delete all keys in all databases on the current host. @@ -923,7 +916,7 @@ def flushall(self, asynchronous: bool = False, **kwargs) -> ResponseT: args.append(b"ASYNC") return self.execute_command("FLUSHALL", *args, **kwargs) - def flushdb(self, asynchronous: bool = False, **kwargs) -> ResponseT: + def flushdb(self, asynchronous: bool = False, **kwargs) -> ResponseT[OKT]: """ Delete all keys in the current database. @@ -937,7 +930,7 @@ def flushdb(self, asynchronous: bool = False, **kwargs) -> ResponseT: args.append(b"ASYNC") return self.execute_command("FLUSHDB", *args, **kwargs) - def sync(self) -> ResponseT: + def sync(self) -> ResponseT[Any]: """ Initiates a replication stream from the master. @@ -949,7 +942,7 @@ def sync(self) -> ResponseT: options[NEVER_DECODE] = [] return self.execute_command("SYNC", **options) - def psync(self, replicationid: str, offset: int): + def psync(self, replicationid: str, offset: int) -> ResponseT[Any]: """ Initiates a replication stream from the master. Newer version for `sync`. @@ -962,7 +955,7 @@ def psync(self, replicationid: str, offset: int): options[NEVER_DECODE] = [] return self.execute_command("PSYNC", replicationid, offset, **options) - def swapdb(self, first: int, second: int, **kwargs) -> ResponseT: + def swapdb(self, first: int, second: int, **kwargs) -> ResponseT[OKT]: """ Swap two databases @@ -970,7 +963,7 @@ def swapdb(self, first: int, second: int, **kwargs) -> ResponseT: """ return self.execute_command("SWAPDB", first, second, **kwargs) - def select(self, index: int, **kwargs) -> ResponseT: + def select(self, index: int, **kwargs) -> ResponseT[OKT]: """Select the Redis logical database at index. See: https://redis.io/commands/select @@ -979,7 +972,7 @@ def select(self, index: int, **kwargs) -> ResponseT: def info( self, section: Union[str, None] = None, *args: List[str], **kwargs - ) -> ResponseT: + ) -> ResponseT[BulkStringResponseT]: """ Returns a dictionary containing information about the Redis server @@ -996,7 +989,7 @@ def info( else: return self.execute_command("INFO", section, *args, **kwargs) - def lastsave(self, **kwargs) -> ResponseT: + def lastsave(self, **kwargs) -> ResponseT[IntegerResponseT]: """ Return a Python datetime object representing the last time the Redis database was saved to disk @@ -1005,7 +998,7 @@ def lastsave(self, **kwargs) -> ResponseT: """ return self.execute_command("LASTSAVE", **kwargs) - def latency_doctor(self): + def latency_doctor(self) -> None: """Raise a NotImplementedError, as the client will not support LATENCY DOCTOR. This funcion is best used within the redis-cli. @@ -1019,7 +1012,7 @@ def latency_doctor(self): """ ) - def latency_graph(self): + def latency_graph(self) -> None: """Raise a NotImplementedError, as the client will not support LATENCY GRAPH. This funcion is best used within the redis-cli. @@ -1033,7 +1026,9 @@ def latency_graph(self): """ ) - def lolwut(self, *version_numbers: Union[str, float], **kwargs) -> ResponseT: + def lolwut( + self, *version_numbers: Union[str, float], **kwargs + ) -> ResponseT[BulkStringResponseT]: """ Get the Redis version and a piece of generative computer art @@ -1044,7 +1039,7 @@ def lolwut(self, *version_numbers: Union[str, float], **kwargs) -> ResponseT: else: return self.execute_command("LOLWUT", **kwargs) - def reset(self) -> ResponseT: + def reset(self) -> ResponseT[str]: """Perform a full reset on the connection's server side contenxt. See: https://redis.io/commands/reset @@ -1062,7 +1057,7 @@ def migrate( replace: bool = False, auth: Union[str, None] = None, **kwargs, - ) -> ResponseT: + ) -> ResponseT[Union[str, OKT]]: """ Migrate 1 or more keys from the current Redis server to a different server specified by the ``host``, ``port`` and ``destination_db``. @@ -1099,7 +1094,7 @@ def migrate( "MIGRATE", host, port, "", destination_db, timeout, *pieces, **kwargs ) - def object(self, infotype: str, key: KeyT, **kwargs) -> ResponseT: + def object(self, infotype: str, key: KeyT, **kwargs) -> ResponseT[Any]: """ Return the encoding, idletime, or refcount about the key """ @@ -1125,7 +1120,7 @@ def memory_help(self, **kwargs) -> None: """ ) - def memory_stats(self, **kwargs) -> ResponseT: + def memory_stats(self, **kwargs) -> ResponseT[ArrayResponseT]: """ Return a dictionary of memory stats @@ -1133,7 +1128,7 @@ def memory_stats(self, **kwargs) -> ResponseT: """ return self.execute_command("MEMORY STATS", **kwargs) - def memory_malloc_stats(self, **kwargs) -> ResponseT: + def memory_malloc_stats(self, **kwargs) -> ResponseT[BulkStringResponseT]: """ Return an internal statistics report from the memory allocator. @@ -1143,7 +1138,7 @@ def memory_malloc_stats(self, **kwargs) -> ResponseT: def memory_usage( self, key: KeyT, samples: Union[int, None] = None, **kwargs - ) -> ResponseT: + ) -> ResponseT[Union[IntegerResponseT, None]]: """ Return the total memory usage for key, its value and associated administrative overheads. @@ -1159,7 +1154,7 @@ def memory_usage( args.extend([b"SAMPLES", samples]) return self.execute_command("MEMORY USAGE", key, *args, **kwargs) - def memory_purge(self, **kwargs) -> ResponseT: + def memory_purge(self, **kwargs) -> ResponseT[OKT]: """ Attempts to purge dirty pages for reclamation by allocator @@ -1167,7 +1162,7 @@ def memory_purge(self, **kwargs) -> ResponseT: """ return self.execute_command("MEMORY PURGE", **kwargs) - def latency_histogram(self, *args): + def latency_histogram(self, *args) -> None: """ This function throws a NotImplementedError since it is intentionally not supported. @@ -1176,7 +1171,7 @@ def latency_histogram(self, *args): "LATENCY HISTOGRAM is intentionally not implemented in the client." ) - def latency_history(self, event: str) -> ResponseT: + def latency_history(self, event: str) -> ResponseT[ArrayResponseT]: """ Returns the raw data of the ``event``'s latency spikes time series. @@ -1184,7 +1179,7 @@ def latency_history(self, event: str) -> ResponseT: """ return self.execute_command("LATENCY HISTORY", event) - def latency_latest(self) -> ResponseT: + def latency_latest(self) -> ResponseT[ArrayResponseT]: """ Reports the latest latency events logged. @@ -1192,7 +1187,7 @@ def latency_latest(self) -> ResponseT: """ return self.execute_command("LATENCY LATEST") - def latency_reset(self, *events: str) -> ResponseT: + def latency_reset(self, *events: str) -> ResponseT[IntegerResponseT]: """ Resets the latency spikes time series of all, or only some, events. @@ -1200,7 +1195,7 @@ def latency_reset(self, *events: str) -> ResponseT: """ return self.execute_command("LATENCY RESET", *events) - def ping(self, **kwargs) -> ResponseT: + def ping(self, **kwargs) -> ResponseT[Union[BulkStringResponseT, str]]: """ Ping the Redis server @@ -1208,7 +1203,7 @@ def ping(self, **kwargs) -> ResponseT: """ return self.execute_command("PING", **kwargs) - def quit(self, **kwargs) -> ResponseT: + def quit(self, **kwargs) -> ResponseT[OKT]: """ Ask the server to close the connection. @@ -1216,7 +1211,7 @@ def quit(self, **kwargs) -> ResponseT: """ return self.execute_command("QUIT", **kwargs) - def replicaof(self, *args, **kwargs) -> ResponseT: + def replicaof(self, *args, **kwargs) -> ResponseT[OKT]: """ Update the replication settings of a redis replica, on the fly. @@ -1229,7 +1224,7 @@ def replicaof(self, *args, **kwargs) -> ResponseT: """ return self.execute_command("REPLICAOF", *args, **kwargs) - def save(self, **kwargs) -> ResponseT: + def save(self, **kwargs) -> ResponseT[OKT]: """ Tell the Redis server to save its data to disk, blocking until the save is complete @@ -1277,12 +1272,12 @@ def shutdown( self.execute_command(*args, **kwargs) except ConnectionError: # a ConnectionError here is expected - return + return None raise RedisError("SHUTDOWN seems to have failed.") def slaveof( self, host: Union[str, None] = None, port: Union[int, None] = None, **kwargs - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Set the server to be a replicated slave of the instance identified by the ``host`` and ``port``. If called without arguments, the @@ -1294,7 +1289,9 @@ def slaveof( return self.execute_command("SLAVEOF", b"NO", b"ONE", **kwargs) return self.execute_command("SLAVEOF", host, port, **kwargs) - def slowlog_get(self, num: Union[int, None] = None, **kwargs) -> ResponseT: + def slowlog_get( + self, num: Union[int, None] = None, **kwargs + ) -> ResponseT[ArrayResponseT]: """ Get the entries from the slowlog. If ``num`` is specified, get the most recent ``num`` items. @@ -1311,7 +1308,7 @@ def slowlog_get(self, num: Union[int, None] = None, **kwargs) -> ResponseT: kwargs[NEVER_DECODE] = [] return self.execute_command(*args, **kwargs) - def slowlog_len(self, **kwargs) -> ResponseT: + def slowlog_len(self, **kwargs) -> ResponseT[IntegerResponseT]: """ Get the number of items in the slowlog @@ -1319,7 +1316,7 @@ def slowlog_len(self, **kwargs) -> ResponseT: """ return self.execute_command("SLOWLOG LEN", **kwargs) - def slowlog_reset(self, **kwargs) -> ResponseT: + def slowlog_reset(self, **kwargs) -> ResponseT[OKT]: """ Remove all items in the slowlog @@ -1327,7 +1324,7 @@ def slowlog_reset(self, **kwargs) -> ResponseT: """ return self.execute_command("SLOWLOG RESET", **kwargs) - def time(self, **kwargs) -> ResponseT: + def time(self, **kwargs) -> ResponseT[ArrayResponseT]: """ Returns the server time as a 2-item tuple of ints: (seconds since epoch, microseconds into this second). @@ -1336,7 +1333,9 @@ def time(self, **kwargs) -> ResponseT: """ return self.execute_command("TIME", **kwargs) - def wait(self, num_replicas: int, timeout: int, **kwargs) -> ResponseT: + def wait( + self, num_replicas: int, timeout: int, **kwargs + ) -> ResponseT[IntegerResponseT]: """ Redis synchronous replication That returns the number of replicas that processed the query when @@ -1349,7 +1348,7 @@ def wait(self, num_replicas: int, timeout: int, **kwargs) -> ResponseT: def waitaof( self, num_local: int, num_replicas: int, timeout: int, **kwargs - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ This command blocks the current client until all previous write commands by that client are acknowledged as having been fsynced @@ -1362,7 +1361,7 @@ def waitaof( "WAITAOF", num_local, num_replicas, timeout, **kwargs ) - def hello(self): + def hello(self) -> None: """ This function throws a NotImplementedError since it is intentionally not supported. @@ -1371,7 +1370,7 @@ def hello(self): "HELLO is intentionally not implemented in the client." ) - def failover(self): + def failover(self) -> None: """ This function throws a NotImplementedError since it is intentionally not supported. @@ -1382,6 +1381,7 @@ def failover(self): class AsyncManagementCommands(ManagementCommands): + async def command_info(self, **kwargs) -> None: return super().command_info(**kwargs) @@ -1428,7 +1428,7 @@ async def shutdown( await self.execute_command(*args, **kwargs) except ConnectionError: # a ConnectionError here is expected - return + return None raise RedisError("SHUTDOWN seems to have failed.") @@ -1447,7 +1447,7 @@ def __init__( self.key = key self._default_overflow = default_overflow # for typing purposes, run the following in constructor and in reset() - self.operations: list[tuple[EncodableT, ...]] = [] + self.operations: List[tuple[EncodableT, ...]] = [] self._last_overflow = "WRAP" self.reset() @@ -1494,7 +1494,6 @@ def incrby( """ if overflow is not None: self.overflow(overflow) - self.operations.append(("INCRBY", fmt, offset, increment)) return self @@ -1532,7 +1531,7 @@ def command(self): cmd.extend(ops) return cmd - def execute(self) -> ResponseT: + def execute(self) -> ResponseT[Any]: """ Execute the operation(s) in a single BITFIELD command. The return value is a list of values corresponding to each operation. If the client @@ -1549,7 +1548,7 @@ class BasicKeyCommands(CommandsProtocol): Redis basic key-based commands """ - def append(self, key: KeyT, value: EncodableT) -> ResponseT: + def append(self, key: KeyT, value: EncodableT) -> ResponseT[IntegerResponseT]: """ Appends the string ``value`` to the value at ``key``. If ``key`` doesn't already exist, create it with a value of ``value``. @@ -1565,7 +1564,7 @@ def bitcount( start: Union[int, None] = None, end: Union[int, None] = None, mode: Optional[str] = None, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Returns the count of set bits in the value of ``key``. Optional ``start`` and ``end`` parameters indicate which bytes to consider @@ -1600,8 +1599,8 @@ def bitfield_ro( key: KeyT, encoding: str, offset: BitfieldOffsetT, - items: Optional[list] = None, - ) -> ResponseT: + items: Optional[List] = None, + ) -> ResponseT[ArrayResponseT]: """ Return an array of the specified bitfield values where the first value is found using ``encoding`` and ``offset`` @@ -1612,13 +1611,14 @@ def bitfield_ro( For more information see https://redis.io/commands/bitfield_ro """ params = [key, "GET", encoding, offset] - items = items or [] for encoding, offset in items: params.extend(["GET", encoding, offset]) return self.execute_command("BITFIELD_RO", *params, keys=[key]) - def bitop(self, operation: str, dest: KeyT, *keys: KeyT) -> ResponseT: + def bitop( + self, operation: str, dest: KeyT, *keys: KeyT + ) -> ResponseT[IntegerResponseT]: """ Perform a bitwise operation using ``operation`` between ``keys`` and store the result in ``dest``. @@ -1634,7 +1634,7 @@ def bitpos( start: Union[int, None] = None, end: Union[int, None] = None, mode: Optional[str] = None, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Return the position of the first bit set to 1 or 0 in a string. ``start`` and ``end`` defines search range. The range is interpreted @@ -1646,14 +1646,12 @@ def bitpos( if bit not in (0, 1): raise DataError("bit must be 0 or 1") params = [key, bit] - - start is not None and params.append(start) - + if start is not None: + params.append(start) if start is not None and end is not None: params.append(end) elif start is None and end is not None: raise DataError("start argument is not set, when end is specified") - if mode is not None: params.append(mode) return self.execute_command("BITPOS", *params, keys=[key]) @@ -1664,7 +1662,7 @@ def copy( destination: str, destination_db: Union[str, None] = None, replace: bool = False, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Copy the value stored in the ``source`` key to the ``destination`` key. @@ -1684,7 +1682,7 @@ def copy( params.append("REPLACE") return self.execute_command("COPY", *params) - def decrby(self, name: KeyT, amount: int = 1) -> ResponseT: + def decrby(self, name: KeyT, amount: int = 1) -> ResponseT[IntegerResponseT]: """ Decrements the value of ``key`` by ``amount``. If no key exists, the value will be initialized as 0 - ``amount`` @@ -1695,16 +1693,16 @@ def decrby(self, name: KeyT, amount: int = 1) -> ResponseT: decr = decrby - def delete(self, *names: KeyT) -> ResponseT: + def delete(self, *names: KeyT) -> ResponseT[IntegerResponseT]: """ Delete one or more keys specified by ``names`` """ return self.execute_command("DEL", *names) - def __delitem__(self, name: KeyT): + def __delitem__(self, name: KeyT) -> None: self.delete(name) - def dump(self, name: KeyT) -> ResponseT: + def dump(self, name: KeyT) -> ResponseT[Union[BulkStringResponseT, None]]: """ Return a serialized version of the value stored at the specified key. If key does not exist a nil bulk reply is returned. @@ -1717,7 +1715,7 @@ def dump(self, name: KeyT) -> ResponseT: options[NEVER_DECODE] = [] return self.execute_command("DUMP", name, **options) - def exists(self, *names: KeyT) -> ResponseT: + def exists(self, *names: KeyT) -> ResponseT[IntegerResponseT]: """ Returns the number of ``names`` that exist @@ -1735,7 +1733,7 @@ def expire( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Set an expire flag on key ``name`` for ``time`` seconds with given ``option``. ``time`` can be represented by an integer or a Python timedelta @@ -1751,7 +1749,6 @@ def expire( """ if isinstance(time, datetime.timedelta): time = int(time.total_seconds()) - exp_option = list() if nx: exp_option.append("NX") @@ -1761,7 +1758,6 @@ def expire( exp_option.append("GT") if lt: exp_option.append("LT") - return self.execute_command("EXPIRE", name, time, *exp_option) def expireat( @@ -1772,7 +1768,7 @@ def expireat( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Set an expire flag on key ``name`` with given ``option``. ``when`` can be represented as an integer indicating unix time or a Python @@ -1788,7 +1784,6 @@ def expireat( """ if isinstance(when, datetime.datetime): when = int(when.timestamp()) - exp_option = list() if nx: exp_option.append("NX") @@ -1798,10 +1793,9 @@ def expireat( exp_option.append("GT") if lt: exp_option.append("LT") - return self.execute_command("EXPIREAT", name, when, *exp_option) - def expiretime(self, key: str) -> int: + def expiretime(self, key: str) -> ResponseT[IntegerResponseT]: """ Returns the absolute Unix timestamp (since January 1, 1970) in seconds at which the given key will expire. @@ -1810,7 +1804,7 @@ def expiretime(self, key: str) -> int: """ return self.execute_command("EXPIRETIME", key) - def get(self, name: KeyT) -> ResponseT: + def get(self, name: KeyT) -> ResponseT[Union[BulkStringResponseT, None]]: """ Return the value at key ``name``, or None if the key doesn't exist @@ -1818,7 +1812,7 @@ def get(self, name: KeyT) -> ResponseT: """ return self.execute_command("GET", name, keys=[name]) - def getdel(self, name: KeyT) -> ResponseT: + def getdel(self, name: KeyT) -> ResponseT[Union[BulkStringResponseT, None]]: """ Get the value at key ``name`` and delete the key. This command is similar to GET, except for the fact that it also deletes @@ -1837,7 +1831,7 @@ def getex( exat: Union[AbsExpiryT, None] = None, pxat: Union[AbsExpiryT, None] = None, persist: bool = False, - ) -> ResponseT: + ) -> ResponseT[BulkStringResponseT]: """ Get the value of key and optionally set its expiration. GETEX is similar to GET, but is a write command with @@ -1858,15 +1852,13 @@ def getex( For more information see https://redis.io/commands/getex """ - opset = {ex, px, exat, pxat} if len(opset) > 2 or len(opset) > 1 and persist: raise DataError( "``ex``, ``px``, ``exat``, ``pxat``, " "and ``persist`` are mutually exclusive." ) - - pieces: list[EncodableT] = [] + pieces: List[EncodableT] = [] # similar to set command if ex is not None: pieces.append("EX") @@ -1891,10 +1883,9 @@ def getex( pieces.append(pxat) if persist: pieces.append("PERSIST") - return self.execute_command("GETEX", name, *pieces) - def __getitem__(self, name: KeyT): + def __getitem__(self, name: KeyT) -> ResponseT[BulkStringResponseT]: """ Return the value at key ``name``, raises a KeyError if the key doesn't exist. @@ -1904,7 +1895,7 @@ def __getitem__(self, name: KeyT): return value raise KeyError(name) - def getbit(self, name: KeyT, offset: int) -> ResponseT: + def getbit(self, name: KeyT, offset: int) -> ResponseT[IntegerResponseT]: """ Returns an integer indicating the value of ``offset`` in ``name`` @@ -1912,7 +1903,9 @@ def getbit(self, name: KeyT, offset: int) -> ResponseT: """ return self.execute_command("GETBIT", name, offset, keys=[name]) - def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: + def getrange( + self, key: KeyT, start: int, end: int + ) -> ResponseT[BulkStringResponseT]: """ Returns the substring of the string value stored at ``key``, determined by the offsets ``start`` and ``end`` (both are inclusive) @@ -1921,7 +1914,9 @@ def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: """ return self.execute_command("GETRANGE", key, start, end, keys=[key]) - def getset(self, name: KeyT, value: EncodableT) -> ResponseT: + def getset( + self, name: KeyT, value: EncodableT + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ Sets the value at key ``name`` to ``value`` and returns the old value at key ``name`` atomically. @@ -1933,7 +1928,7 @@ def getset(self, name: KeyT, value: EncodableT) -> ResponseT: """ return self.execute_command("GETSET", name, value) - def incrby(self, name: KeyT, amount: int = 1) -> ResponseT: + def incrby(self, name: KeyT, amount: int = 1) -> ResponseT[IntegerResponseT]: """ Increments the value of ``key`` by ``amount``. If no key exists, the value will be initialized as ``amount`` @@ -1944,7 +1939,9 @@ def incrby(self, name: KeyT, amount: int = 1) -> ResponseT: incr = incrby - def incrbyfloat(self, name: KeyT, amount: float = 1.0) -> ResponseT: + def incrbyfloat( + self, name: KeyT, amount: float = 1.0 + ) -> ResponseT[BulkStringResponseT]: """ Increments the value at key ``name`` by floating ``amount``. If no key exists, the value will be initialized as ``amount`` @@ -1953,7 +1950,7 @@ def incrbyfloat(self, name: KeyT, amount: float = 1.0) -> ResponseT: """ return self.execute_command("INCRBYFLOAT", name, amount) - def keys(self, pattern: PatternT = "*", **kwargs) -> ResponseT: + def keys(self, pattern: PatternT = "*", **kwargs) -> ResponseT[ArrayResponseT]: """ Returns a list of keys matching ``pattern`` @@ -1963,7 +1960,7 @@ def keys(self, pattern: PatternT = "*", **kwargs) -> ResponseT: def lmove( self, first_list: str, second_list: str, src: str = "LEFT", dest: str = "RIGHT" - ) -> ResponseT: + ) -> ResponseT[BulkStringResponseT]: """ Atomically returns and removes the first/last element of a list, pushing it as the first/last element on the destination list. @@ -1981,7 +1978,7 @@ def blmove( timeout: int, src: str = "LEFT", dest: str = "RIGHT", - ) -> ResponseT: + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ Blocking version of lmove. @@ -1990,7 +1987,7 @@ def blmove( params = [first_list, second_list, src, dest, timeout] return self.execute_command("BLMOVE", *params) - def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: + def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT[ArrayResponseT]: """ Returns a list of values ordered identically to ``keys`` @@ -2005,7 +2002,7 @@ def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: options["keys"] = args return self.execute_command("MGET", *args, **options) - def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: + def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT[OKT]: """ Sets key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values should be strings or types that @@ -2018,7 +2015,9 @@ def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: items.extend(pair) return self.execute_command("MSET", *items) - def msetnx(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: + def msetnx( + self, mapping: Mapping[AnyKeyT, EncodableT] + ) -> ResponseT[IntegerResponseT]: """ Sets key/values based on a mapping if none of the keys are already set. Mapping is a dictionary of key/value pairs. Both keys and values @@ -2032,7 +2031,7 @@ def msetnx(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: items.extend(pair) return self.execute_command("MSETNX", *items) - def move(self, name: KeyT, db: int) -> ResponseT: + def move(self, name: KeyT, db: int) -> ResponseT[IntegerResponseT]: """ Moves the key ``name`` to a different Redis database ``db`` @@ -2040,7 +2039,7 @@ def move(self, name: KeyT, db: int) -> ResponseT: """ return self.execute_command("MOVE", name, db) - def persist(self, name: KeyT) -> ResponseT: + def persist(self, name: KeyT) -> ResponseT[IntegerResponseT]: """ Removes an expiration on ``name`` @@ -2056,7 +2055,7 @@ def pexpire( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Set an expire flag on key ``name`` for ``time`` milliseconds with given ``option``. ``time`` can be represented by an @@ -2072,7 +2071,6 @@ def pexpire( """ if isinstance(time, datetime.timedelta): time = int(time.total_seconds() * 1000) - exp_option = list() if nx: exp_option.append("NX") @@ -2092,7 +2090,7 @@ def pexpireat( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Set an expire flag on key ``name`` with given ``option``. ``when`` can be represented as an integer representing unix time in @@ -2119,7 +2117,7 @@ def pexpireat( exp_option.append("LT") return self.execute_command("PEXPIREAT", name, when, *exp_option) - def pexpiretime(self, key: str) -> int: + def pexpiretime(self, key: str) -> ResponseT[IntegerResponseT]: """ Returns the absolute Unix timestamp (since January 1, 1970) in milliseconds at which the given key will expire. @@ -2128,7 +2126,7 @@ def pexpiretime(self, key: str) -> int: """ return self.execute_command("PEXPIRETIME", key) - def psetex(self, name: KeyT, time_ms: ExpiryT, value: EncodableT): + def psetex(self, name: KeyT, time_ms: ExpiryT, value: EncodableT) -> ResponseT[OKT]: """ Set the value of key ``name`` to ``value`` that expires in ``time_ms`` milliseconds. ``time_ms`` can be represented by an integer or a Python @@ -2140,7 +2138,7 @@ def psetex(self, name: KeyT, time_ms: ExpiryT, value: EncodableT): time_ms = int(time_ms.total_seconds() * 1000) return self.execute_command("PSETEX", name, time_ms, value) - def pttl(self, name: KeyT) -> ResponseT: + def pttl(self, name: KeyT) -> ResponseT[IntegerResponseT]: """ Returns the number of milliseconds until the key ``name`` will expire @@ -2149,8 +2147,8 @@ def pttl(self, name: KeyT) -> ResponseT: return self.execute_command("PTTL", name) def hrandfield( - self, key: str, count: int = None, withvalues: bool = False - ) -> ResponseT: + self, key: str, count: Optional[int] = None, withvalues: bool = False + ) -> ResponseT[Union[BulkStringResponseT, None, ArrayResponseT]]: """ Return a random field from the hash value stored at key. @@ -2169,10 +2167,9 @@ def hrandfield( params.append(count) if withvalues: params.append("WITHVALUES") - return self.execute_command("HRANDFIELD", key, *params) - def randomkey(self, **kwargs) -> ResponseT: + def randomkey(self, **kwargs) -> ResponseT[Union[BulkStringResponseT, None]]: """ Returns the name of a random key @@ -2180,7 +2177,7 @@ def randomkey(self, **kwargs) -> ResponseT: """ return self.execute_command("RANDOMKEY", **kwargs) - def rename(self, src: KeyT, dst: KeyT) -> ResponseT: + def rename(self, src: KeyT, dst: KeyT) -> ResponseT[OKT]: """ Rename key ``src`` to ``dst`` @@ -2188,7 +2185,7 @@ def rename(self, src: KeyT, dst: KeyT) -> ResponseT: """ return self.execute_command("RENAME", src, dst) - def renamenx(self, src: KeyT, dst: KeyT): + def renamenx(self, src: KeyT, dst: KeyT) -> ResponseT[IntegerResponseT]: """ Rename key ``src`` to ``dst`` if ``dst`` doesn't already exist @@ -2205,7 +2202,7 @@ def restore( absttl: bool = False, idletime: Union[int, None] = None, frequency: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Create a key using the provided serialized value, previously obtained using DUMP. @@ -2236,14 +2233,12 @@ def restore( params.append(int(idletime)) except ValueError: raise DataError("idletimemust be an integer") - if frequency is not None: params.append("FREQ") try: params.append(int(frequency)) except ValueError: raise DataError("frequency must be an integer") - return self.execute_command("RESTORE", *params) def set( @@ -2258,7 +2253,7 @@ def set( get: bool = False, exat: Union[AbsExpiryT, None] = None, pxat: Union[AbsExpiryT, None] = None, - ) -> ResponseT: + ) -> ResponseT[Union[BulkStringResponseT, None, OKT]]: """ Set the value at key ``name`` to ``value`` @@ -2287,7 +2282,7 @@ def set( For more information see https://redis.io/commands/set """ - pieces: list[EncodableT] = [name, value] + pieces: List[EncodableT] = [name, value] options = {} if ex is not None: pieces.append("EX") @@ -2319,22 +2314,21 @@ def set( pieces.append(pxat) if keepttl: pieces.append("KEEPTTL") - if nx: pieces.append("NX") if xx: pieces.append("XX") - if get: pieces.append("GET") options["get"] = True - return self.execute_command("SET", *pieces, **options) - def __setitem__(self, name: KeyT, value: EncodableT): + def __setitem__(self, name: KeyT, value: EncodableT) -> None: self.set(name, value) - def setbit(self, name: KeyT, offset: int, value: int) -> ResponseT: + def setbit( + self, name: KeyT, offset: int, value: int + ) -> ResponseT[IntegerResponseT]: """ Flag the ``offset`` in ``name`` as ``value``. Returns an integer indicating the previous value of ``offset``. @@ -2344,7 +2338,7 @@ def setbit(self, name: KeyT, offset: int, value: int) -> ResponseT: value = value and 1 or 0 return self.execute_command("SETBIT", name, offset, value) - def setex(self, name: KeyT, time: ExpiryT, value: EncodableT) -> ResponseT: + def setex(self, name: KeyT, time: ExpiryT, value: EncodableT) -> ResponseT[OKT]: """ Set the value of key ``name`` to ``value`` that expires in ``time`` seconds. ``time`` can be represented by an integer or a Python @@ -2356,7 +2350,7 @@ def setex(self, name: KeyT, time: ExpiryT, value: EncodableT) -> ResponseT: time = int(time.total_seconds()) return self.execute_command("SETEX", name, time, value) - def setnx(self, name: KeyT, value: EncodableT) -> ResponseT: + def setnx(self, name: KeyT, value: EncodableT) -> ResponseT[IntegerResponseT]: """ Set the value of key ``name`` to ``value`` if key doesn't exist @@ -2364,7 +2358,9 @@ def setnx(self, name: KeyT, value: EncodableT) -> ResponseT: """ return self.execute_command("SETNX", name, value) - def setrange(self, name: KeyT, offset: int, value: EncodableT) -> ResponseT: + def setrange( + self, name: KeyT, offset: int, value: EncodableT + ) -> ResponseT[IntegerResponseT]: """ Overwrite bytes in the value of ``name`` starting at ``offset`` with ``value``. If ``offset`` plus the length of ``value`` exceeds the @@ -2390,7 +2386,7 @@ def stralgo( minmatchlen: Union[int, None] = None, withmatchlen: bool = False, **kwargs, - ) -> ResponseT: + ) -> ResponseT[Any]: """ Implements complex algorithms that operate on strings. Right now the only algorithm implemented is the LCS algorithm @@ -2419,8 +2415,7 @@ def stralgo( raise DataError("specific_argument can be only keys or strings") if len and idx: raise DataError("len and idx cannot be provided together.") - - pieces: list[EncodableT] = [algo, specific_argument.upper(), value1, value2] + pieces: List[EncodableT] = [algo, specific_argument.upper(), value1, value2] if len: pieces.append(b"LEN") if idx: @@ -2432,7 +2427,6 @@ def stralgo( pass if withmatchlen: pieces.append(b"WITHMATCHLEN") - return self.execute_command( "STRALGO", *pieces, @@ -2443,7 +2437,7 @@ def stralgo( **kwargs, ) - def strlen(self, name: KeyT) -> ResponseT: + def strlen(self, name: KeyT) -> ResponseT[IntegerResponseT]: """ Return the number of bytes stored in the value of ``name`` @@ -2451,14 +2445,16 @@ def strlen(self, name: KeyT) -> ResponseT: """ return self.execute_command("STRLEN", name, keys=[name]) - def substr(self, name: KeyT, start: int, end: int = -1) -> ResponseT: + def substr( + self, name: KeyT, start: int, end: int = -1 + ) -> ResponseT[BulkStringResponseT]: """ Return a substring of the string at key ``name``. ``start`` and ``end`` are 0-based integers specifying the portion of the string to return. """ return self.execute_command("SUBSTR", name, start, end, keys=[name]) - def touch(self, *args: KeyT) -> ResponseT: + def touch(self, *args: KeyT) -> ResponseT[IntegerResponseT]: """ Alters the last access time of a key(s) ``*args``. A key is ignored if it does not exist. @@ -2467,7 +2463,7 @@ def touch(self, *args: KeyT) -> ResponseT: """ return self.execute_command("TOUCH", *args) - def ttl(self, name: KeyT) -> ResponseT: + def ttl(self, name: KeyT) -> ResponseT[IntegerResponseT]: """ Returns the number of seconds until the key ``name`` will expire @@ -2475,7 +2471,7 @@ def ttl(self, name: KeyT) -> ResponseT: """ return self.execute_command("TTL", name) - def type(self, name: KeyT) -> ResponseT: + def type(self, name: KeyT) -> ResponseT[str]: """ Returns the type of key ``name`` @@ -2499,7 +2495,7 @@ def unwatch(self) -> None: """ warnings.warn(DeprecationWarning("Call UNWATCH from a Pipeline object")) - def unlink(self, *names: KeyT) -> ResponseT: + def unlink(self, *names: KeyT) -> ResponseT[IntegerResponseT]: """ Unlink one or more keys specified by ``names`` @@ -2515,7 +2511,7 @@ def lcs( idx: Optional[bool] = False, minmatchlen: Optional[int] = 0, withmatchlen: Optional[bool] = False, - ) -> Union[str, int, list]: + ) -> ResponseT[Union[BulkStringResponseT, IntegerResponseT, ArrayResponseT]]: """ Find the longest common subsequence between ``key1`` and ``key2``. If ``len`` is true the length of the match will will be returned. @@ -2538,16 +2534,17 @@ def lcs( class AsyncBasicKeyCommands(BasicKeyCommands): - def __delitem__(self, name: KeyT): + + def __delitem__(self, name: KeyT) -> None: raise TypeError("Async Redis client does not support class deletion") - def __contains__(self, name: KeyT): + def __contains__(self, name: KeyT) -> None: raise TypeError("Async Redis client does not support class inclusion") - def __getitem__(self, name: KeyT): + def __getitem__(self, name: KeyT) -> None: raise TypeError("Async Redis client does not support class retrieval") - def __setitem__(self, name: KeyT, value: EncodableT): + def __setitem__(self, name: KeyT, value: EncodableT) -> None: raise TypeError("Async Redis client does not support class assignment") async def watch(self, *names: KeyT) -> None: @@ -2565,7 +2562,7 @@ class ListCommands(CommandsProtocol): def blpop( self, keys: List, timeout: Optional[int] = 0 - ) -> Union[Awaitable[list], list]: + ) -> ResponseT[Union[ArrayResponseT, None]]: """ LPOP a value off of the first non-empty list named in the ``keys`` list. @@ -2586,7 +2583,7 @@ def blpop( def brpop( self, keys: List, timeout: Optional[int] = 0 - ) -> Union[Awaitable[list], list]: + ) -> ResponseT[Union[ArrayResponseT, None]]: """ RPOP a value off of the first non-empty list named in the ``keys`` list. @@ -2607,7 +2604,7 @@ def brpop( def brpoplpush( self, src: str, dst: str, timeout: Optional[int] = 0 - ) -> Union[Awaitable[Optional[str]], Optional[str]]: + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ Pop a value off the tail of ``src``, push it on the head of ``dst`` and then return it. @@ -2629,7 +2626,7 @@ def blmpop( *args: List[str], direction: str, count: Optional[int] = 1, - ) -> Optional[list]: + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Pop ``count`` values (default 1) from first non-empty in the list of provided key names. @@ -2640,16 +2637,11 @@ def blmpop( For more information see https://redis.io/commands/blmpop """ args = [timeout, numkeys, *args, direction, "COUNT", count] - return self.execute_command("BLMPOP", *args) def lmpop( - self, - num_keys: int, - *args: List[str], - direction: str, - count: Optional[int] = 1, - ) -> Union[Awaitable[list], list]: + self, num_keys: int, *args: List[str], direction: str, count: Optional[int] = 1 + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Pop ``count`` values (default 1) first non-empty list key from the list of args provided key names. @@ -2659,12 +2651,11 @@ def lmpop( args = [num_keys] + list(args) + [direction] if count != 1: args.extend(["COUNT", count]) - return self.execute_command("LMPOP", *args) def lindex( self, name: str, index: int - ) -> Union[Awaitable[Optional[str]], Optional[str]]: + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ Return the item from list ``name`` at position ``index`` @@ -2677,7 +2668,7 @@ def lindex( def linsert( self, name: str, where: str, refvalue: str, value: str - ) -> Union[Awaitable[int], int]: + ) -> ResponseT[IntegerResponseT]: """ Insert ``value`` in list ``name`` either immediately before or after [``where``] ``refvalue`` @@ -2689,7 +2680,7 @@ def linsert( """ return self.execute_command("LINSERT", name, where, refvalue, value) - def llen(self, name: str) -> Union[Awaitable[int], int]: + def llen(self, name: str) -> ResponseT[IntegerResponseT]: """ Return the length of the list ``name`` @@ -2698,10 +2689,8 @@ def llen(self, name: str) -> Union[Awaitable[int], int]: return self.execute_command("LLEN", name, keys=[name]) def lpop( - self, - name: str, - count: Optional[int] = None, - ) -> Union[Awaitable[Union[str, List, None]], Union[str, List, None]]: + self, name: str, count: Optional[int] = None + ) -> ResponseT[Union[None, BulkStringResponseT, ArrayResponseT]]: """ Removes and returns the first elements of the list ``name``. @@ -2716,7 +2705,7 @@ def lpop( else: return self.execute_command("LPOP", name) - def lpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def lpush(self, name: str, *values: FieldT) -> ResponseT[IntegerResponseT]: """ Push ``values`` onto the head of the list ``name`` @@ -2724,7 +2713,7 @@ def lpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("LPUSH", name, *values) - def lpushx(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def lpushx(self, name: str, *values: FieldT) -> ResponseT[IntegerResponseT]: """ Push ``value`` onto the head of the list ``name`` if ``name`` exists @@ -2732,7 +2721,7 @@ def lpushx(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("LPUSHX", name, *values) - def lrange(self, name: str, start: int, end: int) -> Union[Awaitable[list], list]: + def lrange(self, name: str, start: int, end: int) -> ResponseT[ArrayResponseT]: """ Return a slice of the list ``name`` between position ``start`` and ``end`` @@ -2744,7 +2733,7 @@ def lrange(self, name: str, start: int, end: int) -> Union[Awaitable[list], list """ return self.execute_command("LRANGE", name, start, end, keys=[name]) - def lrem(self, name: str, count: int, value: str) -> Union[Awaitable[int], int]: + def lrem(self, name: str, count: int, value: str) -> ResponseT[IntegerResponseT]: """ Remove the first ``count`` occurrences of elements equal to ``value`` from the list stored at ``name``. @@ -2758,7 +2747,7 @@ def lrem(self, name: str, count: int, value: str) -> Union[Awaitable[int], int]: """ return self.execute_command("LREM", name, count, value) - def lset(self, name: str, index: int, value: str) -> Union[Awaitable[str], str]: + def lset(self, name: str, index: int, value: str) -> ResponseT[OKT]: """ Set element at ``index`` of list ``name`` to ``value`` @@ -2766,7 +2755,7 @@ def lset(self, name: str, index: int, value: str) -> Union[Awaitable[str], str]: """ return self.execute_command("LSET", name, index, value) - def ltrim(self, name: str, start: int, end: int) -> Union[Awaitable[str], str]: + def ltrim(self, name: str, start: int, end: int) -> ResponseT[OKT]: """ Trim the list ``name``, removing all values not within the slice between ``start`` and ``end`` @@ -2779,10 +2768,8 @@ def ltrim(self, name: str, start: int, end: int) -> Union[Awaitable[str], str]: return self.execute_command("LTRIM", name, start, end) def rpop( - self, - name: str, - count: Optional[int] = None, - ) -> Union[Awaitable[Union[str, List, None]], Union[str, List, None]]: + self, name: str, count: Optional[int] = None + ) -> ResponseT[Union[None, BulkStringResponseT, ArrayResponseT]]: """ Removes and returns the last elements of the list ``name``. @@ -2797,7 +2784,9 @@ def rpop( else: return self.execute_command("RPOP", name) - def rpoplpush(self, src: str, dst: str) -> Union[Awaitable[str], str]: + def rpoplpush( + self, src: str, dst: str + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ RPOP a value off of the ``src`` list and atomically LPUSH it on to the ``dst`` list. Returns the value. @@ -2806,7 +2795,7 @@ def rpoplpush(self, src: str, dst: str) -> Union[Awaitable[str], str]: """ return self.execute_command("RPOPLPUSH", src, dst) - def rpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def rpush(self, name: str, *values: FieldT) -> ResponseT[IntegerResponseT]: """ Push ``values`` onto the tail of the list ``name`` @@ -2814,7 +2803,7 @@ def rpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("RPUSH", name, *values) - def rpushx(self, name: str, *values: str) -> Union[Awaitable[int], int]: + def rpushx(self, name: str, *values: str) -> ResponseT[IntegerResponseT]: """ Push ``value`` onto the tail of the list ``name`` if ``name`` exists @@ -2829,7 +2818,7 @@ def lpos( rank: Optional[int] = None, count: Optional[int] = None, maxlen: Optional[int] = None, - ) -> Union[str, List, None]: + ) -> ResponseT[Union[IntegerResponseT, None, ArrayResponseT]]: """ Get position of ``value`` within the list ``name`` @@ -2855,16 +2844,13 @@ def lpos( For more information see https://redis.io/commands/lpos """ - pieces: list[EncodableT] = [name, value] + pieces: List[EncodableT] = [name, value] if rank is not None: pieces.extend(["RANK", rank]) - if count is not None: pieces.extend(["COUNT", count]) - if maxlen is not None: pieces.extend(["MAXLEN", maxlen]) - return self.execute_command("LPOS", *pieces, keys=[name]) def sort( @@ -2878,7 +2864,7 @@ def sort( alpha: bool = False, store: Optional[str] = None, groups: Optional[bool] = False, - ) -> Union[List, int]: + ) -> ResponseT[ArrayResponseT]: """ Sort and return the list, set or sorted set at ``name``. @@ -2906,8 +2892,7 @@ def sort( """ if (start is not None and num is None) or (num is not None and start is None): raise DataError("``start`` and ``num`` must both be specified") - - pieces: list[EncodableT] = [name] + pieces: List[EncodableT] = [name] if by is not None: pieces.extend([b"BY", by]) if start is not None and num is not None: @@ -2935,7 +2920,6 @@ def sort( "must be specified and contain at least " "two keys" ) - options = {"groups": len(get) if groups else None} options["keys"] = [name] return self.execute_command("SORT", *pieces, **options) @@ -2949,7 +2933,7 @@ def sort_ro( get: Optional[List[str]] = None, desc: bool = False, alpha: bool = False, - ) -> list: + ) -> ResponseT[ArrayResponseT]: """ Returns the elements contained in the list, set or sorted set at key. (read-only variant of the SORT command) @@ -2990,7 +2974,7 @@ def scan( count: Union[int, None] = None, _type: Union[str, None] = None, **kwargs, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Incrementally return lists of key names. Also return a cursor indicating the scan position. @@ -3007,7 +2991,7 @@ def scan( For more information see https://redis.io/commands/scan """ - pieces: list[EncodableT] = [cursor] + pieces: List[EncodableT] = [cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: @@ -3050,7 +3034,7 @@ def sscan( cursor: int = 0, match: Union[PatternT, None] = None, count: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Incrementally return lists of elements in a set. Also return a cursor indicating the scan position. @@ -3061,7 +3045,7 @@ def sscan( For more information see https://redis.io/commands/sscan """ - pieces: list[EncodableT] = [name, cursor] + pieces: List[EncodableT] = [name, cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: @@ -3094,7 +3078,7 @@ def hscan( match: Union[PatternT, None] = None, count: Union[int, None] = None, no_values: Union[bool, None] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Incrementally return key/value slices in a hash. Also return a cursor indicating the scan position. @@ -3107,7 +3091,7 @@ def hscan( For more information see https://redis.io/commands/hscan """ - pieces: list[EncodableT] = [name, cursor] + pieces: List[EncodableT] = [name, cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: @@ -3150,7 +3134,7 @@ def zscan( match: Union[PatternT, None] = None, count: Union[int, None] = None, score_cast_func: Union[type, Callable] = float, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Incrementally return lists of elements in a sorted set. Also return a cursor indicating the scan position. @@ -3201,6 +3185,7 @@ def zscan_iter( class AsyncScanCommands(ScanCommands): + async def scan_iter( self, match: Union[PatternT, None] = None, @@ -3317,7 +3302,7 @@ class SetCommands(CommandsProtocol): see: https://redis.io/topics/data-types#sets """ - def sadd(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def sadd(self, name: str, *values: FieldT) -> ResponseT[IntegerResponseT]: """ Add ``value(s)`` to set ``name`` @@ -3325,7 +3310,7 @@ def sadd(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("SADD", name, *values) - def scard(self, name: str) -> Union[Awaitable[int], int]: + def scard(self, name: str) -> ResponseT[IntegerResponseT]: """ Return the number of elements in set ``name`` @@ -3333,7 +3318,7 @@ def scard(self, name: str) -> Union[Awaitable[int], int]: """ return self.execute_command("SCARD", name, keys=[name]) - def sdiff(self, keys: List, *args: List) -> Union[Awaitable[list], list]: + def sdiff(self, keys: List, *args: List) -> ResponseT[ArrayResponseT]: """ Return the difference of sets specified by ``keys`` @@ -3344,7 +3329,7 @@ def sdiff(self, keys: List, *args: List) -> Union[Awaitable[list], list]: def sdiffstore( self, dest: str, keys: List, *args: List - ) -> Union[Awaitable[int], int]: + ) -> ResponseT[IntegerResponseT]: """ Store the difference of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. @@ -3354,7 +3339,7 @@ def sdiffstore( args = list_or_args(keys, args) return self.execute_command("SDIFFSTORE", dest, *args) - def sinter(self, keys: List, *args: List) -> Union[Awaitable[list], list]: + def sinter(self, keys: List, *args: List) -> ResponseT[ArrayResponseT]: """ Return the intersection of sets specified by ``keys`` @@ -3365,7 +3350,7 @@ def sinter(self, keys: List, *args: List) -> Union[Awaitable[list], list]: def sintercard( self, numkeys: int, keys: List[str], limit: int = 0 - ) -> Union[Awaitable[int], int]: + ) -> ResponseT[IntegerResponseT]: """ Return the cardinality of the intersect of multiple sets specified by ``keys``. @@ -3380,7 +3365,7 @@ def sintercard( def sinterstore( self, dest: str, keys: List, *args: List - ) -> Union[Awaitable[int], int]: + ) -> ResponseT[IntegerResponseT]: """ Store the intersection of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. @@ -3390,9 +3375,7 @@ def sinterstore( args = list_or_args(keys, args) return self.execute_command("SINTERSTORE", dest, *args) - def sismember( - self, name: str, value: str - ) -> Union[Awaitable[Union[Literal[0], Literal[1]]], Union[Literal[0], Literal[1]]]: + def sismember(self, name: str, value: str) -> ResponseT[IntegerResponseT]: """ Return whether ``value`` is a member of set ``name``: - 1 if the value is a member of the set. @@ -3402,7 +3385,7 @@ def sismember( """ return self.execute_command("SISMEMBER", name, value, keys=[name]) - def smembers(self, name: str) -> Union[Awaitable[Set], Set]: + def smembers(self, name: str) -> ResponseT[ArrayResponseT]: """ Return all members of the set ``name`` @@ -3410,10 +3393,9 @@ def smembers(self, name: str) -> Union[Awaitable[Set], Set]: """ return self.execute_command("SMEMBERS", name, keys=[name]) - def smismember(self, name: str, values: List, *args: List) -> Union[ - Awaitable[List[Union[Literal[0], Literal[1]]]], - List[Union[Literal[0], Literal[1]]], - ]: + def smismember( + self, name: str, values: List, *args: List + ) -> ResponseT[ArrayResponseT]: """ Return whether each value in ``values`` is a member of the set ``name`` as a list of ``int`` in the order of ``values``: @@ -3425,7 +3407,7 @@ def smismember(self, name: str, values: List, *args: List) -> Union[ args = list_or_args(values, args) return self.execute_command("SMISMEMBER", name, *args, keys=[name]) - def smove(self, src: str, dst: str, value: str) -> Union[Awaitable[bool], bool]: + def smove(self, src: str, dst: str, value: str) -> ResponseT[IntegerResponseT]: """ Move ``value`` from set ``src`` to set ``dst`` atomically @@ -3433,7 +3415,9 @@ def smove(self, src: str, dst: str, value: str) -> Union[Awaitable[bool], bool]: """ return self.execute_command("SMOVE", src, dst, value) - def spop(self, name: str, count: Optional[int] = None) -> Union[str, List, None]: + def spop( + self, name: str, count: Optional[int] = None + ) -> ResponseT[Union[BulkStringResponseT, None, ArrayResponseT]]: """ Remove and return a random member of set ``name`` @@ -3444,7 +3428,7 @@ def spop(self, name: str, count: Optional[int] = None) -> Union[str, List, None] def srandmember( self, name: str, number: Optional[int] = None - ) -> Union[str, List, None]: + ) -> ResponseT[Union[BulkStringResponseT, ArrayResponseT]]: """ If ``number`` is None, returns a random member of set ``name``. @@ -3457,7 +3441,7 @@ def srandmember( args = (number is not None) and [number] or [] return self.execute_command("SRANDMEMBER", name, *args) - def srem(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def srem(self, name: str, *values: FieldT) -> ResponseT[IntegerResponseT]: """ Remove ``values`` from set ``name`` @@ -3465,7 +3449,7 @@ def srem(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("SREM", name, *values) - def sunion(self, keys: List, *args: List) -> Union[Awaitable[List], List]: + def sunion(self, keys: List, *args: List) -> ResponseT[ArrayResponseT]: """ Return the union of sets specified by ``keys`` @@ -3476,7 +3460,7 @@ def sunion(self, keys: List, *args: List) -> Union[Awaitable[List], List]: def sunionstore( self, dest: str, keys: List, *args: List - ) -> Union[Awaitable[int], int]: + ) -> ResponseT[IntegerResponseT]: """ Store the union of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. @@ -3496,7 +3480,9 @@ class StreamCommands(CommandsProtocol): see: https://redis.io/topics/streams-intro """ - def xack(self, name: KeyT, groupname: GroupT, *ids: StreamIdT) -> ResponseT: + def xack( + self, name: KeyT, groupname: GroupT, *ids: StreamIdT + ) -> ResponseT[IntegerResponseT]: """ Acknowledges the successful processing of one or more messages. @@ -3519,7 +3505,7 @@ def xadd( nomkstream: bool = False, minid: Union[StreamIdT, None] = None, limit: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ Add to a stream. name: name of the stream @@ -3535,10 +3521,9 @@ def xadd( For more information see https://redis.io/commands/xadd """ - pieces: list[EncodableT] = [] + pieces: List[EncodableT] = [] if maxlen is not None and minid is not None: raise DataError("Only one of ```maxlen``` or ```minid``` may be specified") - if maxlen is not None: if not isinstance(maxlen, int) or maxlen < 0: raise DataError("XADD maxlen must be non-negative integer") @@ -3571,7 +3556,7 @@ def xautoclaim( start_id: StreamIdT = "0-0", count: Union[int, None] = None, justid: bool = False, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Transfers ownership of pending stream entries that match the specified criteria. Conceptually, equivalent to calling XPENDING and then XCLAIM, @@ -3597,10 +3582,8 @@ def xautoclaim( ) except TypeError: pass - kwargs = {} pieces = [name, groupname, consumername, min_idle_time, start_id] - try: if int(count) < 0: raise DataError("XPENDING count must be a integer >= 0") @@ -3610,7 +3593,6 @@ def xautoclaim( if justid: pieces.append(b"JUSTID") kwargs["parse_justid"] = True - return self.execute_command("XAUTOCLAIM", *pieces, **kwargs) def xclaim( @@ -3625,7 +3607,7 @@ def xclaim( retrycount: Union[int, None] = None, force: bool = False, justid: bool = False, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Changes the ownership of a pending message. @@ -3667,11 +3649,9 @@ def xclaim( "XCLAIM message_ids must be a non empty list or " "tuple of message IDs to claim" ) - kwargs = {} - pieces: list[EncodableT] = [name, groupname, consumername, str(min_idle_time)] + pieces: List[EncodableT] = [name, groupname, consumername, str(min_idle_time)] pieces.extend(list(message_ids)) - if idle is not None: if not isinstance(idle, int): raise DataError("XCLAIM idle must be an integer") @@ -3684,7 +3664,6 @@ def xclaim( if not isinstance(retrycount, int): raise DataError("XCLAIM retrycount must be an integer") pieces.extend((b"RETRYCOUNT", str(retrycount))) - if force: if not isinstance(force, bool): raise DataError("XCLAIM force must be a boolean") @@ -3696,7 +3675,7 @@ def xclaim( kwargs["parse_justid"] = True return self.execute_command("XCLAIM", *pieces, **kwargs) - def xdel(self, name: KeyT, *ids: StreamIdT) -> ResponseT: + def xdel(self, name: KeyT, *ids: StreamIdT) -> ResponseT[IntegerResponseT]: """ Deletes one or more messages from a stream. @@ -3715,7 +3694,7 @@ def xgroup_create( id: StreamIdT = "$", mkstream: bool = False, entries_read: Optional[int] = None, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Create a new consumer group associated with a stream. name: name of the stream. @@ -3724,17 +3703,16 @@ def xgroup_create( For more information see https://redis.io/commands/xgroup-create """ - pieces: list[EncodableT] = ["XGROUP CREATE", name, groupname, id] + pieces: List[EncodableT] = ["XGROUP CREATE", name, groupname, id] if mkstream: pieces.append(b"MKSTREAM") if entries_read is not None: pieces.extend(["ENTRIESREAD", entries_read]) - return self.execute_command(*pieces) def xgroup_delconsumer( self, name: KeyT, groupname: GroupT, consumername: ConsumerT - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Remove a specific consumer from a consumer group. Returns the number of pending messages that the consumer had before it @@ -3747,7 +3725,9 @@ def xgroup_delconsumer( """ return self.execute_command("XGROUP DELCONSUMER", name, groupname, consumername) - def xgroup_destroy(self, name: KeyT, groupname: GroupT) -> ResponseT: + def xgroup_destroy( + self, name: KeyT, groupname: GroupT + ) -> ResponseT[IntegerResponseT]: """ Destroy a consumer group. name: name of the stream. @@ -3759,7 +3739,7 @@ def xgroup_destroy(self, name: KeyT, groupname: GroupT) -> ResponseT: def xgroup_createconsumer( self, name: KeyT, groupname: GroupT, consumername: ConsumerT - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Consumers in a consumer group are auto-created every time a new consumer name is mentioned by some command. @@ -3780,7 +3760,7 @@ def xgroup_setid( groupname: GroupT, id: StreamIdT, entries_read: Optional[int] = None, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Set the consumer group last delivered ID to something else. name: name of the stream. @@ -3794,7 +3774,9 @@ def xgroup_setid( pieces.extend(["ENTRIESREAD", entries_read]) return self.execute_command("XGROUP SETID", *pieces) - def xinfo_consumers(self, name: KeyT, groupname: GroupT) -> ResponseT: + def xinfo_consumers( + self, name: KeyT, groupname: GroupT + ) -> ResponseT[ArrayResponseT]: """ Returns general information about the consumers in the group. name: name of the stream. @@ -3804,7 +3786,7 @@ def xinfo_consumers(self, name: KeyT, groupname: GroupT) -> ResponseT: """ return self.execute_command("XINFO CONSUMERS", name, groupname) - def xinfo_groups(self, name: KeyT) -> ResponseT: + def xinfo_groups(self, name: KeyT) -> ResponseT[ArrayResponseT]: """ Returns general information about the consumer groups of the stream. name: name of the stream. @@ -3813,7 +3795,7 @@ def xinfo_groups(self, name: KeyT) -> ResponseT: """ return self.execute_command("XINFO GROUPS", name) - def xinfo_stream(self, name: KeyT, full: bool = False) -> ResponseT: + def xinfo_stream(self, name: KeyT, full: bool = False) -> ResponseT[ArrayResponseT]: """ Returns general information about the stream. name: name of the stream. @@ -3828,7 +3810,7 @@ def xinfo_stream(self, name: KeyT, full: bool = False) -> ResponseT: options = {"full": full} return self.execute_command("XINFO STREAM", *pieces, **options) - def xlen(self, name: KeyT) -> ResponseT: + def xlen(self, name: KeyT) -> ResponseT[IntegerResponseT]: """ Returns the number of elements in a given stream. @@ -3836,7 +3818,7 @@ def xlen(self, name: KeyT) -> ResponseT: """ return self.execute_command("XLEN", name, keys=[name]) - def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT: + def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT[ArrayResponseT]: """ Returns information about pending messages of a group. name: name of the stream. @@ -3855,7 +3837,7 @@ def xpending_range( count: int, consumername: Union[ConsumerT, None] = None, idle: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Returns information about pending messages, in a range. @@ -3876,7 +3858,6 @@ def xpending_range( " with min, max and count parameters" ) return self.xpending(name, groupname) - pieces = [name, groupname] if min is None or max is None or count is None: raise DataError( @@ -3900,7 +3881,6 @@ def xpending_range( # consumername if consumername: pieces.append(consumername) - return self.execute_command("XPENDING", *pieces, parse_detail=True) def xrange( @@ -3909,7 +3889,7 @@ def xrange( min: StreamIdT = "-", max: StreamIdT = "+", count: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Read stream values within an interval. @@ -3932,7 +3912,6 @@ def xrange( raise DataError("XRANGE count must be a positive integer") pieces.append(b"COUNT") pieces.append(str(count)) - return self.execute_command("XRANGE", name, *pieces, keys=[name]) def xread( @@ -3940,7 +3919,7 @@ def xread( streams: Dict[KeyT, StreamIdT], count: Union[int, None] = None, block: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Block and monitor multiple streams for new data. @@ -3981,7 +3960,7 @@ def xreadgroup( count: Union[int, None] = None, block: Union[int, None] = None, noack: bool = False, - ) -> ResponseT: + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Read from a stream via a consumer group. @@ -4000,7 +3979,7 @@ def xreadgroup( For more information see https://redis.io/commands/xreadgroup """ - pieces: list[EncodableT] = [b"GROUP", groupname, consumername] + pieces: List[EncodableT] = [b"GROUP", groupname, consumername] if count is not None: if not isinstance(count, int) or count < 1: raise DataError("XREADGROUP count must be a positive integer") @@ -4026,7 +4005,7 @@ def xrevrange( max: StreamIdT = "+", min: StreamIdT = "-", count: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Read stream values within an interval, in reverse order. @@ -4043,13 +4022,12 @@ def xrevrange( For more information see https://redis.io/commands/xrevrange """ - pieces: list[EncodableT] = [max, min] + pieces: List[EncodableT] = [max, min] if count is not None: if not isinstance(count, int) or count < 1: raise DataError("XREVRANGE count must be a positive integer") pieces.append(b"COUNT") pieces.append(str(count)) - return self.execute_command("XREVRANGE", name, *pieces, keys=[name]) def xtrim( @@ -4059,7 +4037,7 @@ def xtrim( approximate: bool = True, minid: Union[StreamIdT, None] = None, limit: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Trims old messages from a stream. name: name of the stream. @@ -4072,13 +4050,11 @@ def xtrim( For more information see https://redis.io/commands/xtrim """ - pieces: list[EncodableT] = [] + pieces: List[EncodableT] = [] if maxlen is not None and minid is not None: raise DataError("Only one of ``maxlen`` or ``minid`` may be specified") - if maxlen is None and minid is None: raise DataError("One of ``maxlen`` or ``minid`` must be specified") - if maxlen is not None: pieces.append(b"MAXLEN") if minid is not None: @@ -4092,7 +4068,6 @@ def xtrim( if limit is not None: pieces.append(b"LIMIT") pieces.append(limit) - return self.execute_command("XTRIM", name, *pieces) @@ -4115,7 +4090,7 @@ def zadd( incr: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[Union[IntegerResponseT, BulkStringResponseT, None]]: """ Set any number of element-name, score pairs to the key ``name``. Pairs are specified as a dict of element-names keys to score values. @@ -4162,8 +4137,7 @@ def zadd( ) if nx and (gt or lt): raise DataError("Only one of 'nx', 'lt', or 'gr' may be defined.") - - pieces: list[EncodableT] = [] + pieces: List[EncodableT] = [] options = {} if nx: pieces.append(b"NX") @@ -4183,7 +4157,7 @@ def zadd( pieces.append(pair[0]) return self.execute_command("ZADD", name, *pieces, **options) - def zcard(self, name: KeyT) -> ResponseT: + def zcard(self, name: KeyT) -> ResponseT[IntegerResponseT]: """ Return the number of elements in the sorted set ``name`` @@ -4191,7 +4165,9 @@ def zcard(self, name: KeyT) -> ResponseT: """ return self.execute_command("ZCARD", name, keys=[name]) - def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: + def zcount( + self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT + ) -> ResponseT[IntegerResponseT]: """ Returns the number of elements in the sorted set at key ``name`` with a score between ``min`` and ``max``. @@ -4200,7 +4176,7 @@ def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: """ return self.execute_command("ZCOUNT", name, min, max, keys=[name]) - def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: + def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT[ArrayResponseT]: """ Returns the difference between the first and all successive input sorted sets provided in ``keys``. @@ -4212,7 +4188,7 @@ def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: pieces.append("WITHSCORES") return self.execute_command("ZDIFF", *pieces, keys=keys) - def zdiffstore(self, dest: KeyT, keys: KeysT) -> ResponseT: + def zdiffstore(self, dest: KeyT, keys: KeysT) -> ResponseT[IntegerResponseT]: """ Computes the difference between the first and all successive input sorted sets provided in ``keys`` and stores the result in ``dest``. @@ -4222,7 +4198,9 @@ def zdiffstore(self, dest: KeyT, keys: KeysT) -> ResponseT: pieces = [len(keys), *keys] return self.execute_command("ZDIFFSTORE", dest, *pieces) - def zincrby(self, name: KeyT, amount: float, value: EncodableT) -> ResponseT: + def zincrby( + self, name: KeyT, amount: float, value: EncodableT + ) -> ResponseT[BulkStringResponseT]: """ Increment the score of ``value`` in sorted set ``name`` by ``amount`` @@ -4251,7 +4229,7 @@ def zinterstore( dest: KeyT, keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], aggregate: Union[str, None] = None, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Intersect multiple sorted sets specified by ``keys`` into a new sorted set, ``dest``. Scores in the destination will be aggregated @@ -4267,7 +4245,7 @@ def zinterstore( def zintercard( self, numkeys: int, keys: List[str], limit: int = 0 - ) -> Union[Awaitable[int], int]: + ) -> ResponseT[IntegerResponseT]: """ Return the cardinality of the intersect of multiple sorted sets specified by ``keys``. @@ -4280,7 +4258,7 @@ def zintercard( args = [numkeys, *keys, "LIMIT", limit] return self.execute_command("ZINTERCARD", *args, keys=keys) - def zlexcount(self, name, min, max): + def zlexcount(self, name, min, max) -> ResponseT[IntegerResponseT]: """ Return the number of items in the sorted set ``name`` between the lexicographical range ``min`` and ``max``. @@ -4289,7 +4267,9 @@ def zlexcount(self, name, min, max): """ return self.execute_command("ZLEXCOUNT", name, min, max, keys=[name]) - def zpopmax(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: + def zpopmax( + self, name: KeyT, count: Union[int, None] = None + ) -> ResponseT[ArrayResponseT]: """ Remove and return up to ``count`` members with the highest scores from the sorted set ``name``. @@ -4300,7 +4280,9 @@ def zpopmax(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: options = {"withscores": True} return self.execute_command("ZPOPMAX", name, *args, **options) - def zpopmin(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: + def zpopmin( + self, name: KeyT, count: Union[int, None] = None + ) -> ResponseT[ArrayResponseT]: """ Remove and return up to ``count`` members with the lowest scores from the sorted set ``name``. @@ -4313,7 +4295,7 @@ def zpopmin(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: def zrandmember( self, key: KeyT, count: int = None, withscores: bool = False - ) -> ResponseT: + ) -> ResponseT[BulkStringResponseT]: """ Return a random element from the sorted set value stored at key. @@ -4334,10 +4316,11 @@ def zrandmember( params.append(count) if withscores: params.append("WITHSCORES") - return self.execute_command("ZRANDMEMBER", key, *params) - def bzpopmax(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: + def bzpopmax( + self, keys: KeysT, timeout: TimeoutSecT = 0 + ) -> ResponseT[Union[ArrayResponseT, None]]: """ ZPOPMAX a value off of the first non-empty sorted set named in the ``keys`` list. @@ -4356,7 +4339,9 @@ def bzpopmax(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: keys.append(timeout) return self.execute_command("BZPOPMAX", *keys) - def bzpopmin(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: + def bzpopmin( + self, keys: KeysT, timeout: TimeoutSecT = 0 + ) -> ResponseT[Union[ArrayResponseT, None]]: """ ZPOPMIN a value off of the first non-empty sorted set named in the ``keys`` list. @@ -4371,7 +4356,7 @@ def bzpopmin(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: """ if timeout is None: timeout = 0 - keys: list[EncodableT] = list_or_args(keys, None) + keys: List[EncodableT] = list_or_args(keys, None) keys.append(timeout) return self.execute_command("BZPOPMIN", *keys) @@ -4382,7 +4367,7 @@ def zmpop( min: Optional[bool] = False, max: Optional[bool] = False, count: Optional[int] = 1, - ) -> Union[Awaitable[list], list]: + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Pop ``count`` values (default 1) off of the first non-empty sorted set named in the ``keys`` list. @@ -4397,7 +4382,6 @@ def zmpop( args.append("MAX") if count != 1: args.extend(["COUNT", count]) - return self.execute_command("ZMPOP", *args) def bzmpop( @@ -4408,7 +4392,7 @@ def bzmpop( min: Optional[bool] = False, max: Optional[bool] = False, count: Optional[int] = 1, - ) -> Optional[list]: + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Pop ``count`` values (default 1) off of the first non-empty sorted set named in the ``keys`` list. @@ -4429,7 +4413,6 @@ def bzmpop( else: args.append("MAX") args.extend(["COUNT", count]) - return self.execute_command("BZMPOP", *args) def _zrange( @@ -4446,7 +4429,7 @@ def _zrange( score_cast_func: Union[type, Callable, None] = float, offset: Union[int, None] = None, num: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[Any]: if byscore and bylex: raise DataError("``byscore`` and ``bylex`` can not be specified together.") if (offset is not None and num is None) or (num is not None and offset is None): @@ -4483,9 +4466,9 @@ def zrange( score_cast_func: Union[type, Callable] = float, byscore: bool = False, bylex: bool = False, - offset: int = None, - num: int = None, - ) -> ResponseT: + offset: Optional[int] = None, + num: Optional[int] = None, + ) -> ResponseT[ArrayResponseT]: """ Return a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in ascending order. @@ -4518,7 +4501,6 @@ def zrange( # because it was supported in 3.5.3 (of redis-py) if not byscore and not bylex and (offset is None and num is None) and desc: return self.zrevrange(name, start, end, withscores, score_cast_func) - return self._zrange( "ZRANGE", None, @@ -4541,7 +4523,7 @@ def zrevrange( end: int, withscores: bool = False, score_cast_func: Union[type, Callable] = float, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Return a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in descending order. @@ -4573,7 +4555,7 @@ def zrangestore( desc: bool = False, offset: Union[int, None] = None, num: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Stores in ``dest`` the result of a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in ascending order. @@ -4619,7 +4601,7 @@ def zrangebylex( max: EncodableT, start: Union[int, None] = None, num: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Return the lexicographical range of values from sorted set ``name`` between ``min`` and ``max``. @@ -4643,7 +4625,7 @@ def zrevrangebylex( min: EncodableT, start: Union[int, None] = None, num: Union[int, None] = None, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Return the reversed lexicographical range of values from sorted set ``name`` between ``max`` and ``min``. @@ -4669,7 +4651,7 @@ def zrangebyscore( num: Union[int, None] = None, withscores: bool = False, score_cast_func: Union[type, Callable] = float, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Return a range of values from the sorted set ``name`` with scores between ``min`` and ``max``. @@ -4704,7 +4686,7 @@ def zrevrangebyscore( num: Union[int, None] = None, withscores: bool = False, score_cast_func: Union[type, Callable] = float, - ): + ) -> ResponseT[ArrayResponseT]: """ Return a range of values from the sorted set ``name`` with scores between ``min`` and ``max`` in descending order. @@ -4731,11 +4713,8 @@ def zrevrangebyscore( return self.execute_command(*pieces, **options) def zrank( - self, - name: KeyT, - value: EncodableT, - withscore: bool = False, - ) -> ResponseT: + self, name: KeyT, value: EncodableT, withscore: bool = False + ) -> ResponseT[Union[IntegerResponseT, None, ArrayResponseT]]: """ Returns a 0-based value indicating the rank of ``value`` in sorted set ``name``. @@ -4748,7 +4727,7 @@ def zrank( return self.execute_command("ZRANK", name, value, "WITHSCORE", keys=[name]) return self.execute_command("ZRANK", name, value, keys=[name]) - def zrem(self, name: KeyT, *values: FieldT) -> ResponseT: + def zrem(self, name: KeyT, *values: FieldT) -> ResponseT[IntegerResponseT]: """ Remove member ``values`` from sorted set ``name`` @@ -4756,7 +4735,9 @@ def zrem(self, name: KeyT, *values: FieldT) -> ResponseT: """ return self.execute_command("ZREM", name, *values) - def zremrangebylex(self, name: KeyT, min: EncodableT, max: EncodableT) -> ResponseT: + def zremrangebylex( + self, name: KeyT, min: EncodableT, max: EncodableT + ) -> ResponseT[IntegerResponseT]: """ Remove all elements in the sorted set ``name`` between the lexicographical range specified by ``min`` and ``max``. @@ -4767,7 +4748,9 @@ def zremrangebylex(self, name: KeyT, min: EncodableT, max: EncodableT) -> Respon """ return self.execute_command("ZREMRANGEBYLEX", name, min, max) - def zremrangebyrank(self, name: KeyT, min: int, max: int) -> ResponseT: + def zremrangebyrank( + self, name: KeyT, min: int, max: int + ) -> ResponseT[IntegerResponseT]: """ Remove all elements in the sorted set ``name`` with ranks between ``min`` and ``max``. Values are 0-based, ordered from smallest score @@ -4780,7 +4763,7 @@ def zremrangebyrank(self, name: KeyT, min: int, max: int) -> ResponseT: def zremrangebyscore( self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Remove all elements in the sorted set ``name`` with scores between ``min`` and ``max``. Returns the number of elements removed. @@ -4790,11 +4773,8 @@ def zremrangebyscore( return self.execute_command("ZREMRANGEBYSCORE", name, min, max) def zrevrank( - self, - name: KeyT, - value: EncodableT, - withscore: bool = False, - ) -> ResponseT: + self, name: KeyT, value: EncodableT, withscore: bool = False + ) -> ResponseT[Union[IntegerResponseT, None, ArrayResponseT]]: """ Returns a 0-based value indicating the descending rank of ``value`` in sorted set ``name``. @@ -4809,7 +4789,9 @@ def zrevrank( ) return self.execute_command("ZREVRANK", name, value, keys=[name]) - def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: + def zscore( + self, name: KeyT, value: EncodableT + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ Return the score of element ``value`` in sorted set ``name`` @@ -4822,7 +4804,7 @@ def zunion( keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], aggregate: Union[str, None] = None, withscores: bool = False, - ) -> ResponseT: + ) -> ResponseT[ArrayResponseT]: """ Return the union of multiple sorted sets specified by ``keys``. ``keys`` can be provided as dictionary of keys and their weights. @@ -4838,7 +4820,7 @@ def zunionstore( dest: KeyT, keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], aggregate: Union[str, None] = None, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Union multiple sorted sets specified by ``keys`` into a new sorted set, ``dest``. Scores in the destination will be @@ -4848,7 +4830,9 @@ def zunionstore( """ return self._zaggregate("ZUNIONSTORE", dest, keys, aggregate) - def zmscore(self, key: KeyT, members: List[str]) -> ResponseT: + def zmscore( + self, key: KeyT, members: List[str] + ) -> ResponseT[Union[ArrayResponseT, None]]: """ Returns the scores associated with the specified members in the sorted set stored at key. @@ -4871,8 +4855,8 @@ def _zaggregate( keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], aggregate: Union[str, None] = None, **options, - ) -> ResponseT: - pieces: list[EncodableT] = [command] + ) -> ResponseT[Any]: + pieces: List[EncodableT] = [command] if dest is not None: pieces.append(dest) pieces.append(len(keys)) @@ -4905,7 +4889,7 @@ class HyperlogCommands(CommandsProtocol): see: https://redis.io/topics/data-types-intro#hyperloglogs """ - def pfadd(self, name: KeyT, *values: FieldT) -> ResponseT: + def pfadd(self, name: KeyT, *values: FieldT) -> ResponseT[IntegerResponseT]: """ Adds the specified elements to the specified HyperLogLog. @@ -4913,7 +4897,7 @@ def pfadd(self, name: KeyT, *values: FieldT) -> ResponseT: """ return self.execute_command("PFADD", name, *values) - def pfcount(self, *sources: KeyT) -> ResponseT: + def pfcount(self, *sources: KeyT) -> ResponseT[IntegerResponseT]: """ Return the approximated cardinality of the set observed by the HyperLogLog at key(s). @@ -4922,7 +4906,7 @@ def pfcount(self, *sources: KeyT) -> ResponseT: """ return self.execute_command("PFCOUNT", *sources) - def pfmerge(self, dest: KeyT, *sources: KeyT) -> ResponseT: + def pfmerge(self, dest: KeyT, *sources: KeyT) -> ResponseT[OKT]: """ Merge N different HyperLogLogs into a single one. @@ -4940,7 +4924,7 @@ class HashCommands(CommandsProtocol): see: https://redis.io/topics/data-types-intro#redis-hashes """ - def hdel(self, name: str, *keys: str) -> Union[Awaitable[int], int]: + def hdel(self, name: str, *keys: str) -> ResponseT[IntegerResponseT]: """ Delete ``keys`` from hash ``name`` @@ -4948,7 +4932,7 @@ def hdel(self, name: str, *keys: str) -> Union[Awaitable[int], int]: """ return self.execute_command("HDEL", name, *keys) - def hexists(self, name: str, key: str) -> Union[Awaitable[bool], bool]: + def hexists(self, name: str, key: str) -> ResponseT[IntegerResponseT]: """ Returns a boolean indicating if ``key`` exists within hash ``name`` @@ -4956,9 +4940,7 @@ def hexists(self, name: str, key: str) -> Union[Awaitable[bool], bool]: """ return self.execute_command("HEXISTS", name, key, keys=[name]) - def hget( - self, name: str, key: str - ) -> Union[Awaitable[Optional[str]], Optional[str]]: + def hget(self, name: str, key: str) -> ResponseT[Union[BulkStringResponseT, None]]: """ Return the value of ``key`` within the hash ``name`` @@ -4966,7 +4948,7 @@ def hget( """ return self.execute_command("HGET", name, key, keys=[name]) - def hgetall(self, name: str) -> Union[Awaitable[dict], dict]: + def hgetall(self, name: str) -> ResponseT[ArrayResponseT]: """ Return a Python dict of the hash's name/value pairs @@ -4976,7 +4958,7 @@ def hgetall(self, name: str) -> Union[Awaitable[dict], dict]: def hincrby( self, name: str, key: str, amount: int = 1 - ) -> Union[Awaitable[int], int]: + ) -> ResponseT[IntegerResponseT]: """ Increment the value of ``key`` in hash ``name`` by ``amount`` @@ -4986,7 +4968,7 @@ def hincrby( def hincrbyfloat( self, name: str, key: str, amount: float = 1.0 - ) -> Union[Awaitable[float], float]: + ) -> ResponseT[BulkStringResponseT]: """ Increment the value of ``key`` in hash ``name`` by floating ``amount`` @@ -4994,7 +4976,7 @@ def hincrbyfloat( """ return self.execute_command("HINCRBYFLOAT", name, key, amount) - def hkeys(self, name: str) -> Union[Awaitable[List], List]: + def hkeys(self, name: str) -> ResponseT[ArrayResponseT]: """ Return the list of keys within hash ``name`` @@ -5002,7 +4984,7 @@ def hkeys(self, name: str) -> Union[Awaitable[List], List]: """ return self.execute_command("HKEYS", name, keys=[name]) - def hlen(self, name: str) -> Union[Awaitable[int], int]: + def hlen(self, name: str) -> ResponseT[IntegerResponseT]: """ Return the number of elements in hash ``name`` @@ -5016,8 +4998,8 @@ def hset( key: Optional[str] = None, value: Optional[str] = None, mapping: Optional[dict] = None, - items: Optional[list] = None, - ) -> Union[Awaitable[int], int]: + items: Optional[List] = None, + ) -> ResponseT[IntegerResponseT]: """ Set ``key`` to ``value`` within hash ``name``, ``mapping`` accepts a dict of key/value pairs that will be @@ -5038,10 +5020,9 @@ def hset( if mapping: for pair in mapping.items(): pieces.extend(pair) - return self.execute_command("HSET", name, *pieces) - def hsetnx(self, name: str, key: str, value: str) -> Union[Awaitable[bool], bool]: + def hsetnx(self, name: str, key: str, value: str) -> ResponseT[IntegerResponseT]: """ Set ``key`` to ``value`` within hash ``name`` if ``key`` does not exist. Returns 1 if HSETNX created a field, otherwise 0. @@ -5050,7 +5031,7 @@ def hsetnx(self, name: str, key: str, value: str) -> Union[Awaitable[bool], bool """ return self.execute_command("HSETNX", name, key, value) - def hmset(self, name: str, mapping: dict) -> Union[Awaitable[str], str]: + def hmset(self, name: str, mapping: dict) -> ResponseT[OKT]: """ Set key to value within hash ``name`` for each corresponding key and value from the ``mapping`` dict. @@ -5070,7 +5051,7 @@ def hmset(self, name: str, mapping: dict) -> Union[Awaitable[str], str]: items.extend(pair) return self.execute_command("HMSET", name, *items) - def hmget(self, name: str, keys: List, *args: List) -> Union[Awaitable[List], List]: + def hmget(self, name: str, keys: List, *args: List) -> ResponseT[ArrayResponseT]: """ Returns a list of values ordered identically to ``keys`` @@ -5079,7 +5060,7 @@ def hmget(self, name: str, keys: List, *args: List) -> Union[Awaitable[List], Li args = list_or_args(keys, args) return self.execute_command("HMGET", name, *args, keys=[name]) - def hvals(self, name: str) -> Union[Awaitable[List], List]: + def hvals(self, name: str) -> ResponseT[ArrayResponseT]: """ Return the list of values within hash ``name`` @@ -5087,7 +5068,7 @@ def hvals(self, name: str) -> Union[Awaitable[List], List]: """ return self.execute_command("HVALS", name, keys=[name]) - def hstrlen(self, name: str, key: str) -> Union[Awaitable[int], int]: + def hstrlen(self, name: str, key: str) -> ResponseT[IntegerResponseT]: """ Return the number of bytes stored in the value of ``key`` within hash ``name`` @@ -5105,7 +5086,7 @@ def hexpire( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[Union[ArrayResponseT, IntegerResponseT]]: """ Sets or updates the expiration time for fields within a hash key, using relative time in seconds. @@ -5138,10 +5119,8 @@ def hexpire( conditions = [nx, xx, gt, lt] if sum(conditions) > 1: raise ValueError("Only one of 'nx', 'xx', 'gt', 'lt' can be specified.") - if isinstance(seconds, datetime.timedelta): seconds = int(seconds.total_seconds()) - options = [] if nx: options.append("NX") @@ -5151,7 +5130,6 @@ def hexpire( options.append("GT") if lt: options.append("LT") - return self.execute_command( "HEXPIRE", name, seconds, *options, "FIELDS", len(fields), *fields ) @@ -5165,7 +5143,7 @@ def hpexpire( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[Union[ArrayResponseT, IntegerResponseT]]: """ Sets or updates the expiration time for fields within a hash key, using relative time in milliseconds. @@ -5198,10 +5176,8 @@ def hpexpire( conditions = [nx, xx, gt, lt] if sum(conditions) > 1: raise ValueError("Only one of 'nx', 'xx', 'gt', 'lt' can be specified.") - if isinstance(milliseconds, datetime.timedelta): milliseconds = int(milliseconds.total_seconds() * 1000) - options = [] if nx: options.append("NX") @@ -5211,7 +5187,6 @@ def hpexpire( options.append("GT") if lt: options.append("LT") - return self.execute_command( "HPEXPIRE", name, milliseconds, *options, "FIELDS", len(fields), *fields ) @@ -5225,7 +5200,7 @@ def hexpireat( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[Union[ArrayResponseT, IntegerResponseT]]: """ Sets or updates the expiration time for fields within a hash key, using an absolute Unix timestamp in seconds. @@ -5258,10 +5233,8 @@ def hexpireat( conditions = [nx, xx, gt, lt] if sum(conditions) > 1: raise ValueError("Only one of 'nx', 'xx', 'gt', 'lt' can be specified.") - if isinstance(unix_time_seconds, datetime.datetime): unix_time_seconds = int(unix_time_seconds.timestamp()) - options = [] if nx: options.append("NX") @@ -5271,7 +5244,6 @@ def hexpireat( options.append("GT") if lt: options.append("LT") - return self.execute_command( "HEXPIREAT", name, @@ -5291,7 +5263,7 @@ def hpexpireat( xx: bool = False, gt: bool = False, lt: bool = False, - ) -> ResponseT: + ) -> ResponseT[Union[ArrayResponseT, IntegerResponseT]]: """ Sets or updates the expiration time for fields within a hash key, using an absolute Unix timestamp in milliseconds. @@ -5324,10 +5296,8 @@ def hpexpireat( conditions = [nx, xx, gt, lt] if sum(conditions) > 1: raise ValueError("Only one of 'nx', 'xx', 'gt', 'lt' can be specified.") - if isinstance(unix_time_milliseconds, datetime.datetime): unix_time_milliseconds = int(unix_time_milliseconds.timestamp() * 1000) - options = [] if nx: options.append("NX") @@ -5337,7 +5307,6 @@ def hpexpireat( options.append("GT") if lt: options.append("LT") - return self.execute_command( "HPEXPIREAT", name, @@ -5348,7 +5317,9 @@ def hpexpireat( *fields, ) - def hpersist(self, name: KeyT, *fields: str) -> ResponseT: + def hpersist( + self, name: KeyT, *fields: str + ) -> ResponseT[Union[ArrayResponseT, IntegerResponseT]]: """ Removes the expiration time for each specified field in a hash. @@ -5367,7 +5338,9 @@ def hpersist(self, name: KeyT, *fields: str) -> ResponseT: """ return self.execute_command("HPERSIST", name, "FIELDS", len(fields), *fields) - def hexpiretime(self, key: KeyT, *fields: str) -> ResponseT: + def hexpiretime( + self, key: KeyT, *fields: str + ) -> ResponseT[Union[ArrayResponseT, IntegerResponseT]]: """ Returns the expiration times of hash fields as Unix timestamps in seconds. @@ -5389,7 +5362,9 @@ def hexpiretime(self, key: KeyT, *fields: str) -> ResponseT: "HEXPIRETIME", key, "FIELDS", len(fields), *fields, keys=[key] ) - def hpexpiretime(self, key: KeyT, *fields: str) -> ResponseT: + def hpexpiretime( + self, key: KeyT, *fields: str + ) -> ResponseT[Union[ArrayResponseT, IntegerResponseT]]: """ Returns the expiration times of hash fields as Unix timestamps in milliseconds. @@ -5411,7 +5386,7 @@ def hpexpiretime(self, key: KeyT, *fields: str) -> ResponseT: "HPEXPIRETIME", key, "FIELDS", len(fields), *fields, keys=[key] ) - def httl(self, key: KeyT, *fields: str) -> ResponseT: + def httl(self, key: KeyT, *fields: str) -> ResponseT[ArrayResponseT]: """ Returns the TTL (Time To Live) in seconds for each specified field within a hash key. @@ -5433,7 +5408,7 @@ def httl(self, key: KeyT, *fields: str) -> ResponseT: "HTTL", key, "FIELDS", len(fields), *fields, keys=[key] ) - def hpttl(self, key: KeyT, *fields: str) -> ResponseT: + def hpttl(self, key: KeyT, *fields: str) -> ResponseT[ArrayResponseT]: """ Returns the TTL (Time To Live) in milliseconds for each specified field within a hash key. @@ -5563,7 +5538,9 @@ class PubSubCommands(CommandsProtocol): see https://redis.io/topics/pubsub """ - def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT: + def publish( + self, channel: ChannelT, message: EncodableT, **kwargs + ) -> ResponseT[IntegerResponseT]: """ Publish ``message`` on ``channel``. Returns the number of subscribers the message was delivered to. @@ -5572,7 +5549,9 @@ def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT """ return self.execute_command("PUBLISH", channel, message, **kwargs) - def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT: + def spublish( + self, shard_channel: ChannelT, message: EncodableT + ) -> ResponseT[IntegerResponseT]: """ Posts a message to the given shard channel. Returns the number of clients that received the message @@ -5581,7 +5560,9 @@ def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT: """ return self.execute_command("SPUBLISH", shard_channel, message) - def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: + def pubsub_channels( + self, pattern: PatternT = "*", **kwargs + ) -> ResponseT[ArrayResponseT]: """ Return a list of channels that have at least one subscriber @@ -5589,7 +5570,9 @@ def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs) - def pubsub_shardchannels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: + def pubsub_shardchannels( + self, pattern: PatternT = "*", **kwargs + ) -> ResponseT[ArrayResponseT]: """ Return a list of shard_channels that have at least one subscriber @@ -5597,7 +5580,7 @@ def pubsub_shardchannels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB SHARDCHANNELS", pattern, **kwargs) - def pubsub_numpat(self, **kwargs) -> ResponseT: + def pubsub_numpat(self, **kwargs) -> ResponseT[IntegerResponseT]: """ Returns the number of subscriptions to patterns @@ -5605,7 +5588,7 @@ def pubsub_numpat(self, **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB NUMPAT", **kwargs) - def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT: + def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT[ArrayResponseT]: """ Return a list of (channel, number of subscribers) tuples for each channel given in ``*args`` @@ -5614,7 +5597,9 @@ def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB NUMSUB", *args, **kwargs) - def pubsub_shardnumsub(self, *args: ChannelT, **kwargs) -> ResponseT: + def pubsub_shardnumsub( + self, *args: ChannelT, **kwargs + ) -> ResponseT[ArrayResponseT]: """ Return a list of (shard_channel, number of subscribers) tuples for each channel given in ``*args`` @@ -5635,12 +5620,10 @@ class ScriptCommands(CommandsProtocol): def _eval( self, command: str, script: str, numkeys: int, *keys_and_args: str - ) -> Union[Awaitable[str], str]: + ) -> ResponseT[Any]: return self.execute_command(command, script, numkeys, *keys_and_args) - def eval( - self, script: str, numkeys: int, *keys_and_args: str - ) -> Union[Awaitable[str], str]: + def eval(self, script: str, numkeys: int, *keys_and_args: str) -> ResponseT[str]: """ Execute the Lua ``script``, specifying the ``numkeys`` the script will touch and the key names and argument values in ``keys_and_args``. @@ -5653,9 +5636,7 @@ def eval( """ return self._eval("EVAL", script, numkeys, *keys_and_args) - def eval_ro( - self, script: str, numkeys: int, *keys_and_args: str - ) -> Union[Awaitable[str], str]: + def eval_ro(self, script: str, numkeys: int, *keys_and_args: str) -> ResponseT[str]: """ The read-only variant of the EVAL command @@ -5668,13 +5649,11 @@ def eval_ro( return self._eval("EVAL_RO", script, numkeys, *keys_and_args) def _evalsha( - self, command: str, sha: str, numkeys: int, *keys_and_args: list - ) -> Union[Awaitable[str], str]: + self, command: str, sha: str, numkeys: int, *keys_and_args: List + ) -> ResponseT[Any]: return self.execute_command(command, sha, numkeys, *keys_and_args) - def evalsha( - self, sha: str, numkeys: int, *keys_and_args: str - ) -> Union[Awaitable[str], str]: + def evalsha(self, sha: str, numkeys: int, *keys_and_args: str) -> ResponseT[str]: """ Use the ``sha`` to execute a Lua script already registered via EVAL or SCRIPT LOAD. Specify the ``numkeys`` the script will touch and the @@ -5688,9 +5667,7 @@ def evalsha( """ return self._evalsha("EVALSHA", sha, numkeys, *keys_and_args) - def evalsha_ro( - self, sha: str, numkeys: int, *keys_and_args: str - ) -> Union[Awaitable[str], str]: + def evalsha_ro(self, sha: str, numkeys: int, *keys_and_args: str) -> ResponseT[str]: """ The read-only variant of the EVALSHA command @@ -5703,7 +5680,7 @@ def evalsha_ro( """ return self._evalsha("EVALSHA_RO", sha, numkeys, *keys_and_args) - def script_exists(self, *args: str) -> ResponseT: + def script_exists(self, *args: str) -> ResponseT[ArrayResponseT]: """ Check if a script exists in the script cache by specifying the SHAs of each script as ``args``. Returns a list of boolean values indicating if @@ -5720,7 +5697,7 @@ def script_debug(self, *args) -> None: def script_flush( self, sync_type: Union[Literal["SYNC"], Literal["ASYNC"]] = None - ) -> ResponseT: + ) -> ResponseT[OKT]: """Flush all scripts from the script cache. ``sync_type`` is by default SYNC (synchronous) but it can also be @@ -5742,7 +5719,7 @@ def script_flush( pieces = [sync_type] return self.execute_command("SCRIPT FLUSH", *pieces) - def script_kill(self) -> ResponseT: + def script_kill(self) -> ResponseT[OKT]: """ Kill the currently executing Lua script @@ -5750,7 +5727,7 @@ def script_kill(self) -> ResponseT: """ return self.execute_command("SCRIPT KILL") - def script_load(self, script: ScriptTextT) -> ResponseT: + def script_load(self, script: ScriptTextT) -> ResponseT[BulkStringResponseT]: """ Load a Lua ``script`` into the script cache. Returns the SHA. @@ -5769,6 +5746,7 @@ def register_script(self: "Redis", script: ScriptTextT) -> Script: class AsyncScriptCommands(ScriptCommands): + async def script_debug(self, *args) -> None: return super().script_debug() @@ -5795,7 +5773,7 @@ def geoadd( nx: bool = False, xx: bool = False, ch: bool = False, - ) -> ResponseT: + ) -> ResponseT[IntegerResponseT]: """ Add the specified geospatial items to the specified key identified by the ``name`` argument. The Geospatial items are given as ordered @@ -5832,7 +5810,7 @@ def geoadd( def geodist( self, name: KeyT, place1: FieldT, place2: FieldT, unit: Union[str, None] = None - ) -> ResponseT: + ) -> ResponseT[Union[BulkStringResponseT, None]]: """ Return the distance between ``place1`` and ``place2`` members of the ``name`` key. @@ -5841,14 +5819,14 @@ def geodist( For more information see https://redis.io/commands/geodist """ - pieces: list[EncodableT] = [name, place1, place2] + pieces: List[EncodableT] = [name, place1, place2] if unit and unit not in ("m", "km", "mi", "ft"): raise DataError("GEODIST invalid unit") elif unit: pieces.append(unit) return self.execute_command("GEODIST", *pieces, keys=[name]) - def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: + def geohash(self, name: KeyT, *values: FieldT) -> ResponseT[ArrayResponseT]: """ Return the geo hash string for each item of ``values`` members of the specified key identified by the ``name`` argument. @@ -5857,7 +5835,7 @@ def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: """ return self.execute_command("GEOHASH", name, *values, keys=[name]) - def geopos(self, name: KeyT, *values: FieldT) -> ResponseT: + def geopos(self, name: KeyT, *values: FieldT) -> ResponseT[ArrayResponseT]: """ Return the positions of each item of ``values`` as members of the specified key identified by the ``name`` argument. Each position @@ -5971,7 +5949,7 @@ def georadiusbymember( def _georadiusgeneric( self, command: str, *args: EncodableT, **kwargs: Union[EncodableT, None] - ) -> ResponseT: + ) -> ResponseT[Any]: pieces = list(args) if kwargs["unit"] and kwargs["unit"] not in ("m", "km", "mi", "ft"): raise DataError("GEORADIUS invalid unit") @@ -5979,10 +5957,8 @@ def _georadiusgeneric( pieces.append(kwargs["unit"]) else: pieces.append("m") - if kwargs["any"] and kwargs["count"] is None: raise DataError("``any`` can't be provided without ``count``") - for arg_name, byte_repr in ( ("withdist", "WITHDIST"), ("withcoord", "WITHCOORD"), @@ -5990,12 +5966,10 @@ def _georadiusgeneric( ): if kwargs[arg_name]: pieces.append(byte_repr) - if kwargs["count"] is not None: pieces.extend(["COUNT", kwargs["count"]]) if kwargs["any"]: pieces.append("ANY") - if kwargs["sort"]: if kwargs["sort"] == "ASC": pieces.append("ASC") @@ -6003,16 +5977,12 @@ def _georadiusgeneric( pieces.append("DESC") else: raise DataError("GEORADIUS invalid sort") - if kwargs["store"] and kwargs["store_dist"]: raise DataError("GEORADIUS store and store_dist cant be set together") - if kwargs["store"]: pieces.extend([b"STORE", kwargs["store"]]) - if kwargs["store_dist"]: pieces.extend([b"STOREDIST", kwargs["store_dist"]]) - return self.execute_command(command, *pieces, **kwargs) def geosearch( @@ -6075,7 +6045,6 @@ def geosearch( For more information see https://redis.io/commands/geosearch """ - return self._geosearchgeneric( "GEOSEARCH", name, @@ -6145,7 +6114,7 @@ def geosearchstore( def _geosearchgeneric( self, command: str, *args: EncodableT, **kwargs: Union[EncodableT, None] - ) -> ResponseT: + ) -> ResponseT[Any]: pieces = list(args) # FROMMEMBER or FROMLONLAT @@ -6204,9 +6173,7 @@ def _geosearchgeneric( ): if kwargs[arg_name]: pieces.append(byte_repr) - kwargs["keys"] = [args[0] if command == "GEOSEARCH" else args[1]] - return self.execute_command(command, *pieces, **kwargs) @@ -6219,7 +6186,7 @@ class ModuleCommands(CommandsProtocol): see: https://redis.io/topics/modules-intro """ - def module_load(self, path, *args) -> ResponseT: + def module_load(self, path, *args) -> ResponseT[OKT]: """ Loads the module from ``path``. Passes all ``*args`` to the module, during loading. @@ -6234,7 +6201,7 @@ def module_loadex( path: str, options: Optional[List[str]] = None, args: Optional[List[str]] = None, - ) -> ResponseT: + ) -> ResponseT[OKT]: """ Loads a module from a dynamic library at runtime with configuration directives. @@ -6247,10 +6214,9 @@ def module_loadex( if args is not None: pieces.append("ARGS") pieces.extend(args) - return self.execute_command("MODULE LOADEX", path, *pieces) - def module_unload(self, name) -> ResponseT: + def module_unload(self, name) -> ResponseT[OKT]: """ Unloads the module ``name``. Raises ``ModuleError`` if ``name`` is not in loaded modules. @@ -6259,7 +6225,7 @@ def module_unload(self, name) -> ResponseT: """ return self.execute_command("MODULE UNLOAD", name) - def module_list(self) -> ResponseT: + def module_list(self) -> ResponseT[ArrayResponseT]: """ Returns a list of dictionaries containing the name and version of all loaded modules. @@ -6273,13 +6239,13 @@ def command_info(self) -> None: "COMMAND INFO is intentionally not implemented in the client." ) - def command_count(self) -> ResponseT: + def command_count(self) -> ResponseT[IntegerResponseT]: return self.execute_command("COMMAND COUNT") - def command_getkeys(self, *args) -> ResponseT: + def command_getkeys(self, *args) -> ResponseT[ArrayResponseT]: return self.execute_command("COMMAND GETKEYS", *args) - def command(self) -> ResponseT: + def command(self) -> ResponseT[ArrayResponseT]: return self.execute_command("COMMAND") @@ -6340,6 +6306,7 @@ def get_encoder(self): class AsyncModuleCommands(ModuleCommands): + async def command_info(self) -> None: return super().command_info() @@ -6349,10 +6316,10 @@ class ClusterCommands(CommandsProtocol): Class for Redis Cluster commands """ - def cluster(self, cluster_arg, *args, **kwargs) -> ResponseT: + def cluster(self, cluster_arg, *args, **kwargs) -> ResponseT[Any]: return self.execute_command(f"CLUSTER {cluster_arg.upper()}", *args, **kwargs) - def readwrite(self, **kwargs) -> ResponseT: + def readwrite(self, **kwargs) -> ResponseT[OKT]: """ Disables read queries for a connection to a Redis Cluster slave node. @@ -6360,7 +6327,7 @@ def readwrite(self, **kwargs) -> ResponseT: """ return self.execute_command("READWRITE", **kwargs) - def readonly(self, **kwargs) -> ResponseT: + def readonly(self, **kwargs) -> ResponseT[OKT]: """ Enables read queries for a connection to a Redis Cluster replica node. @@ -6372,14 +6339,14 @@ def readonly(self, **kwargs) -> ResponseT: AsyncClusterCommands = ClusterCommands -class FunctionCommands: +class FunctionCommands(CommandsProtocol): """ Redis Function commands """ def function_load( self, code: str, replace: Optional[bool] = False - ) -> Union[Awaitable[str], str]: + ) -> ResponseT[BulkStringResponseT]: """ Load a library to Redis. :param code: the source code (must start with @@ -6394,7 +6361,7 @@ def function_load( pieces.append(code) return self.execute_command("FUNCTION LOAD", *pieces) - def function_delete(self, library: str) -> Union[Awaitable[str], str]: + def function_delete(self, library: str) -> ResponseT[OKT]: """ Delete the library called ``library`` and all its functions. @@ -6402,7 +6369,7 @@ def function_delete(self, library: str) -> Union[Awaitable[str], str]: """ return self.execute_command("FUNCTION DELETE", library) - def function_flush(self, mode: str = "SYNC") -> Union[Awaitable[str], str]: + def function_flush(self, mode: str = "SYNC") -> ResponseT[OKT]: """ Deletes all the libraries. @@ -6412,7 +6379,7 @@ def function_flush(self, mode: str = "SYNC") -> Union[Awaitable[str], str]: def function_list( self, library: Optional[str] = "*", withcode: Optional[bool] = False - ) -> Union[Awaitable[List], List]: + ) -> ResponseT[ArrayResponseT]: """ Return information about the functions and libraries. :param library: pecify a pattern for matching library names @@ -6426,12 +6393,12 @@ def function_list( def _fcall( self, command: str, function, numkeys: int, *keys_and_args: Optional[List] - ) -> Union[Awaitable[str], str]: + ) -> ResponseT[Any]: return self.execute_command(command, function, numkeys, *keys_and_args) def fcall( self, function, numkeys: int, *keys_and_args: Optional[List] - ) -> Union[Awaitable[str], str]: + ) -> ResponseT[str]: """ Invoke a function. @@ -6441,7 +6408,7 @@ def fcall( def fcall_ro( self, function, numkeys: int, *keys_and_args: Optional[List] - ) -> Union[Awaitable[str], str]: + ) -> ResponseT[str]: """ This is a read-only variant of the FCALL command that cannot execute commands that modify data. @@ -6450,7 +6417,7 @@ def fcall_ro( """ return self._fcall("FCALL_RO", function, numkeys, *keys_and_args) - def function_dump(self) -> Union[Awaitable[str], str]: + def function_dump(self) -> ResponseT[BulkStringResponseT]: """ Return the serialized payload of loaded libraries. @@ -6460,12 +6427,11 @@ def function_dump(self) -> Union[Awaitable[str], str]: options = {} options[NEVER_DECODE] = [] - return self.execute_command("FUNCTION DUMP", **options) def function_restore( self, payload: str, policy: Optional[str] = "APPEND" - ) -> Union[Awaitable[str], str]: + ) -> ResponseT[OKT]: """ Restore libraries from the serialized ``payload``. You can use the optional policy argument to provide a policy @@ -6475,7 +6441,7 @@ def function_restore( """ return self.execute_command("FUNCTION RESTORE", payload, policy) - def function_kill(self) -> Union[Awaitable[str], str]: + def function_kill(self) -> ResponseT[OKT]: """ Kill a function that is currently executing. @@ -6483,7 +6449,7 @@ def function_kill(self) -> Union[Awaitable[str], str]: """ return self.execute_command("FUNCTION KILL") - def function_stats(self) -> Union[Awaitable[List], List]: + def function_stats(self) -> ResponseT[ArrayResponseT]: """ Return information about the function that's currently running and information about the available execution engines. @@ -6496,7 +6462,8 @@ def function_stats(self) -> Union[Awaitable[List], List]: AsyncFunctionCommands = FunctionCommands -class GearsCommands: +class GearsCommands(CommandsProtocol): + def tfunction_load( self, lib_code: str, replace: bool = False, config: Union[str, None] = None ) -> ResponseT: @@ -6556,14 +6523,13 @@ def tfunction_list( if lib_name is not None: pieces.append("LIBRARY") pieces.append(lib_name) - return self.execute_command("TFUNCTION LIST", *pieces) def _tfcall( self, lib_name: str, func_name: str, - keys: KeysT = None, + keys: Optional[KeysT] = None, _async: bool = False, *args: List, ) -> ResponseT: @@ -6580,11 +6546,7 @@ def _tfcall( return self.execute_command("TFCALL", *pieces) def tfcall( - self, - lib_name: str, - func_name: str, - keys: KeysT = None, - *args: List, + self, lib_name: str, func_name: str, keys: Optional[KeysT] = None, *args: List ) -> ResponseT: """ Invoke a function. @@ -6599,11 +6561,7 @@ def tfcall( return self._tfcall(lib_name, func_name, keys, False, *args) def tfcall_async( - self, - lib_name: str, - func_name: str, - keys: KeysT = None, - *args: List, + self, lib_name: str, func_name: str, keys: Optional[KeysT] = None, *args: List ) -> ResponseT: """ Invoke an async function (coroutine). diff --git a/redis/typing.py b/redis/typing.py index b4d442c444..13305051e3 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -1,11 +1,11 @@ -# from __future__ import annotations - from datetime import datetime, timedelta from typing import ( TYPE_CHECKING, Any, Awaitable, Iterable, + List, + Literal, Mapping, Protocol, Type, @@ -32,7 +32,14 @@ PatternT = _StringLikeT # Patterns matched against keys, fields etc FieldT = EncodableT # Fields within hash tables, streams and geo commands KeysT = Union[KeyT, Iterable[KeyT]] -ResponseT = Union[Awaitable[Any], Any] +OldResponseT = Union[Awaitable[Any], Any] # Deprecated +AnyResponseT = TypeVar("AnyResponseT", bound=Any) +ResponseT = Union[AnyResponseT, Awaitable[AnyResponseT]] +OKT = Literal[True] +ArrayResponseT = List +IntegerResponseT = int +NullResponseT = type(None) +BulkStringResponseT = str ChannelT = _StringLikeT GroupT = _StringLikeT # Consumer group ConsumerT = _StringLikeT # Consumer name @@ -54,8 +61,10 @@ class CommandsProtocol(Protocol): connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] - def execute_command(self, *args, **options) -> ResponseT: ... + def execute_command(self, *args, **options) -> ResponseT[Any]: ... class ClusterCommandsProtocol(CommandsProtocol): encoder: "Encoder" + + def execute_command(self, *args, **options) -> ResponseT[Any]: ...