From 7120c0a5c87b8956e3fcdcf34f638e2ef98ed74e Mon Sep 17 00:00:00 2001 From: superstar54 Date: Tue, 10 Dec 2024 10:29:45 +0100 Subject: [PATCH] Use type_mapping when create new socket --- src/aiida_workgraph/decorator.py | 19 +-------- src/aiida_workgraph/engine/task_manager.py | 10 +++-- src/aiida_workgraph/orm/mapping.py | 18 ++++++++ src/aiida_workgraph/socket.py | 7 +++- src/aiida_workgraph/task.py | 48 ++-------------------- src/aiida_workgraph/utils/__init__.py | 15 +------ src/aiida_workgraph/workgraph.py | 8 ++-- tests/test_tasks.py | 39 ++++++++++-------- tests/test_workchain.py | 2 +- tests/test_workgraph.py | 5 ++- tests/test_yaml.py | 10 ++--- 11 files changed, 70 insertions(+), 111 deletions(-) create mode 100644 src/aiida_workgraph/orm/mapping.py diff --git a/src/aiida_workgraph/decorator.py b/src/aiida_workgraph/decorator.py index 85c3812e..7bc27bd1 100644 --- a/src/aiida_workgraph/decorator.py +++ b/src/aiida_workgraph/decorator.py @@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, List, Optional, Union, Tuple from aiida_workgraph.utils import get_executor from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain -from aiida import orm from aiida.orm.nodes.process.calculation.calcfunction import CalcFunctionNode from aiida.orm.nodes.process.workflow.workfunction import WorkFunctionNode from aiida.engine.processes.ports import PortNamespace @@ -11,6 +10,8 @@ from aiida_workgraph.utils import build_callable, validate_task_inout import inspect from aiida_workgraph.config import builtin_inputs, builtin_outputs +from aiida_workgraph.orm.mapping import type_mapping + task_types = { CalcFunctionNode: "CALCFUNCTION", @@ -19,22 +20,6 @@ WorkChain: "WORKCHAIN", } -type_mapping = { - "default": "workgraph.any", - "namespace": "workgraph.namespace", - int: "workgraph.int", - float: "workgraph.float", - str: "workgraph.string", - bool: "workgraph.bool", - orm.Int: "workgraph.aiida_int", - orm.Float: "workgraph.aiida_float", - orm.Str: "workgraph.aiida_string", - orm.Bool: "workgraph.aiida_bool", - orm.List: "workgraph.aiida_list", - orm.Dict: "workgraph.aiida_dict", - orm.StructureData: "workgraph.aiida_structuredata", -} - def create_task(tdata): """Wrap create_node from node_graph to create a Task.""" diff --git a/src/aiida_workgraph/engine/task_manager.py b/src/aiida_workgraph/engine/task_manager.py index 19a0d395..2ef1f148 100644 --- a/src/aiida_workgraph/engine/task_manager.py +++ b/src/aiida_workgraph/engine/task_manager.py @@ -41,7 +41,7 @@ def get_task(self, name: str): for output in task.outputs: output.value = get_nested_dict( self.ctx._tasks[name]["results"], - output.name, + output.socket_name, default=output.value, ) return task @@ -734,7 +734,9 @@ def update_normal_task_state(self, name, results, success=True): """Set the results of a normal task. A normal task is created by decorating a function with @task(). """ - from aiida_workgraph.utils import get_sorted_names + from aiida_workgraph.config import builtin_outputs + + builtin_output_names = [output["name"] for output in builtin_outputs] if success: task = self.ctx._tasks[name] @@ -743,7 +745,9 @@ def update_normal_task_state(self, name, results, success=True): if len(task["outputs"]) - 2 != len(results): self.on_task_failed(name) return self.process.exit_codes.OUTPUS_NOT_MATCH_RESULTS - output_names = get_sorted_names(task["outputs"])[0:-2] + output_names = [ + name for name in task["outputs"] if name not in builtin_output_names + ] for i, output_name in enumerate(output_names): task["results"][output_name] = results[i] elif isinstance(results, dict): diff --git a/src/aiida_workgraph/orm/mapping.py b/src/aiida_workgraph/orm/mapping.py new file mode 100644 index 00000000..c6396447 --- /dev/null +++ b/src/aiida_workgraph/orm/mapping.py @@ -0,0 +1,18 @@ +from aiida import orm + + +type_mapping = { + "default": "workgraph.any", + "namespace": "workgraph.namespace", + int: "workgraph.int", + float: "workgraph.float", + str: "workgraph.string", + bool: "workgraph.bool", + orm.Int: "workgraph.aiida_int", + orm.Float: "workgraph.aiida_float", + orm.Str: "workgraph.aiida_string", + orm.Bool: "workgraph.aiida_bool", + orm.List: "workgraph.aiida_list", + orm.Dict: "workgraph.aiida_dict", + orm.StructureData: "workgraph.aiida_structuredata", +} diff --git a/src/aiida_workgraph/socket.py b/src/aiida_workgraph/socket.py index 2dc8345e..809d0995 100644 --- a/src/aiida_workgraph/socket.py +++ b/src/aiida_workgraph/socket.py @@ -1,9 +1,13 @@ from typing import Any, Type from aiida import orm -from node_graph.socket import NodeSocket, NodeSocketNamespace +from node_graph.socket import ( + NodeSocket, + NodeSocketNamespace, +) from aiida_workgraph.property import TaskProperty +from aiida_workgraph.orm.mapping import type_mapping class TaskSocket(NodeSocket): @@ -35,6 +39,7 @@ class TaskSocketNamespace(NodeSocketNamespace): _socket_identifier = "workgraph.namespace" _socket_property_class = TaskProperty + _type_mapping: dict = type_mapping def __init__(self, *args, **kwargs): super().__init__(*args, entry_point="aiida_workgraph.socket", **kwargs) diff --git a/src/aiida_workgraph/task.py b/src/aiida_workgraph/task.py index 6926c88a..58fb2a5a 100644 --- a/src/aiida_workgraph/task.py +++ b/src/aiida_workgraph/task.py @@ -83,48 +83,6 @@ def set_context(self, context: Dict[str, Any]) -> None: raise ValueError(msg) self.context_mapping.update(context) - def set(self, data: Dict[str, Any]) -> None: - from node_graph.socket import NodeSocket - - super().set(data) - - def process_nested_inputs( - base_key: str, value: Any, dynamic: bool = False - ) -> None: - """Recursive function to process nested inputs. - Creates sockets and links dynamically for nested values. - """ - if isinstance(value, dict): - keys = list(value.keys()) - for sub_key in keys: - sub_value = value[sub_key] - # Form the full key for the current nested level - full_key = f"{base_key}.{sub_key}" if base_key else sub_key - - # Create a new input socket if it does not exist - if full_key not in self.get_input_names() and dynamic: - self.add_input( - "workgraph.any", - name=full_key, - metadata={"required": True}, - ) - if isinstance(sub_value, NodeSocket): - self.parent.links.new(sub_value, self.inputs[full_key]) - value.pop(sub_key) - else: - # Recursively process nested dictionaries - process_nested_inputs(full_key, sub_value, dynamic) - - # create input sockets and links for items inside a dynamic socket - # TODO the input value could be nested, but we only support one level for now - for key in data: - if self.inputs[key]._socket_identifier == "workgraph.namespace": - process_nested_inputs( - key, - self.inputs[key].value, - dynamic=self.inputs[key].metadata.get("dynamic", False), - ) - def set_from_builder(self, builder: Any) -> None: """Set the task inputs from a AiiDA ProcessBuilder.""" from aiida_workgraph.utils import get_dict_from_builder @@ -230,7 +188,7 @@ def to_widget_value(self): for key in ("properties", "executor", "node_class", "process"): tdata.pop(key, None) for input in tdata["inputs"].values(): - input.pop("property") + input.pop("property", None) tdata["label"] = tdata["identifier"] @@ -289,9 +247,9 @@ def _normalize_tasks( task_objects = [] for task in tasks: if isinstance(task, str): - if task not in self.graph.tasks.keys(): + if task not in self.graph.tasks: raise ValueError( - f"Task '{task}' is not in the graph. Available tasks: {self.graph.tasks.keys()}" + f"Task '{task}' is not in the graph. Available tasks: {self.graph.tasks}" ) task_objects.append(self.graph.tasks[task]) elif isinstance(task, Task): diff --git a/src/aiida_workgraph/utils/__init__.py b/src/aiida_workgraph/utils/__init__.py index b8ee0ef6..7316d4f7 100644 --- a/src/aiida_workgraph/utils/__init__.py +++ b/src/aiida_workgraph/utils/__init__.py @@ -42,18 +42,6 @@ def build_callable(obj: Callable) -> Dict[str, Any]: return executor -def get_sorted_names(data: dict) -> list[str]: - """Get the sorted names from a dictionary.""" - sorted_names = [ - name - for name, _ in sorted( - ((name, item["list_index"]) for name, item in data.items()), - key=lambda x: x[1], - ) - ] - return sorted_names - - def store_nodes_recursely(data: Any) -> None: """Recurse through a data structure and store any unstored nodes that are found along the way :param data: a data structure potentially containing unstored nodes @@ -615,7 +603,6 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di :raises TypeError: If wrong types are provided to the task :return: Processed `inputs`/`outputs` list. """ - from node_graph.utils import list_to_dict if not all(isinstance(item, (dict, str)) for item in inout_list): raise TypeError( @@ -630,7 +617,7 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di elif isinstance(item, dict): processed_inout_list.append(item) - processed_inout_list = list_to_dict(processed_inout_list) + processed_inout_list = processed_inout_list return processed_inout_list diff --git a/src/aiida_workgraph/workgraph.py b/src/aiida_workgraph/workgraph.py index 4afd019c..cf379d41 100644 --- a/src/aiida_workgraph/workgraph.py +++ b/src/aiida_workgraph/workgraph.py @@ -294,12 +294,10 @@ def update(self) -> None: # even if the node.is_finished_ok is True if node.is_finished_ok: # update the output sockets - i = 0 for socket in self.tasks[name].outputs: - socket.value = get_nested_dict( + socket.socket_value = get_nested_dict( node.outputs, socket.socket_name, default=None ) - i += 1 # read results from the process outputs elif isinstance(node, aiida.orm.Data): self.tasks[name].outputs[0].value = node @@ -473,13 +471,13 @@ def extend(self, wg: "WorkGraph", prefix: str = "") -> None: for task in wg.tasks: task.name = prefix + task.name task.parent = self - self.tasks.append(task) + self.tasks._append(task) # self.sequence.extend([prefix + task for task in wg.sequence]) # self.conditions.extend(wg.conditions) self.context.update(wg.context) # links for link in wg.links: - self.links.append(link) + self.links._append(link) @property def error_handlers(self) -> Dict[str, Any]: diff --git a/tests/test_tasks.py b/tests/test_tasks.py index adf20f98..fd5a8d0d 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -18,7 +18,9 @@ def sum_diff(x, y): decorated_add, name="add", x=task1.outputs["sum"], y=task1.outputs["diff"] ) wg.run() - assert task2.outputs["result"].value == 4 + print("node: ", task2.node.outputs.result) + wg.update() + assert task2.outputs["result"].socket_value == 4 def test_task_collection(decorated_add: Callable) -> None: @@ -68,8 +70,8 @@ def test_set_non_dynamic_namespace_socket(decorated_add) -> None: "non_dynamic_port": {"a": task1.outputs["result"], "b": orm.Int(2)}, } ) - assert len(task2.inputs["non_dynamic_port.a"].links) == 1 - assert task2.inputs["non_dynamic_port"].value == {"b": orm.Int(2)} + assert len(task2.inputs["non_dynamic_port.a"].socket_links) == 1 + assert task2.inputs["non_dynamic_port"].socket_value == {"b": orm.Int(2)} assert len(wg.links) == 1 @@ -85,8 +87,12 @@ def test_set_namespace_socket(decorated_add) -> None: "add": {"x": task1.outputs["result"], "y": orm.Int(2)}, } ) - assert len(task2.inputs["add.x"].links) == 1 - assert task2.inputs["add"].value == {"y": orm.Int(2)} + assert len(task2.inputs["add.x"].socket_links) == 1 + assert task2.inputs["add"].socket_value == { + "metadata": {"options": {"stash": {}}}, + "monitors": {}, + "y": orm.Int(2), + } assert len(wg.links) == 1 @@ -110,14 +116,13 @@ def test_set_dynamic_port_input(decorated_add) -> None: ) wg.add_link(task1.outputs["_wait"], task2.inputs["dynamic_port.input1"]) # task will create input for each item in the dynamic port (nodes) - assert "dynamic_port.input1" in task2.get_input_names() - assert "dynamic_port.input2" in task2.get_input_names() + assert "dynamic_port.input1" in task2.inputs + assert "dynamic_port.input2" in task2.inputs # if the value of the item is a Socket, then it will create a link, and pop the item - assert "dynamic_port.input3" in task2.get_input_names() - assert "dynamic_port.nested.input4" in task2.get_input_names() - assert "dynamic_port.nested.input5" in task2.get_input_names() - assert task2.inputs["dynamic_port"].value == { - "input1": None, + assert "dynamic_port.input3" in task2.inputs + assert "dynamic_port.nested.input4" in task2.inputs + assert "dynamic_port.nested.input5" in task2.inputs + assert task2.inputs["dynamic_port"].socket_value == { "input2": orm.Int(2), "nested": {"input4": orm.Int(4)}, } @@ -133,9 +138,7 @@ def test_set_inputs(decorated_add: Callable) -> None: data = wg.prepare_inputs(metadata=None) assert data["wg"]["tasks"]["add1"]["inputs"]["y"]["property"]["value"] == 2 assert ( - data["wg"]["tasks"]["add1"]["inputs"]["metadata"]["property"]["value"][ - "store_provenance" - ] + data["wg"]["tasks"]["add1"]["inputs"]["metadata"]["value"]["store_provenance"] is False ) @@ -151,9 +154,9 @@ def test_set_inputs_from_builder(add_code) -> None: builder.x = 1 builder.y = 2 add1.set_from_builder(builder) - assert add1.inputs["x"].value == 1 - assert add1.inputs["y"].value == 2 - assert add1.inputs["code"].value == add_code + assert add1.inputs["x"].socket_value == 1 + assert add1.inputs["y"].socket_value == 2 + assert add1.inputs["code"].socket_value == add_code with pytest.raises( AttributeError, match=f"Executor {ArithmeticAddCalculation.__name__} does not have the get_builder_from_protocol method.", diff --git a/tests/test_workchain.py b/tests/test_workchain.py index 85f59b1f..164cddb3 100644 --- a/tests/test_workchain.py +++ b/tests/test_workchain.py @@ -19,7 +19,7 @@ def test_build_workchain_inputs_outputs(): node = build_task(MultiplyAddWorkChain)() inputs = MultiplyAddWorkChain.spec().inputs # inputs + metadata + _wait - ninput = len(inputs.ports) + len(inputs.ports["metadata"].ports) + 1 + ninput = len(inputs.ports) + 1 assert len(node.inputs) == ninput assert len(node.outputs) == 3 diff --git a/tests/test_workgraph.py b/tests/test_workgraph.py index e355aa33..2357d8aa 100644 --- a/tests/test_workgraph.py +++ b/tests/test_workgraph.py @@ -96,11 +96,12 @@ def test_organize_nested_inputs(): data = { "metadata": { "call_link_label": "nest", - "options": {"resources": {"num_cpus": 1, "num_machines": 1}}, + "options": {"resources": {"num_machines": 1}, "stash": {}}, }, + "monitors": {}, "x": "1", } - assert inputs["wg"]["tasks"]["task1"]["inputs"]["add"]["property"]["value"] == data + assert inputs["wg"]["tasks"]["task1"]["inputs"]["add"]["value"] == data @pytest.mark.usefixtures("started_daemon_client") diff --git a/tests/test_yaml.py b/tests/test_yaml.py index 8a9bb520..07a5864b 100644 --- a/tests/test_yaml.py +++ b/tests/test_yaml.py @@ -7,11 +7,11 @@ def test_calcfunction(): wg = WorkGraph.from_yaml(os.path.join(cwd, "datas/test_calcfunction.yaml")) - assert wg.tasks["float1"].inputs["value"].value == 3.0 - assert wg.tasks["sumdiff1"].inputs["x"].value == 2.0 - assert wg.tasks["sumdiff2"].inputs["x"].value == 4.0 + assert wg.tasks.float1.inputs.value.socket_value == 3.0 + assert wg.tasks.sumdiff1.inputs.x.socket_value == 2.0 + assert wg.tasks.sumdiff2.inputs.x.socket_value == 4.0 wg.run() - assert wg.tasks["sumdiff2"].node.outputs.sum == 9 + assert wg.tasks.sumdiff2.node.outputs.sum == 9 # skip this test for now @@ -19,4 +19,4 @@ def test_calcfunction(): def test_calcjob(): wg = WorkGraph.from_yaml(os.path.join(cwd, "datas/test_calcjob.yaml")) wg.submit(wait=True) - assert wg.tasks["add2"].node.outputs.sum == 9 + assert wg.tasks.add2.node.outputs.sum == 9