Skip to content

Commit

Permalink
change task's data inputs, outputs and properties to dict
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 10, 2024
1 parent c74bf4e commit 5e38ab4
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 55 deletions.
4 changes: 3 additions & 1 deletion src/aiida_workgraph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
WORKGRAPH_SHORT_EXTRA_KEY = "_workgraph_short"


builtin_inputs = [{"name": "_wait", "link_limit": 1e6, "arg_type": "none"}]
builtin_inputs = [
{"name": "_wait", "link_limit": 1e6, "metadata": {"arg_type": "none"}}
]
builtin_outputs = [{"name": "_wait"}, {"name": "_outputs"}]


Expand Down
18 changes: 12 additions & 6 deletions src/aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,14 @@
def create_task(tdata):
"""Wrap create_node from node_graph to create a Task."""
from node_graph.decorator import create_node
from node_graph.utils import list_to_dict

tdata["type_mapping"] = type_mapping
tdata["metadata"]["node_type"] = tdata["metadata"].pop("task_type")
tdata["properties"] = list_to_dict(tdata.get("properties", {}))
tdata["inputs"] = list_to_dict(tdata.get("inputs", {}))
tdata["outputs"] = list_to_dict(tdata.get("outputs", {}))

return create_node(tdata)


Expand All @@ -67,8 +72,11 @@ def add_input_recursive(
{
"identifier": "workgraph.namespace",
"name": port_name,
"arg_type": "kwargs",
"metadata": {"required": required, "dynamic": port.dynamic},
"metadata": {
"arg_type": "kwargs",
"required": required,
"dynamic": port.dynamic,
},
}
)
for value in port.values():
Expand All @@ -87,8 +95,7 @@ def add_input_recursive(
{
"identifier": socket_type,
"name": port_name,
"arg_type": "kwargs",
"metadata": {"required": required},
"metadata": {"arg_type": "kwargs", "required": required},
}
)
return inputs
Expand Down Expand Up @@ -249,8 +256,7 @@ def build_task_from_AiiDA(
{
"identifier": "workgraph.namespace",
"name": name,
"arg_type": "var_kwargs",
"metadata": {"dynamic": True},
"metadata": {"arg_type": "var_kwargs", "dynamic": True},
}
)

Expand Down
9 changes: 4 additions & 5 deletions src/aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,10 @@ def read_wgdata_from_base(self) -> t.Dict[str, t.Any]:
for name, task in wgdata["tasks"].items():
wgdata["tasks"][name] = deserialize_unsafe(task)
for _, input in wgdata["tasks"][name]["inputs"].items():
if input["property"] is None:
continue
prop = input["property"]
if isinstance(prop["value"], PickledLocalFunction):
prop["value"] = prop["value"].value
if input.get("property"):
prop = input["property"]
if isinstance(prop["value"], PickledLocalFunction):
prop["value"] = prop["value"].value
wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"])
wgdata["context"] = deserialize_unsafe(wgdata["context"])
return wgdata
Expand Down
2 changes: 1 addition & 1 deletion src/aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def process_nested_inputs(
# create input sockets and links for items inside a dynamic socket
# TODO the input value could be nested, but we only support one level for now
for key in data:
if self.inputs[key].identifier == "workgraph.namespace":
if self.inputs[key]._socket_identifier == "workgraph.namespace":
process_nested_inputs(
key,
self.inputs[key].value,
Expand Down
52 changes: 39 additions & 13 deletions src/aiida_workgraph/tasks/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def __init__(self, *args, **kwargs):
def create_sockets(self) -> None:
self.inputs._clear()
self.outputs._clear()
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
)
self.add_output("workgraph.any", "_wait")

def to_dict(self, short: bool = False) -> Dict[str, Any]:
Expand All @@ -43,7 +45,9 @@ class While(Zone):
def create_sockets(self) -> None:
self.inputs._clear()
self.outputs._clear()
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
)
self.add_input(
"node_graph.int", "max_iterations", property_data={"default": 10000}
)
Expand All @@ -62,7 +66,9 @@ class If(Zone):
def create_sockets(self) -> None:
self.inputs._clear()
self.outputs._clear()
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
)
self.add_input("workgraph.any", "conditions")
self.add_input("workgraph.any", "invert_condition")
self.add_output("workgraph.any", "_wait")
Expand All @@ -81,7 +87,9 @@ def create_sockets(self) -> None:
self.outputs._clear()
self.add_input("workgraph.any", "key")
self.add_input("workgraph.any", "value")
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
)
self.add_output("workgraph.any", "_wait")


Expand All @@ -97,7 +105,9 @@ def create_sockets(self) -> None:
self.inputs._clear()
self.outputs._clear()
self.add_input("workgraph.any", "key")
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
)
self.add_output("workgraph.any", "result")
self.add_output("workgraph.any", "_wait")

Expand All @@ -115,7 +125,9 @@ class AiiDAInt(Task):

def create_sockets(self) -> None:
self.add_input("workgraph.any", "value", property_data={"default": 0.0})
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
)
self.add_output("workgraph.aiida_int", "result")
self.add_output("workgraph.any", "_wait")

Expand All @@ -135,7 +147,9 @@ def create_sockets(self) -> None:
self.inputs._clear()
self.outputs._clear()
self.add_input("workgraph.float", "value", property_data={"default": 0.0})
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
)
self.add_output("workgraph.aiida_float", "result")
self.add_output("workgraph.any", "_wait")

Expand All @@ -155,7 +169,9 @@ def create_sockets(self) -> None:
self.inputs._clear()
self.outputs._clear()
self.add_input("workgraph.string", "value", property_data={"default": ""})
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
)
self.add_output("workgraph.aiida_string", "result")
self.add_output("workgraph.any", "_wait")

Expand All @@ -175,7 +191,9 @@ def create_sockets(self) -> None:
self.inputs._clear()
self.outputs._clear()
self.add_input("workgraph.any", "value", property_data={"default": []})
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
)
self.add_output("workgraph.aiida_list", "result")
self.add_output("workgraph.any", "_wait")

Expand All @@ -195,7 +213,9 @@ def create_sockets(self) -> None:
self.inputs._clear()
self.outputs._clear()
self.add_input("workgraph.any", "value", property_data={"default": {}})
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
)
self.add_output("workgraph.aiida_dict", "result")
self.add_output("workgraph.any", "_wait")

Expand Down Expand Up @@ -223,7 +243,9 @@ def create_sockets(self) -> None:
self.add_input("workgraph.any", "pk")
self.add_input("workgraph.any", "uuid")
self.add_input("workgraph.any", "label")
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
)
self.add_output("workgraph.any", "node")
self.add_output("workgraph.any", "_wait")

Expand All @@ -248,7 +270,9 @@ def create_sockets(self) -> None:
self.add_input("workgraph.any", "pk")
self.add_input("workgraph.any", "uuid")
self.add_input("workgraph.any", "label")
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
)
self.add_output("workgraph.any", "Code")
self.add_output("workgraph.any", "_wait")

Expand All @@ -272,6 +296,8 @@ def create_sockets(self) -> None:
self.add_input("workgraph.any", "condition")
self.add_input("workgraph.any", "true")
self.add_input("workgraph.any", "false")
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
)
self.add_output("workgraph.any", "result")
self.add_output("workgraph.any", "_wait")
12 changes: 9 additions & 3 deletions src/aiida_workgraph/tasks/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def create_sockets(self) -> None:
inp.add_property("workgraph.any", default=1.0)
inp = self.add_input("workgraph.any", "timeout")
inp.add_property("workgraph.any", default=86400.0)
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
)
inp.socket_link_limit = 100000
self.add_output("workgraph.any", "result")
self.add_output("workgraph.any", "_wait")
Expand All @@ -49,7 +51,9 @@ def create_sockets(self) -> None:
inp.add_property("workgraph.any", default=1.0)
inp = self.add_input("workgraph.any", "timeout")
inp.add_property("workgraph.any", default=86400.0)
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
)
inp.socket_link_limit = 100000
self.add_output("workgraph.any", "result")
self.add_output("workgraph.any", "_wait")
Expand Down Expand Up @@ -78,7 +82,9 @@ def create_sockets(self) -> None:
inp.add_property("workgraph.any", default=1.0)
inp = self.add_input("workgraph.any", "timeout")
inp.add_property("workgraph.any", default=86400.0)
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
)
inp.socket_link_limit = 100000
self.add_output("workgraph.any", "result")
self.add_output("workgraph.any", "_wait")
12 changes: 9 additions & 3 deletions src/aiida_workgraph/tasks/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def create_sockets(self) -> None:
inp.add_property("workgraph.aiida_float", "x", default=0.0)
inp = self.add_input("workgraph.aiida_float", "y")
inp.add_property("workgraph.aiida_float", "y", default=0.0)
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
)
self.add_output("workgraph.aiida_float", "sum")
self.add_output("workgraph.any", "_wait")
self.add_output("workgraph.any", "_outputs")
Expand Down Expand Up @@ -51,7 +53,9 @@ def create_sockets(self) -> None:
inp.add_property("workgraph.aiida_float", "x", default=0.0)
inp = self.add_input("workgraph.aiida_float", "y")
inp.add_property("workgraph.aiida_float", "y", default=0.0)
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
)
self.add_output("workgraph.aiida_float", "sum")
self.add_output("workgraph.aiida_float", "diff")
self.add_output("workgraph.any", "_wait")
Expand Down Expand Up @@ -83,7 +87,9 @@ def create_sockets(self) -> None:
inp.add_property("workgraph.aiida_int", "y", default=0.0)
inp = self.add_input("workgraph.aiida_int", "z")
inp.add_property("workgraph.aiida_int", "z", default=0.0)
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
self.add_input(
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
)
self.add_output("workgraph.aiida_int", "result")
self.add_output("workgraph.any", "_wait")
self.add_output("workgraph.any", "_outputs")
35 changes: 18 additions & 17 deletions src/aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,17 +252,16 @@ def organize_nested_inputs(wgdata: Dict[str, Any]) -> None:
update_nested_dict(root_prop["value"], key, prop["value"])
prop["value"] = None
for key, input in task["inputs"].items():
if input["property"] is None:
continue
prop = input["property"]
if "." in key and prop["value"] not in [None, {}]:
root, key = key.split(".", 1)
root_prop = task["inputs"][root]["property"]
# update the root property
root_prop["value"] = update_nested_dict(
root_prop["value"], key, prop["value"]
)
prop["value"] = None
if input.get("property"):
prop = input["property"]
if "." in key and prop["value"] not in [None, {}]:
root, key = key.split(".", 1)
root_prop = task["inputs"][root]["property"]
# update the root property
root_prop["value"] = update_nested_dict(
root_prop["value"], key, prop["value"]
)
prop["value"] = None


def generate_node_graph(
Expand Down Expand Up @@ -466,11 +465,10 @@ def serialize_workgraph_inputs(wgdata):
if task["metadata"]["node_type"].upper() == "PYTHONJOB":
PythonJob.serialize_pythonjob_data(task)
for _, input in task["inputs"].items():
if input["property"] is None:
continue
prop = input["property"]
if inspect.isfunction(prop["value"]):
prop["value"] = PickledLocalFunction(prop["value"]).store()
if input.get("property"):
prop = input["property"]
if inspect.isfunction(prop["value"]):
prop["value"] = PickledLocalFunction(prop["value"]).store()
# error_handlers of the workgraph
for _, data in wgdata["error_handlers"].items():
if not data["handler"]["use_module_path"]:
Expand Down Expand Up @@ -548,7 +546,7 @@ def process_properties(task: Dict) -> Dict:
}
#
for name, input in task["inputs"].items():
if input["property"] is not None:
if input.get("property"):
prop = input["property"]
identifier = prop["identifier"]
value = prop.get("value")
Expand Down Expand Up @@ -617,6 +615,7 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di
:raises TypeError: If wrong types are provided to the task
:return: Processed `inputs`/`outputs` list.
"""
from node_graph.utils import list_to_dict

if not all(isinstance(item, (dict, str)) for item in inout_list):
raise TypeError(
Expand All @@ -631,6 +630,8 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di
elif isinstance(item, dict):
processed_inout_list.append(item)

processed_inout_list = list_to_dict(processed_inout_list)

return processed_inout_list


Expand Down
9 changes: 4 additions & 5 deletions src/aiida_workgraph/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,10 @@ def insert_workgraph_to_db(self) -> None:
self.save_task_states()
for name, task in self.wgdata["tasks"].items():
for _, input in task["inputs"].items():
if input["property"] is None:
continue
prop = input["property"]
if inspect.isfunction(prop["value"]):
prop["value"] = PickledLocalFunction(prop["value"]).store()
if input.get("property"):
prop = input["property"]
if inspect.isfunction(prop["value"]):
prop["value"] = PickledLocalFunction(prop["value"]).store()
self.wgdata["tasks"][name] = serialize(task)
# nodes is a copy of tasks, so we need to pop it out
self.wgdata["error_handlers"] = serialize(self.wgdata["error_handlers"])
Expand Down
2 changes: 1 addition & 1 deletion src/aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def update(self) -> None:
i = 0
for socket in self.tasks[name].outputs:
socket.value = get_nested_dict(
node.outputs, socket.name, default=None
node.outputs, socket.socket_name, default=None
)
i += 1
# read results from the process outputs
Expand Down

0 comments on commit 5e38ab4

Please sign in to comment.