diff --git a/src/aiida_workgraph/config.py b/src/aiida_workgraph/config.py index 519747ec..ab316c38 100644 --- a/src/aiida_workgraph/config.py +++ b/src/aiida_workgraph/config.py @@ -2,6 +2,11 @@ from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER WORKGRAPH_EXTRA_KEY = "_workgraph" +WORKGRAPH_SHORT_EXTRA_KEY = "_workgraph_short" + + +builtin_inputs = [{"name": "_wait", "link_limit": 1e6, "arg_type": "none"}] +builtin_outputs = [{"name": "_wait"}, {"name": "_outputs"}] def load_config() -> dict: diff --git a/src/aiida_workgraph/decorator.py b/src/aiida_workgraph/decorator.py index aa448157..0ab2d2df 100644 --- a/src/aiida_workgraph/decorator.py +++ b/src/aiida_workgraph/decorator.py @@ -10,6 +10,7 @@ from aiida_workgraph.task import Task from aiida_workgraph.utils import build_callable, validate_task_inout import inspect +from aiida_workgraph.config import builtin_inputs, builtin_outputs task_types = { CalcFunctionNode: "CALCFUNCTION", @@ -265,16 +266,10 @@ def build_task_from_AiiDA( else outputs ) # add built-in sockets - outputs.append({"identifier": "workgraph.any", "name": "_outputs"}) - outputs.append({"identifier": "workgraph.any", "name": "_wait"}) - inputs.append( - { - "identifier": "workgraph.any", - "name": "_wait", - "link_limit": 1e6, - "arg_type": "none", - } - ) + for output in builtin_outputs: + outputs.append(output.copy()) + for input in builtin_inputs: + inputs.append(input.copy()) tdata["metadata"]["node_class"] = {"module": "aiida_workgraph.task", "name": "Task"} tdata["inputs"] = inputs tdata["outputs"] = outputs @@ -301,9 +296,6 @@ def build_pythonjob_task(func: Callable) -> Task: } _, tdata_py = build_task_from_AiiDA(tdata) tdata = deepcopy(func.tdata) - function_inputs = [ - name for name in tdata["inputs"] if name not in ["_wait", "_outputs"] - ] # merge the inputs and outputs from the PythonJob task to the function task # skip the already existed inputs and outputs for input in [ @@ -333,7 +325,6 @@ def build_pythonjob_task(func: Callable) -> Task: } task = create_task(tdata) task.is_aiida_component = True - task.function_inputs = function_inputs return task, tdata @@ -390,6 +381,8 @@ def build_task_from_workgraph(wg: any) -> Task: outputs = [] group_outputs = [] # add all the inputs/outputs from the tasks in the workgraph + builtin_input_names = [input["name"] for input in builtin_inputs] + builtin_output_names = [output["name"] for output in builtin_outputs] for task in wg.tasks: # inputs inputs.append( @@ -399,7 +392,7 @@ def build_task_from_workgraph(wg: any) -> Task: } ) for socket in task.inputs: - if socket.name == "_wait": + if socket.name in builtin_input_names: continue inputs.append( {"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"} @@ -412,7 +405,7 @@ def build_task_from_workgraph(wg: any) -> Task: } ) for socket in task.outputs: - if socket.name in ["_wait", "_outputs"]: + if socket.name in builtin_output_names: continue outputs.append( {"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"} @@ -424,16 +417,10 @@ def build_task_from_workgraph(wg: any) -> Task: } ) # add built-in sockets - outputs.append({"identifier": "workgraph.any", "name": "_outputs"}) - outputs.append({"identifier": "workgraph.any", "name": "_wait"}) - inputs.append( - { - "identifier": "workgraph.any", - "name": "_wait", - "link_limit": 1e6, - "arg_type": "none", - } - ) + for output in builtin_outputs: + outputs.append(output.copy()) + for input in builtin_inputs: + inputs.append(input.copy()) tdata["metadata"]["node_class"] = {"module": "aiida_workgraph.task", "name": "Task"} tdata["inputs"] = inputs tdata["outputs"] = outputs @@ -507,16 +494,10 @@ def generate_tdata( input["metadata"]["is_function_input"] = True task_outputs = outputs # add built-in sockets - task_inputs.append( - { - "identifier": "workgraph.any", - "name": "_wait", - "link_limit": 1e6, - "arg_type": "none", - } - ) - task_outputs.append({"identifier": "workgraph.any", "name": "_wait"}) - task_outputs.append({"identifier": "workgraph.any", "name": "_outputs"}) + for output in builtin_outputs: + task_outputs.append(output.copy()) + for input in builtin_inputs: + task_inputs.append(input.copy()) tdata = { "identifier": identifier, "metadata": { diff --git a/src/aiida_workgraph/tasks/pythonjob.py b/src/aiida_workgraph/tasks/pythonjob.py index 0b74ff31..b86a5106 100644 --- a/src/aiida_workgraph/tasks/pythonjob.py +++ b/src/aiida_workgraph/tasks/pythonjob.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict from aiida import orm from aiida_pythonjob.data.serializer import general_serializer from aiida_workgraph.task import Task @@ -9,28 +9,18 @@ class PythonJob(Task): identifier = "workgraph.pythonjob" - function_inputs: List = None - def update_from_dict(self, data: Dict[str, Any], **kwargs) -> "PythonJob": """Overwrite the update_from_dict method to handle the PythonJob data.""" - self.function_inputs = data.get("function_inputs", []) self.deserialize_pythonjob_data(data) super().update_from_dict(data) - def to_dict(self, short: bool = False) -> Dict[str, Any]: - data = super().to_dict(short=short) - data["function_inputs"] = self.function_inputs - return data - @classmethod def serialize_pythonjob_data(cls, tdata: Dict[str, Any]): """Serialize the properties for PythonJob.""" - input_kwargs = tdata.get("function_inputs", []) - for name in input_kwargs: - tdata["inputs"][name]["property"]["value"] = cls.serialize_socket_data( - tdata["inputs"][name] - ) + for input in tdata["inputs"].values(): + if input["metadata"].get("is_function_input", False): + input["property"]["value"] = cls.serialize_socket_data(input) @classmethod def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None: @@ -45,13 +35,10 @@ def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None: Returns: Dict[str, Any]: The processed data dictionary. """ - input_kwargs = tdata.get("function_inputs", []) - for name in input_kwargs: - if name in tdata["inputs"]: - tdata["inputs"][name]["property"][ - "value" - ] = cls.deserialize_socket_data(tdata["inputs"][name]) + for input in tdata["inputs"].values(): + if input["metadata"].get("is_function_input", False): + input["property"]["value"] = cls.deserialize_socket_data(input) @classmethod def serialize_socket_data(cls, data: Dict[str, Any]) -> Any: diff --git a/src/aiida_workgraph/tasks/test.py b/src/aiida_workgraph/tasks/test.py index 166fd131..e4d0bd77 100644 --- a/src/aiida_workgraph/tasks/test.py +++ b/src/aiida_workgraph/tasks/test.py @@ -26,6 +26,7 @@ def create_sockets(self) -> None: self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) self.outputs.new("workgraph.aiida_float", "sum") self.outputs.new("workgraph.any", "_wait") + self.outputs.new("workgraph.any", "_outputs") class TestSumDiff(Task): @@ -54,6 +55,7 @@ def create_sockets(self) -> None: self.outputs.new("workgraph.aiida_float", "sum") self.outputs.new("workgraph.aiida_float", "diff") self.outputs.new("workgraph.any", "_wait") + self.outputs.new("workgraph.any", "_outputs") class TestArithmeticMultiplyAdd(Task): @@ -84,3 +86,4 @@ def create_sockets(self) -> None: self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) self.outputs.new("workgraph.aiida_int", "result") self.outputs.new("workgraph.any", "_wait") + self.outputs.new("workgraph.any", "_outputs") diff --git a/src/aiida_workgraph/utils/analysis.py b/src/aiida_workgraph/utils/analysis.py index f9d43716..9ab5c198 100644 --- a/src/aiida_workgraph/utils/analysis.py +++ b/src/aiida_workgraph/utils/analysis.py @@ -3,7 +3,7 @@ # import datetime from aiida.orm import ProcessNode from aiida.orm.utils.serialize import serialize, deserialize_unsafe -from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY +from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY, WORKGRAPH_SHORT_EXTRA_KEY class WorkGraphSaver: @@ -211,7 +211,7 @@ def insert_workgraph_to_db(self) -> None: # self.wgdata["created"] = datetime.datetime.utcnow() # self.wgdata["lastUpdate"] = datetime.datetime.utcnow() short_wgdata = workgraph_to_short_json(self.wgdata) - self.process.base.extras.set("_workgraph_short", short_wgdata) + self.process.base.extras.set(WORKGRAPH_SHORT_EXTRA_KEY, short_wgdata) self.save_task_states() for name, task in self.wgdata["tasks"].items(): for _, input in task["inputs"].items(): diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 6cfb92e0..5fd44293 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -52,7 +52,7 @@ def test_decorators_calcfunction_args(task_calcfunction) -> None: assert set(tdata["kwargs"]) == set(kwargs) assert tdata["var_args"] is None assert tdata["var_kwargs"] == "c" - assert n.outputs.keys() == ["result", "_outputs", "_wait"] + assert set(n.outputs.keys()) == set(["result", "_outputs", "_wait"]) @pytest.fixture(params=["decorator_factory", "decorator"]) @@ -123,7 +123,7 @@ def test_decorators_workfunction_args(task_workfunction) -> None: assert set(tdata["kwargs"]) == set(kwargs) assert tdata["var_args"] is None assert tdata["var_kwargs"] == "c" - assert n.outputs.keys() == ["result", "_outputs", "_wait"] + assert set(n.outputs.keys()) == set(["result", "_outputs", "_wait"]) def test_decorators_parameters() -> None: diff --git a/tests/test_task_from_workgraph.py b/tests/test_task_from_workgraph.py new file mode 100644 index 00000000..c6af4b6b --- /dev/null +++ b/tests/test_task_from_workgraph.py @@ -0,0 +1,44 @@ +import pytest +from aiida_workgraph import WorkGraph +from typing import Callable + + +def test_inputs_outptus(wg_calcfunction: WorkGraph) -> None: + """Test the inputs and outputs of the WorkGraph.""" + wg = WorkGraph(name="test_inputs_outptus") + task1 = wg.add_task(wg_calcfunction, name="add1") + ninput = 0 + for sub_task in wg_calcfunction.tasks: + # remove _wait, but add the namespace + ninput += len(sub_task.inputs) - 1 + 1 + noutput = 0 + for sub_task in wg_calcfunction.tasks: + noutput += len(sub_task.outputs) - 2 + 1 + assert len(task1.inputs) == ninput + 1 + assert len(task1.outputs) == noutput + 2 + assert "sumdiff1.x" in task1.inputs.keys() + assert "sumdiff1.sum" in task1.outputs.keys() + + +@pytest.mark.usefixtures("started_daemon_client") +def test_build_task_from_workgraph(decorated_add: Callable) -> None: + # create a sub workgraph + sub_wg = WorkGraph("build_task_from_workgraph") + sub_wg.add_task(decorated_add, name="add1", x=1, y=3) + sub_wg.add_task( + decorated_add, name="add2", x=2, y=sub_wg.tasks["add1"].outputs["result"] + ) + # + wg = WorkGraph("build_task_from_workgraph") + add1_task = wg.add_task(decorated_add, name="add1", x=1, y=3) + wg_task = wg.add_task(sub_wg, name="sub_wg") + # the default value of the namespace is None + assert wg_task.inputs["add1"].value is None + wg.add_task(decorated_add, name="add2", y=3) + wg.add_link(add1_task.outputs["result"], wg_task.inputs["add1.x"]) + wg.add_link(wg_task.outputs["add2.result"], wg.tasks["add2"].inputs["x"]) + assert len(wg_task.inputs) == 21 + assert len(wg_task.outputs) == 6 + wg.submit(wait=True) + # wg.run() + assert wg.tasks["add2"].outputs["result"].value.value == 12 diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 6915c03e..ddf31281 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -43,30 +43,6 @@ def test_task_collection(decorated_add: Callable) -> None: assert len(task1.waiting_on) == 0 -@pytest.mark.usefixtures("started_daemon_client") -def test_build_task_from_workgraph(decorated_add: Callable) -> None: - # create a sub workgraph - sub_wg = WorkGraph("build_task_from_workgraph") - sub_wg.add_task(decorated_add, name="add1", x=1, y=3) - sub_wg.add_task( - decorated_add, name="add2", x=2, y=sub_wg.tasks["add1"].outputs["result"] - ) - # - wg = WorkGraph("build_task_from_workgraph") - add1_task = wg.add_task(decorated_add, name="add1", x=1, y=3) - wg_task = wg.add_task(sub_wg, name="sub_wg") - # the default value of the namespace is None - assert wg_task.inputs["add1"].value is None - wg.add_task(decorated_add, name="add2", y=3) - wg.add_link(add1_task.outputs["result"], wg_task.inputs["add1.x"]) - wg.add_link(wg_task.outputs["add2.result"], wg.tasks["add2"].inputs["x"]) - assert len(wg_task.inputs) == 21 - assert len(wg_task.outputs) == 6 - wg.submit(wait=True) - # wg.run() - assert wg.tasks["add2"].outputs["result"].value.value == 12 - - @pytest.mark.usefixtures("started_daemon_client") def test_task_wait(decorated_add: Callable) -> None: """Run a WorkGraph with a task that waits on other tasks."""