Skip to content

Commit

Permalink
Add more test for workgraph.py
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 4, 2024
1 parent dd6fc9b commit efa5f3d
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/aiida_workgraph/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER

WORKGRAPH_EXTRA_KEY = "_workgraph"


def load_config() -> dict:
"""Load the configuration from the config file."""
Expand Down
3 changes: 2 additions & 1 deletion src/aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,9 @@ def setup_ctx_workgraph(self, wgdata: t.Dict[str, t.Any]) -> None:
def read_wgdata_from_base(self) -> t.Dict[str, t.Any]:
"""Read workgraph data from base.extras."""
from aiida_workgraph.orm.function_data import PickledLocalFunction
from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY

wgdata = self.node.base.extras.get("_workgraph")
wgdata = self.node.base.extras.get(WORKGRAPH_EXTRA_KEY)
for name, task in wgdata["tasks"].items():
wgdata["tasks"][name] = deserialize_unsafe(task)
for _, input in wgdata["tasks"][name]["inputs"].items():
Expand Down
3 changes: 2 additions & 1 deletion src/aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,11 @@ def get_workgraph_data(process: Union[int, orm.Node]) -> Optional[Dict[str, Any]
"""Get the workgraph data from the process node."""
from aiida.orm.utils.serialize import deserialize_unsafe
from aiida.orm import load_node
from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY

if isinstance(process, int):
process = load_node(process)
wgdata = process.base.extras.get("_workgraph", None)
wgdata = process.base.extras.get(WORKGRAPH_EXTRA_KEY, None)
if wgdata is None:
return
for name, task in wgdata["tasks"].items():
Expand Down
7 changes: 4 additions & 3 deletions src/aiida_workgraph/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# import datetime
from aiida.orm import ProcessNode
from aiida.orm.utils.serialize import serialize, deserialize_unsafe
from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY


class WorkGraphSaver:
Expand Down Expand Up @@ -223,7 +224,7 @@ def insert_workgraph_to_db(self) -> None:
# nodes is a copy of tasks, so we need to pop it out
self.wgdata["error_handlers"] = serialize(self.wgdata["error_handlers"])
self.wgdata["context"] = serialize(self.wgdata["context"])
self.process.base.extras.set("_workgraph", self.wgdata)
self.process.base.extras.set(WORKGRAPH_EXTRA_KEY, self.wgdata)

def save_task_states(self) -> Dict:
"""Get task states."""
Expand Down Expand Up @@ -277,7 +278,7 @@ def get_wgdata_from_db(
) -> Optional[Dict]:

process = self.process if process is None else process
wgdata = process.base.extras.get("_workgraph", None)
wgdata = process.base.extras.get(WORKGRAPH_EXTRA_KEY, None)
if wgdata is None:
print("No workgraph data found in the process node.")
return
Expand Down Expand Up @@ -318,7 +319,7 @@ def exist_in_db(self) -> bool:
Returns:
bool: _description_
"""
if self.process.base.extras.get("_workgraph", None) is not None:
if self.process.base.extras.get(WORKGRAPH_EXTRA_KEY, None) is not None:
return True
return False

Expand Down
3 changes: 1 addition & 2 deletions src/aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,7 @@ def load(cls, pk: int) -> Optional["WorkGraph"]:
process = aiida.orm.load_node(pk)
wgdata = get_workgraph_data(process)
if wgdata is None:
print("No workgraph data found in the process node.")
return
raise ValueError(f"WorkGraph data not found for process {pk}.")
wg = cls.from_dict(wgdata)
wg.process = process
wg.update()
Expand Down
1 change: 1 addition & 0 deletions tests/test_error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def handle_negative_sum(task: Task):
}
},
)
assert len(wg.error_handlers) == 1
wg.submit(
inputs={
"add1": {"code": add_code, "x": orm.Int(1), "y": orm.Int(-2)},
Expand Down
33 changes: 32 additions & 1 deletion tests/test_workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,29 @@ def test_add_task():
assert len(wg.links) == 1


def test_show_state(wg_calcfunction):
from io import StringIO
import sys

# Redirect stdout to capture prints
captured_output = StringIO()
sys.stdout = captured_output
# Call the method
wg_calcfunction.name = "test_show_state"
wg_calcfunction.show()
# Reset stdout
sys.stdout = sys.__stdout__
# Check the output
output = captured_output.getvalue()
assert "WorkGraph: test_show_state, PK: None, State: CREATED" in output
assert "sumdiff1" in output
assert "PLANNED" in output


def test_save_load(wg_calcfunction):
"""Save the workgraph"""
from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY

wg = wg_calcfunction
wg.name = "test_save_load"
wg.save()
Expand All @@ -38,6 +59,12 @@ def test_save_load(wg_calcfunction):
assert wg.process.label == "test_save_load"
wg2 = WorkGraph.load(wg.process.pk)
assert len(wg.tasks) == len(wg2.tasks)
# remove the extra
wg.process.base.extras.delete(WORKGRAPH_EXTRA_KEY)
with pytest.raises(
ValueError, match=f"WorkGraph data not found for process {wg.process.pk}."
):
WorkGraph.load(wg.process.pk)


def test_organize_nested_inputs():
Expand Down Expand Up @@ -86,7 +113,7 @@ def test_reset_message(wg_calcjob):
assert "Action: reset. {'add2'}" in report


def test_restart(wg_calcfunction):
def test_restart_and_reset(wg_calcfunction):
"""Restart from a finished workgraph.
Load the workgraph, modify the task, and restart the workgraph.
Only the modified node and its child tasks will be rerun."""
Expand All @@ -109,6 +136,10 @@ def test_restart(wg_calcfunction):
assert wg1.tasks["sumdiff2"].node.pk != wg.tasks["sumdiff2"].pk
assert wg1.tasks["sumdiff3"].node.pk != wg.tasks["sumdiff3"].pk
assert wg1.tasks["sumdiff3"].node.outputs.sum == 19
wg1.reset()
assert wg1.process is None
assert wg1.tasks["sumdiff3"].process is None
assert wg1.tasks["sumdiff3"].state == "PLANNED"


def test_extend_workgraph(decorated_add_multiply_group):
Expand Down
4 changes: 4 additions & 0 deletions tests/widget/test_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def test_workgraph_widget(wg_calcfunction):
# to_html
data = wg.to_html()
assert isinstance(data, IFrame)
# check _repr_mimebundle_ is working
data = wg._repr_mimebundle_()


def test_workgraph_task(wg_calcfunction):
Expand All @@ -26,3 +28,5 @@ def test_workgraph_task(wg_calcfunction):
# to html
data = wg.tasks["sumdiff2"].to_html()
assert isinstance(data, IFrame)
# check _repr_mimebundle_ is working
data = wg._repr_mimebundle_()

0 comments on commit efa5f3d

Please sign in to comment.