diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0ee03e6a..faf78c73 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -22,7 +22,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: '3.8' + python-version: '3.10' - name: Install Python dependencies run: pip install -e .[pre-commit,tests] diff --git a/aiida_worktree/__init__.py b/aiida_worktree/__init__.py index 6926eae5..bf810204 100644 --- a/aiida_worktree/__init__.py +++ b/aiida_worktree/__init__.py @@ -1,4 +1,5 @@ -from .decorator import node, build_node from .worktree import WorkTree +from .node import Node +from .decorator import node, build_node __version__ = "0.0.1" diff --git a/aiida_worktree/decorator.py b/aiida_worktree/decorator.py index e1d0b5e0..6aabb06c 100644 --- a/aiida_worktree/decorator.py +++ b/aiida_worktree/decorator.py @@ -1,5 +1,5 @@ from typing import Any -from scinode.utils.node import get_executor +from aiida_worktree.utils import get_executor from aiida.engine.processes.functions import calcfunction, workfunction from aiida.engine.processes.calcjobs import CalcJob from aiida.engine.processes.workchains import WorkChain @@ -27,7 +27,9 @@ def add_input_recursive(inputs, port, prefix=None): def build_node(ndata): """Register a node from a AiiDA component.""" - from scinode.utils.decorator import create_node + from node_graph.decorator import create_node + from aiida_worktree.node import Node + import cloudpickle as pickle path, executor_name, = ndata.pop( "path" @@ -44,16 +46,26 @@ def build_node(ndata): inputs = [] outputs = [] spec = executor.spec() - for key, port in spec.inputs.ports.items(): + for _key, port in spec.inputs.ports.items(): add_input_recursive(inputs, port) kwargs = [input[1] for input in inputs] - for key, port in spec.outputs.ports.items(): + for _key, port in spec.outputs.ports.items(): outputs.append(["General", port.name]) # print("kwargs: ", kwargs) + ndata["node_class"] = Node ndata["kwargs"] = kwargs ndata["inputs"] = inputs ndata["outputs"] = outputs ndata["identifier"] = ndata.pop("identifier", ndata["executor"]["name"]) + # TODO In order to reload the WorkTree from process, "is_pickle" should be True + # so I pickled the function here, but this is not necessary + # we need to update the node_graph to support the path and name of the function + executor = { + "executor": pickle.dumps(executor), + "type": "function", + "is_pickle": True, + } + ndata["executor"] = executor node = create_node(ndata) return node @@ -68,7 +80,7 @@ def decorator_node( catalog="Others", executor_type="function", ): - """Generate a decorator that register a function as a SciNode node. + """Generate a decorator that register a function as a node. Attributes: indentifier (str): node identifier @@ -79,13 +91,15 @@ def decorator_node( inputs (list): node inputs outputs (list): node outputs """ + from aiida_worktree.node import Node + properties = properties or [] inputs = inputs or [] outputs = outputs or [["General", "result"]] def decorator(func): import cloudpickle as pickle - from scinode.utils.decorator import generate_input_sockets, create_node + from node_graph.decorator import generate_input_sockets, create_node nonlocal identifier @@ -94,10 +108,9 @@ def decorator(func): # use cloudpickle to serialize function executor = { "executor": pickle.dumps(func), - "type": executor_type, + "type": "function", "is_pickle": True, } - # # Get the args and kwargs of the function args, kwargs, var_args, var_kwargs, _inputs = generate_input_sockets( func, inputs, properties @@ -113,6 +126,7 @@ def decorator(func): else: node_type = "Normal" ndata = { + "node_class": Node, "identifier": identifier, "node_type": node_type, "args": args, @@ -151,13 +165,15 @@ def decorator_node_group( inputs (list): node inputs outputs (list): node outputs """ + from aiida_worktree.node import Node + properties = properties or [] inputs = inputs or [] outputs = outputs or [] def decorator(func): import cloudpickle as pickle - from scinode.utils.decorator import generate_input_sockets, create_node + from node_graph.decorator import generate_input_sockets, create_node nonlocal identifier, inputs, outputs @@ -179,11 +195,12 @@ def decorator(func): # inputs = [[nt.nodes[input[0]].inputs[input[1]].identifier, input[2]] for input in group_inputs] # outputs = [[nt.nodes[output[0]].outputs[output[1]].identifier, output[2]] for output in group_outputs] # node_inputs = [["General", input[2]] for input in inputs] - node_outputs = [["General", output[2]] for output in outputs] + node_outputs = [["General", output[1]] for output in outputs] # print(node_inputs, node_outputs) # node_type = "worktree" ndata = { + "node_class": Node, "identifier": identifier, "args": args, "kwargs": kwargs, diff --git a/aiida_worktree/engine/worktree.py b/aiida_worktree/engine/worktree.py index 85b2f694..99c1d80d 100644 --- a/aiida_worktree/engine/worktree.py +++ b/aiida_worktree/engine/worktree.py @@ -265,7 +265,7 @@ def _do_step(self) -> t.Any: result: t.Any = None try: - self.launch_worktree() + self.run_worktree() except _PropagateReturn as exception: finished, result = True, exception.exit_code else: @@ -367,7 +367,7 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None: self.resume() def setup(self): - from scinode.utils.nt_analysis import ConnectivityAnalysis + from node_graph.analysis import ConnectivityAnalysis from aiida_worktree.utils import build_node_link self.ctx.new_data = dict() @@ -392,25 +392,23 @@ def setup(self): self.ctx.ctrl_links = ntdata["ctrl_links"] self.ctx.worktree = ntdata print("init") - # init - for _name, node in self.ctx.nodes.items(): - node["state"] = "CREATED" - node["process"] = None # nc = ConnectivityAnalysis(ntdata) self.ctx.connectivity = nc.build_connectivity() self.ctx.msgs = [] self.node.set_process_label(f"WorkTree: {self.ctx.worktree['name']}") # while worktree - if self.ctx.worktree["is_while"]: + if self.ctx.worktree["worktree_type"].upper() == "WHILE": should_run = self.check_while_conditions() if not should_run: self.set_node_state(self.ctx.nodes.keys(), "SKIPPED") # for worktree - if self.ctx.worktree["is_for"]: + if self.ctx.worktree["worktree_type"].upper() == "FOR": should_run = self.check_for_conditions() if not should_run: self.set_node_state(self.ctx.nodes.keys(), "SKIPPED") + # init node results + self.set_node_results() def init_ctx(self, datas): from aiida_worktree.utils import update_nested_dict @@ -420,13 +418,49 @@ def init_ctx(self, datas): key = key.replace("__", ".") update_nested_dict(self.ctx, key, value) - def launch_worktree(self): - print("launch_worktree: ") + def set_node_results(self): + for _, node in self.ctx.nodes.items(): + if node.get("process"): + if isinstance(node["process"], str): + node["process"] = orm.load_node(node["process"]) + self.set_node_result(node) + self.set_node_result(node) + + def set_node_result(self, node): + name = node["name"] + print(f"set node result: {name}") + if node.get("process"): + print(f"set node result: {name} process") + state = node["process"].process_state.value.upper() + if state == "FINISHED": + node["state"] = state + if node["metadata"]["node_type"] == "worktree": + # expose the outputs of nodetree + node["results"] = getattr( + node["process"].outputs, "group_outputs", None + ) + # self.ctx.new_data[name] = outputs + else: + node["results"] = node["process"].outputs + # self.ctx.new_data[name] = node["results"] + self.ctx.nodes[name]["state"] = "FINISHED" + self.node_to_ctx(name) + print(f"Node: {name} finished.") + elif state == "EXCEPTED": + node["state"] = state + node["results"] = node["process"].outputs + # self.ctx.new_data[name] = node["results"] + self.ctx.nodes[name]["state"] = "FAILED" + # set child state to FAILED + self.set_node_state(self.ctx.connectivity["child_node"][name], "FAILED") + print(f"Node: {name} failed.") + else: + print(f"set node result: None") + node["results"] = None + + def run_worktree(self): + print("run_worktree: ") self.report("Lanch worktree.") - if len(self.ctx.worktree["starts"]) > 0: - self.run_nodes(self.ctx.worktree["starts"]) - self.ctx.worktree["starts"] = [] - return node_to_run = [] for name, node in self.ctx.nodes.items(): # update node state @@ -458,40 +492,14 @@ def is_worktree_finished(self): ] and node["state"] == "RUNNING" ): - if node.get("process"): - state = node["process"].process_state.value.upper() - print(node["name"], state) - if state == "FINISHED": - node["state"] = state - if node["metadata"]["node_type"] == "worktree": - # expose the outputs of nodetree - node["results"] = getattr( - node["process"].outputs, "group_outputs", None - ) - # self.ctx.new_data[name] = outputs - else: - node["results"] = node["process"].outputs - # self.ctx.new_data[name] = node["results"] - self.ctx.nodes[name]["state"] = "FINISHED" - self.node_to_ctx(name) - print(f"Node: {name} finished.") - elif state == "EXCEPTED": - node["state"] = state - node["results"] = node["process"].outputs - # self.ctx.new_data[name] = node["results"] - self.ctx.nodes[name]["state"] = "FAILED" - # set child state to FAILED - self.set_node_state( - self.ctx.connectivity["child_node"][name], "FAILED" - ) - print(f"Node: {name} failed.") + self.set_node_result(node) if node["state"] in ["RUNNING", "CREATED", "READY"]: is_finished = False if is_finished: - if self.ctx.worktree["is_while"]: + if self.ctx.worktree["worktree_type"].upper() == "WHILE": should_run = self.check_while_conditions() is_finished = not should_run - if self.ctx.worktree["is_for"]: + if self.ctx.worktree["worktree_type"].upper() == "FOR": should_run = self.check_for_conditions() is_finished = not should_run return is_finished @@ -743,6 +751,7 @@ def update_ctx_variable(self, value): def node_to_ctx(self, name): from aiida_worktree.utils import update_nested_dict + print("node to ctx: ", name) items = self.ctx.nodes[name]["to_ctx"] for item in items: update_nested_dict( @@ -835,23 +844,23 @@ def finalize(self): from aiida_worktree.utils import get_nested_dict # expose group outputs - print("finalize") group_outputs = {} print("group outputs: ", self.ctx.worktree["metadata"]["group_outputs"]) for output in self.ctx.worktree["metadata"]["group_outputs"]: print("output: ", output) - if output[0] == "ctx": - group_outputs[output[2]] = get_nested_dict(self.ctx, output[1]) + node_name, socket_name = output[0].split(".") + if node_name == "ctx": + group_outputs[output[1]] = get_nested_dict(self.ctx, socket_name) else: - group_outputs[output[2]] = self.ctx.nodes[output[0]]["results"][ + group_outputs[output[1]] = self.ctx.nodes[node_name]["results"][ output[1] ] self.out("group_outputs", group_outputs) self.out("new_data", self.ctx.new_data) self.report("Finalize") - print(f"Finalize worktree {self.ctx.worktree['name']}") for name, node in self.ctx.nodes.items(): if node["state"] == "FAILED": print(f" Node {name} failed.") return self.exit_codes.NODE_FAILED + print(f"Finalize worktree {self.ctx.worktree['name']}\n") # check if all nodes are finished with nonzero exit code diff --git a/aiida_worktree/node.py b/aiida_worktree/node.py new file mode 100644 index 00000000..292809a0 --- /dev/null +++ b/aiida_worktree/node.py @@ -0,0 +1,29 @@ +from node_graph.node import Node as GraphNode + + +class Node(GraphNode): + """Represent a Node in the AiiDA WorkTree. + + The class extends from node_graph.node.Node and add new + attributes to it. + """ + + socket_entry = "aiida_worktree.socket" + property_entry = "aiida_worktree.property" + + def __init__(self, **kwargs): + """ + Initialize a Node instance. + """ + super().__init__(**kwargs) + self.to_ctx = [] + self.wait = [] + self.process = None + + def to_dict(self): + ndata = super().to_dict() + ndata["to_ctx"] = self.to_ctx + ndata["wait"] = self.wait + ndata["process"] = self.process.uuid if self.process else None + + return ndata diff --git a/aiida_worktree/nodes/builtin.py b/aiida_worktree/nodes/builtin.py index 140d9976..dc8dd495 100644 --- a/aiida_worktree/nodes/builtin.py +++ b/aiida_worktree/nodes/builtin.py @@ -1,4 +1,4 @@ -from scinode.core.node import Node +from aiida_worktree.node import Node from aiida_worktree.executors.builtin import GatherWorkChain diff --git a/aiida_worktree/nodes/qe.py b/aiida_worktree/nodes/qe.py index a3b7cf8d..203f5be6 100644 --- a/aiida_worktree/nodes/qe.py +++ b/aiida_worktree/nodes/qe.py @@ -1,4 +1,5 @@ -from scinode.core.node import Node +from aiida_worktree.node import Node +from aiida import orm class AiiDAKpoint(Node): @@ -10,8 +11,8 @@ class AiiDAKpoint(Node): kwargs = ["mesh", "offset"] def create_properties(self): - self.properties.new("IntVector", "mesh", default=[1, 1, 1], size=3) - self.properties.new("IntVector", "offset", default=[0, 0, 0], size=3) + self.properties.new("AiiDAIntVector", "mesh", default=[1, 1, 1], size=3) + self.properties.new("AiiDAIntVector", "offset", default=[0, 0, 0], size=3) def create_sockets(self): self.outputs.new("General", "Kpoint") @@ -59,7 +60,7 @@ class AiiDAPWPseudo(Node): def create_properties(self): self.properties.new( - "String", "psuedo_familay", default="SSSP/1.2/PBEsol/efficiency" + "AiiDAString", "psuedo_familay", default="SSSP/1.2/PBEsol/efficiency" ) def create_sockets(self): diff --git a/aiida_worktree/nodes/test.py b/aiida_worktree/nodes/test.py index f31f0406..080eefdd 100644 --- a/aiida_worktree/nodes/test.py +++ b/aiida_worktree/nodes/test.py @@ -1,4 +1,4 @@ -from scinode.core.node import Node +from aiida_worktree.node import Node class AiiDAInt(Node): @@ -16,7 +16,7 @@ def create_properties(self): def create_sockets(self): inp = self.inputs.new("General", "value", default=0.0) inp.add_property("AiiDAInt", default=1.0) - self.outputs.new("Int", "result") + self.outputs.new("AiiDAInt", "result") def get_executor(self): return { @@ -35,11 +35,11 @@ class AiiDAFloat(Node): kwargs = ["t"] def create_properties(self): - self.properties.new("Float", "t", default=1.0) + self.properties.new("AiiDAFloat", "t", default=1.0) def create_sockets(self): - self.inputs.new("Float", "value", default=0.0) - self.outputs.new("Float", "result") + self.inputs.new("AiiDAFloat", "value", default=0.0) + self.outputs.new("AiiDAFloat", "result") def get_executor(self): return { @@ -58,11 +58,11 @@ class AiiDAString(Node): kwargs = ["t"] def create_properties(self): - self.properties.new("Float", "t", default=1.0) + self.properties.new("AiiDAFloat", "t", default=1.0) def create_sockets(self): - self.inputs.new("String", "value", default="") - self.outputs.new("String", "result") + self.inputs.new("AiiDAString", "value", default="") + self.outputs.new("AiiDAString", "result") def get_executor(self): return { @@ -173,16 +173,16 @@ class AiiDAAdd(Node): kwargs = ["t"] def create_properties(self): - self.properties.new("Float", "t", default=1.0) + self.properties.new("AiiDAFloat", "t", default=1.0) def create_sockets(self): self.inputs.clear() self.outputs.clear() - inp = self.inputs.new("Float", "x") + inp = self.inputs.new("AiiDAFloat", "x") inp.add_property("AiiDAFloat", "x", default=0.0) - inp = self.inputs.new("Float", "y") + inp = self.inputs.new("AiiDAFloat", "y") inp.add_property("AiiDAFloat", "y", default=0.0) - self.outputs.new("Float", "sum") + self.outputs.new("AiiDAFloat", "sum") def get_executor(self): return { @@ -205,9 +205,9 @@ def create_properties(self): def create_sockets(self): self.inputs.clear() self.outputs.clear() - self.inputs.new("Float", "x") - self.inputs.new("Float", "y") - self.outputs.new("Bool", "result") + self.inputs.new("AiiDAFloat", "x") + self.inputs.new("AiiDAFloat", "y") + self.outputs.new("AiiDABool", "result") def get_executor(self): return { @@ -227,17 +227,17 @@ class AiiDASumDiff(Node): kwargs = ["t"] def create_properties(self): - self.properties.new("Float", "t", default=1.0) + self.properties.new("AiiDAFloat", "t", default=1.0) def create_sockets(self): self.inputs.clear() self.outputs.clear() - inp = self.inputs.new("Float", "x") + inp = self.inputs.new("AiiDAFloat", "x") inp.add_property("AiiDAFloat", "x", default=0.0) - inp = self.inputs.new("Float", "y") + inp = self.inputs.new("AiiDAFloat", "y") inp.add_property("AiiDAFloat", "y", default=0.0) - self.outputs.new("Float", "sum") - self.outputs.new("Float", "diff") + self.outputs.new("AiiDAFloat", "sum") + self.outputs.new("AiiDAFloat", "diff") def get_executor(self): return { @@ -261,11 +261,11 @@ def create_sockets(self): self.inputs.clear() self.outputs.clear() self.inputs.new("General", "code") - inp = self.inputs.new("Int", "x") + inp = self.inputs.new("AiiDAInt", "x") inp.add_property("AiiDAInt", "x", default=0.0) - inp = self.inputs.new("Int", "y") + inp = self.inputs.new("AiiDAInt", "y") inp.add_property("AiiDAInt", "y", default=0.0) - self.outputs.new("Int", "sum") + self.outputs.new("AiiDAInt", "sum") def get_executor(self): return { @@ -289,13 +289,13 @@ def create_sockets(self): self.inputs.clear() self.outputs.clear() self.inputs.new("General", "code") - inp = self.inputs.new("Int", "x") + inp = self.inputs.new("AiiDAInt", "x") inp.add_property("AiiDAInt", "x", default=0.0) - inp = self.inputs.new("Int", "y") + inp = self.inputs.new("AiiDAInt", "y") inp.add_property("AiiDAInt", "y", default=0.0) - inp = self.inputs.new("Int", "z") + inp = self.inputs.new("AiiDAInt", "z") inp.add_property("AiiDAInt", "z", default=0.0) - self.outputs.new("Int", "result") + self.outputs.new("AiiDAInt", "result") def get_executor(self): return { diff --git a/aiida_worktree/properties/built_in.py b/aiida_worktree/properties/built_in.py index c87a38f1..faed4a26 100644 --- a/aiida_worktree/properties/built_in.py +++ b/aiida_worktree/properties/built_in.py @@ -1,8 +1,27 @@ -from scinode.core.property import NodeProperty -from scinode.serialization.built_in import SerializeJson, SerializePickle +from node_graph.property import NodeProperty +from node_graph.serializer import SerializeJson, SerializePickle +from node_graph.properties.builtin import ( + VectorProperty, + BaseDictProperty, + BaseListProperty, + IntProperty, + BoolProperty, + FloatProperty, + StringProperty, +) from aiida import orm +class GeneralProperty(NodeProperty, SerializePickle): + """A new class for General type.""" + + identifier: str = "General" + data_type = "General" + + def __init__(self, name, description="", default=None, update=None) -> None: + super().__init__(name, description, default, update) + + class AiiDAIntProperty(NodeProperty, SerializeJson): """A new class for integer type.""" @@ -146,10 +165,110 @@ def set_value(self, value): raise Exception("{} is not a dict.".format(value)) +class AiiDAIntVectorProperty(VectorProperty): + """A new class for integer vector type.""" + + identifier: str = "AiiDAIntVector" + data_type = "AiiDAIntVector" + + def __init__( + self, name, description="", size=3, default=[0, 0, 0], update=None + ) -> None: + super().__init__(name, description, size, default, update) + + def set_value(self, value): + # run the callback function + if len(value) == self.size: + for i in range(self.size): + if isinstance(value[i], int): + self._value[i] = value[i] + if self.update is not None: + self.update() + else: + raise Exception( + f"Set property {self.name} failed. {value[i]} is not a integer." + ) + else: + raise Exception( + "Length {} is not equal to the size {}.".format(len(value), self.size) + ) + + +class AiiDAFloatVectorProperty(VectorProperty): + """A new class for float vector type.""" + + identifier: str = "AiiDAFloatVector" + data_type = "AiiDAFloatVector" + + def __init__( + self, name, description="", size=3, default=[0, 0, 0], update=None + ) -> None: + super().__init__(name, description, size, default, update) + + def set_value(self, value): + # run the callback function + if len(value) == self.size: + for i in range(self.size): + if isinstance(value[i], (int, float)): + self._value[i] = value[i] + if self.update is not None: + self.update() + else: + raise Exception("{} is not a float.".format(value[i])) + else: + raise Exception( + "Length {} is not equal to the size {}.".format(len(value), self.size) + ) + + def get_metadata(self): + metadata = {"default": self.default, "size": self.size} + return metadata + + +# ==================================== +# Vector + + +class BoolVectorProperty(VectorProperty): + """A new class for bool vector type.""" + + identifier: str = "BoolVector" + data_type = "BoolVector" + + def __init__( + self, name, description="", size=3, default=[0, 0, 0], update=None + ) -> None: + super().__init__(name, description, size, default, update) + + def set_value(self, value): + # run the callback function + if len(value) == self.size: + for i in range(self.size): + if isinstance(value[i], (bool, int)): + self._value[i] = value[i] + if self.update is not None: + self.update() + else: + raise Exception("{} is not a bool.".format(value[i])) + else: + raise Exception( + "Length {} is not equal to the size {}.".format(len(value), self.size) + ) + + property_list = [ + IntProperty, + FloatProperty, + BoolProperty, + StringProperty, + GeneralProperty, + BaseDictProperty, + BaseListProperty, AiiDAIntProperty, AiiDAFloatProperty, AiiDAStringProperty, AiiDABoolProperty, AiiDADictProperty, + AiiDAIntVectorProperty, + AiiDAFloatVectorProperty, ] diff --git a/aiida_worktree/property.py b/aiida_worktree/property.py new file mode 100644 index 00000000..39e2e688 --- /dev/null +++ b/aiida_worktree/property.py @@ -0,0 +1,7 @@ +from node_graph.property import NodeProperty as GraphNodeProperty + + +class NodeProperty(GraphNodeProperty): + """Represent a property of a Node in the AiiDA WorkTree.""" + + property_entry = "aiida_worktree.property" diff --git a/aiida_worktree/socket.py b/aiida_worktree/socket.py new file mode 100644 index 00000000..60654975 --- /dev/null +++ b/aiida_worktree/socket.py @@ -0,0 +1,8 @@ +from node_graph.socket import NodeSocket as GraphNodeSocket + + +class NodeSocket(GraphNodeSocket): + """Represent a socket of a Node in the AiiDA WorkTree.""" + + socket_entry = "aiida_worktree.socket" + property_entry = "aiida_worktree.property" diff --git a/aiida_worktree/sockets/__init__.py b/aiida_worktree/sockets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiida_worktree/sockets/built_in.py b/aiida_worktree/sockets/built_in.py new file mode 100644 index 00000000..527983cb --- /dev/null +++ b/aiida_worktree/sockets/built_in.py @@ -0,0 +1,103 @@ +from aiida_worktree.socket import NodeSocket +from node_graph.serializer import SerializeJson, SerializePickle +from node_graph.sockets.builtin import ( + SocketBaseDict, + SocketBaseList, +) + + +class SocketGeneral(NodeSocket, SerializePickle): + """General socket.""" + + identifier: str = "General" + + def __init__( + self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs + ) -> None: + super().__init__(name, node, type, index, uuid=uuid) + self.add_property("General", name, **kwargs) + + +class SocketAiiDAFloat(NodeSocket, SerializeJson): + """AiiDAFloat socket.""" + + identifier: str = "AiiDAFloat" + + def __init__( + self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs + ) -> None: + super().__init__(name, node, type, index, uuid=uuid) + self.add_property("AiiDAFloat", name, **kwargs) + + +class SocketAiiDAInt(NodeSocket, SerializeJson): + """AiiDAInt socket.""" + + identifier: str = "AiiDAInt" + + def __init__( + self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs + ) -> None: + super().__init__(name, node, type, index, uuid=uuid) + self.add_property("AiiDAInt", name, **kwargs) + + +class SocketAiiDAString(NodeSocket, SerializeJson): + """AiiDAString socket.""" + + identifier: str = "AiiDAString" + + def __init__( + self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs + ) -> None: + super().__init__(name, node, type, index, uuid=uuid) + self.add_property("AiiDAString", name, **kwargs) + + +class SocketAiiDABool(NodeSocket, SerializeJson): + """AiiDABool socket.""" + + identifier: str = "AiiDABool" + + def __init__( + self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs + ) -> None: + super().__init__(name, node, type, index, uuid=uuid) + self.add_property("AiiDABool", name, **kwargs) + + +class SocketAiiDAIntVector(NodeSocket, SerializeJson): + """Socket with a AiiDAIntVector property.""" + + identifier: str = "AiiDAIntVector" + + def __init__( + self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs + ) -> None: + super().__init__(name, node, type, index, uuid=uuid) + self.add_property("AiiDAIntVector", name, **kwargs) + + +class SocketAiiDAFloatVector(NodeSocket, SerializeJson): + """Socket with a FloatVector property.""" + + identifier: str = "FloatVector" + + def __init__( + self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs + ) -> None: + super().__init__(name, node, type, index, uuid=uuid) + self.add_property("FloatVector", name, **kwargs) + + +socket_list = [ + SocketGeneral, + SocketBaseDict, + SocketBaseList, + SocketAiiDAInt, + SocketAiiDAFloat, + SocketAiiDAString, + SocketAiiDABool, + SocketAiiDAIntVector, + SocketAiiDAFloatVector, +] diff --git a/aiida_worktree/worktree.py b/aiida_worktree/worktree.py index 71d6669c..065deab1 100644 --- a/aiida_worktree/worktree.py +++ b/aiida_worktree/worktree.py @@ -1,12 +1,14 @@ -import scinode +import node_graph import aiida -class WorkTree(scinode.core.nodetree.NodeTree): - """ - Represents a working tree for AiiDA's worktree engine. The class extends from scinode's NodeTree and provides - methods to run, submit tasks, wait for tasks to finish, and update the process status. It is used to handle - various states of a worktree process and provides convenient operations to interact with it. +class WorkTree(node_graph.NodeGraph): + """Build a node-based workflow AiiDA's worktree engine. + + The class extends from NodeGraph and provides methods to run, + submit tasks, wait for tasks to finish, and update the process status. + It is used to handle various states of a worktree process and provides + convenient operations to interact with it. Attributes: process (aiida.orm.ProcessNode): The process node that represents the process status and other details. @@ -14,19 +16,19 @@ class WorkTree(scinode.core.nodetree.NodeTree): pk (int): The primary key of the process node. """ + node_entry = "aiida_worktree.node" + def __init__(self, name="WorkTree", **kwargs): """ Initialize a WorkTree instance. Args: name (str, optional): The name of the WorkTree. Defaults to 'WorkTree'. - **kwargs: Additional keyword arguments to be passed to the NodeTree class. + **kwargs: Additional keyword arguments to be passed to the WorkTree class. """ super().__init__(name, **kwargs) self.ctx = {} - self.starts = [] - self.is_while = False - self.is_for = False + self.worktree_type = "NORMAL" self.sequence = [] self.conditions = [] @@ -36,10 +38,12 @@ def run(self): the process and then calls the update method to update the state of the process. """ from aiida_worktree.engine.worktree import WorkTree + from aiida.orm.utils.serialize import serialize ntdata = self.to_dict() all = {"nt": ntdata} _result, self.process = aiida.engine.run_get_node(WorkTree, **all) + self.process.base.extras.set("nt", serialize(ntdata)) self.update() def submit(self, wait=False, timeout=60): @@ -52,27 +56,25 @@ def submit(self, wait=False, timeout=60): """ from aiida_worktree.engine.worktree import WorkTree from aiida_worktree.utils import merge_properties + from aiida.orm.utils.serialize import serialize ntdata = self.to_dict() merge_properties(ntdata) all = {"nt": ntdata} self.process = aiida.engine.submit(WorkTree, **all) + # + self.process.base.extras.set("nt", serialize(ntdata)) if wait: self.wait(timeout=timeout) def to_dict(self): ntdata = super().to_dict() - for node in self.nodes: - ntdata["nodes"][node.name]["to_ctx"] = getattr(node, "to_ctx", []) - ntdata["nodes"][node.name]["wait"] = getattr(node, "wait", []) self.ctx["sequence"] = self.sequence # only alphanumeric and underscores are allowed ntdata["ctx"] = { key.replace(".", "__"): value for key, value in self.ctx.items() } - ntdata["starts"] = self.starts - ntdata["is_while"] = self.is_while - ntdata["is_for"] = self.is_for + ntdata["worktree_type"] = self.worktree_type ntdata["conditions"] = self.conditions return ntdata @@ -107,13 +109,13 @@ def update(self): linked to the current process, and data nodes linked to the current process. """ self.state = self.process.process_state.value.upper() - self.pk = self.process.pk outgoing = self.process.base.links.get_outgoing() for link in outgoing.all(): node = link.node if isinstance(node, aiida.orm.ProcessNode) and getattr( node, "process_state", False ): + self.nodes[link.link_label].process = node self.nodes[link.link_label].state = node.process_state.value.upper() self.nodes[link.link_label].node = node self.nodes[link.link_label].pk = node.pk @@ -123,3 +125,24 @@ def update(self): self.nodes[label].state = "FINISHED" self.nodes[label].node = node self.nodes[label].pk = node.pk + + @property + def pk(self): + return self.process.pk if self.process else None + + @classmethod + def load(cls, pk): + """ + Load the process node with the given primary key. + + Args: + pk (int): The primary key of the process node. + """ + from aiida.orm.utils.serialize import deserialize_unsafe + + process = aiida.orm.load_node(pk) + wtdata = deserialize_unsafe(process.base.extras.get("nt")) + wt = cls.from_dict(wtdata) + wt.process = process + wt.update() + return wt diff --git a/docs/source/concept/worktree.rst b/docs/source/concept/worktree.rst index 9121d8b9..c423b8fb 100644 --- a/docs/source/concept/worktree.rst +++ b/docs/source/concept/worktree.rst @@ -38,6 +38,16 @@ Create and launch worktree wt.submit() +Load worktree from the AiiDA process +===================================== +WorkTree save its data as a extra attribute into its process, so that one can rebuild the WorkTree from the process. + + +.. code-block:: python + + from aiida_worktree import WorkTree + # pk is the process id of a WorkTree + WorkTree.load(pk) Execute order =============== diff --git a/docs/source/howto/continue_finished_worktree.ipynb b/docs/source/howto/continue_finished_worktree.ipynb new file mode 100644 index 00000000..78085ef0 --- /dev/null +++ b/docs/source/howto/continue_finished_worktree.ipynb @@ -0,0 +1,222 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "22d177dc-6cfb-4de2-9509-f1eb45e10cf2", + "metadata": {}, + "source": [ + "# How to continue a finished WorkTree" + ] + }, + { + "cell_type": "markdown", + "id": "58696c91", + "metadata": {}, + "source": [ + "## Introduction\n", + "`WorkTree` supports adding new nodes to a already finished WorkTree and continue the submitting the jobs. WorkTree save its data as a extra attribute into its process, so that one can rebuild the WorkTree from the process.\n", + "\n", + "Load the AiiDA profile." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c6b83fb5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Profile" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%load_ext aiida\n", + "from aiida import load_profile\n", + "load_profile()" + ] + }, + { + "cell_type": "markdown", + "id": "0f46d277", + "metadata": {}, + "source": [ + "## Create a `add_multiply` WorkTree" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ece10d89", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[93m\u001b[1mWarning\u001b[0m: \u001b[22mRabbitMQ v3.9.13 is not supported and will cause unexpected problems!\u001b[0m\n", + "\u001b[93m\u001b[1mWarning\u001b[0m: \u001b[22mIt can cause long-running workflows to crash and jobs to be submitted multiple times.\u001b[0m\n", + "\u001b[93m\u001b[1mWarning\u001b[0m: \u001b[22mSee https://github.com/aiidateam/aiida-core/wiki/RabbitMQ-version-to-use for details.\u001b[0m\n" + ] + } + ], + "source": [ + "from aiida_worktree import node\n", + "from aiida.engine import calcfunction\n", + "\n", + "# define add node\n", + "@node()\n", + "@calcfunction\n", + "def add(x, y):\n", + " return x + y\n", + "\n", + "# define multiply node\n", + "@node()\n", + "@calcfunction\n", + "def multiply(x, y):\n", + " return x*y\n", + "\n", + "from aiida_worktree import WorkTree\n", + "from aiida.orm import Int\n", + "x = Int(2.0)\n", + "y = Int(3.0)\n", + "z = Int(4.0)\n", + "\n", + "wt = WorkTree(\"first_workflow\")\n", + "wt.nodes.new(add, name=\"add1\", x=x, y=y)\n", + "wt.nodes.new(multiply, name=\"multiply1\", y=z)\n", + "wt.links.new(wt.nodes[\"add1\"].outputs[0], wt.nodes[\"multiply1\"].inputs[\"x\"])\n", + "\n", + "wt.submit(wait=True)" + ] + }, + { + "cell_type": "markdown", + "id": "51f3b8d1", + "metadata": {}, + "source": [ + "## Load the old WorkTree and add new nodes\n", + "Now, we want to add a new `add` node, and use the results from previous `add1` and `multiply1` nodes. Use the `load` method to load a WorkTree from a process." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4f8bf5c3", + "metadata": {}, + "outputs": [], + "source": [ + "from aiida_worktree import WorkTree\n", + "pk = wt.pk\n", + "wt2 = WorkTree.load(pk)\n", + "wt2.nodes.new(add, name=\"add2\")\n", + "wt2.links.new(wt2.nodes[\"add1\"].outputs[0], wt2.nodes[\"add2\"].inputs[\"x\"])\n", + "wt2.links.new(wt2.nodes[\"multiply1\"].outputs[0], wt2.nodes[\"add2\"].inputs[\"y\"])\n", + "wt2.submit(wait=True)" + ] + }, + { + "cell_type": "markdown", + "id": "ea0ee6b0", + "metadata": {}, + "source": [ + "### Check status and results\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "fdb3d16e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "State of WorkTree : FINISHED\n", + "Result of add1 : 5\n", + "Result of multiply1 : 20\n", + "Result of add2 : 25\n" + ] + } + ], + "source": [ + "print(\"State of WorkTree : {}\".format(wt2.state))\n", + "print('Result of add1 : {}'.format(wt.nodes[\"add1\"].node.outputs.result.value))\n", + "print('Result of multiply1 : {}'.format(wt.nodes[\"multiply1\"].node.outputs.result.value))\n", + "print('Result of add2 : {}'.format(wt2.nodes[\"add2\"].node.outputs.result.value))" + ] + }, + { + "cell_type": "markdown", + "id": "addf94b0", + "metadata": {}, + "source": [ + "Generate node graph from the AiiDA process:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "6c83c650", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\nN10510\n\nadd (10510)\nState: finished\nExit Code: 0\n\n\n\nN10511\n\nInt (10511)\nvalue: 25\n\n\n\nN10510->N10511\n\n\nCREATE\nresult\n\n\n\nN10501\n\nInt (10501)\nvalue: 2\n\n\n\nN10504\n\nWorkTree: first_workflow (10504)\nState: finished\nExit Code: 0\n\n\n\nN10501->N10504\n\n\nINPUT_WORK\nnt__nodes__add1__properties__x__value\n\n\n\nN10505\n\nadd (10505)\nState: finished\nExit Code: 0\n\n\n\nN10501->N10505\n\n\nINPUT_CALC\nx\n\n\n\nN10509\n\nWorkTree: first_workflow (10509)\nState: finished\nExit Code: 0\n\n\n\nN10501->N10509\n\n\nINPUT_WORK\nnt__nodes__add1__properties__x__value\n\n\n\nN10502\n\nInt (10502)\nvalue: 3\n\n\n\nN10502->N10504\n\n\nINPUT_WORK\nnt__nodes__add1__properties__y__value\n\n\n\nN10502->N10505\n\n\nINPUT_CALC\ny\n\n\n\nN10502->N10509\n\n\nINPUT_WORK\nnt__nodes__add1__properties__y__value\n\n\n\nN10503\n\nInt (10503)\nvalue: 4\n\n\n\nN10503->N10504\n\n\nINPUT_WORK\nnt__nodes__multiply1__properties__y__value\n\n\n\nN10507\n\nmultiply (10507)\nState: finished\nExit Code: 0\n\n\n\nN10503->N10507\n\n\nINPUT_CALC\ny\n\n\n\nN10503->N10509\n\n\nINPUT_WORK\nnt__nodes__multiply1__properties__y__value\n\n\n\nN10504->N10505\n\n\nCALL_CALC\nadd1\n\n\n\nN10504->N10507\n\n\nCALL_CALC\nmultiply1\n\n\n\nN10506\n\nInt (10506)\nvalue: 5\n\n\n\nN10505->N10506\n\n\nCREATE\nresult\n\n\n\nN10506->N10510\n\n\nINPUT_CALC\nx\n\n\n\nN10506->N10507\n\n\nINPUT_CALC\nx\n\n\n\nN10508\n\nInt (10508)\nvalue: 20\n\n\n\nN10507->N10508\n\n\nCREATE\nresult\n\n\n\nN10508->N10510\n\n\nINPUT_CALC\ny\n\n\n\nN10509->N10510\n\n\nCALL_CALC\nadd2\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from aiida_worktree.utils import generate_node_graph\n", + "generate_node_graph(wt2.nodes['add2'].pk)" + ] + }, + { + "cell_type": "markdown", + "id": "9b1cfcb3", + "metadata": {}, + "source": [ + "The provenance graph of the WorkTree is still keeped in the AiiDA database." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.4 ('scinode')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "vscode": { + "interpreter": { + "hash": "2f450c1ff08798c4974437dd057310afef0de414c25d1fd960ad375311c3f6ff" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/howto/ctx.ipynb b/docs/source/howto/ctx.ipynb index f400df02..d1330b67 100644 --- a/docs/source/howto/ctx.ipynb +++ b/docs/source/howto/ctx.ipynb @@ -95,7 +95,7 @@ "- One can export the data from context to the node group outputs.\n", " \n", " ```python\n", - " @node.group(outputs=[[\"ctx\", \"sum\", \"result\"]])\n", + " @node.group(outputs=[[\"ctx.sum\", \"result\"]])\n", " def my_while(n, limit):\n", " nt = WorkTree(\"while_worktree\")\n", " add1 = wt.nodes.new(add, x=2, y=3)\n", diff --git a/docs/source/howto/for.ipynb b/docs/source/howto/for.ipynb index 62fed54b..865e089f 100644 --- a/docs/source/howto/for.ipynb +++ b/docs/source/howto/for.ipynb @@ -97,7 +97,7 @@ "def add_multiply_for(sequence):\n", " wt = WorkTree(\"add_multiply_for\")\n", " # tell the engine that this is a `for` worktree\n", - " wt.is_for = True\n", + " wt.worktree_type = \"FOR\"\n", " # the sequence to be iter\n", " wt.sequence = sequence\n", " # set a context variable before running.\n", diff --git a/docs/source/howto/if.ipynb b/docs/source/howto/if.ipynb index 9a8888f7..8b310e4d 100644 --- a/docs/source/howto/if.ipynb +++ b/docs/source/howto/if.ipynb @@ -98,7 +98,7 @@ "\n", "# Create a WorkTree which is dynamically generated based on the input\n", "# then we output the result of from the context (ctx)\n", - "@node.group(outputs = [[\"ctx\", \"result\", \"result\"]])\n", + "@node.group(outputs = [[\"ctx.result\", \"result\"]])\n", "def add_multiply_if(x, y):\n", " from aiida.orm import load_node\n", " wt = WorkTree(\"add_multiply_if\")\n", diff --git a/docs/source/howto/index.rst b/docs/source/howto/index.rst index 13715df2..3a078c3e 100644 --- a/docs/source/howto/index.rst +++ b/docs/source/howto/index.rst @@ -14,3 +14,4 @@ This section contains a collection of HowTos for various topics. while ctx wait + continue_finished_worktree diff --git a/docs/source/howto/while.ipynb b/docs/source/howto/while.ipynb index d65e3ec1..31290823 100644 --- a/docs/source/howto/while.ipynb +++ b/docs/source/howto/while.ipynb @@ -106,11 +106,11 @@ "\n", "# Create a WorkTree will repeat itself based on the conditions\n", "# then we output the result of from the context (ctx)\n", - "@node.group(outputs = [[\"ctx\", \"n\", \"result\"]])\n", + "@node.group(outputs = [[\"ctx.n\", \"result\"]])\n", "def add_multiply_while(n, limit):\n", " wt = WorkTree(\"add_multiply_while\")\n", " # tell the engine that this is a `while` worktree\n", - " wt.is_while = True\n", + " wt.worktree_type = \"WHILE\"\n", " # the `result` of compares1 node is used as condition\n", " wt.conditions = [[\"compare1\", \"result\"]]\n", " # set a context variable before running.\n", diff --git a/setup.py b/setup.py index 95fafda4..76be3191 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ def test_suite(): install_requires=[ "numpy", "aiida-core", - "scinode", + "node-graph", "cloudpickle", "aiida-pseudo", "aiida-quantumespresso", @@ -43,12 +43,15 @@ def test_suite(): "aiida.node": [ "process.workflow.worktree = aiida_worktree.orm.worktree:WorkTreeNode", ], - "scinode_node": [ + "aiida_worktree.node": [ "aiida = aiida_worktree.nodes:node_list", ], - "scinode_property": [ + "aiida_worktree.property": [ "aiida = aiida_worktree.properties.built_in:property_list", ], + "aiida_worktree.socket": [ + "aiida = aiida_worktree.sockets.built_in:socket_list", + ], }, package_data={}, python_requires=">=3.8", diff --git a/tests/conftest.py b/tests/conftest.py index 7af18457..1d780c6e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -145,7 +145,7 @@ def add_multiply(x, y, z): def decorated_add_multiply_group(decorated_add, decorated_multiply): """Generate a decorated node for test.""" - @node.group(outputs=[["multiply", "result", "result"]]) + @node.group(outputs=[["multiply.result", "result"]]) def add_multiply_group(x, y, z): wt = WorkTree("add_multiply_group") add1 = wt.nodes.new(decorated_add, name="add1", x=x, y=y) diff --git a/tests/test_for.py b/tests/test_for.py index cd16ecad..ff5b6e0f 100644 --- a/tests/test_for.py +++ b/tests/test_for.py @@ -6,11 +6,11 @@ def test_for(decorated_add, decorated_multiply): # Create a WorkTree will loop the a sequence - @node.group(outputs=[["ctx", "total", "result"]]) + @node.group(outputs=[["ctx.total", "result"]]) def add_multiply_for(sequence): wt = WorkTree("add_multiply_for") # tell the engine that this is a `for` worktree - wt.is_for = True + wt.worktree_type = "FOR" # the sequence to be iter wt.sequence = sequence # set a context variable before running. diff --git a/tests/test_qe.py b/tests/test_qe.py index b42c2d10..5187b97c 100644 --- a/tests/test_qe.py +++ b/tests/test_qe.py @@ -39,6 +39,7 @@ def test_structure(wt_structure_si): """Run simple calcfunction.""" wt = wt_structure_si wt.name = "test_structure" + print(wt.to_dict()) wt.submit(wait=True) assert len(wt_structure_si.nodes["structure1"].node.get_ase()) == 2 diff --git a/tests/test_while.py b/tests/test_while.py index 52bbbaa1..a972bf1d 100644 --- a/tests/test_while.py +++ b/tests/test_while.py @@ -6,10 +6,10 @@ def test_while(decorated_add, decorated_multiply, decorated_compare): # Create a WorkTree will repeat itself based on the conditions - @node.group(outputs=[["ctx", "n", "result"]]) + @node.group(outputs=[["ctx.n", "result"]]) def my_while(n, limit): wt = WorkTree("while_worktree") - wt.is_while = True + wt.worktree_type = "WHILE" wt.conditions = [["compare1", "result"]] wt.ctx = {"n": n} wt.nodes.new(decorated_compare, name="compare1", x="{{n}}", y=orm.Int(limit)) diff --git a/tests/test_workchain.py b/tests/test_workchain.py index ebaddcc7..80748af3 100644 --- a/tests/test_workchain.py +++ b/tests/test_workchain.py @@ -46,3 +46,6 @@ def test_build_workchain(build_workchain): wt.links.new(multiply_add1.outputs[0], multiply_add2.inputs["z"]) wt.submit(wait=True, timeout=100) assert wt.nodes["multiply_add2"].node.outputs.result == 17 + # reload wt + wt1 = WorkTree.load(wt.pk) + assert wt1.nodes["multiply_add2"].node.outputs.result == 17