Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 251 additions & 22 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def __init__(

self.name = name
self.graph = graph
self._parse_conditional_branches()
self.flow = flow
self.code_package_metadata = code_package_metadata
self.code_package_sha = code_package_sha
Expand Down Expand Up @@ -920,6 +921,121 @@ def _compile_workflow_template(self):
)
)

# Visit every node and record information on conditional step structure
def _parse_conditional_branches(self):
self.conditional_nodes = set()
self.conditional_join_nodes = set()
self.matching_conditional_join_dict = {}

node_conditional_parents = {}
node_conditional_branches = {}

def _visit(node, seen, conditional_branch, conditional_parents=None):
if not node.type == "split-switch" and not (
conditional_branch and conditional_parents
):
# skip regular non-conditional nodes entirely
return

if node.type == "split-switch":
conditional_branch = conditional_branch + [node.name]
node_conditional_branches[node.name] = conditional_branch

conditional_parents = (
[node.name]
if not conditional_parents
else conditional_parents + [node.name]
)
node_conditional_parents[node.name] = conditional_parents

if conditional_parents and not node.type == "split-switch":
node_conditional_parents[node.name] = conditional_parents
conditional_branch = conditional_branch + [node.name]
node_conditional_branches[node.name] = conditional_branch

self.conditional_nodes.add(node.name)

if conditional_branch and conditional_parents:
for n in node.out_funcs:
child = self.graph[n]
if n not in seen:
_visit(
child, seen + [n], conditional_branch, conditional_parents
)

# First we visit all nodes to determine conditional parents and branches
for n in self.graph:
_visit(n, [], [])

# Then we traverse again in order to determine conditional join nodes, and matching conditional join info
for node in self.graph:
if node_conditional_parents.get(node.name, False):
# do the required postprocessing for anything requiring node.in_funcs

# check that in previous parsing we have not closed all conditional in_funcs.
# If so, this step can not be conditional either
is_conditional = any(
in_func in self.conditional_nodes
or self.graph[in_func].type == "split-switch"
for in_func in node.in_funcs
)
if is_conditional:
self.conditional_nodes.add(node.name)
else:
if node.name in self.conditional_nodes:
self.conditional_nodes.remove(node.name)

# does this node close the latest conditional parent branches?
conditional_in_funcs = [
in_func
for in_func in node.in_funcs
if node_conditional_branches.get(in_func, False)
]
closed_conditional_parents = []
for last_split_switch in node_conditional_parents.get(node.name, [])[
::-1
]:
last_conditional_split_nodes = self.graph[
last_split_switch
].out_funcs
# p needs to be in at least one conditional_branch for it to be closed.
if all(
any(
p in node_conditional_branches.get(in_func, [])
for in_func in conditional_in_funcs
)
for p in last_conditional_split_nodes
):
closed_conditional_parents.append(last_split_switch)

self.conditional_join_nodes.add(node.name)
self.matching_conditional_join_dict[last_split_switch] = (
node.name
)

# Did we close all conditionals? Then this branch and all its children are not conditional anymore (unless a new conditional branch is encountered).
if not [
p
for p in node_conditional_parents.get(node.name, [])
if p not in closed_conditional_parents
]:
if node.name in self.conditional_nodes:
self.conditional_nodes.remove(node.name)
node_conditional_parents[node.name] = []
for p in node.out_funcs:
if p in self.conditional_nodes:
self.conditional_nodes.remove(p)
node_conditional_parents[p] = []

def _is_conditional_node(self, node):
return node.name in self.conditional_nodes

def _is_conditional_join_node(self, node):
return node.name in self.conditional_join_nodes

def _matching_conditional_join(self, node):
return self.matching_conditional_join_dict.get(node.name, None)

# Visit every node and yield the uber DAGTemplate(s).
def _dag_templates(self):
def _visit(
Expand All @@ -941,19 +1057,15 @@ def _visit(
dag_tasks = []
if templates is None:
templates = []

if exit_node is not None and exit_node is node.name:
return templates, dag_tasks
if node.name == "start":
# Start node has no dependencies.
dag_task = DAGTask(self._sanitize(node.name)).template(
self._sanitize(node.name)
)
if node.type == "split-switch":
raise ArgoWorkflowsException(
"Deploying flows with switch statement "
"to Argo Workflows is not supported currently."
)
elif (
if (
node.is_inside_foreach
and self.graph[node.in_funcs[0]].type == "foreach"
and not self.graph[node.in_funcs[0]].parallel_foreach
Expand Down Expand Up @@ -1044,15 +1156,32 @@ def _visit(
else:
# Every other node needs only input-paths
parameters = [
Parameter("input-paths").value(
compress_list(
[
"argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
% (n, self._sanitize(n))
for n in node.in_funcs
],
# NOTE: We set zlibmin to infinite because zlib compression for the Argo input-paths breaks template value substitution.
zlibmin=inf,
(
Parameter("input-paths").value(
compress_list(
[
"argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
% (n, self._sanitize(n))
for n in node.in_funcs
],
# NOTE: We set zlibmin to infinite because zlib compression for the Argo input-paths breaks template value substitution.
zlibmin=inf,
)
)
if not self._is_conditional_join_node(node)
# The value fetching for input-paths from conditional steps has to be quite involved
# in order to avoid issues with replacements due to missing step outputs.
# NOTE: we differentiate the input-path expression only for conditional joins so we can still utilize the list compression,
# but do not have to rework all decompress usage due to the need for a custom separator
else Parameter("input-paths").value(
compress_list(
[
"argo-{{workflow.name}}/%s/{{=(get(tasks['%s']?.outputs?.parameters, 'task-id') ?? 'no-task')}}"
% (n, self._sanitize(n))
for n in node.in_funcs
],
separator="%", # non-default separator is required due to commas in the value expression
)
)
)
]
Expand Down Expand Up @@ -1087,15 +1216,43 @@ def _visit(
]
)

conditional_deps = [
"%s.Succeeded" % self._sanitize(in_func)
for in_func in node.in_funcs
if self._is_conditional_node(self.graph[in_func])
]
required_deps = [
"%s.Succeeded" % self._sanitize(in_func)
for in_func in node.in_funcs
if not self._is_conditional_node(self.graph[in_func])
]
both_conditions = required_deps and conditional_deps

depends_str = "{required}{_and}{conditional}".format(
required=("(%s)" if both_conditions else "%s")
% " && ".join(required_deps),
_and=" && " if both_conditions else "",
conditional=("(%s)" if both_conditions else "%s")
% " || ".join(conditional_deps),
)
dag_task = (
DAGTask(self._sanitize(node.name))
.dependencies(
[self._sanitize(in_func) for in_func in node.in_funcs]
)
.depends(depends_str)
.template(self._sanitize(node.name))
.arguments(Arguments().parameters(parameters))
)

# Add conditional if this is the first step in a conditional branch
if (
self._is_conditional_node(node)
and self.graph[node.in_funcs[0]].type == "split-switch"
):
in_func = node.in_funcs[0]
dag_task.when(
"{{tasks.%s.outputs.parameters.switch-step}}==%s"
% (self._sanitize(in_func), node.name)
)

dag_tasks.append(dag_task)
# End the workflow if we have reached the end of the flow
if node.type == "end":
Expand All @@ -1121,6 +1278,23 @@ def _visit(
dag_tasks,
parent_foreach,
)
elif node.type == "split-switch":
for n in node.out_funcs:
_visit(
self.graph[n],
self._matching_conditional_join(node),
templates,
dag_tasks,
parent_foreach,
)

return _visit(
self.graph[self._matching_conditional_join(node)],
exit_node,
templates,
dag_tasks,
parent_foreach,
)
# For foreach nodes generate a new sub DAGTemplate
# We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`)
elif node.type == "foreach":
Expand Down Expand Up @@ -1148,7 +1322,7 @@ def _visit(
#
foreach_task = (
DAGTask(foreach_template_name)
.dependencies([self._sanitize(node.name)])
.depends(f"{self._sanitize(node.name)}.Succeeded")
.template(foreach_template_name)
.arguments(
Arguments().parameters(
Expand Down Expand Up @@ -1193,6 +1367,16 @@ def _visit(
% self._sanitize(node.name)
)
)
# Add conditional if this is the first step in a conditional branch
if self._is_conditional_node(node) and not any(
self._is_conditional_node(self.graph[in_func])
for in_func in node.in_funcs
):
in_func = node.in_funcs[0]
foreach_task.when(
"{{tasks.%s.outputs.parameters.switch-step}}==%s"
% (self._sanitize(in_func), node.name)
)
dag_tasks.append(foreach_task)
templates, dag_tasks_1 = _visit(
self.graph[node.out_funcs[0]],
Expand Down Expand Up @@ -1236,7 +1420,22 @@ def _visit(
self.graph[node.matching_join].in_funcs[0]
)
}
)
if not self._is_conditional_join_node(
self.graph[node.matching_join]
)
else
# Note: If the nodes leading to the join are conditional, then we need to use an expression to pick the outputs from the task that executed.
# ref for operators: https://github.com/expr-lang/expr/blob/master/docs/language-definition.md
{
"expression": "get((%s)?.parameters, 'task-id')"
% " ?? ".join(
f"tasks['{self._sanitize(func)}']?.outputs"
for func in self.graph[
node.matching_join
].in_funcs
)
}
),
]
if not node.parallel_foreach
else [
Expand Down Expand Up @@ -1269,7 +1468,7 @@ def _visit(
join_foreach_task = (
DAGTask(self._sanitize(self.graph[node.matching_join].name))
.template(self._sanitize(self.graph[node.matching_join].name))
.dependencies([foreach_template_name])
.depends(f"{foreach_template_name}.Succeeded")
.arguments(
Arguments().parameters(
(
Expand Down Expand Up @@ -1568,10 +1767,25 @@ def _container_templates(self):
]
)
input_paths = "%s/_parameters/%s" % (run_id, task_id_params)
# Only for static joins and conditional_joins
elif self._is_conditional_join_node(node) and not (
node.type == "join"
and self.graph[node.split_parents[-1]].type == "foreach"
):
input_paths = (
"$(python -m metaflow.plugins.argo.conditional_input_paths %s)"
% input_paths
)
elif (
node.type == "join"
and self.graph[node.split_parents[-1]].type == "foreach"
):
# foreach-joins straight out of conditional branches are not yet supported
if self._is_conditional_join_node(node):
raise ArgoWorkflowsException(
"Conditionals steps that transition directly into a join step are not currently supported. "
"As a workaround, you can add a normal step after the conditional steps that transitions to a join step."
)
# Set aggregated input-paths for a for-each join
foreach_step = next(
n for n in node.in_funcs if self.graph[n].is_inside_foreach
Expand Down Expand Up @@ -1814,7 +2028,7 @@ def _container_templates(self):
[Parameter("num-parallel"), Parameter("task-id-entropy")]
)
else:
# append this only for joins of foreaches, not static splits
# append these only for joins of foreaches, not static splits
inputs.append(Parameter("split-cardinality"))
# check if the node is a @parallel node.
elif node.parallel_step:
Expand Down Expand Up @@ -1849,6 +2063,13 @@ def _container_templates(self):
# are derived at runtime.
if not (node.name == "end" or node.parallel_step):
outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})]

# If this step is a split-switch one, we need to output the switch step name
if node.type == "split-switch":
outputs.append(
Parameter("switch-step").valueFrom({"path": "/mnt/out/switch_step"})
)

if node.type == "foreach":
# Emit split cardinality from foreach task
outputs.append(
Expand Down Expand Up @@ -3981,6 +4202,10 @@ def dependencies(self, dependencies):
self.payload["dependencies"] = dependencies
return self

def depends(self, depends: str):
self.payload["depends"] = depends
return self

def template(self, template):
# Template reference
self.payload["template"] = template
Expand All @@ -3992,6 +4217,10 @@ def inline(self, template):
self.payload["inline"] = template.to_json()
return self

def when(self, when: str):
self.payload["when"] = when
return self

def with_param(self, with_param):
self.payload["withParam"] = with_param
return self
Expand Down
Loading
Loading