From 462e7a6d3c004e65364ee4acc22a8145707b71fd Mon Sep 17 00:00:00 2001 From: Travis Hunter Date: Fri, 11 Apr 2025 12:27:47 -0400 Subject: [PATCH] Enable specifying load-balancing strategy on per-command basis --- redis/asyncio/cluster.py | 3 +- redis/cluster.py | 50 +----------------------------- redis/commands/cluster.py | 17 +++++----- redis/load_balancer.py | 49 +++++++++++++++++++++++++++++ tests/test_asyncio/test_cluster.py | 2 +- tests/test_cluster.py | 2 +- 6 files changed, 62 insertions(+), 61 deletions(-) create mode 100644 redis/load_balancer.py diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 28fcd3aa23..abb9d1e319 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -37,8 +37,6 @@ REPLICA, SLOT_ID, AbstractRedisCluster, - LoadBalancer, - LoadBalancingStrategy, block_pipeline_command, get_node_name, parse_cluster_slots, @@ -63,6 +61,7 @@ TimeoutError, TryAgainError, ) +from redis.load_balancer import LoadBalancer, LoadBalancingStrategy from redis.typing import AnyKeyT, EncodableT, KeyT from redis.utils import ( SSL_AVAILABLE, diff --git a/redis/cluster.py b/redis/cluster.py index 4ec03ac98f..55a3bf7758 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -4,7 +4,6 @@ import threading import time from collections import OrderedDict -from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Union from redis._parsers import CommandsParser, Encoder @@ -37,6 +36,7 @@ TimeoutError, TryAgainError, ) +from redis.load_balancer import LoadBalancer, LoadBalancingStrategy from redis.lock import Lock from redis.retry import Retry from redis.utils import ( @@ -1328,54 +1328,6 @@ def __del__(self): self.redis_connection.close() -class LoadBalancingStrategy(Enum): - ROUND_ROBIN = "round_robin" - ROUND_ROBIN_REPLICAS = "round_robin_replicas" - RANDOM_REPLICA = "random_replica" - - -class LoadBalancer: - """ - Round-Robin Load Balancing - """ - - def __init__(self, start_index: int = 0) -> None: - self.primary_to_idx = {} - self.start_index = start_index - - def get_server_index( - self, - primary: str, - list_size: int, - load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.ROUND_ROBIN, - ) -> int: - if load_balancing_strategy == LoadBalancingStrategy.RANDOM_REPLICA: - return self._get_random_replica_index(list_size) - else: - return self._get_round_robin_index( - primary, - list_size, - load_balancing_strategy == LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, - ) - - def reset(self) -> None: - self.primary_to_idx.clear() - - def _get_random_replica_index(self, list_size: int) -> int: - return random.randint(1, list_size - 1) - - def _get_round_robin_index( - self, primary: str, list_size: int, replicas_only: bool - ) -> int: - server_index = self.primary_to_idx.setdefault(primary, self.start_index) - if replicas_only and server_index == 0: - # skip the primary node index - server_index = 1 - # Update the index for the next round - self.primary_to_idx[primary] = (server_index + 1) % list_size - return server_index - - class NodesManager: def __init__( self, diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index 13f2035265..ddc993fae3 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -16,6 +16,7 @@ from redis.crc import key_slot from redis.exceptions import RedisClusterException, RedisError +from redis.load_balancer import LoadBalancingStrategy from redis.typing import ( AnyKeyT, ClusterCommandsProtocol, @@ -124,7 +125,7 @@ def _partition_pairs_by_slot( return slots_to_pairs def _execute_pipeline_by_slot( - self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]] + self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]], *, load_balancing_strategy: Optional["LoadBalancingStrategy"] = None ) -> List[Any]: read_from_replicas = self.read_from_replicas and command in READ_COMMANDS pipe = self.pipeline() @@ -133,7 +134,7 @@ def _execute_pipeline_by_slot( command, *slot_args, target_nodes=[ - self.nodes_manager.get_node_from_slot(slot, read_from_replicas) + self.nodes_manager.get_node_from_slot(slot, read_from_replicas, load_balancing_strategy) ], ) for slot, slot_args in slots_to_args.items() @@ -153,7 +154,7 @@ def _reorder_keys_by_command( } return [results[key] for key in keys] - def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]: + def mget_nonatomic(self, keys: KeysT, *args: KeyT, load_balancing_strategy: Optional["LoadBalancingStrategy"] = None) -> List[Optional[Any]]: """ Splits the keys into different slots and then calls MGET for the keys of every slot. This operation will not be atomic @@ -171,7 +172,7 @@ def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]: slots_to_keys = self._partition_keys_by_slot(keys) # Execute commands using a pipeline - res = self._execute_pipeline_by_slot("MGET", slots_to_keys) + res = self._execute_pipeline_by_slot("MGET", slots_to_keys, load_balancing_strategy=load_balancing_strategy) # Reorder keys in the order the user provided & return return self._reorder_keys_by_command(keys, slots_to_keys, res) @@ -265,7 +266,7 @@ class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands): A class containing commands that handle more than one key """ - async def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]: + async def mget_nonatomic(self, keys: KeysT, *args: KeyT, load_balancing_strategy: Optional["LoadBalancingStrategy"] = None) -> List[Optional[Any]]: """ Splits the keys into different slots and then calls MGET for the keys of every slot. This operation will not be atomic @@ -283,7 +284,7 @@ async def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]: slots_to_keys = self._partition_keys_by_slot(keys) # Execute commands using a pipeline - res = await self._execute_pipeline_by_slot("MGET", slots_to_keys) + res = await self._execute_pipeline_by_slot("MGET", slots_to_keys, load_balancing_strategy=load_balancing_strategy) # Reorder keys in the order the user provided & return return self._reorder_keys_by_command(keys, slots_to_keys, res) @@ -320,7 +321,7 @@ async def _split_command_across_slots(self, command: str, *keys: KeyT) -> int: return sum(await self._execute_pipeline_by_slot(command, slots_to_keys)) async def _execute_pipeline_by_slot( - self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]] + self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]], *, load_balancing_strategy: Optional["LoadBalancingStrategy"] = None ) -> List[Any]: if self._initialize: await self.initialize() @@ -331,7 +332,7 @@ async def _execute_pipeline_by_slot( command, *slot_args, target_nodes=[ - self.nodes_manager.get_node_from_slot(slot, read_from_replicas) + self.nodes_manager.get_node_from_slot(slot, read_from_replicas, load_balancing_strategy) ], ) for slot, slot_args in slots_to_args.items() diff --git a/redis/load_balancer.py b/redis/load_balancer.py new file mode 100644 index 0000000000..dd3be19e09 --- /dev/null +++ b/redis/load_balancer.py @@ -0,0 +1,49 @@ +from enum import Enum + + +class LoadBalancingStrategy(Enum): + ROUND_ROBIN = "round_robin" + ROUND_ROBIN_REPLICAS = "round_robin_replicas" + RANDOM_REPLICA = "random_replica" + + +class LoadBalancer: + """ + Round-Robin Load Balancing + """ + + def __init__(self, start_index: int = 0) -> None: + self.primary_to_idx = {} + self.start_index = start_index + + def get_server_index( + self, + primary: str, + list_size: int, + load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.ROUND_ROBIN, + ) -> int: + if load_balancing_strategy == LoadBalancingStrategy.RANDOM_REPLICA: + return self._get_random_replica_index(list_size) + else: + return self._get_round_robin_index( + primary, + list_size, + load_balancing_strategy == LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + ) + + def reset(self) -> None: + self.primary_to_idx.clear() + + def _get_random_replica_index(self, list_size: int) -> int: + return random.randint(1, list_size - 1) + + def _get_round_robin_index( + self, primary: str, list_size: int, replicas_only: bool + ) -> int: + server_index = self.primary_to_idx.setdefault(primary, self.start_index) + if replicas_only and server_index == 0: + # skip the primary node index + server_index = 1 + # Update the index for the next round + self.primary_to_idx[primary] = (server_index + 1) % list_size + return server_index \ No newline at end of file diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index a0429152ec..4de01896dc 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -18,7 +18,6 @@ PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, - LoadBalancingStrategy, get_node_name, ) from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot @@ -34,6 +33,7 @@ RedisError, ResponseError, ) +from redis.load_balancer import LoadBalancingStrategy from redis.utils import str_if_bytes from tests.conftest import ( assert_resp_response, diff --git a/tests/test_cluster.py b/tests/test_cluster.py index d96342f87a..b4869d64ff 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -20,7 +20,6 @@ REDIS_CLUSTER_HASH_SLOTS, REPLICA, ClusterNode, - LoadBalancingStrategy, NodesManager, RedisCluster, get_node_name, @@ -39,6 +38,7 @@ ResponseError, TimeoutError, ) +from redis.load_balancer import LoadBalancingStrategy from redis.retry import Retry from redis.utils import str_if_bytes from tests.test_pubsub import wait_for_message