diff --git a/src/aiida_workgraph/task.py b/src/aiida_workgraph/task.py index 9d61f642..4e9dc571 100644 --- a/src/aiida_workgraph/task.py +++ b/src/aiida_workgraph/task.py @@ -88,25 +88,38 @@ def set(self, data: Dict[str, Any]) -> None: from node_graph.socket import NodeSocket super().set(data) + + def process_nested_inputs(base_key: str, value: Any) -> None: + """Recursive function to process nested inputs. + Creates sockets and links dynamically for nested values. + """ + if isinstance(value, dict): + keys = list(value.keys()) + for sub_key in keys: + sub_value = value[sub_key] + # Form the full key for the current nested level + full_key = f"{base_key}.{sub_key}" if base_key else sub_key + + # Create a new input socket if it does not exist + if full_key not in self.inputs.keys(): + self.inputs.new( + "workgraph.any", + name=full_key, + metadata={"required": True}, + ) + + if isinstance(sub_value, NodeSocket): + self.parent.links.new(sub_value, self.inputs[full_key]) + value.pop(sub_key) + else: + # Recursively process nested dictionaries + process_nested_inputs(full_key, sub_value) + # create input sockets and links for items inside a dynamic socket # TODO the input value could be nested, but we only support one level for now - for key, value in data.items(): + for key in data: if self.inputs[key].metadata.get("dynamic", False): - if isinstance(value, dict): - keys = list(value.keys()) - for sub_key in keys: - # create a new input socket if it does not exist - if f"{key}.{sub_key}" not in self.inputs.keys(): - self.inputs.new( - "workgraph.any", - name=f"{key}.{sub_key}", - metadata={"required": True}, - ) - if isinstance(value[sub_key], NodeSocket): - self.parent.links.new( - value[sub_key], self.inputs[f"{key}.{sub_key}"] - ) - self.inputs[key].value.pop(sub_key) + process_nested_inputs(key, self.inputs[key].value) def set_from_builder(self, builder: Any) -> None: """Set the task inputs from a AiiDA ProcessBuilder.""" diff --git a/tests/test_tasks.py b/tests/test_tasks.py index ddf31281..75d70d6a 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -56,7 +56,28 @@ def test_task_wait(decorated_add: Callable) -> None: assert "tasks ready to run: add1" in report +def test_set_namespace_socket(decorated_add) -> None: + """Test setting the namespace of a task.""" + from .utils.test_workchain import WorkChainWithNestNamespace + + wg = WorkGraph(name="test_set_namespace") + task1 = wg.add_task(decorated_add) + task2 = wg.add_task(WorkChainWithNestNamespace) + task2.set( + { + "add": {"x": task1.outputs["result"], "y": orm.Int(2)}, + } + ) + assert len(task2.inputs["add.x"].links) == 1 + assert task2.inputs["add"].value == {"y": orm.Int(2)} + assert len(wg.links) == 1 + + def test_set_dynamic_port_input(decorated_add) -> None: + """Test setting dynamic port input of a task. + Use can pass AiiDA nodes as values of the dynamic port, + and the task will create the input for each item in the dynamic port. + """ from .utils.test_workchain import WorkChainWithDynamicNamespace wg = WorkGraph(name="test_set_dynamic_port_input") @@ -67,6 +88,7 @@ def test_set_dynamic_port_input(decorated_add) -> None: "input1": None, "input2": orm.Int(2), "input3": task1.outputs["result"], + "nested": {"input4": orm.Int(4), "input5": task1.outputs["result"]}, }, ) wg.add_link(task1.outputs["_wait"], task2.inputs["dynamic_port.input1"]) @@ -75,7 +97,14 @@ def test_set_dynamic_port_input(decorated_add) -> None: assert "dynamic_port.input2" in task2.inputs.keys() # if the value of the item is a Socket, then it will create a link, and pop the item assert "dynamic_port.input3" in task2.inputs.keys() - assert task2.inputs["dynamic_port"].value == {"input1": None, "input2": orm.Int(2)} + assert "dynamic_port.nested.input4" in task2.inputs.keys() + assert "dynamic_port.nested.input5" in task2.inputs.keys() + assert task2.inputs["dynamic_port"].value == { + "input1": None, + "input2": orm.Int(2), + "nested": {"input4": orm.Int(4)}, + } + assert len(wg.links) == 3 def test_set_inputs(decorated_add: Callable) -> None: diff --git a/tests/utils/test_workchain.py b/tests/utils/test_workchain.py index af702642..4e745a17 100644 --- a/tests/utils/test_workchain.py +++ b/tests/utils/test_workchain.py @@ -12,6 +12,8 @@ class WorkChainWithNestNamespace(WorkChain): def define(cls, spec): """Specify inputs and outputs.""" super().define(spec) + spec.input_namespace("non_dynamic_port") + spec.input("non_dynamic_port.a") spec.expose_inputs( ArithmeticAddCalculation, namespace="add",