Skip to content

Commit

Permalink
Syntactic sugar to add a task to the zone
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 13, 2024
1 parent 8089a67 commit 73567f7
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
6 changes: 6 additions & 0 deletions src/aiida_workgraph/tasks/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.children = TaskCollection(parent=self)

def add_task(self, *args, **kwargs):
"""Syntactic sugar to add a task to the zone."""
task = self.parent.add_task(*args, **kwargs)
self.children.add(task)
return task

def create_sockets(self) -> None:
self.inputs._clear()
self.outputs._clear()
Expand Down
5 changes: 2 additions & 3 deletions tests/test_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ def test_if_task(decorated_add, decorated_multiply, decorated_compare):
wg = WorkGraph("test_if")
add1 = wg.add_task(decorated_add, name="add1", x=1, y=1)
condition1 = wg.add_task(decorated_compare, name="condition1", x=1, y=0)
add2 = wg.add_task(decorated_add, name="add2", x=add1.outputs.result, y=2)
if1 = wg.add_task("If", name="if_true", conditions=condition1.outputs.result)
if1.children.add("add2")
if_zone = wg.add_task("If", name="if_true", conditions=condition1.outputs.result)
add2 = if_zone.add_task(decorated_add, name="add2", x=add1.outputs.result, y=2)
multiply1 = wg.add_task(
decorated_multiply, name="multiply1", x=add1.outputs.result, y=2
)
Expand Down
16 changes: 6 additions & 10 deletions tests/test_zone.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@ def test_zone_task(decorated_add):
"""Test the zone task."""

wg = WorkGraph("test_zone")
wg.context = {}
add1 = wg.add_task(decorated_add, name="add1", x=1, y=1)
wg.add_task(decorated_add, name="add2", x=1, y=1)
add3 = wg.add_task(decorated_add, name="add3", x=1, y=add1.outputs.result)
wg.add_task(decorated_add, name="add4", x=1, y=add3.outputs.result)
wg.add_task(decorated_add, name="add5", x=1, y=add3.outputs.result)
zone1 = wg.add_task("workgraph.zone", name="Zone1")
zone1.children.add(["add2", "add3"])
zone1 = wg.add_task("workgraph.zone", name="zone1")
zone1.add_task(decorated_add, name="add2", x=1, y=1)
zone1.add_task(decorated_add, name="add3", x=1, y=add1.outputs.result)
wg.add_task(decorated_add, name="add4", x=1, y=wg.tasks.add2.outputs.result)
wg.add_task(decorated_add, name="add5", x=1, y=wg.tasks.add3.outputs.result)
wg.run()
report = get_workchain_report(wg.process, "REPORT")
assert "tasks ready to run: add1" in report
assert "tasks ready to run: add2,add3" in report
assert "tasks ready to run: add4" in report
assert "tasks ready to run: add5" in report
assert "tasks ready to run: add4,add5" in report

0 comments on commit 73567f7

Please sign in to comment.