Skip to content

Allow multi-node commands in async pipeline #3439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 64 additions & 25 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
@@ -1070,12 +1070,13 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
ret = False
for cmd in commands:
try:
cmd.result = await self.parse_response(
result = await self.parse_response(
connection, cmd.args[0], **cmd.kwargs
)
except Exception as e:
cmd.result = e
result = e
ret = True
cmd.set_node_result(self.name, result)

# Release connection
self._free.append(connection)
@@ -1514,7 +1515,7 @@ async def _execute(
allow_redirections: bool = True,
) -> List[Any]:
todo = [
cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception)
cmd for cmd in stack if not cmd.unwrap_result() or cmd.get_all_exceptions()
]

nodes = {}
@@ -1530,12 +1531,11 @@ async def _execute(
raise RedisClusterException(
f"No targets were found to execute {cmd.args} command on"
)
if len(target_nodes) > 1:
raise RedisClusterException(f"Too many targets for command {cmd.args}")
node = target_nodes[0]
if node.name not in nodes:
nodes[node.name] = (node, [])
nodes[node.name][1].append(cmd)
cmd.target_nodes = target_nodes
for node in target_nodes:
if node.name not in nodes:
nodes[node.name] = (node, [])
nodes[node.name][1].append(cmd)

errors = await asyncio.gather(
*(
@@ -1548,40 +1548,58 @@ async def _execute(
if allow_redirections:
# send each errored command individually
for cmd in todo:
if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
try:
cmd.result = await client.execute_command(
*cmd.args, **cmd.kwargs
)
except Exception as e:
cmd.result = e
for name, exc in cmd.get_all_exceptions():
if isinstance(exc, (TryAgainError, MovedError, AskError)):
try:
result = await client.execute_command(
*cmd.args, **cmd.kwargs
)
except Exception as e:
result = e

if isinstance(result, dict):
cmd.result = result
else:
cmd.set_node_result(name, result)

# We have already retried the command on all nodes.
break

if raise_on_error:
for cmd in todo:
result = cmd.result
if isinstance(result, Exception):
name_exc = cmd.get_first_exception()
if name_exc:
name, exc = name_exc
command = " ".join(map(safe_str, cmd.args))
# Note: this will only raise the first exception, but that is
# consistent with RedisCluster.execute_command.
msg = (
f"Command # {cmd.position + 1} ({command}) of pipeline "
f"caused error: {result.args}"
f"caused error on node {name}: "
f"{exc.args}"
)
result.args = (msg,) + result.args[1:]
raise result
exc.args = (msg,) + exc.args[1:]
raise exc

default_node = nodes.get(client.get_default_node().name)
if default_node is not None:
# This pipeline execution used the default node, check if we need
# to replace it.
# Note: when the error is raised we'll reset the default node in the
# caller function.
has_exc = False
for cmd in default_node[1]:
# Check if it has a command that failed with a relevant
# exception
if type(cmd.result) in self.__class__.ERRORS_ALLOW_RETRY:
client.replace_default_node()
for name, exc in cmd.get_all_exceptions():
if type(exc) in self.__class__.ERRORS_ALLOW_RETRY:
client.replace_default_node()
has_exc = True
break
if has_exc:
break

return [cmd.result for cmd in stack]
return [cmd.unwrap_result() for cmd in stack]

def _split_command_across_slots(
self, command: str, *keys: KeyT
@@ -1620,7 +1638,28 @@ def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
self.args = args
self.kwargs = kwargs
self.position = position
self.result: Union[Any, Exception] = None
self.result: Dict[str, Union[Any, Exception]] = {}
self.target_nodes = None

def set_node_result(self, node_name: str, result: Union[Any, Exception]):
self.result[node_name] = result

def unwrap_result(
self,
) -> Optional[Union[Any, Exception, Dict[str, Union[Any, Exception]]]]:
if len(self.result) == 0:
return None
if len(self.result) == 1:
return next(iter(self.result.values()))
return self.result

def get_first_exception(self) -> Optional[Tuple[str, Exception]]:
return next(
((n, r) for n, r in self.result.items() if isinstance(r, Exception)), None
)

def get_all_exceptions(self) -> List[Tuple[str, Exception]]:
return [(n, r) for n, r in self.result.items() if isinstance(r, Exception)]

def __repr__(self) -> str:
return f"[{self.position}] {self.args} ({self.kwargs})"