Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for setting inputs for non dynamic namespace #386

Merged
merged 1 commit into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def set(self, data: Dict[str, Any]) -> None:

super().set(data)

def process_nested_inputs(base_key: str, value: Any) -> None:
def process_nested_inputs(
base_key: str, value: Any, dynamic: bool = False
) -> None:
"""Recursive function to process nested inputs.
Creates sockets and links dynamically for nested values.
"""
Expand All @@ -101,25 +103,28 @@ def process_nested_inputs(base_key: str, value: Any) -> None:
full_key = f"{base_key}.{sub_key}" if base_key else sub_key

# Create a new input socket if it does not exist
if full_key not in self.inputs.keys():
if full_key not in self.inputs.keys() and dynamic:
self.inputs.new(
"workgraph.any",
name=full_key,
metadata={"required": True},
)

if isinstance(sub_value, NodeSocket):
self.parent.links.new(sub_value, self.inputs[full_key])
value.pop(sub_key)
else:
# Recursively process nested dictionaries
process_nested_inputs(full_key, sub_value)
process_nested_inputs(full_key, sub_value, dynamic)

# create input sockets and links for items inside a dynamic socket
# TODO the input value could be nested, but we only support one level for now
for key in data:
if self.inputs[key].metadata.get("dynamic", False):
process_nested_inputs(key, self.inputs[key].value)
if self.inputs[key].identifier == "workgraph.namespace":
process_nested_inputs(
key,
self.inputs[key].value,
dynamic=self.inputs[key].metadata.get("dynamic", False),
)

def set_from_builder(self, builder: Any) -> None:
"""Set the task inputs from a AiiDA ProcessBuilder."""
Expand Down
17 changes: 17 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,23 @@ def test_task_wait(decorated_add: Callable) -> None:
assert "tasks ready to run: add1" in report


def test_set_non_dynamic_namespace_socket(decorated_add) -> None:
"""Test setting the namespace of a task."""
from .utils.test_workchain import WorkChainWithNestNamespace

wg = WorkGraph(name="test_set_namespace")
task1 = wg.add_task(decorated_add)
task2 = wg.add_task(WorkChainWithNestNamespace)
task2.set(
{
"non_dynamic_port": {"a": task1.outputs["result"], "b": orm.Int(2)},
}
)
assert len(task2.inputs["non_dynamic_port.a"].links) == 1
assert task2.inputs["non_dynamic_port"].value == {"b": orm.Int(2)}
assert len(wg.links) == 1


def test_set_namespace_socket(decorated_add) -> None:
"""Test setting the namespace of a task."""
from .utils.test_workchain import WorkChainWithNestNamespace
Expand Down
2 changes: 2 additions & 0 deletions tests/utils/test_workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def define(cls, spec):
"""Specify inputs and outputs."""
super().define(spec)
spec.input_namespace("non_dynamic_port")
spec.input("non_dynamic_port.a", valid_type=Int)
spec.input("non_dynamic_port.b", valid_type=Int)
spec.input("non_dynamic_port.a")
spec.expose_inputs(
ArithmeticAddCalculation,
Expand Down
Loading