Skip to content

Commit 462e7a6

Browse files
Enable specifying load-balancing strategy on per-command basis
1 parent 04589d4 commit 462e7a6

File tree

6 files changed

+62
-61
lines changed

6 files changed

+62
-61
lines changed

redis/asyncio/cluster.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@
3737
REPLICA,
3838
SLOT_ID,
3939
AbstractRedisCluster,
40-
LoadBalancer,
41-
LoadBalancingStrategy,
4240
block_pipeline_command,
4341
get_node_name,
4442
parse_cluster_slots,
@@ -63,6 +61,7 @@
6361
TimeoutError,
6462
TryAgainError,
6563
)
64+
from redis.load_balancer import LoadBalancer, LoadBalancingStrategy
6665
from redis.typing import AnyKeyT, EncodableT, KeyT
6766
from redis.utils import (
6867
SSL_AVAILABLE,

redis/cluster.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import threading
55
import time
66
from collections import OrderedDict
7-
from enum import Enum
87
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
98

109
from redis._parsers import CommandsParser, Encoder
@@ -37,6 +36,7 @@
3736
TimeoutError,
3837
TryAgainError,
3938
)
39+
from redis.load_balancer import LoadBalancer, LoadBalancingStrategy
4040
from redis.lock import Lock
4141
from redis.retry import Retry
4242
from redis.utils import (
@@ -1328,54 +1328,6 @@ def __del__(self):
13281328
self.redis_connection.close()
13291329

13301330

1331-
class LoadBalancingStrategy(Enum):
1332-
ROUND_ROBIN = "round_robin"
1333-
ROUND_ROBIN_REPLICAS = "round_robin_replicas"
1334-
RANDOM_REPLICA = "random_replica"
1335-
1336-
1337-
class LoadBalancer:
1338-
"""
1339-
Round-Robin Load Balancing
1340-
"""
1341-
1342-
def __init__(self, start_index: int = 0) -> None:
1343-
self.primary_to_idx = {}
1344-
self.start_index = start_index
1345-
1346-
def get_server_index(
1347-
self,
1348-
primary: str,
1349-
list_size: int,
1350-
load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.ROUND_ROBIN,
1351-
) -> int:
1352-
if load_balancing_strategy == LoadBalancingStrategy.RANDOM_REPLICA:
1353-
return self._get_random_replica_index(list_size)
1354-
else:
1355-
return self._get_round_robin_index(
1356-
primary,
1357-
list_size,
1358-
load_balancing_strategy == LoadBalancingStrategy.ROUND_ROBIN_REPLICAS,
1359-
)
1360-
1361-
def reset(self) -> None:
1362-
self.primary_to_idx.clear()
1363-
1364-
def _get_random_replica_index(self, list_size: int) -> int:
1365-
return random.randint(1, list_size - 1)
1366-
1367-
def _get_round_robin_index(
1368-
self, primary: str, list_size: int, replicas_only: bool
1369-
) -> int:
1370-
server_index = self.primary_to_idx.setdefault(primary, self.start_index)
1371-
if replicas_only and server_index == 0:
1372-
# skip the primary node index
1373-
server_index = 1
1374-
# Update the index for the next round
1375-
self.primary_to_idx[primary] = (server_index + 1) % list_size
1376-
return server_index
1377-
1378-
13791331
class NodesManager:
13801332
def __init__(
13811333
self,

redis/commands/cluster.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from redis.crc import key_slot
1818
from redis.exceptions import RedisClusterException, RedisError
19+
from redis.load_balancer import LoadBalancingStrategy
1920
from redis.typing import (
2021
AnyKeyT,
2122
ClusterCommandsProtocol,
@@ -124,7 +125,7 @@ def _partition_pairs_by_slot(
124125
return slots_to_pairs
125126

126127
def _execute_pipeline_by_slot(
127-
self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]]
128+
self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]], *, load_balancing_strategy: Optional["LoadBalancingStrategy"] = None
128129
) -> List[Any]:
129130
read_from_replicas = self.read_from_replicas and command in READ_COMMANDS
130131
pipe = self.pipeline()
@@ -133,7 +134,7 @@ def _execute_pipeline_by_slot(
133134
command,
134135
*slot_args,
135136
target_nodes=[
136-
self.nodes_manager.get_node_from_slot(slot, read_from_replicas)
137+
self.nodes_manager.get_node_from_slot(slot, read_from_replicas, load_balancing_strategy)
137138
],
138139
)
139140
for slot, slot_args in slots_to_args.items()
@@ -153,7 +154,7 @@ def _reorder_keys_by_command(
153154
}
154155
return [results[key] for key in keys]
155156

156-
def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
157+
def mget_nonatomic(self, keys: KeysT, *args: KeyT, load_balancing_strategy: Optional["LoadBalancingStrategy"] = None) -> List[Optional[Any]]:
157158
"""
158159
Splits the keys into different slots and then calls MGET
159160
for the keys of every slot. This operation will not be atomic
@@ -171,7 +172,7 @@ def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
171172
slots_to_keys = self._partition_keys_by_slot(keys)
172173

173174
# Execute commands using a pipeline
174-
res = self._execute_pipeline_by_slot("MGET", slots_to_keys)
175+
res = self._execute_pipeline_by_slot("MGET", slots_to_keys, load_balancing_strategy=load_balancing_strategy)
175176

176177
# Reorder keys in the order the user provided & return
177178
return self._reorder_keys_by_command(keys, slots_to_keys, res)
@@ -265,7 +266,7 @@ class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands):
265266
A class containing commands that handle more than one key
266267
"""
267268

268-
async def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
269+
async def mget_nonatomic(self, keys: KeysT, *args: KeyT, load_balancing_strategy: Optional["LoadBalancingStrategy"] = None) -> List[Optional[Any]]:
269270
"""
270271
Splits the keys into different slots and then calls MGET
271272
for the keys of every slot. This operation will not be atomic
@@ -283,7 +284,7 @@ async def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
283284
slots_to_keys = self._partition_keys_by_slot(keys)
284285

285286
# Execute commands using a pipeline
286-
res = await self._execute_pipeline_by_slot("MGET", slots_to_keys)
287+
res = await self._execute_pipeline_by_slot("MGET", slots_to_keys, load_balancing_strategy=load_balancing_strategy)
287288

288289
# Reorder keys in the order the user provided & return
289290
return self._reorder_keys_by_command(keys, slots_to_keys, res)
@@ -320,7 +321,7 @@ async def _split_command_across_slots(self, command: str, *keys: KeyT) -> int:
320321
return sum(await self._execute_pipeline_by_slot(command, slots_to_keys))
321322

322323
async def _execute_pipeline_by_slot(
323-
self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]]
324+
self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]], *, load_balancing_strategy: Optional["LoadBalancingStrategy"] = None
324325
) -> List[Any]:
325326
if self._initialize:
326327
await self.initialize()
@@ -331,7 +332,7 @@ async def _execute_pipeline_by_slot(
331332
command,
332333
*slot_args,
333334
target_nodes=[
334-
self.nodes_manager.get_node_from_slot(slot, read_from_replicas)
335+
self.nodes_manager.get_node_from_slot(slot, read_from_replicas, load_balancing_strategy)
335336
],
336337
)
337338
for slot, slot_args in slots_to_args.items()

redis/load_balancer.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from enum import Enum
2+
3+
4+
class LoadBalancingStrategy(Enum):
5+
ROUND_ROBIN = "round_robin"
6+
ROUND_ROBIN_REPLICAS = "round_robin_replicas"
7+
RANDOM_REPLICA = "random_replica"
8+
9+
10+
class LoadBalancer:
11+
"""
12+
Round-Robin Load Balancing
13+
"""
14+
15+
def __init__(self, start_index: int = 0) -> None:
16+
self.primary_to_idx = {}
17+
self.start_index = start_index
18+
19+
def get_server_index(
20+
self,
21+
primary: str,
22+
list_size: int,
23+
load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.ROUND_ROBIN,
24+
) -> int:
25+
if load_balancing_strategy == LoadBalancingStrategy.RANDOM_REPLICA:
26+
return self._get_random_replica_index(list_size)
27+
else:
28+
return self._get_round_robin_index(
29+
primary,
30+
list_size,
31+
load_balancing_strategy == LoadBalancingStrategy.ROUND_ROBIN_REPLICAS,
32+
)
33+
34+
def reset(self) -> None:
35+
self.primary_to_idx.clear()
36+
37+
def _get_random_replica_index(self, list_size: int) -> int:
38+
return random.randint(1, list_size - 1)
39+
40+
def _get_round_robin_index(
41+
self, primary: str, list_size: int, replicas_only: bool
42+
) -> int:
43+
server_index = self.primary_to_idx.setdefault(primary, self.start_index)
44+
if replicas_only and server_index == 0:
45+
# skip the primary node index
46+
server_index = 1
47+
# Update the index for the next round
48+
self.primary_to_idx[primary] = (server_index + 1) % list_size
49+
return server_index

tests/test_asyncio/test_cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
PIPELINE_BLOCKED_COMMANDS,
1919
PRIMARY,
2020
REPLICA,
21-
LoadBalancingStrategy,
2221
get_node_name,
2322
)
2423
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
@@ -34,6 +33,7 @@
3433
RedisError,
3534
ResponseError,
3635
)
36+
from redis.load_balancer import LoadBalancingStrategy
3737
from redis.utils import str_if_bytes
3838
from tests.conftest import (
3939
assert_resp_response,

tests/test_cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
REDIS_CLUSTER_HASH_SLOTS,
2121
REPLICA,
2222
ClusterNode,
23-
LoadBalancingStrategy,
2423
NodesManager,
2524
RedisCluster,
2625
get_node_name,
@@ -39,6 +38,7 @@
3938
ResponseError,
4039
TimeoutError,
4140
)
41+
from redis.load_balancer import LoadBalancingStrategy
4242
from redis.retry import Retry
4343
from redis.utils import str_if_bytes
4444
from tests.test_pubsub import wait_for_message

0 commit comments

Comments
 (0)