Skip to content

Commit

Permalink
add builtin_inputs and outputs in config
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 5, 2024
1 parent d44c9e4 commit f9e5344
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 84 deletions.
5 changes: 5 additions & 0 deletions src/aiida_workgraph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER

WORKGRAPH_EXTRA_KEY = "_workgraph"
WORKGRAPH_SHORT_EXTRA_KEY = "_workgraph_short"


builtin_inputs = [{"name": "_wait", "link_limit": 1e6, "arg_type": "none"}]
builtin_outputs = [{"name": "_wait"}, {"name": "_outputs"}]


def load_config() -> dict:
Expand Down
53 changes: 17 additions & 36 deletions src/aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from aiida_workgraph.task import Task
from aiida_workgraph.utils import build_callable, validate_task_inout
import inspect
from aiida_workgraph.config import builtin_inputs, builtin_outputs

task_types = {
CalcFunctionNode: "CALCFUNCTION",
Expand Down Expand Up @@ -265,16 +266,10 @@ def build_task_from_AiiDA(
else outputs
)
# add built-in sockets
outputs.append({"identifier": "workgraph.any", "name": "_outputs"})
outputs.append({"identifier": "workgraph.any", "name": "_wait"})
inputs.append(
{
"identifier": "workgraph.any",
"name": "_wait",
"link_limit": 1e6,
"arg_type": "none",
}
)
for output in builtin_outputs:
outputs.append(output.copy())
for input in builtin_inputs:
inputs.append(input.copy())
tdata["metadata"]["node_class"] = {"module": "aiida_workgraph.task", "name": "Task"}
tdata["inputs"] = inputs
tdata["outputs"] = outputs
Expand All @@ -301,9 +296,6 @@ def build_pythonjob_task(func: Callable) -> Task:
}
_, tdata_py = build_task_from_AiiDA(tdata)
tdata = deepcopy(func.tdata)
function_inputs = [
name for name in tdata["inputs"] if name not in ["_wait", "_outputs"]
]
# merge the inputs and outputs from the PythonJob task to the function task
# skip the already existed inputs and outputs
for input in [
Expand Down Expand Up @@ -333,7 +325,6 @@ def build_pythonjob_task(func: Callable) -> Task:
}
task = create_task(tdata)
task.is_aiida_component = True
task.function_inputs = function_inputs
return task, tdata


Expand Down Expand Up @@ -390,6 +381,8 @@ def build_task_from_workgraph(wg: any) -> Task:
outputs = []
group_outputs = []
# add all the inputs/outputs from the tasks in the workgraph
builtin_input_names = [input["name"] for input in builtin_inputs]
builtin_output_names = [output["name"] for output in builtin_outputs]
for task in wg.tasks:
# inputs
inputs.append(
Expand All @@ -399,7 +392,7 @@ def build_task_from_workgraph(wg: any) -> Task:
}
)
for socket in task.inputs:
if socket.name == "_wait":
if socket.name in builtin_input_names:
continue
inputs.append(
{"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"}
Expand All @@ -412,7 +405,7 @@ def build_task_from_workgraph(wg: any) -> Task:
}
)
for socket in task.outputs:
if socket.name in ["_wait", "_outputs"]:
if socket.name in builtin_output_names:
continue
outputs.append(
{"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"}
Expand All @@ -424,16 +417,10 @@ def build_task_from_workgraph(wg: any) -> Task:
}
)
# add built-in sockets
outputs.append({"identifier": "workgraph.any", "name": "_outputs"})
outputs.append({"identifier": "workgraph.any", "name": "_wait"})
inputs.append(
{
"identifier": "workgraph.any",
"name": "_wait",
"link_limit": 1e6,
"arg_type": "none",
}
)
for output in builtin_outputs:
outputs.append(output.copy())
for input in builtin_inputs:
inputs.append(input.copy())
tdata["metadata"]["node_class"] = {"module": "aiida_workgraph.task", "name": "Task"}
tdata["inputs"] = inputs
tdata["outputs"] = outputs
Expand Down Expand Up @@ -507,16 +494,10 @@ def generate_tdata(
input["metadata"]["is_function_input"] = True
task_outputs = outputs
# add built-in sockets
task_inputs.append(
{
"identifier": "workgraph.any",
"name": "_wait",
"link_limit": 1e6,
"arg_type": "none",
}
)
task_outputs.append({"identifier": "workgraph.any", "name": "_wait"})
task_outputs.append({"identifier": "workgraph.any", "name": "_outputs"})
for output in builtin_outputs:
task_outputs.append(output.copy())
for input in builtin_inputs:
task_inputs.append(input.copy())
tdata = {
"identifier": identifier,
"metadata": {
Expand Down
27 changes: 7 additions & 20 deletions src/aiida_workgraph/tasks/pythonjob.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict
from aiida import orm
from aiida_pythonjob.data.serializer import general_serializer
from aiida_workgraph.task import Task
Expand All @@ -9,28 +9,18 @@ class PythonJob(Task):

identifier = "workgraph.pythonjob"

function_inputs: List = None

def update_from_dict(self, data: Dict[str, Any], **kwargs) -> "PythonJob":
"""Overwrite the update_from_dict method to handle the PythonJob data."""
self.function_inputs = data.get("function_inputs", [])
self.deserialize_pythonjob_data(data)
super().update_from_dict(data)

def to_dict(self, short: bool = False) -> Dict[str, Any]:
data = super().to_dict(short=short)
data["function_inputs"] = self.function_inputs
return data

@classmethod
def serialize_pythonjob_data(cls, tdata: Dict[str, Any]):
"""Serialize the properties for PythonJob."""

input_kwargs = tdata.get("function_inputs", [])
for name in input_kwargs:
tdata["inputs"][name]["property"]["value"] = cls.serialize_socket_data(
tdata["inputs"][name]
)
for input in tdata["inputs"].values():
if input["metadata"].get("is_function_input", False):
input["property"]["value"] = cls.serialize_socket_data(input)

@classmethod
def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None:
Expand All @@ -45,13 +35,10 @@ def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None:
Returns:
Dict[str, Any]: The processed data dictionary.
"""
input_kwargs = tdata.get("function_inputs", [])

for name in input_kwargs:
if name in tdata["inputs"]:
tdata["inputs"][name]["property"][
"value"
] = cls.deserialize_socket_data(tdata["inputs"][name])
for input in tdata["inputs"].values():
if input["metadata"].get("is_function_input", False):
input["property"]["value"] = cls.deserialize_socket_data(input)

@classmethod
def serialize_socket_data(cls, data: Dict[str, Any]) -> Any:
Expand Down
3 changes: 3 additions & 0 deletions src/aiida_workgraph/tasks/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def create_sockets(self) -> None:
self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.outputs.new("workgraph.aiida_float", "sum")
self.outputs.new("workgraph.any", "_wait")
self.outputs.new("workgraph.any", "_outputs")


class TestSumDiff(Task):
Expand Down Expand Up @@ -54,6 +55,7 @@ def create_sockets(self) -> None:
self.outputs.new("workgraph.aiida_float", "sum")
self.outputs.new("workgraph.aiida_float", "diff")
self.outputs.new("workgraph.any", "_wait")
self.outputs.new("workgraph.any", "_outputs")


class TestArithmeticMultiplyAdd(Task):
Expand Down Expand Up @@ -84,3 +86,4 @@ def create_sockets(self) -> None:
self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.outputs.new("workgraph.aiida_int", "result")
self.outputs.new("workgraph.any", "_wait")
self.outputs.new("workgraph.any", "_outputs")
4 changes: 2 additions & 2 deletions src/aiida_workgraph/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# import datetime
from aiida.orm import ProcessNode
from aiida.orm.utils.serialize import serialize, deserialize_unsafe
from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY
from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY, WORKGRAPH_SHORT_EXTRA_KEY


class WorkGraphSaver:
Expand Down Expand Up @@ -211,7 +211,7 @@ def insert_workgraph_to_db(self) -> None:
# self.wgdata["created"] = datetime.datetime.utcnow()
# self.wgdata["lastUpdate"] = datetime.datetime.utcnow()
short_wgdata = workgraph_to_short_json(self.wgdata)
self.process.base.extras.set("_workgraph_short", short_wgdata)
self.process.base.extras.set(WORKGRAPH_SHORT_EXTRA_KEY, short_wgdata)
self.save_task_states()
for name, task in self.wgdata["tasks"].items():
for _, input in task["inputs"].items():
Expand Down
4 changes: 2 additions & 2 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_decorators_calcfunction_args(task_calcfunction) -> None:
assert set(tdata["kwargs"]) == set(kwargs)
assert tdata["var_args"] is None
assert tdata["var_kwargs"] == "c"
assert n.outputs.keys() == ["result", "_outputs", "_wait"]
assert set(n.outputs.keys()) == set(["result", "_outputs", "_wait"])


@pytest.fixture(params=["decorator_factory", "decorator"])
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_decorators_workfunction_args(task_workfunction) -> None:
assert set(tdata["kwargs"]) == set(kwargs)
assert tdata["var_args"] is None
assert tdata["var_kwargs"] == "c"
assert n.outputs.keys() == ["result", "_outputs", "_wait"]
assert set(n.outputs.keys()) == set(["result", "_outputs", "_wait"])


def test_decorators_parameters() -> None:
Expand Down
44 changes: 44 additions & 0 deletions tests/test_task_from_workgraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
from aiida_workgraph import WorkGraph
from typing import Callable


def test_inputs_outptus(wg_calcfunction: WorkGraph) -> None:
"""Test the inputs and outputs of the WorkGraph."""
wg = WorkGraph(name="test_inputs_outptus")
task1 = wg.add_task(wg_calcfunction, name="add1")
ninput = 0
for sub_task in wg_calcfunction.tasks:
# remove _wait, but add the namespace
ninput += len(sub_task.inputs) - 1 + 1
noutput = 0
for sub_task in wg_calcfunction.tasks:
noutput += len(sub_task.outputs) - 2 + 1
assert len(task1.inputs) == ninput + 1
assert len(task1.outputs) == noutput + 2
assert "sumdiff1.x" in task1.inputs.keys()
assert "sumdiff1.sum" in task1.outputs.keys()


@pytest.mark.usefixtures("started_daemon_client")
def test_build_task_from_workgraph(decorated_add: Callable) -> None:
# create a sub workgraph
sub_wg = WorkGraph("build_task_from_workgraph")
sub_wg.add_task(decorated_add, name="add1", x=1, y=3)
sub_wg.add_task(
decorated_add, name="add2", x=2, y=sub_wg.tasks["add1"].outputs["result"]
)
#
wg = WorkGraph("build_task_from_workgraph")
add1_task = wg.add_task(decorated_add, name="add1", x=1, y=3)
wg_task = wg.add_task(sub_wg, name="sub_wg")
# the default value of the namespace is None
assert wg_task.inputs["add1"].value is None
wg.add_task(decorated_add, name="add2", y=3)
wg.add_link(add1_task.outputs["result"], wg_task.inputs["add1.x"])
wg.add_link(wg_task.outputs["add2.result"], wg.tasks["add2"].inputs["x"])
assert len(wg_task.inputs) == 21
assert len(wg_task.outputs) == 6
wg.submit(wait=True)
# wg.run()
assert wg.tasks["add2"].outputs["result"].value.value == 12
24 changes: 0 additions & 24 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,30 +43,6 @@ def test_task_collection(decorated_add: Callable) -> None:
assert len(task1.waiting_on) == 0


@pytest.mark.usefixtures("started_daemon_client")
def test_build_task_from_workgraph(decorated_add: Callable) -> None:
# create a sub workgraph
sub_wg = WorkGraph("build_task_from_workgraph")
sub_wg.add_task(decorated_add, name="add1", x=1, y=3)
sub_wg.add_task(
decorated_add, name="add2", x=2, y=sub_wg.tasks["add1"].outputs["result"]
)
#
wg = WorkGraph("build_task_from_workgraph")
add1_task = wg.add_task(decorated_add, name="add1", x=1, y=3)
wg_task = wg.add_task(sub_wg, name="sub_wg")
# the default value of the namespace is None
assert wg_task.inputs["add1"].value is None
wg.add_task(decorated_add, name="add2", y=3)
wg.add_link(add1_task.outputs["result"], wg_task.inputs["add1.x"])
wg.add_link(wg_task.outputs["add2.result"], wg.tasks["add2"].inputs["x"])
assert len(wg_task.inputs) == 21
assert len(wg_task.outputs) == 6
wg.submit(wait=True)
# wg.run()
assert wg.tasks["add2"].outputs["result"].value.value == 12


@pytest.mark.usefixtures("started_daemon_client")
def test_task_wait(decorated_add: Callable) -> None:
"""Run a WorkGraph with a task that waits on other tasks."""
Expand Down

0 comments on commit f9e5344

Please sign in to comment.