-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix recursive workgraph #336
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #336 +/- ##
==========================================
+ Coverage 75.75% 80.61% +4.86%
==========================================
Files 70 66 -4
Lines 4615 5139 +524
==========================================
+ Hits 3496 4143 +647
+ Misses 1119 996 -123
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
@@ -1032,6 +1032,13 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None | |||
|
|||
self.report(f"Run task: {name}, type: {task['metadata']['node_type']}") | |||
executor, _ = get_executor(task["executor"]) | |||
# Add the executor to the globals so that it can be used in the task | |||
# in the case of recursive workgraph | |||
# We also need to rebuild the Task calss and attach it to the executor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# We also need to rebuild the Task calss and attach it to the executor | |
# We also need to rebuild the Task calls and attach it to the executor |
@@ -1032,6 +1032,13 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None | |||
|
|||
self.report(f"Run task: {name}, type: {task['metadata']['node_type']}") | |||
executor, _ = get_executor(task["executor"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What kind of task is this? When I create add a task in WorkGraph there is no task["executor"]
but I can do task.get_executor()
if task["metadata"]["node_type"].upper() == "GRAPH_BUILDER": | ||
task_class = Task.from_dict(self.ctx._tasks[name]) | ||
executor.node = executor.task = task_class.__class__ | ||
executor.__globals__[executor.__name__] = executor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My concern is that executors will have the same name as it is not always unique and then override each other. For calcfunctions for example it is the name of the function which is reasonable, as it is only overridden if a function with the same name is defined. But for the case of local functions, this overrides any global definition.
from aiida_workgraph import task, WorkGraph
from aiida.engine import calcfunction
from aiida import load_profile
load_profile()
@task.graph_builder()
def my_add():
@calcfunction
def add(x, y):
return x+y
wg = WorkGraph()
task = wg.add_task(add, x=1, y=1)
return wg
wg = my_add()
print(wg.tasks["add1"].get_executor()['name']) # out 'add' but better 'my_add.add'
I guess solving this issue requires much more work and since we don't have any examples defining calcfunctions locally (however we have to load codes locally in graph_builder), I think it is not very crucial, but an issue would be nice to keep this in mind.
Fix #333 .
There seems to be a bug in
cloudpickle
, that the decorated function is not in theglobals
.This PR provides a temporary solution.
There is also a bug related to
group_outputs
, which is fixed in #335 .