From c74bf4e76b1339e3defbcd4993e31aa0da85a0cf Mon Sep 17 00:00:00 2001 From: superstar54 Date: Tue, 10 Dec 2024 06:03:37 +0100 Subject: [PATCH] rename socket's attribute's name --- docs/gallery/howto/autogen/aggregate.py | 6 +-- pyproject.toml | 2 +- src/aiida_workgraph/collection.py | 34 ------------ src/aiida_workgraph/decorator.py | 4 +- src/aiida_workgraph/engine/task_manager.py | 12 +++-- src/aiida_workgraph/socket.py | 22 +++++--- src/aiida_workgraph/sockets/builtins.py | 63 ++++++++++------------ src/aiida_workgraph/task.py | 7 ++- src/aiida_workgraph/tasks/monitors.py | 6 +-- src/aiida_workgraph/utils/graph.py | 8 +-- tests/test_decorator.py | 58 ++++++++++---------- tests/test_link.py | 2 +- tests/test_pythonjob.py | 2 +- tests/test_socket.py | 19 ++++--- 14 files changed, 107 insertions(+), 138 deletions(-) diff --git a/docs/gallery/howto/autogen/aggregate.py b/docs/gallery/howto/autogen/aggregate.py index fc3138fc..770c9172 100644 --- a/docs/gallery/howto/autogen/aggregate.py +++ b/docs/gallery/howto/autogen/aggregate.py @@ -109,7 +109,7 @@ def aggregate( aggregate_task = wg.add_task(aggregate, name="aggregate_task") # we have to increase the link limit because by default workgraph only supports one link per input socket -aggregate_task.inputs["collected_values"].link_limit = 50 +aggregate_task.inputs["collected_values"].socket_link_limit = 50 for i in range(2): # this can be chosen as wanted generator_task = wg.add_task(generator, name=f"generator{i}", seed=Int(i)) @@ -188,8 +188,8 @@ def aggregate(**collected_values): # we have to increase the link limit because by default workgraph only supports # one link per input socket. -aggregate_task.inputs["collected_ints"].link_limit = 50 -aggregate_task.inputs["collected_floats"].link_limit = 50 +aggregate_task.inputs["collected_ints"].socket_link_limit = 50 +aggregate_task.inputs["collected_floats"].socket_link_limit = 50 for i in range(2): # this can be chosen as wanted diff --git a/pyproject.toml b/pyproject.toml index 3679dde5..e01006ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,7 +127,7 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" [project.entry-points."aiida_workgraph.socket"] "workgraph.any" = "aiida_workgraph.sockets.builtins:SocketAny" -"workgraph.namespace" = "aiida_workgraph.sockets.builtins:SocketNamespace" +"workgraph.namespace" = "aiida_workgraph.socket:TaskSocketNamespace" "workgraph.int" = "aiida_workgraph.sockets.builtins:SocketInt" "workgraph.float" = "aiida_workgraph.sockets.builtins:SocketFloat" "workgraph.string" = "aiida_workgraph.sockets.builtins:SocketString" diff --git a/src/aiida_workgraph/collection.py b/src/aiida_workgraph/collection.py index 5e657c99..e73dbf59 100644 --- a/src/aiida_workgraph/collection.py +++ b/src/aiida_workgraph/collection.py @@ -1,8 +1,6 @@ from node_graph.collection import ( NodeCollection, PropertyCollection, - InputSocketCollection, - OutputSocketCollection, ) from typing import Any, Callable, Optional, Union @@ -60,35 +58,3 @@ def _new( identifier = build_property_from_AiiDA(identifier) # Call the original new method return super()._new(identifier, name, **kwargs) - - -class WorkGraphInputSocketCollection(InputSocketCollection): - def _new( - self, - identifier: Union[Callable, str], - name: Optional[str] = None, - **kwargs: Any - ) -> Any: - from aiida_workgraph.socket import build_socket_from_AiiDA - - # build the socket on the fly if the identifier is a callable - if callable(identifier): - identifier = build_socket_from_AiiDA(identifier) - # Call the original new method - return super()._new(identifier, name, **kwargs) - - -class WorkGraphOutputSocketCollection(OutputSocketCollection): - def _new( - self, - identifier: Union[Callable, str], - name: Optional[str] = None, - **kwargs: Any - ) -> Any: - from aiida_workgraph.socket import build_socket_from_AiiDA - - # build the socket on the fly if the identifier is a callable - if callable(identifier): - identifier = build_socket_from_AiiDA(identifier) - # Call the original new method - return super()._new(identifier, name, **kwargs) diff --git a/src/aiida_workgraph/decorator.py b/src/aiida_workgraph/decorator.py index 0ab2d2df..626c2023 100644 --- a/src/aiida_workgraph/decorator.py +++ b/src/aiida_workgraph/decorator.py @@ -69,7 +69,6 @@ def add_input_recursive( "name": port_name, "arg_type": "kwargs", "metadata": {"required": required, "dynamic": port.dynamic}, - "property": {"identifier": "workgraph.any", "default": None}, } ) for value in port.values(): @@ -248,11 +247,10 @@ def build_task_from_AiiDA( if name not in [input["name"] for input in inputs]: inputs.append( { - "identifier": "workgraph.any", + "identifier": "workgraph.namespace", "name": name, "arg_type": "var_kwargs", "metadata": {"dynamic": True}, - "property": {"identifier": "workgraph.any", "default": {}}, } ) diff --git a/src/aiida_workgraph/engine/task_manager.py b/src/aiida_workgraph/engine/task_manager.py index d0175dfe..19a0d395 100644 --- a/src/aiida_workgraph/engine/task_manager.py +++ b/src/aiida_workgraph/engine/task_manager.py @@ -249,6 +249,7 @@ def run_tasks(self, names: List[str], continue_workgraph: bool = True) -> None: kwargs[key] = args[i] # update the port namespace kwargs = update_nested_dict_with_special_keys(kwargs) + print("kwargs: ", kwargs) # kwargs["meta.label"] = name # output must be a Data type or a mapping of {string: Data} task["results"] = {} @@ -592,9 +593,14 @@ def get_inputs( for name, input in task["inputs"].items(): # print(f"input: {input['name']}") if len(input["links"]) == 0: - inputs[name] = self.ctx_manager.update_context_variable( - input["property"]["value"] - ) + if input["identifier"] == "workgraph.namespace": + inputs[name] = self.ctx_manager.update_context_variable( + input["value"] + ) + else: + inputs[name] = self.ctx_manager.update_context_variable( + input["property"]["value"] + ) elif len(input["links"]) == 1: link = input["links"][0] if self.ctx._tasks[link["from_node"]]["results"] is None: diff --git a/src/aiida_workgraph/socket.py b/src/aiida_workgraph/socket.py index ff9f677a..2dc8345e 100644 --- a/src/aiida_workgraph/socket.py +++ b/src/aiida_workgraph/socket.py @@ -1,7 +1,7 @@ from typing import Any, Type from aiida import orm -from node_graph.socket import NodeSocket +from node_graph.socket import NodeSocket, NodeSocketNamespace from aiida_workgraph.property import TaskProperty @@ -11,7 +11,7 @@ class TaskSocket(NodeSocket): # use TaskProperty from aiida_workgraph.property # to override the default NodeProperty from node_graph - node_property = TaskProperty + _socket_property_class = TaskProperty @property def node_value(self): @@ -19,15 +19,25 @@ def node_value(self): def get_node_value(self): """Obtain the actual Python `value` of the object attached to the Socket.""" - if isinstance(self.value, orm.Data): - if hasattr(self.value, "value"): - return self.value.value + if isinstance(self.socket_value, orm.Data): + if hasattr(self.socket_value, "value"): + return self.socket_value.value else: raise ValueError( "Data node does not have a value attribute. We do not know how to extract the raw Python value." ) else: - return self.value + return self.socket_value + + +class TaskSocketNamespace(NodeSocketNamespace): + """Represent a namespace of a Task in the AiiDA WorkGraph.""" + + _socket_identifier = "workgraph.namespace" + _socket_property_class = TaskProperty + + def __init__(self, *args, **kwargs): + super().__init__(*args, entry_point="aiida_workgraph.socket", **kwargs) def build_socket_from_AiiDA(DataClass: Type[Any]) -> Type[TaskSocket]: diff --git a/src/aiida_workgraph/sockets/builtins.py b/src/aiida_workgraph/sockets/builtins.py index 3d31e5a2..981923ed 100644 --- a/src/aiida_workgraph/sockets/builtins.py +++ b/src/aiida_workgraph/sockets/builtins.py @@ -4,103 +4,96 @@ class SocketAny(TaskSocket): """Any socket.""" - identifier: str = "workgraph.any" - property_identifier: str = "workgraph.any" - - -class SocketNamespace(TaskSocket): - """Namespace socket.""" - - identifier: str = "workgraph.namespace" - property_identifier: str = "workgraph.any" + _socket_identifier: str = "workgraph.any" + _socket_property_identifier: str = "workgraph.any" class SocketFloat(TaskSocket): """Float socket.""" - identifier: str = "workgraph.float" - property_identifier: str = "workgraph.float" + _socket_identifier: str = "workgraph.float" + _socket_property_identifier: str = "workgraph.float" class SocketInt(TaskSocket): """Int socket.""" - identifier: str = "workgraph.int" - property_identifier: str = "workgraph.int" + _socket_identifier: str = "workgraph.int" + _socket_property_identifier: str = "workgraph.int" class SocketString(TaskSocket): """String socket.""" - identifier: str = "workgraph.string" - property_identifier: str = "workgraph.string" + _socket_identifier: str = "workgraph.string" + _socket_property_identifier: str = "workgraph.string" class SocketBool(TaskSocket): """Bool socket.""" - identifier: str = "workgraph.bool" - property_identifier: str = "workgraph.bool" + _socket_identifier: str = "workgraph.bool" + _socket_property_identifier: str = "workgraph.bool" class SocketAiiDAFloat(TaskSocket): """AiiDAFloat socket.""" - identifier: str = "workgraph.aiida_float" - property_identifier: str = "workgraph.aiida_float" + _socket_identifier: str = "workgraph.aiida_float" + _socket_property_identifier: str = "workgraph.aiida_float" class SocketAiiDAInt(TaskSocket): """AiiDAInt socket.""" - identifier: str = "workgraph.aiida_int" - property_identifier: str = "workgraph.aiida_int" + _socket_identifier: str = "workgraph.aiida_int" + _socket_property_identifier: str = "workgraph.aiida_int" class SocketAiiDAString(TaskSocket): """AiiDAString socket.""" - identifier: str = "workgraph.aiida_string" - property_identifier: str = "workgraph.aiida_string" + _socket_identifier: str = "workgraph.aiida_string" + _socket_property_identifier: str = "workgraph.aiida_string" class SocketAiiDABool(TaskSocket): """AiiDABool socket.""" - identifier: str = "workgraph.aiida_bool" - property_identifier: str = "workgraph.aiida_bool" + _socket_identifier: str = "workgraph.aiida_bool" + _socket_property_identifier: str = "workgraph.aiida_bool" class SocketAiiDAList(TaskSocket): """AiiDAList socket.""" - identifier: str = "workgraph.aiida_list" - property_identifier: str = "workgraph.aiida_list" + _socket_identifier: str = "workgraph.aiida_list" + _socket_property_identifier: str = "workgraph.aiida_list" class SocketAiiDADict(TaskSocket): """AiiDADict socket.""" - identifier: str = "workgraph.aiida_dict" - property_identifier: str = "workgraph.aiida_dict" + _socket_identifier: str = "workgraph.aiida_dict" + _socket_property_identifier: str = "workgraph.aiida_dict" class SocketAiiDAIntVector(TaskSocket): """Socket with a AiiDAIntVector property.""" - identifier: str = "workgraph.aiida_int_vector" - property_identifier: str = "workgraph.aiida_int_vector" + _socket_identifier: str = "workgraph.aiida_int_vector" + _socket_property_identifier: str = "workgraph.aiida_int_vector" class SocketAiiDAFloatVector(TaskSocket): """Socket with a FloatVector property.""" - identifier: str = "workgraph.aiida_float_vector" - property_identifier: str = "workgraph.aiida_float_vector" + _socket_identifier: str = "workgraph.aiida_float_vector" + _socket_property_identifier: str = "workgraph.aiida_float_vector" class SocketStructureData(TaskSocket): """Any socket.""" - identifier: str = "workgraph.aiida_structuredata" - property_identifier: str = "workgraph.aiida_structuredata" + _socket_identifier: str = "workgraph.aiida_structuredata" + _socket_property_identifier: str = "workgraph.aiida_structuredata" diff --git a/src/aiida_workgraph/task.py b/src/aiida_workgraph/task.py index b0abf062..2b550ffc 100644 --- a/src/aiida_workgraph/task.py +++ b/src/aiida_workgraph/task.py @@ -4,11 +4,10 @@ from aiida_workgraph.properties import property_pool from aiida_workgraph.sockets import socket_pool +from aiida_workgraph.socket import NodeSocketNamespace from node_graph_widget import NodeGraphWidget from aiida_workgraph.collection import ( WorkGraphPropertyCollection, - WorkGraphInputSocketCollection, - WorkGraphOutputSocketCollection, ) import aiida from typing import Any, Dict, Optional, Union, Callable, List, Set, Iterable @@ -38,8 +37,8 @@ def __init__( """ super().__init__( property_collection_class=WorkGraphPropertyCollection, - input_collection_class=WorkGraphInputSocketCollection, - output_collection_class=WorkGraphOutputSocketCollection, + input_collection_class=NodeSocketNamespace, + output_collection_class=NodeSocketNamespace, **kwargs, ) self.context_mapping = {} if context_mapping is None else context_mapping diff --git a/src/aiida_workgraph/tasks/monitors.py b/src/aiida_workgraph/tasks/monitors.py index 63ded5e4..029741c3 100644 --- a/src/aiida_workgraph/tasks/monitors.py +++ b/src/aiida_workgraph/tasks/monitors.py @@ -23,7 +23,7 @@ def create_sockets(self) -> None: inp = self.add_input("workgraph.any", "timeout") inp.add_property("workgraph.any", default=86400.0) self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) - inp.link_limit = 100000 + inp.socket_link_limit = 100000 self.add_output("workgraph.any", "result") self.add_output("workgraph.any", "_wait") @@ -50,7 +50,7 @@ def create_sockets(self) -> None: inp = self.add_input("workgraph.any", "timeout") inp.add_property("workgraph.any", default=86400.0) self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) - inp.link_limit = 100000 + inp.socket_link_limit = 100000 self.add_output("workgraph.any", "result") self.add_output("workgraph.any", "_wait") @@ -79,6 +79,6 @@ def create_sockets(self) -> None: inp = self.add_input("workgraph.any", "timeout") inp.add_property("workgraph.any", default=86400.0) self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) - inp.link_limit = 100000 + inp.socket_link_limit = 100000 self.add_output("workgraph.any", "result") self.add_output("workgraph.any", "_wait") diff --git a/src/aiida_workgraph/utils/graph.py b/src/aiida_workgraph/utils/graph.py index 2e3b668f..0ac91cdb 100644 --- a/src/aiida_workgraph/utils/graph.py +++ b/src/aiida_workgraph/utils/graph.py @@ -45,9 +45,9 @@ def link_creation_hook(self, link: Any) -> None: "type": "add_link", "data": { "from_node": link.from_node.name, - "from_socket": link.from_socket.name, + "from_socket": link.from_socket.socket_name, "to_node": link.to_node.name, - "to_socket": link.to_socket.name, + "to_socket": link.to_socket.socket_name, }, } ) @@ -65,9 +65,9 @@ def link_deletion_hook(self, link: Any) -> None: "type": "delete_link", "data": { "from_node": link.from_node.name, - "from_socket": link.from_socket.name, + "from_socket": link.from_socket.socket_name, "to_node": link.to_node.name, - "to_socket": link.to_socket.name, + "to_socket": link.to_socket.socket_name, }, } ) diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 38e79912..5bbb995a 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -1,5 +1,6 @@ import pytest from aiida_workgraph import WorkGraph, task +from aiida_workgraph.socket import TaskSocketNamespace from typing import Callable @@ -15,44 +16,42 @@ def add_multiply(x, y): assert "product" in n.outputs -@pytest.fixture(params=["decorator_factory", "decorator"]) -def task_calcfunction(request): - if request.param == "decorator_factory": - - @task.calcfunction() - def test(a, b=1, **c): - print(a, b, c) - - elif request.param == "decorator": +def test_decorators_args() -> None: + @task() + def test(a, b=1, **c): + print(a, b, c) - @task.calcfunction - def test(a, b=1, **c): - print(a, b, c) + n = test.task() + tdata = n.to_dict() + assert tdata["args"] == [] + assert set(tdata["kwargs"]) == set(["a", "b"]) + assert tdata["var_args"] is None + assert tdata["var_kwargs"] == "c" + assert set(n.get_output_names()) == set(["result", "_outputs", "_wait"]) + assert isinstance(n.inputs.c, TaskSocketNamespace) - else: - raise ValueError(f"{request.param} not supported.") - return test +def test_decorators_calcfunction_args() -> None: + @task.calcfunction() + def test(a, b=1, **c): + print(a, b, c) -def test_decorators_calcfunction_args(task_calcfunction) -> None: metadata_kwargs = set( [ - f"metadata.{key}" - for key in task_calcfunction.process_class.spec() - .inputs.ports["metadata"] - .ports.keys() + f"{key}" + for key in test.process_class.spec().inputs.ports["metadata"].ports.keys() ] ) - kwargs = set(task_calcfunction.process_class.spec().inputs.ports.keys()).union( - metadata_kwargs - ) - n = task_calcfunction.task() + kwargs = set(test.process_class.spec().inputs.ports.keys()) + n = test.task() tdata = n.to_dict() assert tdata["args"] == [] assert set(tdata["kwargs"]) == set(kwargs) assert tdata["var_args"] is None assert tdata["var_kwargs"] == "c" assert set(n.get_output_names()) == set(["result", "_outputs", "_wait"]) + assert isinstance(n.inputs.c, TaskSocketNamespace) + assert set(n.inputs.metadata._keys()) == metadata_kwargs @pytest.fixture(params=["decorator_factory", "decorator"]) @@ -107,15 +106,13 @@ def test(a, b=1, **c): def test_decorators_workfunction_args(task_workfunction) -> None: metadata_kwargs = set( [ - f"metadata.{key}" + f"{key}" for key in task_workfunction.process_class.spec() .inputs.ports["metadata"] .ports.keys() ] ) - kwargs = set(task_workfunction.process_class.spec().inputs.ports.keys()).union( - metadata_kwargs - ) + kwargs = set(task_workfunction.process_class.spec().inputs.ports.keys()) # n = task_workfunction.task() tdata = n.to_dict() @@ -124,6 +121,7 @@ def test_decorators_workfunction_args(task_workfunction) -> None: assert tdata["var_args"] is None assert tdata["var_kwargs"] == "c" assert set(n.get_output_names()) == set(["result", "_outputs", "_wait"]) + assert set(n.inputs.metadata._keys()) == metadata_kwargs def test_decorators_parameters() -> None: @@ -137,7 +135,7 @@ def test(a, b=1, **c): return {"sum": a + b, "product": a * b} test1 = test.task() - assert test1.inputs["c"].link_limit == 1000 + assert test1.inputs["c"].socket_link_limit == 1000 assert "sum" in test1.get_output_names() assert "product" in test1.get_output_names() @@ -183,7 +181,7 @@ def test_inputs_outputs_workchain() -> None: wg = WorkGraph() task = wg.add_task(MultiplyAddWorkChain) assert "metadata" in task.get_input_names() - assert "metadata.call_link_label" in task.get_input_names() + assert "call_link_label" in task.inputs.metadata._keys() assert "result" in task.get_output_names() diff --git a/tests/test_link.py b/tests/test_link.py index e7f2a3ba..244c86b2 100644 --- a/tests/test_link.py +++ b/tests/test_link.py @@ -20,7 +20,7 @@ def sum(**datas): float2 = wg.add_task("workgraph.aiida_node", pk=Float(2.0).store().pk) float3 = wg.add_task("workgraph.aiida_node", pk=Float(3.0).store().pk) sum1 = wg.add_task(sum, "sum1") - sum1.inputs["datas"].link_limit = 100 + sum1.inputs["datas"].socket_link_limit = 100 wg.add_link(float1.outputs[0], sum1.inputs["datas"]) wg.add_link(float2.outputs[0], sum1.inputs["datas"]) wg.add_link(float3.outputs[0], sum1.inputs["datas"]) diff --git a/tests/test_pythonjob.py b/tests/test_pythonjob.py index d7793e22..8fe7369b 100644 --- a/tests/test_pythonjob.py +++ b/tests/test_pythonjob.py @@ -71,7 +71,7 @@ def add(x, y=1, **kwargs): }, ) # data inside the kwargs should be serialized separately - wg.process.inputs.wg.tasks.add.inputs.kwargs.property.value.m.value == 2 + wg.process.inputs.wg.tasks.add.inputs.kwargs.socket_property.value.m.value == 2 assert wg.tasks["add"].outputs["result"].value.value == 8 # load the workgraph wg = WorkGraph.load(wg.pk) diff --git a/tests/test_socket.py b/tests/test_socket.py index d9192854..d96ed634 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -30,8 +30,8 @@ def test_type_mapping(data_type, data, identifier) -> None: def add(x: data_type): pass - assert add.task().inputs["x"].identifier == identifier - assert add.task().inputs["x"].property.identifier == identifier + assert add.task().inputs["x"]._socket_identifier == identifier + assert add.task().inputs["x"].socket_property.identifier == identifier add_task = add.task() add_task.set({"x": data}) # test set data from context @@ -48,14 +48,14 @@ def test_vector_socket() -> None: "vector2d", property_data={"size": 2, "default": [1, 2]}, ) - assert t.inputs["vector2d"].property.get_metadata() == { + assert t.inputs["vector2d"].socket_property.get_metadata() == { "size": 2, "default": [1, 2], } with pytest.raises(ValueError, match="Invalid size: Expected 2, got 3 instead."): - t.inputs["vector2d"].value = [1, 2, 3] + t.inputs["vector2d"].socket_value = [1, 2, 3] with pytest.raises(ValueError, match="Invalid item type: Expected "): - t.inputs["vector2d"].value = [1.1, 2.2] + t.inputs["vector2d"].socket_value = [1.1, 2.2] def test_aiida_data_socket() -> None: @@ -71,8 +71,8 @@ def test_aiida_data_socket() -> None: def add(x: data_type): pass - assert add.task().inputs["x"].identifier == identifier - assert add.task().inputs["x"].property.identifier == identifier + assert add.task().inputs["x"]._socket_identifier == identifier + assert add.task().inputs["x"].socket_property.identifier == identifier add_task = add.task() add_task.set({"x": data}) # test set data from context @@ -129,9 +129,8 @@ def test(a, b=1, **kwargs): return {"sum": a + b, "product": a * b} test1 = test.node() - assert test1.inputs["kwargs"].link_limit == 1e6 - assert test1.inputs["kwargs"].identifier == "workgraph.namespace" - assert test1.inputs["kwargs"].property.value is None + assert test1.inputs["kwargs"].socket_link_limit == 1e6 + assert test1.inputs["kwargs"]._socket_identifier == "workgraph.namespace" @pytest.mark.parametrize(