Skip to content

Commit

Permalink
Add more test for workgraph, task, and property (#378)
Browse files Browse the repository at this point in the history
* Add more tests for workgraph.py
* add set_from_builder for task
* Add orm.List ad orm.Dict socket and into type_mapping
* Update test for vector property
  • Loading branch information
superstar54 authored Dec 4, 2024
1 parent dd6fc9b commit b182af0
Show file tree
Hide file tree
Showing 18 changed files with 150 additions and 16 deletions.
16 changes: 16 additions & 0 deletions docs/gallery/autogen/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,22 @@ def multiply(x, y):
generate_node_graph(wg.pk)


######################################################################
# One can also set task inputs from an AiiDA process builder directly.
#

from aiida.calculations.arithmetic.add import ArithmeticAddCalculation

builder = ArithmeticAddCalculation.get_builder()
builder.code = code
builder.x = Int(2)
builder.y = Int(3)

wg = WorkGraph("test_set_inputs_from_builder")
add1 = wg.add_task(ArithmeticAddCalculation, name="add1")
add1.set_from_builder(builder)


######################################################################
# Graph builder
# -------------
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph"
"workgraph.aiida_bool" = "aiida_workgraph.properties.builtins:PropertyAiiDABool"
"workgraph.aiida_int_vector" = "aiida_workgraph.properties.builtins:PropertyAiiDAIntVector"
"workgraph.aiida_float_vector" = "aiida_workgraph.properties.builtins:PropertyAiiDAFloatVector"
"workgraph.aiida_aiida_dict" = "aiida_workgraph.properties.builtins:PropertyAiiDADict"
"workgraph.aiida_list" = "aiida_workgraph.properties.builtins:PropertyAiiDAList"
"workgraph.aiida_dict" = "aiida_workgraph.properties.builtins:PropertyAiiDADict"
"workgraph.aiida_structuredata" = "aiida_workgraph.properties.builtins:PropertyStructureData"

[project.entry-points."aiida_workgraph.socket"]
Expand All @@ -138,6 +139,8 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph"
"workgraph.aiida_bool" = "aiida_workgraph.sockets.builtins:SocketAiiDABool"
"workgraph.aiida_int_vector" = "aiida_workgraph.sockets.builtins:SocketAiiDAIntVector"
"workgraph.aiida_float_vector" = "aiida_workgraph.sockets.builtins:SocketAiiDAFloatVector"
"workgraph.aiida_list" = "aiida_workgraph.sockets.builtins:SocketAiiDAList"
"workgraph.aiida_dict" = "aiida_workgraph.sockets.builtins:SocketAiiDADict"
"workgraph.aiida_structuredata" = "aiida_workgraph.sockets.builtins:SocketStructureData"


Expand Down
2 changes: 2 additions & 0 deletions src/aiida_workgraph/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER

WORKGRAPH_EXTRA_KEY = "_workgraph"


def load_config() -> dict:
"""Load the configuration from the config file."""
Expand Down
2 changes: 2 additions & 0 deletions src/aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
orm.Float: "workgraph.aiida_float",
orm.Str: "workgraph.aiida_string",
orm.Bool: "workgraph.aiida_bool",
orm.List: "workgraph.aiida_list",
orm.Dict: "workgraph.aiida_dict",
orm.StructureData: "workgraph.aiida_structuredata",
}

Expand Down
3 changes: 2 additions & 1 deletion src/aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,9 @@ def setup_ctx_workgraph(self, wgdata: t.Dict[str, t.Any]) -> None:
def read_wgdata_from_base(self) -> t.Dict[str, t.Any]:
"""Read workgraph data from base.extras."""
from aiida_workgraph.orm.function_data import PickledLocalFunction
from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY

wgdata = self.node.base.extras.get("_workgraph")
wgdata = self.node.base.extras.get(WORKGRAPH_EXTRA_KEY)
for name, task in wgdata["tasks"].items():
wgdata["tasks"][name] = deserialize_unsafe(task)
for _, input in wgdata["tasks"][name]["inputs"].items():
Expand Down
12 changes: 12 additions & 0 deletions src/aiida_workgraph/properties/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ def validate(self, value: any) -> None:
)


class PropertyAiiDAList(TaskProperty):
"""A new class for List type."""

identifier: str = "workgraph.aiida_list"
allowed_types = (list, orm.List, str, type(None))

def set_value(self, value: Union[list, orm.List, str] = None) -> None:
if isinstance(value, (list)):
value = orm.List(list=value)
super().set_value(value)


class PropertyAiiDADict(TaskProperty):
"""A new class for Dict type."""

Expand Down
14 changes: 14 additions & 0 deletions src/aiida_workgraph/sockets/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ class SocketAiiDABool(TaskSocket):
property_identifier: str = "workgraph.aiida_bool"


class SocketAiiDAList(TaskSocket):
"""AiiDAList socket."""

identifier: str = "workgraph.aiida_list"
property_identifier: str = "workgraph.aiida_list"


class SocketAiiDADict(TaskSocket):
"""AiiDADict socket."""

identifier: str = "workgraph.aiida_dict"
property_identifier: str = "workgraph.aiida_dict"


class SocketAiiDAIntVector(TaskSocket):
"""Socket with a AiiDAIntVector property."""

Expand Down
17 changes: 14 additions & 3 deletions src/aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,25 @@ def set_context(self, context: Dict[str, Any]) -> None:
raise ValueError(msg)
self.context_mapping.update(context)

def set_from_builder(self, builder: Any) -> None:
"""Set the task inputs from a AiiDA ProcessBuilder."""
from aiida_workgraph.utils import get_dict_from_builder

data = get_dict_from_builder(builder)
self.set(data)

def set_from_protocol(self, *args: Any, **kwargs: Any) -> None:
"""Set the task inputs from protocol data."""
from aiida_workgraph.utils import get_executor, get_dict_from_builder
from aiida_workgraph.utils import get_executor

executor = get_executor(self.get_executor())[0]
# check if the executor has the get_builder_from_protocol method
if not hasattr(executor, "get_builder_from_protocol"):
raise AttributeError(
f"Executor {executor.__name__} does not have the get_builder_from_protocol method."
)
builder = executor.get_builder_from_protocol(*args, **kwargs)
data = get_dict_from_builder(builder)
self.set(data)
self.set_from_builder(builder)

@classmethod
def new(
Expand Down
3 changes: 2 additions & 1 deletion src/aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,11 @@ def get_workgraph_data(process: Union[int, orm.Node]) -> Optional[Dict[str, Any]
"""Get the workgraph data from the process node."""
from aiida.orm.utils.serialize import deserialize_unsafe
from aiida.orm import load_node
from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY

if isinstance(process, int):
process = load_node(process)
wgdata = process.base.extras.get("_workgraph", None)
wgdata = process.base.extras.get(WORKGRAPH_EXTRA_KEY, None)
if wgdata is None:
return
for name, task in wgdata["tasks"].items():
Expand Down
7 changes: 4 additions & 3 deletions src/aiida_workgraph/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# import datetime
from aiida.orm import ProcessNode
from aiida.orm.utils.serialize import serialize, deserialize_unsafe
from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY


class WorkGraphSaver:
Expand Down Expand Up @@ -223,7 +224,7 @@ def insert_workgraph_to_db(self) -> None:
# nodes is a copy of tasks, so we need to pop it out
self.wgdata["error_handlers"] = serialize(self.wgdata["error_handlers"])
self.wgdata["context"] = serialize(self.wgdata["context"])
self.process.base.extras.set("_workgraph", self.wgdata)
self.process.base.extras.set(WORKGRAPH_EXTRA_KEY, self.wgdata)

def save_task_states(self) -> Dict:
"""Get task states."""
Expand Down Expand Up @@ -277,7 +278,7 @@ def get_wgdata_from_db(
) -> Optional[Dict]:

process = self.process if process is None else process
wgdata = process.base.extras.get("_workgraph", None)
wgdata = process.base.extras.get(WORKGRAPH_EXTRA_KEY, None)
if wgdata is None:
print("No workgraph data found in the process node.")
return
Expand Down Expand Up @@ -318,7 +319,7 @@ def exist_in_db(self) -> bool:
Returns:
bool: _description_
"""
if self.process.base.extras.get("_workgraph", None) is not None:
if self.process.base.extras.get(WORKGRAPH_EXTRA_KEY, None) is not None:
return True
return False

Expand Down
3 changes: 1 addition & 2 deletions src/aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,7 @@ def load(cls, pk: int) -> Optional["WorkGraph"]:
process = aiida.orm.load_node(pk)
wgdata = get_workgraph_data(process)
if wgdata is None:
print("No workgraph data found in the process node.")
return
raise ValueError(f"WorkGraph data not found for process {pk}.")
wg = cls.from_dict(wgdata)
wg.process = process
wg.update()
Expand Down
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def add_code(fixture_localhost):
from aiida.orm import InstalledCode

code = InstalledCode(
label="add", computer=fixture_localhost, filepath_executable="/bin/bash"
label="add",
computer=fixture_localhost,
filepath_executable="/bin/bash",
default_calc_job_plugin="arithmetic.add",
)
code.store()
return code
Expand Down
6 changes: 6 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def test_load_config():
from aiida_workgraph.config import load_config

config = load_config()
assert isinstance(config, dict)
assert config == {}
1 change: 1 addition & 0 deletions tests/test_error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def handle_negative_sum(task: Task):
}
},
)
assert len(wg.error_handlers) == 1
wg.submit(
inputs={
"add1": {"code": add_code, "x": orm.Int(1), "y": orm.Int(-2)},
Expand Down
12 changes: 9 additions & 3 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
(orm.Str, "abc", "workgraph.aiida_string"),
(orm.Bool, True, "workgraph.aiida_bool"),
(orm.Bool, "{{variable}}", "workgraph.aiida_bool"),
(orm.List, [1, 2, 3], "workgraph.aiida_list"),
(orm.Dict, {"a": 1}, "workgraph.aiida_dict"),
),
)
def test_type_mapping(data_type, data, identifier) -> None:
Expand Down Expand Up @@ -46,10 +48,14 @@ def test_vector_socket() -> None:
"vector2d",
property_data={"size": 2, "default": [1, 2]},
)
try:
assert t.inputs["vector2d"].property.get_metadata() == {
"size": 2,
"default": [1, 2],
}
with pytest.raises(ValueError, match="Invalid size: Expected 2, got 3 instead."):
t.inputs["vector2d"].value = [1, 2, 3]
except Exception as e:
assert "Invalid size: Expected 2, got 3 instead." in str(e)
with pytest.raises(ValueError, match="Invalid item type: Expected "):
t.inputs["vector2d"].value = [1.1, 2.2]


def test_aiida_data_socket() -> None:
Expand Down
21 changes: 21 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,24 @@ def test_set_inputs(decorated_add: Callable) -> None:
]
is False
)


def test_set_inputs_from_builder(add_code) -> None:
"""Test setting inputs of a task from a builder function."""
from aiida.calculations.arithmetic.add import ArithmeticAddCalculation

wg = WorkGraph(name="test_set_inputs_from_builder")
add1 = wg.add_task(ArithmeticAddCalculation, "add1")
# create the builder
builder = add_code.get_builder()
builder.x = 1
builder.y = 2
add1.set_from_builder(builder)
assert add1.inputs["x"].value == 1
assert add1.inputs["y"].value == 2
assert add1.inputs["code"].value == add_code
with pytest.raises(
AttributeError,
match=f"Executor {ArithmeticAddCalculation.__name__} does not have the get_builder_from_protocol method.",
):
add1.set_from_protocol(code=add_code, protocol="fast")
33 changes: 32 additions & 1 deletion tests/test_workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,29 @@ def test_add_task():
assert len(wg.links) == 1


def test_show_state(wg_calcfunction):
from io import StringIO
import sys

# Redirect stdout to capture prints
captured_output = StringIO()
sys.stdout = captured_output
# Call the method
wg_calcfunction.name = "test_show_state"
wg_calcfunction.show()
# Reset stdout
sys.stdout = sys.__stdout__
# Check the output
output = captured_output.getvalue()
assert "WorkGraph: test_show_state, PK: None, State: CREATED" in output
assert "sumdiff1" in output
assert "PLANNED" in output


def test_save_load(wg_calcfunction):
"""Save the workgraph"""
from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY

wg = wg_calcfunction
wg.name = "test_save_load"
wg.save()
Expand All @@ -38,6 +59,12 @@ def test_save_load(wg_calcfunction):
assert wg.process.label == "test_save_load"
wg2 = WorkGraph.load(wg.process.pk)
assert len(wg.tasks) == len(wg2.tasks)
# remove the extra
wg.process.base.extras.delete(WORKGRAPH_EXTRA_KEY)
with pytest.raises(
ValueError, match=f"WorkGraph data not found for process {wg.process.pk}."
):
WorkGraph.load(wg.process.pk)


def test_organize_nested_inputs():
Expand Down Expand Up @@ -86,7 +113,7 @@ def test_reset_message(wg_calcjob):
assert "Action: reset. {'add2'}" in report


def test_restart(wg_calcfunction):
def test_restart_and_reset(wg_calcfunction):
"""Restart from a finished workgraph.
Load the workgraph, modify the task, and restart the workgraph.
Only the modified node and its child tasks will be rerun."""
Expand All @@ -109,6 +136,10 @@ def test_restart(wg_calcfunction):
assert wg1.tasks["sumdiff2"].node.pk != wg.tasks["sumdiff2"].pk
assert wg1.tasks["sumdiff3"].node.pk != wg.tasks["sumdiff3"].pk
assert wg1.tasks["sumdiff3"].node.outputs.sum == 19
wg1.reset()
assert wg1.process is None
assert wg1.tasks["sumdiff3"].process is None
assert wg1.tasks["sumdiff3"].state == "PLANNED"


def test_extend_workgraph(decorated_add_multiply_group):
Expand Down
4 changes: 4 additions & 0 deletions tests/widget/test_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def test_workgraph_widget(wg_calcfunction):
# to_html
data = wg.to_html()
assert isinstance(data, IFrame)
# check _repr_mimebundle_ is working
data = wg._repr_mimebundle_()


def test_workgraph_task(wg_calcfunction):
Expand All @@ -26,3 +28,5 @@ def test_workgraph_task(wg_calcfunction):
# to html
data = wg.tasks["sumdiff2"].to_html()
assert isinstance(data, IFrame)
# check _repr_mimebundle_ is working
data = wg.tasks["sumdiff2"]._repr_mimebundle_()

0 comments on commit b182af0

Please sign in to comment.