From 87f3bf52a0cdbf4c307c9c6c17828b245b5505b7 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Fri, 8 Aug 2025 18:53:30 +0300 Subject: [PATCH 01/16] wip --- metaflow/plugins/argo/argo_workflows.py | 54 ++++++++++++++++--- .../plugins/argo/argo_workflows_decorator.py | 7 +++ 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 7e6ac43ce95..51f7c678adb 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -928,6 +928,7 @@ def _visit( templates=None, dag_tasks=None, parent_foreach=None, + visited_nodes=None, ): # Returns Tuple[List[Template], List[DAGTask]] """ """ # Every for-each node results in a separate subDAG and an equivalent @@ -937,10 +938,23 @@ def _visit( # of the for-each node. # Emit if we have reached the end of the sub workflow + if visited_nodes is None: + visited_nodes = set() if dag_tasks is None: + print("RESET DAG_TASKS") dag_tasks = [] if templates is None: + print("RESET TEMPLATES") templates = [] + + # Break early if we have reached a node we already visited. Happens when parsing through all conditional branches of split-switch + if node.name in visited_nodes: + print(f"BROKE EARLY on step :{node.name}") + return templates, dag_tasks + else: + print(f"added to visited :{node.name}") + visited_nodes.add(node.name) + if exit_node is not None and exit_node is node.name: return templates, dag_tasks if node.name == "start": @@ -948,12 +962,7 @@ def _visit( dag_task = DAGTask(self._sanitize(node.name)).template( self._sanitize(node.name) ) - if node.type == "split-switch": - raise ArgoWorkflowsException( - "Deploying flows with switch statement " - "to Argo Workflows is not supported currently." - ) - elif ( + if ( node.is_inside_foreach and self.graph[node.in_funcs[0]].type == "foreach" and not self.graph[node.in_funcs[0]].parallel_foreach @@ -1113,6 +1122,7 @@ def _visit( templates, dag_tasks, parent_foreach, + visited_nodes, ) return _visit( self.graph[node.matching_join], @@ -1120,6 +1130,27 @@ def _visit( templates, dag_tasks, parent_foreach, + visited_nodes, + ) + elif node.type == "split-switch": + # Traverse all branches of a switch split. This should work as all branches lead to 'exit_node' + for n in node.out_funcs[:-1]: + _visit( + self.graph[n], + exit_node, + templates, + dag_tasks, + parent_foreach, + visited_nodes, + ) + + return _visit( + self.graph[node.out_funcs[-1:][0]], + exit_node, + templates, + dag_tasks, + parent_foreach, + visited_nodes, ) # For foreach nodes generate a new sub DAGTemplate # We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`) @@ -1200,6 +1231,7 @@ def _visit( templates, [], node.name, + visited_nodes, ) # How do foreach's work on Argo: @@ -1318,6 +1350,7 @@ def _visit( templates, dag_tasks, parent_foreach, + visited_nodes, ) # For linear nodes continue traversing to the next node if node.type in ("linear", "join", "start"): @@ -1327,6 +1360,7 @@ def _visit( templates, dag_tasks, parent_foreach, + visited_nodes, ) else: raise ArgoWorkflowsException( @@ -1849,6 +1883,14 @@ def _container_templates(self): # are derived at runtime. if not (node.name == "end" or node.parallel_step): outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})] + + # If this step is a split-switch one, we need to output the switch step name + # Note we can not use node.type for this, as the start step can also be a switching one + if node.type == "split-switch": + outputs.append( + Parameter("switch-step").valueFrom({"path": "/mnt/out/switch_step"}) + ) + if node.type == "foreach": # Emit split cardinality from foreach task outputs.append( diff --git a/metaflow/plugins/argo/argo_workflows_decorator.py b/metaflow/plugins/argo/argo_workflows_decorator.py index ce92d34b5b4..1a020b96db5 100644 --- a/metaflow/plugins/argo/argo_workflows_decorator.py +++ b/metaflow/plugins/argo/argo_workflows_decorator.py @@ -123,6 +123,13 @@ def task_finished( with open("/mnt/out/split_cardinality", "w") as file: json.dump(flow._foreach_num_splits, file) + # For conditional branches we need to record the value of the switch to disk, in order to pass it as an + # output from the switching step to be used further down the DAG + if graph[step_name].type == "switch-split": + switch_step_name = getattr(self, graph[step_name].condition) + with open("/mnt/out/switch_step", "w") as file: + json.dump(switch_step_name, file) + # For steps that have a `@parallel` decorator set to them, we will be relying on Jobsets # to run the task. In this case, we cannot set anything in the # `/mnt/out` directory, since such form of output mounts are not available to Jobset executions. From a00d57d7cb0f4e2d1f8f310cb9a5870a0e74c37a Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Mon, 11 Aug 2025 14:17:55 +0300 Subject: [PATCH 02/16] add conditional info to graph parsing --- metaflow/graph.py | 60 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/metaflow/graph.py b/metaflow/graph.py index 5013971eb28..fd56d8000fb 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -3,6 +3,7 @@ import re from itertools import chain +from typing import List, Optional from .util import to_pod @@ -80,6 +81,9 @@ def __init__( self.split_parents = [] self.split_branches = [] self.matching_join = None + self.is_conditional = False # will this node always be executed, or is it in a conditional branch? + self.conditional_branch = [] + self.conditional_join = None # Node where conditional branches end, and further nodes always execute. # these attributes are populated by _postprocess self.is_inside_foreach = False @@ -297,7 +301,14 @@ def _postprocess(self): node.is_inside_foreach = True def _traverse_graph(self): - def traverse(node, seen, split_parents, split_branches): + def traverse( + node, + seen, + split_parents, + split_branches, + conditional_branch: List[str], + conditional_root_nodes: Optional[List[List[str]]] = None, + ): add_split_branch = False try: self.sorted_nodes.remove(node.name) @@ -312,6 +323,14 @@ def traverse(node, seen, split_parents, split_branches): elif node.type == "split-switch": node.split_parents = split_parents node.split_branches = split_branches + + conditional_branch = conditional_branch + [node.name] + node.conditional_branch = conditional_branch + conditional_root_nodes = ( + [node.out_funcs] + if not conditional_root_nodes + else conditional_root_nodes + [node.out_funcs] + ) elif node.type == "join": # ignore joins without splits if split_parents: @@ -324,6 +343,41 @@ def traverse(node, seen, split_parents, split_branches): node.split_parents = split_parents node.split_branches = split_branches + if conditional_root_nodes and not node.type == "split-switch": + conditional_branch = conditional_branch + [node.name] + node.conditional_branch = conditional_branch + # Multiple cases for conditional branching. TODO: describe the structure + # 1. we are in only one conditional branch + # 2. we are in a nested conditional branch + + *root_nodes, last_root_nodes = conditional_root_nodes + # Check if the node is joining all of the conditional root nodes branches. + is_conditional_join = all( + any(p in last_root_nodes for p in self[in_func].conditional_branch) + for in_func in node.in_funcs + ) + + if is_conditional_join: + conditional_root_nodes = root_nodes + + # we are in a conditional branch if we have conditional root nodes left open, and + # we did not join the most recent conditional branches. + is_in_conditional_branch = ( + bool(conditional_root_nodes) and not is_conditional_join + ) + + if not is_in_conditional_branch: + conditional_branch = [] + # add the conditional join step info + for n in set( + step + for in_func in node.in_funcs + for step in self[in_func].conditional_branch + ): + self[n].conditional_join = node.name + + node.is_conditional = is_in_conditional_branch + for n in node.out_funcs: # graph may contain loops - ignore them if n not in seen: @@ -336,10 +390,12 @@ def traverse(node, seen, split_parents, split_branches): seen + [n], split_parents, split_branches + ([n] if add_split_branch else []), + conditional_branch, + conditional_root_nodes, ) if "start" in self: - traverse(self["start"], [], [], []) + traverse(self["start"], [], [], [], []) # fix the order of in_funcs for node in self.nodes.values(): From 3ee6557f5ce6f4c07a105da88d7a943bb3fcff0a Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Mon, 11 Aug 2025 18:05:35 +0300 Subject: [PATCH 03/16] correctly dump chosen step to disk for argo --- metaflow/plugins/argo/argo_workflows_decorator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows_decorator.py b/metaflow/plugins/argo/argo_workflows_decorator.py index 1a020b96db5..67c8fd91363 100644 --- a/metaflow/plugins/argo/argo_workflows_decorator.py +++ b/metaflow/plugins/argo/argo_workflows_decorator.py @@ -125,10 +125,12 @@ def task_finished( # For conditional branches we need to record the value of the switch to disk, in order to pass it as an # output from the switching step to be used further down the DAG - if graph[step_name].type == "switch-split": - switch_step_name = getattr(self, graph[step_name].condition) + if graph[step_name].type == "split-switch": + # TODO: A nicer way to access the chosen step? + _out_funcs, _ = flow._transition + chosen_step = _out_funcs[0] with open("/mnt/out/switch_step", "w") as file: - json.dump(switch_step_name, file) + file.write(chosen_step) # For steps that have a `@parallel` decorator set to them, we will be relying on Jobsets # to run the task. In this case, we cannot set anything in the From 0e99798a8937dc373d40a1a3a4513f6bbf9cd9b9 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Mon, 11 Aug 2025 18:53:39 +0300 Subject: [PATCH 04/16] fix conditional usage for argo DAG --- metaflow/plugins/argo/argo_workflows.py | 66 +++++++++++++++---------- 1 file changed, 41 insertions(+), 25 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 51f7c678adb..a2770e9687c 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -928,7 +928,6 @@ def _visit( templates=None, dag_tasks=None, parent_foreach=None, - visited_nodes=None, ): # Returns Tuple[List[Template], List[DAGTask]] """ """ # Every for-each node results in a separate subDAG and an equivalent @@ -938,23 +937,11 @@ def _visit( # of the for-each node. # Emit if we have reached the end of the sub workflow - if visited_nodes is None: - visited_nodes = set() if dag_tasks is None: - print("RESET DAG_TASKS") dag_tasks = [] if templates is None: - print("RESET TEMPLATES") templates = [] - # Break early if we have reached a node we already visited. Happens when parsing through all conditional branches of split-switch - if node.name in visited_nodes: - print(f"BROKE EARLY on step :{node.name}") - return templates, dag_tasks - else: - print(f"added to visited :{node.name}") - visited_nodes.add(node.name) - if exit_node is not None and exit_node is node.name: return templates, dag_tasks if node.name == "start": @@ -1096,15 +1083,42 @@ def _visit( ] ) + conditional_deps = [ + "%s.Succeeded" % self._sanitize(in_func) + for in_func in node.in_funcs + if self.graph[in_func].is_conditional + ] + required_deps = [ + "%s.Succeeded" % self._sanitize(in_func) + for in_func in node.in_funcs + if not self.graph[in_func].is_conditional + ] + both_conditions = required_deps and conditional_deps + + depends_str = "{required}{_and}{conditional}".format( + required=("(%s)" if both_conditions else "%s") + % " && ".join(required_deps), + _and=" && " if both_conditions else "", + conditional=("(%s)" if both_conditions else "%s") + % " || ".join(conditional_deps), + ) dag_task = ( DAGTask(self._sanitize(node.name)) - .dependencies( - [self._sanitize(in_func) for in_func in node.in_funcs] - ) + .depends(depends_str) .template(self._sanitize(node.name)) .arguments(Arguments().parameters(parameters)) ) + # Add conditional if this is the first step in a conditional branch + if node.is_conditional and not any( + self.graph[in_func].is_conditional for in_func in node.in_funcs + ): + in_func = node.in_funcs[0] + dag_task.when( + "{{tasks.%s.outputs.parameters.switch-step}}==%s" + % (self._sanitize(in_func), node.name) + ) + dag_tasks.append(dag_task) # End the workflow if we have reached the end of the flow if node.type == "end": @@ -1133,24 +1147,21 @@ def _visit( visited_nodes, ) elif node.type == "split-switch": - # Traverse all branches of a switch split. This should work as all branches lead to 'exit_node' - for n in node.out_funcs[:-1]: + for n in node.out_funcs: _visit( self.graph[n], - exit_node, + node.conditional_join, templates, dag_tasks, parent_foreach, - visited_nodes, ) return _visit( - self.graph[node.out_funcs[-1:][0]], + self.graph[node.conditional_join], exit_node, templates, dag_tasks, parent_foreach, - visited_nodes, ) # For foreach nodes generate a new sub DAGTemplate # We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`) @@ -1231,7 +1242,6 @@ def _visit( templates, [], node.name, - visited_nodes, ) # How do foreach's work on Argo: @@ -1350,7 +1360,6 @@ def _visit( templates, dag_tasks, parent_foreach, - visited_nodes, ) # For linear nodes continue traversing to the next node if node.type in ("linear", "join", "start"): @@ -1360,7 +1369,6 @@ def _visit( templates, dag_tasks, parent_foreach, - visited_nodes, ) else: raise ArgoWorkflowsException( @@ -4023,6 +4031,10 @@ def dependencies(self, dependencies): self.payload["dependencies"] = dependencies return self + def depends(self, depends: str): + self.payload["depends"] = depends + return self + def template(self, template): # Template reference self.payload["template"] = template @@ -4034,6 +4046,10 @@ def inline(self, template): self.payload["inline"] = template.to_json() return self + def when(self, when: str): + self.payload["when"] = when + return self + def with_param(self, with_param): self.payload["withParam"] = with_param return self From d0a03fd8f6370913a74f19f488b355b7e754fb01 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Tue, 12 Aug 2025 12:27:20 +0300 Subject: [PATCH 05/16] introduce script for parsing conditional input paths. rename and introduce more properties to graph --- metaflow/graph.py | 13 ++++++++++--- metaflow/plugins/argo/argo_workflows.py | 9 +++++++-- .../plugins/argo/conditional_input_paths.py | 18 ++++++++++++++++++ 3 files changed, 35 insertions(+), 5 deletions(-) create mode 100644 metaflow/plugins/argo/conditional_input_paths.py diff --git a/metaflow/graph.py b/metaflow/graph.py index fd56d8000fb..26cb4ec2470 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -81,9 +81,15 @@ def __init__( self.split_parents = [] self.split_branches = [] self.matching_join = None + # Conditional info, also populated in _traverse_graph self.is_conditional = False # will this node always be executed, or is it in a conditional branch? - self.conditional_branch = [] - self.conditional_join = None # Node where conditional branches end, and further nodes always execute. + self.is_conditional_join = ( + False # Does this node 'join' some set of conditional branches? + ) + self.conditional_branch = ( + [] + ) # All the steps leading to this node that depends on a condition, starting from the split-switch + self.conditional_end_node = None # Node where conditional branches end, and further nodes always execute. # these attributes are populated by _postprocess self.is_inside_foreach = False @@ -358,6 +364,7 @@ def traverse( ) if is_conditional_join: + node.is_conditional_join = True conditional_root_nodes = root_nodes # we are in a conditional branch if we have conditional root nodes left open, and @@ -374,7 +381,7 @@ def traverse( for in_func in node.in_funcs for step in self[in_func].conditional_branch ): - self[n].conditional_join = node.name + self[n].conditional_end_node = node.name node.is_conditional = is_in_conditional_branch diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index a2770e9687c..dc80754fd9a 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1150,14 +1150,14 @@ def _visit( for n in node.out_funcs: _visit( self.graph[n], - node.conditional_join, + node.conditional_end_node, templates, dag_tasks, parent_foreach, ) return _visit( - self.graph[node.conditional_join], + self.graph[node.conditional_end_node], exit_node, templates, dag_tasks, @@ -1610,6 +1610,11 @@ def _container_templates(self): ] ) input_paths = "%s/_parameters/%s" % (run_id, task_id_params) + elif node.is_conditional_join: + input_paths = ( + "$(python -m metaflow.plugins.argo.conditional_input_paths %s)" + % input_paths + ) elif ( node.type == "join" and self.graph[node.split_parents[-1]].type == "foreach" diff --git a/metaflow/plugins/argo/conditional_input_paths.py b/metaflow/plugins/argo/conditional_input_paths.py new file mode 100644 index 00000000000..b224faf35f8 --- /dev/null +++ b/metaflow/plugins/argo/conditional_input_paths.py @@ -0,0 +1,18 @@ +from math import inf +import sys +from metaflow.util import decompress_list, compress_list + + +def generate_input_paths(input_paths): + # => run_id/step/:foo,bar + paths = decompress_list(input_paths) + + # some of the paths are going to be malformed due to never having executed per conditional. + # strip these out of the list. + + trimmed = [path for path in paths if not "{{" in path] + return compress_list(trimmed, zlibmin=inf) + + +if __name__ == "__main__": + print(generate_input_paths(sys.argv[1])) From 2d8de9212c34c49a6f6826ad1228fa8e62fa95cf Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Wed, 13 Aug 2025 12:45:57 +0300 Subject: [PATCH 06/16] cleanup --- metaflow/plugins/argo/argo_workflows.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index dc80754fd9a..b5e49d6c23f 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1136,7 +1136,6 @@ def _visit( templates, dag_tasks, parent_foreach, - visited_nodes, ) return _visit( self.graph[node.matching_join], @@ -1144,7 +1143,6 @@ def _visit( templates, dag_tasks, parent_foreach, - visited_nodes, ) elif node.type == "split-switch": for n in node.out_funcs: From dca2f378050048b882700db56053c0d07b340b61 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Thu, 14 Aug 2025 23:31:33 +0300 Subject: [PATCH 07/16] reworking conditional graph parsing --- metaflow/graph.py | 107 +++++++++++++----------- metaflow/plugins/argo/argo_workflows.py | 4 +- 2 files changed, 60 insertions(+), 51 deletions(-) diff --git a/metaflow/graph.py b/metaflow/graph.py index 26cb4ec2470..1ef03d9ac24 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -81,17 +81,18 @@ def __init__( self.split_parents = [] self.split_branches = [] self.matching_join = None - # Conditional info, also populated in _traverse_graph + # these attributes are populated by _postprocess + self.is_inside_foreach = False + # Conditional info self.is_conditional = False # will this node always be executed, or is it in a conditional branch? + self.matching_conditional_join = None # which step joins the conditional branches. filled for split-switch only. self.is_conditional_join = ( False # Does this node 'join' some set of conditional branches? ) + self.conditional_parents = [] self.conditional_branch = ( [] ) # All the steps leading to this node that depends on a condition, starting from the split-switch - self.conditional_end_node = None # Node where conditional branches end, and further nodes always execute. - # these attributes are populated by _postprocess - self.is_inside_foreach = False def _expr_str(self, expr): return "%s.%s" % (expr.value.id, expr.attr) @@ -306,14 +307,54 @@ def _postprocess(self): if [f for f in foreaches if self.nodes[f].matching_join != node.name]: node.is_inside_foreach = True + # Fill in conditionals related info. + if node.conditional_parents: + # do the required postprocessing for anything requiring node.in_funcs + + # does this node close the latest conditional parent branches? + conditional_in_funcs = [ + in_func + for in_func in node.in_funcs + if self[in_func].conditional_branch + ] + closed_conditional_parents = [] + for last_split_switch in node.conditional_parents[::-1]: + # last_split_switch = node.conditional_parents[-1] + last_conditional_split_nodes = self[last_split_switch].out_funcs + # p needs to be in at least one conditional_branch for it to be closed. + if all( + any( + p in self[in_func].conditional_branch + for in_func in conditional_in_funcs + ) + for p in last_conditional_split_nodes + ): + closed_conditional_parents.append(last_split_switch) + + node.is_conditional_join = True + self[last_split_switch].matching_conditional_join = node.name + + # Did we close all conditionals? Then this branch and all its children are not conditional anymore. + if not [ + p + for p in node.conditional_parents + if p not in closed_conditional_parents + ]: + node.is_conditional = False + node.conditional_parents = [] + for p in node.out_funcs: + child = self[p] + child.is_conditional = False + child.conditional_parents = [] + def _traverse_graph(self): def traverse( - node, + node: DAGNode, seen, split_parents, split_branches, conditional_branch: List[str], - conditional_root_nodes: Optional[List[List[str]]] = None, + conditional_parents: Optional[List[str]] = None, ): add_split_branch = False try: @@ -330,12 +371,11 @@ def traverse( node.split_parents = split_parents node.split_branches = split_branches - conditional_branch = conditional_branch + [node.name] - node.conditional_branch = conditional_branch - conditional_root_nodes = ( - [node.out_funcs] - if not conditional_root_nodes - else conditional_root_nodes + [node.out_funcs] + node.conditional_branch = conditional_branch + [node.name] + node.conditional_parents = ( + [node.name] + if not conditional_parents + else conditional_parents + [node.name] ) elif node.type == "join": # ignore joins without splits @@ -349,41 +389,10 @@ def traverse( node.split_parents = split_parents node.split_branches = split_branches - if conditional_root_nodes and not node.type == "split-switch": - conditional_branch = conditional_branch + [node.name] - node.conditional_branch = conditional_branch - # Multiple cases for conditional branching. TODO: describe the structure - # 1. we are in only one conditional branch - # 2. we are in a nested conditional branch - - *root_nodes, last_root_nodes = conditional_root_nodes - # Check if the node is joining all of the conditional root nodes branches. - is_conditional_join = all( - any(p in last_root_nodes for p in self[in_func].conditional_branch) - for in_func in node.in_funcs - ) - - if is_conditional_join: - node.is_conditional_join = True - conditional_root_nodes = root_nodes - - # we are in a conditional branch if we have conditional root nodes left open, and - # we did not join the most recent conditional branches. - is_in_conditional_branch = ( - bool(conditional_root_nodes) and not is_conditional_join - ) - - if not is_in_conditional_branch: - conditional_branch = [] - # add the conditional join step info - for n in set( - step - for in_func in node.in_funcs - for step in self[in_func].conditional_branch - ): - self[n].conditional_end_node = node.name - - node.is_conditional = is_in_conditional_branch + if conditional_parents and not node.type == "split-switch": + node.conditional_parents = conditional_parents + node.conditional_branch = conditional_branch + [node.name] + node.is_conditional = True for n in node.out_funcs: # graph may contain loops - ignore them @@ -397,8 +406,8 @@ def traverse( seen + [n], split_parents, split_branches + ([n] if add_split_branch else []), - conditional_branch, - conditional_root_nodes, + node.conditional_branch, + node.conditional_parents, ) if "start" in self: diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index b5e49d6c23f..dfe8b681538 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1148,14 +1148,14 @@ def _visit( for n in node.out_funcs: _visit( self.graph[n], - node.conditional_end_node, + node.matching_conditional_join, templates, dag_tasks, parent_foreach, ) return _visit( - self.graph[node.conditional_end_node], + self.graph[node.matching_conditional_join], exit_node, templates, dag_tasks, From 511bd5101fa6a5c5855eced6e1647a328c2a93c5 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Fri, 15 Aug 2025 02:29:11 +0300 Subject: [PATCH 08/16] fix foreaches --- metaflow/plugins/argo/argo_workflows.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index dfe8b681538..be1db3c32cd 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1188,7 +1188,7 @@ def _visit( # foreach_task = ( DAGTask(foreach_template_name) - .dependencies([self._sanitize(node.name)]) + .depends(f"{self._sanitize(node.name)}.Succeeded") .template(foreach_template_name) .arguments( Arguments().parameters( @@ -1233,6 +1233,15 @@ def _visit( % self._sanitize(node.name) ) ) + # Add conditional if this is the first step in a conditional branch + if node.is_conditional and not any( + self.graph[in_func].is_conditional for in_func in node.in_funcs + ): + in_func = node.in_funcs[0] + foreach_task.when( + "{{tasks.%s.outputs.parameters.switch-step}}==%s" + % (self._sanitize(in_func), node.name) + ) dag_tasks.append(foreach_task) templates, dag_tasks_1 = _visit( self.graph[node.out_funcs[0]], @@ -1309,7 +1318,7 @@ def _visit( join_foreach_task = ( DAGTask(self._sanitize(self.graph[node.matching_join].name)) .template(self._sanitize(self.graph[node.matching_join].name)) - .dependencies([foreach_template_name]) + .depends(f"{foreach_template_name}.Succeeded") .arguments( Arguments().parameters( ( From 0a4f3c4c1fcf20e927296b405a72268ce6212b50 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Fri, 15 Aug 2025 03:03:42 +0300 Subject: [PATCH 09/16] cleanup --- metaflow/plugins/argo/argo_workflows.py | 1 - 1 file changed, 1 deletion(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index be1db3c32cd..d4adcc51548 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1905,7 +1905,6 @@ def _container_templates(self): outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})] # If this step is a split-switch one, we need to output the switch step name - # Note we can not use node.type for this, as the start step can also be a switching one if node.type == "split-switch": outputs.append( Parameter("switch-step").valueFrom({"path": "/mnt/out/switch_step"}) From 923f1d12bf3fa495b29e635c0044910da0495094 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Fri, 15 Aug 2025 14:45:04 +0300 Subject: [PATCH 10/16] fix argo foreach template task-id output for conditional steps --- metaflow/plugins/argo/argo_workflows.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index d4adcc51548..7436da59574 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1285,6 +1285,21 @@ def _visit( self.graph[node.matching_join].in_funcs[0] ) } + if not self.graph[ + node.matching_join + ].is_conditional_join + else + # Note: If the nodes leading to the join are conditional, then we need to use an expression to pick the outputs from the task that executed. + # ref for operators: https://github.com/expr-lang/expr/blob/master/docs/language-definition.md + { + "expression": "get((%s)?.outputs?.parameters, 'task-id')" + % " ?? ".join( + f"tasks['{self._sanitize(func)}']" + for func in self.graph[ + node.matching_join + ].in_funcs + ) + } ) ] if not node.parallel_foreach From 16f492f0b34abb733e723fc2f4b1c25dcbc10841 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Fri, 15 Aug 2025 18:25:33 +0300 Subject: [PATCH 11/16] WIP: fix foreach pathspecs for join step with conditional parents --- metaflow/plugins/argo/argo_workflows.py | 41 +++++++++++++++++++------ 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 7436da59574..25bf81b2c81 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1292,15 +1292,32 @@ def _visit( # Note: If the nodes leading to the join are conditional, then we need to use an expression to pick the outputs from the task that executed. # ref for operators: https://github.com/expr-lang/expr/blob/master/docs/language-definition.md { - "expression": "get((%s)?.outputs?.parameters, 'task-id')" + "expression": "get((%s)?.parameters, 'task-id')" % " ?? ".join( - f"tasks['{self._sanitize(func)}']" + f"tasks['{self._sanitize(func)}']?.outputs" for func in self.graph[ node.matching_join ].in_funcs ) } - ) + ), + # Add the out step for all foreach templates to keep things simpler. + # This is used to be able to create the input-paths correctly for the join step + Parameter("foreach-out-step").valueFrom( + { + "expression": "filter([%s], {#[1]=='Succeeded'})[0][0]" + % ",".join( + '["%s", tasks["%s"].status]' + % ( + self._sanitize(func), + self._sanitize(func), + ) + for func in self.graph[ + node.matching_join + ].in_funcs + ) + } + ), ] if not node.parallel_foreach else [ @@ -1346,6 +1363,12 @@ def _visit( "{{tasks.%s.outputs.parameters.split-cardinality}}" % self._sanitize(node.name) ), + # Only pick the output step from the first iteration of the foreach task, as it should be identical for all. + # TODO: This still needs fixing. + Parameter("foreach-out-step").value( + "{{= toJson(tasks['%s'].outputs.parameters['foreach-out-step'])[0] }}" + % foreach_template_name + ), ] if not node.parallel_foreach else [ @@ -1632,7 +1655,7 @@ def _container_templates(self): ] ) input_paths = "%s/_parameters/%s" % (run_id, task_id_params) - elif node.is_conditional_join: + elif node.is_conditional_join and not node.type == "join": input_paths = ( "$(python -m metaflow.plugins.argo.conditional_input_paths %s)" % input_paths @@ -1647,11 +1670,8 @@ def _container_templates(self): ) if not self.graph[node.split_parents[-1]].parallel_foreach: input_paths = ( - "$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})" - % ( - foreach_step, - input_paths, - ) + "$(python -m metaflow.plugins.argo.generate_input_paths {{inputs.parameters.foreach-out-step}} {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})" + % (input_paths,) ) else: # Handle @parallel where output from volume mount isn't accessible @@ -1883,8 +1903,9 @@ def _container_templates(self): [Parameter("num-parallel"), Parameter("task-id-entropy")] ) else: - # append this only for joins of foreaches, not static splits + # append these only for joins of foreaches, not static splits inputs.append(Parameter("split-cardinality")) + inputs.append(Parameter("foreach-out-step")) # check if the node is a @parallel node. elif node.parallel_step: inputs.extend( From 2c1d760f9d9e6d3aba24a2df2ea3030d683d9ff0 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Fri, 15 Aug 2025 22:24:56 +0300 Subject: [PATCH 12/16] fix static joins and conditional_join again --- metaflow/plugins/argo/argo_workflows.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 25bf81b2c81..639d1b1e01b 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1655,7 +1655,11 @@ def _container_templates(self): ] ) input_paths = "%s/_parameters/%s" % (run_id, task_id_params) - elif node.is_conditional_join and not node.type == "join": + # Only for static joins and conditional_joins + elif node.is_conditional_join and not ( + node.type == "join" + and self.graph[node.split_parents[-1]].type == "foreach" + ): input_paths = ( "$(python -m metaflow.plugins.argo.conditional_input_paths %s)" % input_paths From 158f725ad9717b3f3056d9da0cc780eba1b7dd61 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Sat, 16 Aug 2025 01:55:01 +0300 Subject: [PATCH 13/16] revert foreach case parsing and opt for exception for now. fix graph parsing for odd cases --- metaflow/graph.py | 6 ++++ metaflow/plugins/argo/argo_workflows.py | 42 +++++++++---------------- 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/metaflow/graph.py b/metaflow/graph.py index 1ef03d9ac24..3d199959359 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -311,6 +311,12 @@ def _postprocess(self): if node.conditional_parents: # do the required postprocessing for anything requiring node.in_funcs + # check that in previous parsing we have not closed all conditional in_funcs. + # If so, this step can not be conditional either + node.is_conditional = any( + self[in_func].is_conditional or self[in_func].type == "split-switch" + for in_func in node.in_funcs + ) # does this node close the latest conditional parent branches? conditional_in_funcs = [ in_func diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 639d1b1e01b..67e9d659825 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1110,8 +1110,9 @@ def _visit( ) # Add conditional if this is the first step in a conditional branch - if node.is_conditional and not any( - self.graph[in_func].is_conditional for in_func in node.in_funcs + if ( + node.is_conditional + and self.graph[node.in_funcs[0]].type == "split-switch" ): in_func = node.in_funcs[0] dag_task.when( @@ -1301,23 +1302,6 @@ def _visit( ) } ), - # Add the out step for all foreach templates to keep things simpler. - # This is used to be able to create the input-paths correctly for the join step - Parameter("foreach-out-step").valueFrom( - { - "expression": "filter([%s], {#[1]=='Succeeded'})[0][0]" - % ",".join( - '["%s", tasks["%s"].status]' - % ( - self._sanitize(func), - self._sanitize(func), - ) - for func in self.graph[ - node.matching_join - ].in_funcs - ) - } - ), ] if not node.parallel_foreach else [ @@ -1363,12 +1347,6 @@ def _visit( "{{tasks.%s.outputs.parameters.split-cardinality}}" % self._sanitize(node.name) ), - # Only pick the output step from the first iteration of the foreach task, as it should be identical for all. - # TODO: This still needs fixing. - Parameter("foreach-out-step").value( - "{{= toJson(tasks['%s'].outputs.parameters['foreach-out-step'])[0] }}" - % foreach_template_name - ), ] if not node.parallel_foreach else [ @@ -1668,14 +1646,23 @@ def _container_templates(self): node.type == "join" and self.graph[node.split_parents[-1]].type == "foreach" ): + # foreach-joins straight out of conditional branches are not yet supported + if node.is_conditional_join: + raise ArgoWorkflowsException( + "Foreach steps with a conditional step as the last one are not yet supported with Argo Workflows." + "For now, you can add a merging step after the conditional ones that will be then joined by the foreach-join" + ) # Set aggregated input-paths for a for-each join foreach_step = next( n for n in node.in_funcs if self.graph[n].is_inside_foreach ) if not self.graph[node.split_parents[-1]].parallel_foreach: input_paths = ( - "$(python -m metaflow.plugins.argo.generate_input_paths {{inputs.parameters.foreach-out-step}} {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})" - % (input_paths,) + "$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})" + % ( + foreach_step, + input_paths, + ) ) else: # Handle @parallel where output from volume mount isn't accessible @@ -1909,7 +1896,6 @@ def _container_templates(self): else: # append these only for joins of foreaches, not static splits inputs.append(Parameter("split-cardinality")) - inputs.append(Parameter("foreach-out-step")) # check if the node is a @parallel node. elif node.parallel_step: inputs.extend( From 6cc281c4aca658325852d6193a3f4041f8a9eb8c Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Mon, 18 Aug 2025 20:53:58 +0300 Subject: [PATCH 14/16] reword foreach exception --- metaflow/plugins/argo/argo_workflows.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 67e9d659825..8bc31037f82 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1649,8 +1649,8 @@ def _container_templates(self): # foreach-joins straight out of conditional branches are not yet supported if node.is_conditional_join: raise ArgoWorkflowsException( - "Foreach steps with a conditional step as the last one are not yet supported with Argo Workflows." - "For now, you can add a merging step after the conditional ones that will be then joined by the foreach-join" + "Conditionals steps that transition directly into a join step are not currently supported. " + "As a workaround, you can add a normal step after the conditional steps that transitions to a join step." ) # Set aggregated input-paths for a for-each join foreach_step = next( From 51457677da412827414b8cce69e160527fe6359a Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Tue, 19 Aug 2025 00:47:46 +0300 Subject: [PATCH 15/16] revert all changes to graph.py, moving all conditional parsing to argo_workflows implementation --- metaflow/graph.py | 82 +------------- metaflow/plugins/argo/argo_workflows.py | 141 ++++++++++++++++++++++-- 2 files changed, 131 insertions(+), 92 deletions(-) diff --git a/metaflow/graph.py b/metaflow/graph.py index 3d199959359..5013971eb28 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -3,7 +3,6 @@ import re from itertools import chain -from typing import List, Optional from .util import to_pod @@ -83,16 +82,6 @@ def __init__( self.matching_join = None # these attributes are populated by _postprocess self.is_inside_foreach = False - # Conditional info - self.is_conditional = False # will this node always be executed, or is it in a conditional branch? - self.matching_conditional_join = None # which step joins the conditional branches. filled for split-switch only. - self.is_conditional_join = ( - False # Does this node 'join' some set of conditional branches? - ) - self.conditional_parents = [] - self.conditional_branch = ( - [] - ) # All the steps leading to this node that depends on a condition, starting from the split-switch def _expr_str(self, expr): return "%s.%s" % (expr.value.id, expr.attr) @@ -307,61 +296,8 @@ def _postprocess(self): if [f for f in foreaches if self.nodes[f].matching_join != node.name]: node.is_inside_foreach = True - # Fill in conditionals related info. - if node.conditional_parents: - # do the required postprocessing for anything requiring node.in_funcs - - # check that in previous parsing we have not closed all conditional in_funcs. - # If so, this step can not be conditional either - node.is_conditional = any( - self[in_func].is_conditional or self[in_func].type == "split-switch" - for in_func in node.in_funcs - ) - # does this node close the latest conditional parent branches? - conditional_in_funcs = [ - in_func - for in_func in node.in_funcs - if self[in_func].conditional_branch - ] - closed_conditional_parents = [] - for last_split_switch in node.conditional_parents[::-1]: - # last_split_switch = node.conditional_parents[-1] - last_conditional_split_nodes = self[last_split_switch].out_funcs - # p needs to be in at least one conditional_branch for it to be closed. - if all( - any( - p in self[in_func].conditional_branch - for in_func in conditional_in_funcs - ) - for p in last_conditional_split_nodes - ): - closed_conditional_parents.append(last_split_switch) - - node.is_conditional_join = True - self[last_split_switch].matching_conditional_join = node.name - - # Did we close all conditionals? Then this branch and all its children are not conditional anymore. - if not [ - p - for p in node.conditional_parents - if p not in closed_conditional_parents - ]: - node.is_conditional = False - node.conditional_parents = [] - for p in node.out_funcs: - child = self[p] - child.is_conditional = False - child.conditional_parents = [] - def _traverse_graph(self): - def traverse( - node: DAGNode, - seen, - split_parents, - split_branches, - conditional_branch: List[str], - conditional_parents: Optional[List[str]] = None, - ): + def traverse(node, seen, split_parents, split_branches): add_split_branch = False try: self.sorted_nodes.remove(node.name) @@ -376,13 +312,6 @@ def traverse( elif node.type == "split-switch": node.split_parents = split_parents node.split_branches = split_branches - - node.conditional_branch = conditional_branch + [node.name] - node.conditional_parents = ( - [node.name] - if not conditional_parents - else conditional_parents + [node.name] - ) elif node.type == "join": # ignore joins without splits if split_parents: @@ -395,11 +324,6 @@ def traverse( node.split_parents = split_parents node.split_branches = split_branches - if conditional_parents and not node.type == "split-switch": - node.conditional_parents = conditional_parents - node.conditional_branch = conditional_branch + [node.name] - node.is_conditional = True - for n in node.out_funcs: # graph may contain loops - ignore them if n not in seen: @@ -412,12 +336,10 @@ def traverse( seen + [n], split_parents, split_branches + ([n] if add_split_branch else []), - node.conditional_branch, - node.conditional_parents, ) if "start" in self: - traverse(self["start"], [], [], [], []) + traverse(self["start"], [], [], []) # fix the order of in_funcs for node in self.nodes.values(): diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 8bc31037f82..b2ff5b3884c 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -143,6 +143,7 @@ def __init__( self.name = name self.graph = graph + self._parse_conditional_branches() self.flow = flow self.code_package_metadata = code_package_metadata self.code_package_sha = code_package_sha @@ -920,6 +921,121 @@ def _compile_workflow_template(self): ) ) + # Visit every node and record information on conditional step structure + def _parse_conditional_branches(self): + self.conditional_nodes = set() + self.conditional_join_nodes = set() + self.matching_conditional_join_dict = {} + + node_conditional_parents = {} + node_conditional_branches = {} + + def _visit(node, seen, conditional_branch, conditional_parents=None): + if not node.type == "split-switch" and not ( + conditional_branch and conditional_parents + ): + # skip regular non-conditional nodes entirely + return + + if node.type == "split-switch": + conditional_branch = conditional_branch + [node.name] + node_conditional_branches[node.name] = conditional_branch + + conditional_parents = ( + [node.name] + if not conditional_parents + else conditional_parents + [node.name] + ) + node_conditional_parents[node.name] = conditional_parents + + if conditional_parents and not node.type == "split-switch": + node_conditional_parents[node.name] = conditional_parents + conditional_branch = conditional_branch + [node.name] + node_conditional_branches[node.name] = conditional_branch + + self.conditional_nodes.add(node.name) + + if conditional_branch and conditional_parents: + for n in node.out_funcs: + child = self.graph[n] + if n not in seen: + _visit( + child, seen + [n], conditional_branch, conditional_parents + ) + + # First we visit all nodes to determine conditional parents and branches + for n in self.graph: + _visit(n, [], []) + + # Then we traverse again in order to determine conditional join nodes, and matching conditional join info + for node in self.graph: + if node_conditional_parents.get(node.name, False): + # do the required postprocessing for anything requiring node.in_funcs + + # check that in previous parsing we have not closed all conditional in_funcs. + # If so, this step can not be conditional either + is_conditional = any( + in_func in self.conditional_nodes + or self.graph[in_func].type == "split-switch" + for in_func in node.in_funcs + ) + if is_conditional: + self.conditional_nodes.add(node.name) + else: + if node.name in self.conditional_nodes: + self.conditional_nodes.remove(node.name) + + # does this node close the latest conditional parent branches? + conditional_in_funcs = [ + in_func + for in_func in node.in_funcs + if node_conditional_branches.get(in_func, False) + ] + closed_conditional_parents = [] + for last_split_switch in node_conditional_parents.get(node.name, [])[ + ::-1 + ]: + last_conditional_split_nodes = self.graph[ + last_split_switch + ].out_funcs + # p needs to be in at least one conditional_branch for it to be closed. + if all( + any( + p in node_conditional_branches.get(in_func, []) + for in_func in conditional_in_funcs + ) + for p in last_conditional_split_nodes + ): + closed_conditional_parents.append(last_split_switch) + + self.conditional_join_nodes.add(node.name) + self.matching_conditional_join_dict[last_split_switch] = ( + node.name + ) + + # Did we close all conditionals? Then this branch and all its children are not conditional anymore (unless a new conditional branch is encountered). + if not [ + p + for p in node_conditional_parents.get(node.name, []) + if p not in closed_conditional_parents + ]: + if node.name in self.conditional_nodes: + self.conditional_nodes.remove(node.name) + node_conditional_parents[node.name] = [] + for p in node.out_funcs: + if p in self.conditional_nodes: + self.conditional_nodes.remove(p) + node_conditional_parents[p] = [] + + def _is_conditional_node(self, node): + return node.name in self.conditional_nodes + + def _is_conditional_join_node(self, node): + return node.name in self.conditional_join_nodes + + def _matching_conditional_join(self, node): + return self.matching_conditional_join_dict.get(node.name, None) + # Visit every node and yield the uber DAGTemplate(s). def _dag_templates(self): def _visit( @@ -1086,12 +1202,12 @@ def _visit( conditional_deps = [ "%s.Succeeded" % self._sanitize(in_func) for in_func in node.in_funcs - if self.graph[in_func].is_conditional + if self._is_conditional_node(self.graph[in_func]) ] required_deps = [ "%s.Succeeded" % self._sanitize(in_func) for in_func in node.in_funcs - if not self.graph[in_func].is_conditional + if not self._is_conditional_node(self.graph[in_func]) ] both_conditions = required_deps and conditional_deps @@ -1111,7 +1227,7 @@ def _visit( # Add conditional if this is the first step in a conditional branch if ( - node.is_conditional + self._is_conditional_node(node) and self.graph[node.in_funcs[0]].type == "split-switch" ): in_func = node.in_funcs[0] @@ -1149,14 +1265,14 @@ def _visit( for n in node.out_funcs: _visit( self.graph[n], - node.matching_conditional_join, + self._matching_conditional_join(node), templates, dag_tasks, parent_foreach, ) return _visit( - self.graph[node.matching_conditional_join], + self.graph[self._matching_conditional_join(node)], exit_node, templates, dag_tasks, @@ -1235,8 +1351,9 @@ def _visit( ) ) # Add conditional if this is the first step in a conditional branch - if node.is_conditional and not any( - self.graph[in_func].is_conditional for in_func in node.in_funcs + if self._is_conditional_node(node) and not any( + self._is_conditional_node(self.graph[in_func]) + for in_func in node.in_funcs ): in_func = node.in_funcs[0] foreach_task.when( @@ -1286,9 +1403,9 @@ def _visit( self.graph[node.matching_join].in_funcs[0] ) } - if not self.graph[ - node.matching_join - ].is_conditional_join + if not self._is_conditional_join_node( + self.graph[node.matching_join] + ) else # Note: If the nodes leading to the join are conditional, then we need to use an expression to pick the outputs from the task that executed. # ref for operators: https://github.com/expr-lang/expr/blob/master/docs/language-definition.md @@ -1634,7 +1751,7 @@ def _container_templates(self): ) input_paths = "%s/_parameters/%s" % (run_id, task_id_params) # Only for static joins and conditional_joins - elif node.is_conditional_join and not ( + elif self._is_conditional_join_node(node) and not ( node.type == "join" and self.graph[node.split_parents[-1]].type == "foreach" ): @@ -1647,7 +1764,7 @@ def _container_templates(self): and self.graph[node.split_parents[-1]].type == "foreach" ): # foreach-joins straight out of conditional branches are not yet supported - if node.is_conditional_join: + if self._is_conditional_join_node(node): raise ArgoWorkflowsException( "Conditionals steps that transition directly into a join step are not currently supported. " "As a workaround, you can add a normal step after the conditional steps that transitions to a join step." From d1e2707b6ce3837ef66248a4ae851a17ea14bf67 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Tue, 19 Aug 2025 22:39:20 +0300 Subject: [PATCH 16/16] fix conditional input-paths parsing failures --- metaflow/plugins/argo/argo_workflows.py | 35 ++++++++++++++----- .../plugins/argo/conditional_input_paths.py | 6 ++-- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index b2ff5b3884c..d778912c8cf 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1156,15 +1156,32 @@ def _visit( else: # Every other node needs only input-paths parameters = [ - Parameter("input-paths").value( - compress_list( - [ - "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}" - % (n, self._sanitize(n)) - for n in node.in_funcs - ], - # NOTE: We set zlibmin to infinite because zlib compression for the Argo input-paths breaks template value substitution. - zlibmin=inf, + ( + Parameter("input-paths").value( + compress_list( + [ + "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}" + % (n, self._sanitize(n)) + for n in node.in_funcs + ], + # NOTE: We set zlibmin to infinite because zlib compression for the Argo input-paths breaks template value substitution. + zlibmin=inf, + ) + ) + if not self._is_conditional_join_node(node) + # The value fetching for input-paths from conditional steps has to be quite involved + # in order to avoid issues with replacements due to missing step outputs. + # NOTE: we differentiate the input-path expression only for conditional joins so we can still utilize the list compression, + # but do not have to rework all decompress usage due to the need for a custom separator + else Parameter("input-paths").value( + compress_list( + [ + "argo-{{workflow.name}}/%s/{{=(get(tasks['%s']?.outputs?.parameters, 'task-id') ?? 'no-task')}}" + % (n, self._sanitize(n)) + for n in node.in_funcs + ], + separator="%", # non-default separator is required due to commas in the value expression + ) ) ) ] diff --git a/metaflow/plugins/argo/conditional_input_paths.py b/metaflow/plugins/argo/conditional_input_paths.py index b224faf35f8..d03714f54e3 100644 --- a/metaflow/plugins/argo/conditional_input_paths.py +++ b/metaflow/plugins/argo/conditional_input_paths.py @@ -4,13 +4,13 @@ def generate_input_paths(input_paths): - # => run_id/step/:foo,bar - paths = decompress_list(input_paths) + # Note the non-default separator due to difficulties setting parameter values from conditional step outputs. + paths = decompress_list(input_paths, separator="%") # some of the paths are going to be malformed due to never having executed per conditional. # strip these out of the list. - trimmed = [path for path in paths if not "{{" in path] + trimmed = [path for path in paths if not path.endswith("/no-task")] return compress_list(trimmed, zlibmin=inf)