diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 7e6ac43ce95..d778912c8cf 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -143,6 +143,7 @@ def __init__( self.name = name self.graph = graph + self._parse_conditional_branches() self.flow = flow self.code_package_metadata = code_package_metadata self.code_package_sha = code_package_sha @@ -920,6 +921,121 @@ def _compile_workflow_template(self): ) ) + # Visit every node and record information on conditional step structure + def _parse_conditional_branches(self): + self.conditional_nodes = set() + self.conditional_join_nodes = set() + self.matching_conditional_join_dict = {} + + node_conditional_parents = {} + node_conditional_branches = {} + + def _visit(node, seen, conditional_branch, conditional_parents=None): + if not node.type == "split-switch" and not ( + conditional_branch and conditional_parents + ): + # skip regular non-conditional nodes entirely + return + + if node.type == "split-switch": + conditional_branch = conditional_branch + [node.name] + node_conditional_branches[node.name] = conditional_branch + + conditional_parents = ( + [node.name] + if not conditional_parents + else conditional_parents + [node.name] + ) + node_conditional_parents[node.name] = conditional_parents + + if conditional_parents and not node.type == "split-switch": + node_conditional_parents[node.name] = conditional_parents + conditional_branch = conditional_branch + [node.name] + node_conditional_branches[node.name] = conditional_branch + + self.conditional_nodes.add(node.name) + + if conditional_branch and conditional_parents: + for n in node.out_funcs: + child = self.graph[n] + if n not in seen: + _visit( + child, seen + [n], conditional_branch, conditional_parents + ) + + # First we visit all nodes to determine conditional parents and branches + for n in self.graph: + _visit(n, [], []) + + # Then we traverse again in order to determine conditional join nodes, and matching conditional join info + for node in self.graph: + if node_conditional_parents.get(node.name, False): + # do the required postprocessing for anything requiring node.in_funcs + + # check that in previous parsing we have not closed all conditional in_funcs. + # If so, this step can not be conditional either + is_conditional = any( + in_func in self.conditional_nodes + or self.graph[in_func].type == "split-switch" + for in_func in node.in_funcs + ) + if is_conditional: + self.conditional_nodes.add(node.name) + else: + if node.name in self.conditional_nodes: + self.conditional_nodes.remove(node.name) + + # does this node close the latest conditional parent branches? + conditional_in_funcs = [ + in_func + for in_func in node.in_funcs + if node_conditional_branches.get(in_func, False) + ] + closed_conditional_parents = [] + for last_split_switch in node_conditional_parents.get(node.name, [])[ + ::-1 + ]: + last_conditional_split_nodes = self.graph[ + last_split_switch + ].out_funcs + # p needs to be in at least one conditional_branch for it to be closed. + if all( + any( + p in node_conditional_branches.get(in_func, []) + for in_func in conditional_in_funcs + ) + for p in last_conditional_split_nodes + ): + closed_conditional_parents.append(last_split_switch) + + self.conditional_join_nodes.add(node.name) + self.matching_conditional_join_dict[last_split_switch] = ( + node.name + ) + + # Did we close all conditionals? Then this branch and all its children are not conditional anymore (unless a new conditional branch is encountered). + if not [ + p + for p in node_conditional_parents.get(node.name, []) + if p not in closed_conditional_parents + ]: + if node.name in self.conditional_nodes: + self.conditional_nodes.remove(node.name) + node_conditional_parents[node.name] = [] + for p in node.out_funcs: + if p in self.conditional_nodes: + self.conditional_nodes.remove(p) + node_conditional_parents[p] = [] + + def _is_conditional_node(self, node): + return node.name in self.conditional_nodes + + def _is_conditional_join_node(self, node): + return node.name in self.conditional_join_nodes + + def _matching_conditional_join(self, node): + return self.matching_conditional_join_dict.get(node.name, None) + # Visit every node and yield the uber DAGTemplate(s). def _dag_templates(self): def _visit( @@ -941,6 +1057,7 @@ def _visit( dag_tasks = [] if templates is None: templates = [] + if exit_node is not None and exit_node is node.name: return templates, dag_tasks if node.name == "start": @@ -948,12 +1065,7 @@ def _visit( dag_task = DAGTask(self._sanitize(node.name)).template( self._sanitize(node.name) ) - if node.type == "split-switch": - raise ArgoWorkflowsException( - "Deploying flows with switch statement " - "to Argo Workflows is not supported currently." - ) - elif ( + if ( node.is_inside_foreach and self.graph[node.in_funcs[0]].type == "foreach" and not self.graph[node.in_funcs[0]].parallel_foreach @@ -1044,15 +1156,32 @@ def _visit( else: # Every other node needs only input-paths parameters = [ - Parameter("input-paths").value( - compress_list( - [ - "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}" - % (n, self._sanitize(n)) - for n in node.in_funcs - ], - # NOTE: We set zlibmin to infinite because zlib compression for the Argo input-paths breaks template value substitution. - zlibmin=inf, + ( + Parameter("input-paths").value( + compress_list( + [ + "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}" + % (n, self._sanitize(n)) + for n in node.in_funcs + ], + # NOTE: We set zlibmin to infinite because zlib compression for the Argo input-paths breaks template value substitution. + zlibmin=inf, + ) + ) + if not self._is_conditional_join_node(node) + # The value fetching for input-paths from conditional steps has to be quite involved + # in order to avoid issues with replacements due to missing step outputs. + # NOTE: we differentiate the input-path expression only for conditional joins so we can still utilize the list compression, + # but do not have to rework all decompress usage due to the need for a custom separator + else Parameter("input-paths").value( + compress_list( + [ + "argo-{{workflow.name}}/%s/{{=(get(tasks['%s']?.outputs?.parameters, 'task-id') ?? 'no-task')}}" + % (n, self._sanitize(n)) + for n in node.in_funcs + ], + separator="%", # non-default separator is required due to commas in the value expression + ) ) ) ] @@ -1087,15 +1216,43 @@ def _visit( ] ) + conditional_deps = [ + "%s.Succeeded" % self._sanitize(in_func) + for in_func in node.in_funcs + if self._is_conditional_node(self.graph[in_func]) + ] + required_deps = [ + "%s.Succeeded" % self._sanitize(in_func) + for in_func in node.in_funcs + if not self._is_conditional_node(self.graph[in_func]) + ] + both_conditions = required_deps and conditional_deps + + depends_str = "{required}{_and}{conditional}".format( + required=("(%s)" if both_conditions else "%s") + % " && ".join(required_deps), + _and=" && " if both_conditions else "", + conditional=("(%s)" if both_conditions else "%s") + % " || ".join(conditional_deps), + ) dag_task = ( DAGTask(self._sanitize(node.name)) - .dependencies( - [self._sanitize(in_func) for in_func in node.in_funcs] - ) + .depends(depends_str) .template(self._sanitize(node.name)) .arguments(Arguments().parameters(parameters)) ) + # Add conditional if this is the first step in a conditional branch + if ( + self._is_conditional_node(node) + and self.graph[node.in_funcs[0]].type == "split-switch" + ): + in_func = node.in_funcs[0] + dag_task.when( + "{{tasks.%s.outputs.parameters.switch-step}}==%s" + % (self._sanitize(in_func), node.name) + ) + dag_tasks.append(dag_task) # End the workflow if we have reached the end of the flow if node.type == "end": @@ -1121,6 +1278,23 @@ def _visit( dag_tasks, parent_foreach, ) + elif node.type == "split-switch": + for n in node.out_funcs: + _visit( + self.graph[n], + self._matching_conditional_join(node), + templates, + dag_tasks, + parent_foreach, + ) + + return _visit( + self.graph[self._matching_conditional_join(node)], + exit_node, + templates, + dag_tasks, + parent_foreach, + ) # For foreach nodes generate a new sub DAGTemplate # We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`) elif node.type == "foreach": @@ -1148,7 +1322,7 @@ def _visit( # foreach_task = ( DAGTask(foreach_template_name) - .dependencies([self._sanitize(node.name)]) + .depends(f"{self._sanitize(node.name)}.Succeeded") .template(foreach_template_name) .arguments( Arguments().parameters( @@ -1193,6 +1367,16 @@ def _visit( % self._sanitize(node.name) ) ) + # Add conditional if this is the first step in a conditional branch + if self._is_conditional_node(node) and not any( + self._is_conditional_node(self.graph[in_func]) + for in_func in node.in_funcs + ): + in_func = node.in_funcs[0] + foreach_task.when( + "{{tasks.%s.outputs.parameters.switch-step}}==%s" + % (self._sanitize(in_func), node.name) + ) dag_tasks.append(foreach_task) templates, dag_tasks_1 = _visit( self.graph[node.out_funcs[0]], @@ -1236,7 +1420,22 @@ def _visit( self.graph[node.matching_join].in_funcs[0] ) } - ) + if not self._is_conditional_join_node( + self.graph[node.matching_join] + ) + else + # Note: If the nodes leading to the join are conditional, then we need to use an expression to pick the outputs from the task that executed. + # ref for operators: https://github.com/expr-lang/expr/blob/master/docs/language-definition.md + { + "expression": "get((%s)?.parameters, 'task-id')" + % " ?? ".join( + f"tasks['{self._sanitize(func)}']?.outputs" + for func in self.graph[ + node.matching_join + ].in_funcs + ) + } + ), ] if not node.parallel_foreach else [ @@ -1269,7 +1468,7 @@ def _visit( join_foreach_task = ( DAGTask(self._sanitize(self.graph[node.matching_join].name)) .template(self._sanitize(self.graph[node.matching_join].name)) - .dependencies([foreach_template_name]) + .depends(f"{foreach_template_name}.Succeeded") .arguments( Arguments().parameters( ( @@ -1568,10 +1767,25 @@ def _container_templates(self): ] ) input_paths = "%s/_parameters/%s" % (run_id, task_id_params) + # Only for static joins and conditional_joins + elif self._is_conditional_join_node(node) and not ( + node.type == "join" + and self.graph[node.split_parents[-1]].type == "foreach" + ): + input_paths = ( + "$(python -m metaflow.plugins.argo.conditional_input_paths %s)" + % input_paths + ) elif ( node.type == "join" and self.graph[node.split_parents[-1]].type == "foreach" ): + # foreach-joins straight out of conditional branches are not yet supported + if self._is_conditional_join_node(node): + raise ArgoWorkflowsException( + "Conditionals steps that transition directly into a join step are not currently supported. " + "As a workaround, you can add a normal step after the conditional steps that transitions to a join step." + ) # Set aggregated input-paths for a for-each join foreach_step = next( n for n in node.in_funcs if self.graph[n].is_inside_foreach @@ -1814,7 +2028,7 @@ def _container_templates(self): [Parameter("num-parallel"), Parameter("task-id-entropy")] ) else: - # append this only for joins of foreaches, not static splits + # append these only for joins of foreaches, not static splits inputs.append(Parameter("split-cardinality")) # check if the node is a @parallel node. elif node.parallel_step: @@ -1849,6 +2063,13 @@ def _container_templates(self): # are derived at runtime. if not (node.name == "end" or node.parallel_step): outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})] + + # If this step is a split-switch one, we need to output the switch step name + if node.type == "split-switch": + outputs.append( + Parameter("switch-step").valueFrom({"path": "/mnt/out/switch_step"}) + ) + if node.type == "foreach": # Emit split cardinality from foreach task outputs.append( @@ -3981,6 +4202,10 @@ def dependencies(self, dependencies): self.payload["dependencies"] = dependencies return self + def depends(self, depends: str): + self.payload["depends"] = depends + return self + def template(self, template): # Template reference self.payload["template"] = template @@ -3992,6 +4217,10 @@ def inline(self, template): self.payload["inline"] = template.to_json() return self + def when(self, when: str): + self.payload["when"] = when + return self + def with_param(self, with_param): self.payload["withParam"] = with_param return self diff --git a/metaflow/plugins/argo/argo_workflows_decorator.py b/metaflow/plugins/argo/argo_workflows_decorator.py index ce92d34b5b4..67c8fd91363 100644 --- a/metaflow/plugins/argo/argo_workflows_decorator.py +++ b/metaflow/plugins/argo/argo_workflows_decorator.py @@ -123,6 +123,15 @@ def task_finished( with open("/mnt/out/split_cardinality", "w") as file: json.dump(flow._foreach_num_splits, file) + # For conditional branches we need to record the value of the switch to disk, in order to pass it as an + # output from the switching step to be used further down the DAG + if graph[step_name].type == "split-switch": + # TODO: A nicer way to access the chosen step? + _out_funcs, _ = flow._transition + chosen_step = _out_funcs[0] + with open("/mnt/out/switch_step", "w") as file: + file.write(chosen_step) + # For steps that have a `@parallel` decorator set to them, we will be relying on Jobsets # to run the task. In this case, we cannot set anything in the # `/mnt/out` directory, since such form of output mounts are not available to Jobset executions. diff --git a/metaflow/plugins/argo/conditional_input_paths.py b/metaflow/plugins/argo/conditional_input_paths.py new file mode 100644 index 00000000000..d03714f54e3 --- /dev/null +++ b/metaflow/plugins/argo/conditional_input_paths.py @@ -0,0 +1,18 @@ +from math import inf +import sys +from metaflow.util import decompress_list, compress_list + + +def generate_input_paths(input_paths): + # Note the non-default separator due to difficulties setting parameter values from conditional step outputs. + paths = decompress_list(input_paths, separator="%") + + # some of the paths are going to be malformed due to never having executed per conditional. + # strip these out of the list. + + trimmed = [path for path in paths if not path.endswith("/no-task")] + return compress_list(trimmed, zlibmin=inf) + + +if __name__ == "__main__": + print(generate_input_paths(sys.argv[1]))