Skip to content

Commit

Permalink
Add test for setting inputs for non dynamic namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 6, 2024
1 parent cdc4f3e commit 940e344
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def set(self, data: Dict[str, Any]) -> None:

super().set(data)

def process_nested_inputs(base_key: str, value: Any) -> None:
def process_nested_inputs(
base_key: str, value: Any, dynamic: bool = False
) -> None:
"""Recursive function to process nested inputs.
Creates sockets and links dynamically for nested values.
"""
Expand All @@ -101,25 +103,28 @@ def process_nested_inputs(base_key: str, value: Any) -> None:
full_key = f"{base_key}.{sub_key}" if base_key else sub_key

# Create a new input socket if it does not exist
if full_key not in self.inputs.keys():
if full_key not in self.inputs.keys() and dynamic:
self.inputs.new(
"workgraph.any",
name=full_key,
metadata={"required": True},
)

if isinstance(sub_value, NodeSocket):
self.parent.links.new(sub_value, self.inputs[full_key])
value.pop(sub_key)
else:
# Recursively process nested dictionaries
process_nested_inputs(full_key, sub_value)
process_nested_inputs(full_key, sub_value, dynamic)

# create input sockets and links for items inside a dynamic socket
# TODO the input value could be nested, but we only support one level for now
for key in data:
if self.inputs[key].metadata.get("dynamic", False):
process_nested_inputs(key, self.inputs[key].value)
if self.inputs[key].identifier == "workgraph.namespace":
process_nested_inputs(
key,
self.inputs[key].value,
dynamic=self.inputs[key].metadata.get("dynamic", False),
)

def set_from_builder(self, builder: Any) -> None:
"""Set the task inputs from a AiiDA ProcessBuilder."""
Expand Down
17 changes: 17 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,23 @@ def test_task_wait(decorated_add: Callable) -> None:
assert "tasks ready to run: add1" in report


def test_set_non_dynamic_namespace_socket(decorated_add) -> None:
"""Test setting the namespace of a task."""
from .utils.test_workchain import WorkChainWithNestNamespace

wg = WorkGraph(name="test_set_namespace")
task1 = wg.add_task(decorated_add)
task2 = wg.add_task(WorkChainWithNestNamespace)
task2.set(
{
"non_dynamic_port": {"a": task1.outputs["result"], "b": orm.Int(2)},
}
)
assert len(task2.inputs["non_dynamic_port.a"].links) == 1
assert task2.inputs["non_dynamic_port"].value == {"b": orm.Int(2)}
assert len(wg.links) == 1


def test_set_namespace_socket(decorated_add) -> None:
"""Test setting the namespace of a task."""
from .utils.test_workchain import WorkChainWithNestNamespace
Expand Down
2 changes: 2 additions & 0 deletions tests/utils/test_workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def define(cls, spec):
"""Specify inputs and outputs."""
super().define(spec)
spec.input_namespace("non_dynamic_port")
spec.input("non_dynamic_port.a", valid_type=Int)
spec.input("non_dynamic_port.b", valid_type=Int)
spec.input("non_dynamic_port.a")
spec.expose_inputs(
ArithmeticAddCalculation,
Expand Down

0 comments on commit 940e344

Please sign in to comment.