From 940e34463bd934cd566aaa5993b528baf686fdf9 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Fri, 6 Dec 2024 14:08:38 +0100 Subject: [PATCH] Add test for setting inputs for non dynamic namespace --- src/aiida_workgraph/task.py | 17 +++++++++++------ tests/test_tasks.py | 17 +++++++++++++++++ tests/utils/test_workchain.py | 2 ++ 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/src/aiida_workgraph/task.py b/src/aiida_workgraph/task.py index 4e9dc571..882be22c 100644 --- a/src/aiida_workgraph/task.py +++ b/src/aiida_workgraph/task.py @@ -89,7 +89,9 @@ def set(self, data: Dict[str, Any]) -> None: super().set(data) - def process_nested_inputs(base_key: str, value: Any) -> None: + def process_nested_inputs( + base_key: str, value: Any, dynamic: bool = False + ) -> None: """Recursive function to process nested inputs. Creates sockets and links dynamically for nested values. """ @@ -101,25 +103,28 @@ def process_nested_inputs(base_key: str, value: Any) -> None: full_key = f"{base_key}.{sub_key}" if base_key else sub_key # Create a new input socket if it does not exist - if full_key not in self.inputs.keys(): + if full_key not in self.inputs.keys() and dynamic: self.inputs.new( "workgraph.any", name=full_key, metadata={"required": True}, ) - if isinstance(sub_value, NodeSocket): self.parent.links.new(sub_value, self.inputs[full_key]) value.pop(sub_key) else: # Recursively process nested dictionaries - process_nested_inputs(full_key, sub_value) + process_nested_inputs(full_key, sub_value, dynamic) # create input sockets and links for items inside a dynamic socket # TODO the input value could be nested, but we only support one level for now for key in data: - if self.inputs[key].metadata.get("dynamic", False): - process_nested_inputs(key, self.inputs[key].value) + if self.inputs[key].identifier == "workgraph.namespace": + process_nested_inputs( + key, + self.inputs[key].value, + dynamic=self.inputs[key].metadata.get("dynamic", False), + ) def set_from_builder(self, builder: Any) -> None: """Set the task inputs from a AiiDA ProcessBuilder.""" diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 75d70d6a..4044ecee 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -56,6 +56,23 @@ def test_task_wait(decorated_add: Callable) -> None: assert "tasks ready to run: add1" in report +def test_set_non_dynamic_namespace_socket(decorated_add) -> None: + """Test setting the namespace of a task.""" + from .utils.test_workchain import WorkChainWithNestNamespace + + wg = WorkGraph(name="test_set_namespace") + task1 = wg.add_task(decorated_add) + task2 = wg.add_task(WorkChainWithNestNamespace) + task2.set( + { + "non_dynamic_port": {"a": task1.outputs["result"], "b": orm.Int(2)}, + } + ) + assert len(task2.inputs["non_dynamic_port.a"].links) == 1 + assert task2.inputs["non_dynamic_port"].value == {"b": orm.Int(2)} + assert len(wg.links) == 1 + + def test_set_namespace_socket(decorated_add) -> None: """Test setting the namespace of a task.""" from .utils.test_workchain import WorkChainWithNestNamespace diff --git a/tests/utils/test_workchain.py b/tests/utils/test_workchain.py index 4e745a17..bbac9711 100644 --- a/tests/utils/test_workchain.py +++ b/tests/utils/test_workchain.py @@ -13,6 +13,8 @@ def define(cls, spec): """Specify inputs and outputs.""" super().define(spec) spec.input_namespace("non_dynamic_port") + spec.input("non_dynamic_port.a", valid_type=Int) + spec.input("non_dynamic_port.b", valid_type=Int) spec.input("non_dynamic_port.a") spec.expose_inputs( ArithmeticAddCalculation,