Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use node graph #3

Merged
merged 8 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.8'
python-version: '3.10'

- name: Install Python dependencies
run: pip install -e .[pre-commit,tests]
Expand Down
3 changes: 2 additions & 1 deletion aiida_worktree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .decorator import node, build_node
from .worktree import WorkTree
from .node import Node
from .decorator import node, build_node

__version__ = "0.0.1"
37 changes: 27 additions & 10 deletions aiida_worktree/decorator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any
from scinode.utils.node import get_executor
from aiida_worktree.utils import get_executor
from aiida.engine.processes.functions import calcfunction, workfunction
from aiida.engine.processes.calcjobs import CalcJob
from aiida.engine.processes.workchains import WorkChain
Expand Down Expand Up @@ -27,7 +27,9 @@ def add_input_recursive(inputs, port, prefix=None):

def build_node(ndata):
"""Register a node from a AiiDA component."""
from scinode.utils.decorator import create_node
from node_graph.decorator import create_node
from aiida_worktree.node import Node
import cloudpickle as pickle

path, executor_name, = ndata.pop(
"path"
Expand All @@ -44,16 +46,26 @@ def build_node(ndata):
inputs = []
outputs = []
spec = executor.spec()
for key, port in spec.inputs.ports.items():
for _key, port in spec.inputs.ports.items():
add_input_recursive(inputs, port)
kwargs = [input[1] for input in inputs]
for key, port in spec.outputs.ports.items():
for _key, port in spec.outputs.ports.items():
outputs.append(["General", port.name])
# print("kwargs: ", kwargs)
ndata["node_class"] = Node
ndata["kwargs"] = kwargs
ndata["inputs"] = inputs
ndata["outputs"] = outputs
ndata["identifier"] = ndata.pop("identifier", ndata["executor"]["name"])
# TODO In order to reload the WorkTree from process, "is_pickle" should be True
# so I pickled the function here, but this is not necessary
# we need to update the node_graph to support the path and name of the function
executor = {
"executor": pickle.dumps(executor),
"type": "function",
"is_pickle": True,
}
ndata["executor"] = executor
node = create_node(ndata)
return node

Expand All @@ -68,7 +80,7 @@ def decorator_node(
catalog="Others",
executor_type="function",
):
"""Generate a decorator that register a function as a SciNode node.
"""Generate a decorator that register a function as a node.

Attributes:
indentifier (str): node identifier
Expand All @@ -79,13 +91,15 @@ def decorator_node(
inputs (list): node inputs
outputs (list): node outputs
"""
from aiida_worktree.node import Node

properties = properties or []
inputs = inputs or []
outputs = outputs or [["General", "result"]]

def decorator(func):
import cloudpickle as pickle
from scinode.utils.decorator import generate_input_sockets, create_node
from node_graph.decorator import generate_input_sockets, create_node

nonlocal identifier

Expand All @@ -94,10 +108,9 @@ def decorator(func):
# use cloudpickle to serialize function
executor = {
"executor": pickle.dumps(func),
"type": executor_type,
"type": "function",
"is_pickle": True,
}
#
# Get the args and kwargs of the function
args, kwargs, var_args, var_kwargs, _inputs = generate_input_sockets(
func, inputs, properties
Expand All @@ -113,6 +126,7 @@ def decorator(func):
else:
node_type = "Normal"
ndata = {
"node_class": Node,
"identifier": identifier,
"node_type": node_type,
"args": args,
Expand Down Expand Up @@ -151,13 +165,15 @@ def decorator_node_group(
inputs (list): node inputs
outputs (list): node outputs
"""
from aiida_worktree.node import Node

properties = properties or []
inputs = inputs or []
outputs = outputs or []

def decorator(func):
import cloudpickle as pickle
from scinode.utils.decorator import generate_input_sockets, create_node
from node_graph.decorator import generate_input_sockets, create_node

nonlocal identifier, inputs, outputs

Expand All @@ -179,11 +195,12 @@ def decorator(func):
# inputs = [[nt.nodes[input[0]].inputs[input[1]].identifier, input[2]] for input in group_inputs]
# outputs = [[nt.nodes[output[0]].outputs[output[1]].identifier, output[2]] for output in group_outputs]
# node_inputs = [["General", input[2]] for input in inputs]
node_outputs = [["General", output[2]] for output in outputs]
node_outputs = [["General", output[1]] for output in outputs]
# print(node_inputs, node_outputs)
#
node_type = "worktree"
ndata = {
"node_class": Node,
"identifier": identifier,
"args": args,
"kwargs": kwargs,
Expand Down
105 changes: 57 additions & 48 deletions aiida_worktree/engine/worktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _do_step(self) -> t.Any:
result: t.Any = None

try:
self.launch_worktree()
self.run_worktree()
except _PropagateReturn as exception:
finished, result = True, exception.exit_code
else:
Expand Down Expand Up @@ -367,7 +367,7 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None:
self.resume()

def setup(self):
from scinode.utils.nt_analysis import ConnectivityAnalysis
from node_graph.analysis import ConnectivityAnalysis
from aiida_worktree.utils import build_node_link

self.ctx.new_data = dict()
Expand All @@ -392,25 +392,23 @@ def setup(self):
self.ctx.ctrl_links = ntdata["ctrl_links"]
self.ctx.worktree = ntdata
print("init")
# init
for _name, node in self.ctx.nodes.items():
node["state"] = "CREATED"
node["process"] = None
#
nc = ConnectivityAnalysis(ntdata)
self.ctx.connectivity = nc.build_connectivity()
self.ctx.msgs = []
self.node.set_process_label(f"WorkTree: {self.ctx.worktree['name']}")
# while worktree
if self.ctx.worktree["is_while"]:
if self.ctx.worktree["worktree_type"].upper() == "WHILE":
should_run = self.check_while_conditions()
if not should_run:
self.set_node_state(self.ctx.nodes.keys(), "SKIPPED")
# for worktree
if self.ctx.worktree["is_for"]:
if self.ctx.worktree["worktree_type"].upper() == "FOR":
should_run = self.check_for_conditions()
if not should_run:
self.set_node_state(self.ctx.nodes.keys(), "SKIPPED")
# init node results
self.set_node_results()

def init_ctx(self, datas):
from aiida_worktree.utils import update_nested_dict
Expand All @@ -420,13 +418,49 @@ def init_ctx(self, datas):
key = key.replace("__", ".")
update_nested_dict(self.ctx, key, value)

def launch_worktree(self):
print("launch_worktree: ")
def set_node_results(self):
for _, node in self.ctx.nodes.items():
if node.get("process"):
if isinstance(node["process"], str):
node["process"] = orm.load_node(node["process"])
self.set_node_result(node)
self.set_node_result(node)

def set_node_result(self, node):
name = node["name"]
print(f"set node result: {name}")
if node.get("process"):
print(f"set node result: {name} process")
state = node["process"].process_state.value.upper()
if state == "FINISHED":
node["state"] = state
if node["metadata"]["node_type"] == "worktree":
# expose the outputs of nodetree
node["results"] = getattr(
node["process"].outputs, "group_outputs", None
)
# self.ctx.new_data[name] = outputs
else:
node["results"] = node["process"].outputs
# self.ctx.new_data[name] = node["results"]
self.ctx.nodes[name]["state"] = "FINISHED"
self.node_to_ctx(name)
print(f"Node: {name} finished.")
elif state == "EXCEPTED":
node["state"] = state
node["results"] = node["process"].outputs
# self.ctx.new_data[name] = node["results"]
self.ctx.nodes[name]["state"] = "FAILED"
# set child state to FAILED
self.set_node_state(self.ctx.connectivity["child_node"][name], "FAILED")
print(f"Node: {name} failed.")
else:
print(f"set node result: None")
node["results"] = None

def run_worktree(self):
print("run_worktree: ")
self.report("Lanch worktree.")
if len(self.ctx.worktree["starts"]) > 0:
self.run_nodes(self.ctx.worktree["starts"])
self.ctx.worktree["starts"] = []
return
node_to_run = []
for name, node in self.ctx.nodes.items():
# update node state
Expand Down Expand Up @@ -458,40 +492,14 @@ def is_worktree_finished(self):
]
and node["state"] == "RUNNING"
):
if node.get("process"):
state = node["process"].process_state.value.upper()
print(node["name"], state)
if state == "FINISHED":
node["state"] = state
if node["metadata"]["node_type"] == "worktree":
# expose the outputs of nodetree
node["results"] = getattr(
node["process"].outputs, "group_outputs", None
)
# self.ctx.new_data[name] = outputs
else:
node["results"] = node["process"].outputs
# self.ctx.new_data[name] = node["results"]
self.ctx.nodes[name]["state"] = "FINISHED"
self.node_to_ctx(name)
print(f"Node: {name} finished.")
elif state == "EXCEPTED":
node["state"] = state
node["results"] = node["process"].outputs
# self.ctx.new_data[name] = node["results"]
self.ctx.nodes[name]["state"] = "FAILED"
# set child state to FAILED
self.set_node_state(
self.ctx.connectivity["child_node"][name], "FAILED"
)
print(f"Node: {name} failed.")
self.set_node_result(node)
if node["state"] in ["RUNNING", "CREATED", "READY"]:
is_finished = False
if is_finished:
if self.ctx.worktree["is_while"]:
if self.ctx.worktree["worktree_type"].upper() == "WHILE":
should_run = self.check_while_conditions()
is_finished = not should_run
if self.ctx.worktree["is_for"]:
if self.ctx.worktree["worktree_type"].upper() == "FOR":
should_run = self.check_for_conditions()
is_finished = not should_run
return is_finished
Expand Down Expand Up @@ -743,6 +751,7 @@ def update_ctx_variable(self, value):
def node_to_ctx(self, name):
from aiida_worktree.utils import update_nested_dict

print("node to ctx: ", name)
items = self.ctx.nodes[name]["to_ctx"]
for item in items:
update_nested_dict(
Expand Down Expand Up @@ -835,23 +844,23 @@ def finalize(self):
from aiida_worktree.utils import get_nested_dict

# expose group outputs
print("finalize")
group_outputs = {}
print("group outputs: ", self.ctx.worktree["metadata"]["group_outputs"])
for output in self.ctx.worktree["metadata"]["group_outputs"]:
print("output: ", output)
if output[0] == "ctx":
group_outputs[output[2]] = get_nested_dict(self.ctx, output[1])
node_name, socket_name = output[0].split(".")
if node_name == "ctx":
group_outputs[output[1]] = get_nested_dict(self.ctx, socket_name)
else:
group_outputs[output[2]] = self.ctx.nodes[output[0]]["results"][
group_outputs[output[1]] = self.ctx.nodes[node_name]["results"][
output[1]
]
self.out("group_outputs", group_outputs)
self.out("new_data", self.ctx.new_data)
self.report("Finalize")
print(f"Finalize worktree {self.ctx.worktree['name']}")
for name, node in self.ctx.nodes.items():
if node["state"] == "FAILED":
print(f" Node {name} failed.")
return self.exit_codes.NODE_FAILED
print(f"Finalize worktree {self.ctx.worktree['name']}\n")
# check if all nodes are finished with nonzero exit code
29 changes: 29 additions & 0 deletions aiida_worktree/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from node_graph.node import Node as GraphNode


class Node(GraphNode):
"""Represent a Node in the AiiDA WorkTree.

The class extends from node_graph.node.Node and add new
attributes to it.
"""

socket_entry = "aiida_worktree.socket"
property_entry = "aiida_worktree.property"

def __init__(self, **kwargs):
"""
Initialize a Node instance.
"""
super().__init__(**kwargs)
self.to_ctx = []
self.wait = []
self.process = None

def to_dict(self):
ndata = super().to_dict()
ndata["to_ctx"] = self.to_ctx
ndata["wait"] = self.wait
ndata["process"] = self.process.uuid if self.process else None

return ndata
2 changes: 1 addition & 1 deletion aiida_worktree/nodes/builtin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from scinode.core.node import Node
from aiida_worktree.node import Node
from aiida_worktree.executors.builtin import GatherWorkChain


Expand Down
9 changes: 5 additions & 4 deletions aiida_worktree/nodes/qe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from scinode.core.node import Node
from aiida_worktree.node import Node
from aiida import orm


class AiiDAKpoint(Node):
Expand All @@ -10,8 +11,8 @@ class AiiDAKpoint(Node):
kwargs = ["mesh", "offset"]

def create_properties(self):
self.properties.new("IntVector", "mesh", default=[1, 1, 1], size=3)
self.properties.new("IntVector", "offset", default=[0, 0, 0], size=3)
self.properties.new("AiiDAIntVector", "mesh", default=[1, 1, 1], size=3)
self.properties.new("AiiDAIntVector", "offset", default=[0, 0, 0], size=3)

def create_sockets(self):
self.outputs.new("General", "Kpoint")
Expand Down Expand Up @@ -59,7 +60,7 @@ class AiiDAPWPseudo(Node):

def create_properties(self):
self.properties.new(
"String", "psuedo_familay", default="SSSP/1.2/PBEsol/efficiency"
"AiiDAString", "psuedo_familay", default="SSSP/1.2/PBEsol/efficiency"
)

def create_sockets(self):
Expand Down
Loading
Loading