Skip to content

Commit

Permalink
Use type_mapping when create new socket
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 10, 2024
1 parent 5e38ab4 commit 7120c0a
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 111 deletions.
19 changes: 2 additions & 17 deletions src/aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from aiida_workgraph.utils import get_executor
from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain
from aiida import orm
from aiida.orm.nodes.process.calculation.calcfunction import CalcFunctionNode
from aiida.orm.nodes.process.workflow.workfunction import WorkFunctionNode
from aiida.engine.processes.ports import PortNamespace
from aiida_workgraph.task import Task
from aiida_workgraph.utils import build_callable, validate_task_inout
import inspect
from aiida_workgraph.config import builtin_inputs, builtin_outputs
from aiida_workgraph.orm.mapping import type_mapping


task_types = {
CalcFunctionNode: "CALCFUNCTION",
Expand All @@ -19,22 +20,6 @@
WorkChain: "WORKCHAIN",
}

type_mapping = {
"default": "workgraph.any",
"namespace": "workgraph.namespace",
int: "workgraph.int",
float: "workgraph.float",
str: "workgraph.string",
bool: "workgraph.bool",
orm.Int: "workgraph.aiida_int",
orm.Float: "workgraph.aiida_float",
orm.Str: "workgraph.aiida_string",
orm.Bool: "workgraph.aiida_bool",
orm.List: "workgraph.aiida_list",
orm.Dict: "workgraph.aiida_dict",
orm.StructureData: "workgraph.aiida_structuredata",
}


def create_task(tdata):
"""Wrap create_node from node_graph to create a Task."""
Expand Down
10 changes: 7 additions & 3 deletions src/aiida_workgraph/engine/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_task(self, name: str):
for output in task.outputs:
output.value = get_nested_dict(
self.ctx._tasks[name]["results"],
output.name,
output.socket_name,
default=output.value,
)
return task
Expand Down Expand Up @@ -734,7 +734,9 @@ def update_normal_task_state(self, name, results, success=True):
"""Set the results of a normal task.
A normal task is created by decorating a function with @task().
"""
from aiida_workgraph.utils import get_sorted_names
from aiida_workgraph.config import builtin_outputs

builtin_output_names = [output["name"] for output in builtin_outputs]

if success:
task = self.ctx._tasks[name]
Expand All @@ -743,7 +745,9 @@ def update_normal_task_state(self, name, results, success=True):
if len(task["outputs"]) - 2 != len(results):
self.on_task_failed(name)
return self.process.exit_codes.OUTPUS_NOT_MATCH_RESULTS
output_names = get_sorted_names(task["outputs"])[0:-2]
output_names = [
name for name in task["outputs"] if name not in builtin_output_names
]
for i, output_name in enumerate(output_names):
task["results"][output_name] = results[i]
elif isinstance(results, dict):
Expand Down
18 changes: 18 additions & 0 deletions src/aiida_workgraph/orm/mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from aiida import orm


type_mapping = {
"default": "workgraph.any",
"namespace": "workgraph.namespace",
int: "workgraph.int",
float: "workgraph.float",
str: "workgraph.string",
bool: "workgraph.bool",
orm.Int: "workgraph.aiida_int",
orm.Float: "workgraph.aiida_float",
orm.Str: "workgraph.aiida_string",
orm.Bool: "workgraph.aiida_bool",
orm.List: "workgraph.aiida_list",
orm.Dict: "workgraph.aiida_dict",
orm.StructureData: "workgraph.aiida_structuredata",
}
7 changes: 6 additions & 1 deletion src/aiida_workgraph/socket.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Any, Type

from aiida import orm
from node_graph.socket import NodeSocket, NodeSocketNamespace
from node_graph.socket import (
NodeSocket,
NodeSocketNamespace,
)

from aiida_workgraph.property import TaskProperty
from aiida_workgraph.orm.mapping import type_mapping


class TaskSocket(NodeSocket):
Expand Down Expand Up @@ -35,6 +39,7 @@ class TaskSocketNamespace(NodeSocketNamespace):

_socket_identifier = "workgraph.namespace"
_socket_property_class = TaskProperty
_type_mapping: dict = type_mapping

def __init__(self, *args, **kwargs):
super().__init__(*args, entry_point="aiida_workgraph.socket", **kwargs)
Expand Down
48 changes: 3 additions & 45 deletions src/aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,48 +83,6 @@ def set_context(self, context: Dict[str, Any]) -> None:
raise ValueError(msg)
self.context_mapping.update(context)

def set(self, data: Dict[str, Any]) -> None:
from node_graph.socket import NodeSocket

super().set(data)

def process_nested_inputs(
base_key: str, value: Any, dynamic: bool = False
) -> None:
"""Recursive function to process nested inputs.
Creates sockets and links dynamically for nested values.
"""
if isinstance(value, dict):
keys = list(value.keys())
for sub_key in keys:
sub_value = value[sub_key]
# Form the full key for the current nested level
full_key = f"{base_key}.{sub_key}" if base_key else sub_key

# Create a new input socket if it does not exist
if full_key not in self.get_input_names() and dynamic:
self.add_input(
"workgraph.any",
name=full_key,
metadata={"required": True},
)
if isinstance(sub_value, NodeSocket):
self.parent.links.new(sub_value, self.inputs[full_key])
value.pop(sub_key)
else:
# Recursively process nested dictionaries
process_nested_inputs(full_key, sub_value, dynamic)

# create input sockets and links for items inside a dynamic socket
# TODO the input value could be nested, but we only support one level for now
for key in data:
if self.inputs[key]._socket_identifier == "workgraph.namespace":
process_nested_inputs(
key,
self.inputs[key].value,
dynamic=self.inputs[key].metadata.get("dynamic", False),
)

def set_from_builder(self, builder: Any) -> None:
"""Set the task inputs from a AiiDA ProcessBuilder."""
from aiida_workgraph.utils import get_dict_from_builder
Expand Down Expand Up @@ -230,7 +188,7 @@ def to_widget_value(self):
for key in ("properties", "executor", "node_class", "process"):
tdata.pop(key, None)
for input in tdata["inputs"].values():
input.pop("property")
input.pop("property", None)

tdata["label"] = tdata["identifier"]

Expand Down Expand Up @@ -289,9 +247,9 @@ def _normalize_tasks(
task_objects = []
for task in tasks:
if isinstance(task, str):
if task not in self.graph.tasks.keys():
if task not in self.graph.tasks:
raise ValueError(
f"Task '{task}' is not in the graph. Available tasks: {self.graph.tasks.keys()}"
f"Task '{task}' is not in the graph. Available tasks: {self.graph.tasks}"
)
task_objects.append(self.graph.tasks[task])
elif isinstance(task, Task):
Expand Down
15 changes: 1 addition & 14 deletions src/aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,6 @@ def build_callable(obj: Callable) -> Dict[str, Any]:
return executor


def get_sorted_names(data: dict) -> list[str]:
"""Get the sorted names from a dictionary."""
sorted_names = [
name
for name, _ in sorted(
((name, item["list_index"]) for name, item in data.items()),
key=lambda x: x[1],
)
]
return sorted_names


def store_nodes_recursely(data: Any) -> None:
"""Recurse through a data structure and store any unstored nodes that are found along the way
:param data: a data structure potentially containing unstored nodes
Expand Down Expand Up @@ -615,7 +603,6 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di
:raises TypeError: If wrong types are provided to the task
:return: Processed `inputs`/`outputs` list.
"""
from node_graph.utils import list_to_dict

if not all(isinstance(item, (dict, str)) for item in inout_list):
raise TypeError(
Expand All @@ -630,7 +617,7 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di
elif isinstance(item, dict):
processed_inout_list.append(item)

processed_inout_list = list_to_dict(processed_inout_list)
processed_inout_list = processed_inout_list

return processed_inout_list

Expand Down
8 changes: 3 additions & 5 deletions src/aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,10 @@ def update(self) -> None:
# even if the node.is_finished_ok is True
if node.is_finished_ok:
# update the output sockets
i = 0
for socket in self.tasks[name].outputs:
socket.value = get_nested_dict(
socket.socket_value = get_nested_dict(
node.outputs, socket.socket_name, default=None
)
i += 1
# read results from the process outputs
elif isinstance(node, aiida.orm.Data):
self.tasks[name].outputs[0].value = node
Expand Down Expand Up @@ -473,13 +471,13 @@ def extend(self, wg: "WorkGraph", prefix: str = "") -> None:
for task in wg.tasks:
task.name = prefix + task.name
task.parent = self
self.tasks.append(task)
self.tasks._append(task)
# self.sequence.extend([prefix + task for task in wg.sequence])
# self.conditions.extend(wg.conditions)
self.context.update(wg.context)
# links
for link in wg.links:
self.links.append(link)
self.links._append(link)

@property
def error_handlers(self) -> Dict[str, Any]:
Expand Down
39 changes: 21 additions & 18 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def sum_diff(x, y):
decorated_add, name="add", x=task1.outputs["sum"], y=task1.outputs["diff"]
)
wg.run()
assert task2.outputs["result"].value == 4
print("node: ", task2.node.outputs.result)
wg.update()
assert task2.outputs["result"].socket_value == 4


def test_task_collection(decorated_add: Callable) -> None:
Expand Down Expand Up @@ -68,8 +70,8 @@ def test_set_non_dynamic_namespace_socket(decorated_add) -> None:
"non_dynamic_port": {"a": task1.outputs["result"], "b": orm.Int(2)},
}
)
assert len(task2.inputs["non_dynamic_port.a"].links) == 1
assert task2.inputs["non_dynamic_port"].value == {"b": orm.Int(2)}
assert len(task2.inputs["non_dynamic_port.a"].socket_links) == 1
assert task2.inputs["non_dynamic_port"].socket_value == {"b": orm.Int(2)}
assert len(wg.links) == 1


Expand All @@ -85,8 +87,12 @@ def test_set_namespace_socket(decorated_add) -> None:
"add": {"x": task1.outputs["result"], "y": orm.Int(2)},
}
)
assert len(task2.inputs["add.x"].links) == 1
assert task2.inputs["add"].value == {"y": orm.Int(2)}
assert len(task2.inputs["add.x"].socket_links) == 1
assert task2.inputs["add"].socket_value == {
"metadata": {"options": {"stash": {}}},
"monitors": {},
"y": orm.Int(2),
}
assert len(wg.links) == 1


Expand All @@ -110,14 +116,13 @@ def test_set_dynamic_port_input(decorated_add) -> None:
)
wg.add_link(task1.outputs["_wait"], task2.inputs["dynamic_port.input1"])
# task will create input for each item in the dynamic port (nodes)
assert "dynamic_port.input1" in task2.get_input_names()
assert "dynamic_port.input2" in task2.get_input_names()
assert "dynamic_port.input1" in task2.inputs
assert "dynamic_port.input2" in task2.inputs
# if the value of the item is a Socket, then it will create a link, and pop the item
assert "dynamic_port.input3" in task2.get_input_names()
assert "dynamic_port.nested.input4" in task2.get_input_names()
assert "dynamic_port.nested.input5" in task2.get_input_names()
assert task2.inputs["dynamic_port"].value == {
"input1": None,
assert "dynamic_port.input3" in task2.inputs
assert "dynamic_port.nested.input4" in task2.inputs
assert "dynamic_port.nested.input5" in task2.inputs
assert task2.inputs["dynamic_port"].socket_value == {
"input2": orm.Int(2),
"nested": {"input4": orm.Int(4)},
}
Expand All @@ -133,9 +138,7 @@ def test_set_inputs(decorated_add: Callable) -> None:
data = wg.prepare_inputs(metadata=None)
assert data["wg"]["tasks"]["add1"]["inputs"]["y"]["property"]["value"] == 2
assert (
data["wg"]["tasks"]["add1"]["inputs"]["metadata"]["property"]["value"][
"store_provenance"
]
data["wg"]["tasks"]["add1"]["inputs"]["metadata"]["value"]["store_provenance"]
is False
)

Expand All @@ -151,9 +154,9 @@ def test_set_inputs_from_builder(add_code) -> None:
builder.x = 1
builder.y = 2
add1.set_from_builder(builder)
assert add1.inputs["x"].value == 1
assert add1.inputs["y"].value == 2
assert add1.inputs["code"].value == add_code
assert add1.inputs["x"].socket_value == 1
assert add1.inputs["y"].socket_value == 2
assert add1.inputs["code"].socket_value == add_code
with pytest.raises(
AttributeError,
match=f"Executor {ArithmeticAddCalculation.__name__} does not have the get_builder_from_protocol method.",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_build_workchain_inputs_outputs():
node = build_task(MultiplyAddWorkChain)()
inputs = MultiplyAddWorkChain.spec().inputs
# inputs + metadata + _wait
ninput = len(inputs.ports) + len(inputs.ports["metadata"].ports) + 1
ninput = len(inputs.ports) + 1
assert len(node.inputs) == ninput
assert len(node.outputs) == 3

Expand Down
5 changes: 3 additions & 2 deletions tests/test_workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,12 @@ def test_organize_nested_inputs():
data = {
"metadata": {
"call_link_label": "nest",
"options": {"resources": {"num_cpus": 1, "num_machines": 1}},
"options": {"resources": {"num_machines": 1}, "stash": {}},
},
"monitors": {},
"x": "1",
}
assert inputs["wg"]["tasks"]["task1"]["inputs"]["add"]["property"]["value"] == data
assert inputs["wg"]["tasks"]["task1"]["inputs"]["add"]["value"] == data


@pytest.mark.usefixtures("started_daemon_client")
Expand Down
10 changes: 5 additions & 5 deletions tests/test_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@

def test_calcfunction():
wg = WorkGraph.from_yaml(os.path.join(cwd, "datas/test_calcfunction.yaml"))
assert wg.tasks["float1"].inputs["value"].value == 3.0
assert wg.tasks["sumdiff1"].inputs["x"].value == 2.0
assert wg.tasks["sumdiff2"].inputs["x"].value == 4.0
assert wg.tasks.float1.inputs.value.socket_value == 3.0
assert wg.tasks.sumdiff1.inputs.x.socket_value == 2.0
assert wg.tasks.sumdiff2.inputs.x.socket_value == 4.0
wg.run()
assert wg.tasks["sumdiff2"].node.outputs.sum == 9
assert wg.tasks.sumdiff2.node.outputs.sum == 9


# skip this test for now
@pytest.mark.skip(reason="need to fix the identifier for a node from build_task")
def test_calcjob():
wg = WorkGraph.from_yaml(os.path.join(cwd, "datas/test_calcjob.yaml"))
wg.submit(wait=True)
assert wg.tasks["add2"].node.outputs.sum == 9
assert wg.tasks.add2.node.outputs.sum == 9

0 comments on commit 7120c0a

Please sign in to comment.