Skip to content

Commit

Permalink
Recursively process nested dictionaries when setting a dynamic port s…
Browse files Browse the repository at this point in the history
…ocket (#384)

Recursively process nested dictionaries, to create input sockets and links for items inside a dynamic socket,
  • Loading branch information
superstar54 authored Dec 6, 2024
1 parent 0f42d53 commit 1fb3dc9
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 17 deletions.
45 changes: 29 additions & 16 deletions src/aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,25 +88,38 @@ def set(self, data: Dict[str, Any]) -> None:
from node_graph.socket import NodeSocket

super().set(data)

def process_nested_inputs(base_key: str, value: Any) -> None:
"""Recursive function to process nested inputs.
Creates sockets and links dynamically for nested values.
"""
if isinstance(value, dict):
keys = list(value.keys())
for sub_key in keys:
sub_value = value[sub_key]
# Form the full key for the current nested level
full_key = f"{base_key}.{sub_key}" if base_key else sub_key

# Create a new input socket if it does not exist
if full_key not in self.inputs.keys():
self.inputs.new(
"workgraph.any",
name=full_key,
metadata={"required": True},
)

if isinstance(sub_value, NodeSocket):
self.parent.links.new(sub_value, self.inputs[full_key])
value.pop(sub_key)
else:
# Recursively process nested dictionaries
process_nested_inputs(full_key, sub_value)

# create input sockets and links for items inside a dynamic socket
# TODO the input value could be nested, but we only support one level for now
for key, value in data.items():
for key in data:
if self.inputs[key].metadata.get("dynamic", False):
if isinstance(value, dict):
keys = list(value.keys())
for sub_key in keys:
# create a new input socket if it does not exist
if f"{key}.{sub_key}" not in self.inputs.keys():
self.inputs.new(
"workgraph.any",
name=f"{key}.{sub_key}",
metadata={"required": True},
)
if isinstance(value[sub_key], NodeSocket):
self.parent.links.new(
value[sub_key], self.inputs[f"{key}.{sub_key}"]
)
self.inputs[key].value.pop(sub_key)
process_nested_inputs(key, self.inputs[key].value)

def set_from_builder(self, builder: Any) -> None:
"""Set the task inputs from a AiiDA ProcessBuilder."""
Expand Down
31 changes: 30 additions & 1 deletion tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,28 @@ def test_task_wait(decorated_add: Callable) -> None:
assert "tasks ready to run: add1" in report


def test_set_namespace_socket(decorated_add) -> None:
"""Test setting the namespace of a task."""
from .utils.test_workchain import WorkChainWithNestNamespace

wg = WorkGraph(name="test_set_namespace")
task1 = wg.add_task(decorated_add)
task2 = wg.add_task(WorkChainWithNestNamespace)
task2.set(
{
"add": {"x": task1.outputs["result"], "y": orm.Int(2)},
}
)
assert len(task2.inputs["add.x"].links) == 1
assert task2.inputs["add"].value == {"y": orm.Int(2)}
assert len(wg.links) == 1


def test_set_dynamic_port_input(decorated_add) -> None:
"""Test setting dynamic port input of a task.
Use can pass AiiDA nodes as values of the dynamic port,
and the task will create the input for each item in the dynamic port.
"""
from .utils.test_workchain import WorkChainWithDynamicNamespace

wg = WorkGraph(name="test_set_dynamic_port_input")
Expand All @@ -67,6 +88,7 @@ def test_set_dynamic_port_input(decorated_add) -> None:
"input1": None,
"input2": orm.Int(2),
"input3": task1.outputs["result"],
"nested": {"input4": orm.Int(4), "input5": task1.outputs["result"]},
},
)
wg.add_link(task1.outputs["_wait"], task2.inputs["dynamic_port.input1"])
Expand All @@ -75,7 +97,14 @@ def test_set_dynamic_port_input(decorated_add) -> None:
assert "dynamic_port.input2" in task2.inputs.keys()
# if the value of the item is a Socket, then it will create a link, and pop the item
assert "dynamic_port.input3" in task2.inputs.keys()
assert task2.inputs["dynamic_port"].value == {"input1": None, "input2": orm.Int(2)}
assert "dynamic_port.nested.input4" in task2.inputs.keys()
assert "dynamic_port.nested.input5" in task2.inputs.keys()
assert task2.inputs["dynamic_port"].value == {
"input1": None,
"input2": orm.Int(2),
"nested": {"input4": orm.Int(4)},
}
assert len(wg.links) == 3


def test_set_inputs(decorated_add: Callable) -> None:
Expand Down
2 changes: 2 additions & 0 deletions tests/utils/test_workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class WorkChainWithNestNamespace(WorkChain):
def define(cls, spec):
"""Specify inputs and outputs."""
super().define(spec)
spec.input_namespace("non_dynamic_port")
spec.input("non_dynamic_port.a")
spec.expose_inputs(
ArithmeticAddCalculation,
namespace="add",
Expand Down

0 comments on commit 1fb3dc9

Please sign in to comment.