diff --git a/src/aiida_workgraph/config.py b/src/aiida_workgraph/config.py index ab316c38..94876cd4 100644 --- a/src/aiida_workgraph/config.py +++ b/src/aiida_workgraph/config.py @@ -5,7 +5,9 @@ WORKGRAPH_SHORT_EXTRA_KEY = "_workgraph_short" -builtin_inputs = [{"name": "_wait", "link_limit": 1e6, "arg_type": "none"}] +builtin_inputs = [ + {"name": "_wait", "link_limit": 1e6, "metadata": {"arg_type": "none"}} +] builtin_outputs = [{"name": "_wait"}, {"name": "_outputs"}] diff --git a/src/aiida_workgraph/decorator.py b/src/aiida_workgraph/decorator.py index 626c2023..85c3812e 100644 --- a/src/aiida_workgraph/decorator.py +++ b/src/aiida_workgraph/decorator.py @@ -39,9 +39,14 @@ def create_task(tdata): """Wrap create_node from node_graph to create a Task.""" from node_graph.decorator import create_node + from node_graph.utils import list_to_dict tdata["type_mapping"] = type_mapping tdata["metadata"]["node_type"] = tdata["metadata"].pop("task_type") + tdata["properties"] = list_to_dict(tdata.get("properties", {})) + tdata["inputs"] = list_to_dict(tdata.get("inputs", {})) + tdata["outputs"] = list_to_dict(tdata.get("outputs", {})) + return create_node(tdata) @@ -67,8 +72,11 @@ def add_input_recursive( { "identifier": "workgraph.namespace", "name": port_name, - "arg_type": "kwargs", - "metadata": {"required": required, "dynamic": port.dynamic}, + "metadata": { + "arg_type": "kwargs", + "required": required, + "dynamic": port.dynamic, + }, } ) for value in port.values(): @@ -87,8 +95,7 @@ def add_input_recursive( { "identifier": socket_type, "name": port_name, - "arg_type": "kwargs", - "metadata": {"required": required}, + "metadata": {"arg_type": "kwargs", "required": required}, } ) return inputs @@ -249,8 +256,7 @@ def build_task_from_AiiDA( { "identifier": "workgraph.namespace", "name": name, - "arg_type": "var_kwargs", - "metadata": {"dynamic": True}, + "metadata": {"arg_type": "var_kwargs", "dynamic": True}, } ) diff --git a/src/aiida_workgraph/engine/workgraph.py b/src/aiida_workgraph/engine/workgraph.py index 582cad94..4590ecfb 100644 --- a/src/aiida_workgraph/engine/workgraph.py +++ b/src/aiida_workgraph/engine/workgraph.py @@ -314,11 +314,10 @@ def read_wgdata_from_base(self) -> t.Dict[str, t.Any]: for name, task in wgdata["tasks"].items(): wgdata["tasks"][name] = deserialize_unsafe(task) for _, input in wgdata["tasks"][name]["inputs"].items(): - if input["property"] is None: - continue - prop = input["property"] - if isinstance(prop["value"], PickledLocalFunction): - prop["value"] = prop["value"].value + if input.get("property"): + prop = input["property"] + if isinstance(prop["value"], PickledLocalFunction): + prop["value"] = prop["value"].value wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"]) wgdata["context"] = deserialize_unsafe(wgdata["context"]) return wgdata diff --git a/src/aiida_workgraph/task.py b/src/aiida_workgraph/task.py index 2b550ffc..6926c88a 100644 --- a/src/aiida_workgraph/task.py +++ b/src/aiida_workgraph/task.py @@ -118,7 +118,7 @@ def process_nested_inputs( # create input sockets and links for items inside a dynamic socket # TODO the input value could be nested, but we only support one level for now for key in data: - if self.inputs[key].identifier == "workgraph.namespace": + if self.inputs[key]._socket_identifier == "workgraph.namespace": process_nested_inputs( key, self.inputs[key].value, diff --git a/src/aiida_workgraph/tasks/builtins.py b/src/aiida_workgraph/tasks/builtins.py index 92550035..03d83aca 100644 --- a/src/aiida_workgraph/tasks/builtins.py +++ b/src/aiida_workgraph/tasks/builtins.py @@ -19,7 +19,9 @@ def __init__(self, *args, **kwargs): def create_sockets(self) -> None: self.inputs._clear() self.outputs._clear() - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) self.add_output("workgraph.any", "_wait") def to_dict(self, short: bool = False) -> Dict[str, Any]: @@ -43,7 +45,9 @@ class While(Zone): def create_sockets(self) -> None: self.inputs._clear() self.outputs._clear() - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) self.add_input( "node_graph.int", "max_iterations", property_data={"default": 10000} ) @@ -62,7 +66,9 @@ class If(Zone): def create_sockets(self) -> None: self.inputs._clear() self.outputs._clear() - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) self.add_input("workgraph.any", "conditions") self.add_input("workgraph.any", "invert_condition") self.add_output("workgraph.any", "_wait") @@ -81,7 +87,9 @@ def create_sockets(self) -> None: self.outputs._clear() self.add_input("workgraph.any", "key") self.add_input("workgraph.any", "value") - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) self.add_output("workgraph.any", "_wait") @@ -97,7 +105,9 @@ def create_sockets(self) -> None: self.inputs._clear() self.outputs._clear() self.add_input("workgraph.any", "key") - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) self.add_output("workgraph.any", "result") self.add_output("workgraph.any", "_wait") @@ -115,7 +125,9 @@ class AiiDAInt(Task): def create_sockets(self) -> None: self.add_input("workgraph.any", "value", property_data={"default": 0.0}) - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) self.add_output("workgraph.aiida_int", "result") self.add_output("workgraph.any", "_wait") @@ -135,7 +147,9 @@ def create_sockets(self) -> None: self.inputs._clear() self.outputs._clear() self.add_input("workgraph.float", "value", property_data={"default": 0.0}) - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) self.add_output("workgraph.aiida_float", "result") self.add_output("workgraph.any", "_wait") @@ -155,7 +169,9 @@ def create_sockets(self) -> None: self.inputs._clear() self.outputs._clear() self.add_input("workgraph.string", "value", property_data={"default": ""}) - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) self.add_output("workgraph.aiida_string", "result") self.add_output("workgraph.any", "_wait") @@ -175,7 +191,9 @@ def create_sockets(self) -> None: self.inputs._clear() self.outputs._clear() self.add_input("workgraph.any", "value", property_data={"default": []}) - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) self.add_output("workgraph.aiida_list", "result") self.add_output("workgraph.any", "_wait") @@ -195,7 +213,9 @@ def create_sockets(self) -> None: self.inputs._clear() self.outputs._clear() self.add_input("workgraph.any", "value", property_data={"default": {}}) - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) self.add_output("workgraph.aiida_dict", "result") self.add_output("workgraph.any", "_wait") @@ -223,7 +243,9 @@ def create_sockets(self) -> None: self.add_input("workgraph.any", "pk") self.add_input("workgraph.any", "uuid") self.add_input("workgraph.any", "label") - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) self.add_output("workgraph.any", "node") self.add_output("workgraph.any", "_wait") @@ -248,7 +270,9 @@ def create_sockets(self) -> None: self.add_input("workgraph.any", "pk") self.add_input("workgraph.any", "uuid") self.add_input("workgraph.any", "label") - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) self.add_output("workgraph.any", "Code") self.add_output("workgraph.any", "_wait") @@ -272,6 +296,8 @@ def create_sockets(self) -> None: self.add_input("workgraph.any", "condition") self.add_input("workgraph.any", "true") self.add_input("workgraph.any", "false") - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) self.add_output("workgraph.any", "result") self.add_output("workgraph.any", "_wait") diff --git a/src/aiida_workgraph/tasks/monitors.py b/src/aiida_workgraph/tasks/monitors.py index 029741c3..9371ab3b 100644 --- a/src/aiida_workgraph/tasks/monitors.py +++ b/src/aiida_workgraph/tasks/monitors.py @@ -22,7 +22,9 @@ def create_sockets(self) -> None: inp.add_property("workgraph.any", default=1.0) inp = self.add_input("workgraph.any", "timeout") inp.add_property("workgraph.any", default=86400.0) - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000 + ) inp.socket_link_limit = 100000 self.add_output("workgraph.any", "result") self.add_output("workgraph.any", "_wait") @@ -49,7 +51,9 @@ def create_sockets(self) -> None: inp.add_property("workgraph.any", default=1.0) inp = self.add_input("workgraph.any", "timeout") inp.add_property("workgraph.any", default=86400.0) - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000 + ) inp.socket_link_limit = 100000 self.add_output("workgraph.any", "result") self.add_output("workgraph.any", "_wait") @@ -78,7 +82,9 @@ def create_sockets(self) -> None: inp.add_property("workgraph.any", default=1.0) inp = self.add_input("workgraph.any", "timeout") inp.add_property("workgraph.any", default=86400.0) - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000 + ) inp.socket_link_limit = 100000 self.add_output("workgraph.any", "result") self.add_output("workgraph.any", "_wait") diff --git a/src/aiida_workgraph/tasks/test.py b/src/aiida_workgraph/tasks/test.py index 766812fb..77adeff9 100644 --- a/src/aiida_workgraph/tasks/test.py +++ b/src/aiida_workgraph/tasks/test.py @@ -23,7 +23,9 @@ def create_sockets(self) -> None: inp.add_property("workgraph.aiida_float", "x", default=0.0) inp = self.add_input("workgraph.aiida_float", "y") inp.add_property("workgraph.aiida_float", "y", default=0.0) - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000 + ) self.add_output("workgraph.aiida_float", "sum") self.add_output("workgraph.any", "_wait") self.add_output("workgraph.any", "_outputs") @@ -51,7 +53,9 @@ def create_sockets(self) -> None: inp.add_property("workgraph.aiida_float", "x", default=0.0) inp = self.add_input("workgraph.aiida_float", "y") inp.add_property("workgraph.aiida_float", "y", default=0.0) - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000 + ) self.add_output("workgraph.aiida_float", "sum") self.add_output("workgraph.aiida_float", "diff") self.add_output("workgraph.any", "_wait") @@ -83,7 +87,9 @@ def create_sockets(self) -> None: inp.add_property("workgraph.aiida_int", "y", default=0.0) inp = self.add_input("workgraph.aiida_int", "z") inp.add_property("workgraph.aiida_int", "z", default=0.0) - self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000) + self.add_input( + "workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000 + ) self.add_output("workgraph.aiida_int", "result") self.add_output("workgraph.any", "_wait") self.add_output("workgraph.any", "_outputs") diff --git a/src/aiida_workgraph/utils/__init__.py b/src/aiida_workgraph/utils/__init__.py index 70c94a72..b8ee0ef6 100644 --- a/src/aiida_workgraph/utils/__init__.py +++ b/src/aiida_workgraph/utils/__init__.py @@ -252,17 +252,16 @@ def organize_nested_inputs(wgdata: Dict[str, Any]) -> None: update_nested_dict(root_prop["value"], key, prop["value"]) prop["value"] = None for key, input in task["inputs"].items(): - if input["property"] is None: - continue - prop = input["property"] - if "." in key and prop["value"] not in [None, {}]: - root, key = key.split(".", 1) - root_prop = task["inputs"][root]["property"] - # update the root property - root_prop["value"] = update_nested_dict( - root_prop["value"], key, prop["value"] - ) - prop["value"] = None + if input.get("property"): + prop = input["property"] + if "." in key and prop["value"] not in [None, {}]: + root, key = key.split(".", 1) + root_prop = task["inputs"][root]["property"] + # update the root property + root_prop["value"] = update_nested_dict( + root_prop["value"], key, prop["value"] + ) + prop["value"] = None def generate_node_graph( @@ -466,11 +465,10 @@ def serialize_workgraph_inputs(wgdata): if task["metadata"]["node_type"].upper() == "PYTHONJOB": PythonJob.serialize_pythonjob_data(task) for _, input in task["inputs"].items(): - if input["property"] is None: - continue - prop = input["property"] - if inspect.isfunction(prop["value"]): - prop["value"] = PickledLocalFunction(prop["value"]).store() + if input.get("property"): + prop = input["property"] + if inspect.isfunction(prop["value"]): + prop["value"] = PickledLocalFunction(prop["value"]).store() # error_handlers of the workgraph for _, data in wgdata["error_handlers"].items(): if not data["handler"]["use_module_path"]: @@ -548,7 +546,7 @@ def process_properties(task: Dict) -> Dict: } # for name, input in task["inputs"].items(): - if input["property"] is not None: + if input.get("property"): prop = input["property"] identifier = prop["identifier"] value = prop.get("value") @@ -617,6 +615,7 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di :raises TypeError: If wrong types are provided to the task :return: Processed `inputs`/`outputs` list. """ + from node_graph.utils import list_to_dict if not all(isinstance(item, (dict, str)) for item in inout_list): raise TypeError( @@ -631,6 +630,8 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di elif isinstance(item, dict): processed_inout_list.append(item) + processed_inout_list = list_to_dict(processed_inout_list) + return processed_inout_list diff --git a/src/aiida_workgraph/utils/analysis.py b/src/aiida_workgraph/utils/analysis.py index 8ce218bf..ef7112ba 100644 --- a/src/aiida_workgraph/utils/analysis.py +++ b/src/aiida_workgraph/utils/analysis.py @@ -215,11 +215,10 @@ def insert_workgraph_to_db(self) -> None: self.save_task_states() for name, task in self.wgdata["tasks"].items(): for _, input in task["inputs"].items(): - if input["property"] is None: - continue - prop = input["property"] - if inspect.isfunction(prop["value"]): - prop["value"] = PickledLocalFunction(prop["value"]).store() + if input.get("property"): + prop = input["property"] + if inspect.isfunction(prop["value"]): + prop["value"] = PickledLocalFunction(prop["value"]).store() self.wgdata["tasks"][name] = serialize(task) # nodes is a copy of tasks, so we need to pop it out self.wgdata["error_handlers"] = serialize(self.wgdata["error_handlers"]) diff --git a/src/aiida_workgraph/workgraph.py b/src/aiida_workgraph/workgraph.py index 08a111ba..4afd019c 100644 --- a/src/aiida_workgraph/workgraph.py +++ b/src/aiida_workgraph/workgraph.py @@ -297,7 +297,7 @@ def update(self) -> None: i = 0 for socket in self.tasks[name].outputs: socket.value = get_nested_dict( - node.outputs, socket.name, default=None + node.outputs, socket.socket_name, default=None ) i += 1 # read results from the process outputs