Skip to content

Commit

Permalink
use input_links
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 10, 2024
1 parent 7120c0a commit d7c3745
Show file tree
Hide file tree
Showing 20 changed files with 119 additions and 152 deletions.
18 changes: 12 additions & 6 deletions src/aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,13 @@ def build_task_from_workgraph(wg: any) -> Task:
}
)
for socket in task.inputs:
if socket.name in builtin_input_names:
if socket.socket_name in builtin_input_names:
continue
inputs.append(
{"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"}
{
"identifier": socket._socket_identifier,
"name": f"{task.name}.{socket.socket_name}",
}
)
# outputs
outputs.append(
Expand All @@ -394,15 +397,18 @@ def build_task_from_workgraph(wg: any) -> Task:
}
)
for socket in task.outputs:
if socket.name in builtin_output_names:
if socket.socket_name in builtin_output_names:
continue
outputs.append(
{"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"}
{
"identifier": socket._socket_identifier,
"name": f"{task.name}.{socket.socket_name}",
}
)
group_outputs.append(
{
"name": f"{task.name}.{socket.name}",
"from": f"{task.name}.{socket.name}",
"name": f"{task.name}.{socket.socket_name}",
"from": f"{task.name}.{socket.socket_name}",
}
)
# add built-in sockets
Expand Down
79 changes: 27 additions & 52 deletions src/aiida_workgraph/engine/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def get_task(self, name: str):
task = Task.from_dict(self.ctx._tasks[name])
# update task results
for output in task.outputs:
output.value = get_nested_dict(
output.socket_value = get_nested_dict(
self.ctx._tasks[name]["results"],
output.socket_name,
default=output.value,
default=output.socket_value,
)
return task

Expand Down Expand Up @@ -588,21 +588,20 @@ def get_inputs(
var_args = None
var_kwargs = None
task = self.ctx._tasks[name]
properties = task.get("properties", {})
inputs = {}
for name, prop in task.get("properties", {}).items():
inputs[name] = self.ctx_manager.update_context_variable(prop["value"])
for name, input in task["inputs"].items():
# print(f"input: {input['name']}")
if len(input["links"]) == 0:
if input["identifier"] == "workgraph.namespace":
inputs[name] = self.ctx_manager.update_context_variable(
input["value"]
)
else:
inputs[name] = self.ctx_manager.update_context_variable(
input["property"]["value"]
)
elif len(input["links"]) == 1:
link = input["links"][0]
if input["identifier"] == "workgraph.namespace":
inputs[name] = self.ctx_manager.update_context_variable(input["value"])
else:
inputs[name] = self.ctx_manager.update_context_variable(
input["property"]["value"]
)
for name, links in task["input_links"].items():
if len(links) == 1:
link = links[0]
if self.ctx._tasks[link["from_node"]]["results"] is None:
inputs[name] = None
else:
Expand All @@ -617,9 +616,9 @@ def get_inputs(
link["from_socket"],
)
# handle the case of multiple outputs
elif len(input["links"]) > 1:
elif len(links) > 1:
value = {}
for link in input["links"]:
for link in links:
item_name = f'{link["from_node"]}_{link["from_socket"]}'
# handle the special socket _wait, _outputs
if link["from_socket"] == "_wait":
Expand All @@ -631,42 +630,18 @@ def get_inputs(
"results"
][link["from_socket"]]
inputs[name] = value
for name in task.get("args", []):
if name in inputs:
args.append(inputs[name])
args_dict[name] = inputs[name]
else:
value = self.ctx_manager.update_context_variable(
properties[name]["value"]
)
args.append(value)
args_dict[name] = value
for name in task.get("kwargs", []):
if name in inputs:
kwargs[name] = inputs[name]
else:
value = self.ctx_manager.update_context_variable(
properties[name]["value"]
)
kwargs[name] = value
if task["var_args"] is not None:
name = task["var_args"]
if name in inputs:
var_args = inputs[name]
else:
value = self.ctx_manager.update_context_variable(
properties[name]["value"]
)
var_args = value
if task["var_kwargs"] is not None:
name = task["var_kwargs"]
if name in inputs:
var_kwargs = inputs[name]
else:
value = self.ctx_manager.update_context_variable(
properties[name]["value"]
)
var_kwargs = value
for name, input in inputs.items():
# only need to check the top level key
key = name.split(".")[0]
if key in task["args"]:
args.append(input)
args_dict[name] = input
elif key in task["kwargs"]:
kwargs[name] = input
elif key == task["var_args"]:
var_args = input
elif key == task["var_kwargs"]:
var_kwargs = input
return args, kwargs, var_args, var_kwargs, args_dict

def update_task_state(self, name: str, success=True) -> None:
Expand Down
11 changes: 6 additions & 5 deletions src/aiida_workgraph/tasks/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def serialize_pythonjob_data(cls, tdata: Dict[str, Any]):

for input in tdata["inputs"].values():
if input["metadata"].get("is_function_input", False):
input["property"]["value"] = cls.serialize_socket_data(input)
if input.get("property", {}).get("value") is not None:
input["property"]["value"] = cls.serialize_socket_data(input)

@classmethod
def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None:
Expand All @@ -43,7 +44,7 @@ def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None:
@classmethod
def serialize_socket_data(cls, data: Dict[str, Any]) -> Any:
if data.get("identifier", "Any").upper() == "WORKGRAPH.NAMESPACE":
if data["property"]["value"] is None:
if data.get("property", {}).get("value") is None:
return None
if isinstance(data["property"]["value"], dict):
serialized_result = {}
Expand All @@ -53,14 +54,14 @@ def serialize_socket_data(cls, data: Dict[str, Any]) -> Any:
else:
raise ValueError("Namespace socket should be a dictionary.")
else:
if isinstance(data["property"]["value"], orm.Data):
if isinstance(data.get("property", {}).get("value"), orm.Data):
return data["property"]["value"]
return general_serializer(data["property"]["value"])

@classmethod
def deserialize_socket_data(cls, data: Dict[str, Any]) -> Any:
if data.get("identifier", "Any").upper() == "WORKGRAPH.NAMESPACE":
if isinstance(data["property"]["value"], dict):
if isinstance(data.get("property", {}).get("value"), dict):
deserialized_result = {}
for key, value in data["property"]["value"].items():
if isinstance(value, orm.Data):
Expand All @@ -71,6 +72,6 @@ def deserialize_socket_data(cls, data: Dict[str, Any]) -> Any:
else:
raise ValueError("Namespace socket should be a dictionary.")
else:
if isinstance(data["property"]["value"], orm.Data):
if isinstance(data.get("property", {}).get("value"), orm.Data):
return data["property"]["value"].value
return data["property"]["value"]
37 changes: 11 additions & 26 deletions src/aiida_workgraph/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,29 +82,14 @@ def build_task_link(self) -> None:
1) workgraph links
"""
# reset task input links
for name, task in self.wgdata["tasks"].items():
for _, input in task["inputs"].items():
input["links"] = []
for _, output in task["outputs"].items():
output["links"] = []
# create a `input_links` to store the input links for each task
for task in self.wgdata["tasks"].values():
task["input_links"] = {}
for link in self.wgdata["links"]:
to_socket = [
socket
for name, socket in self.wgdata["tasks"][link["to_node"]][
"inputs"
].items()
if name == link["to_socket"]
][0]
from_socket = [
socket
for name, socket in self.wgdata["tasks"][link["from_node"]][
"outputs"
].items()
if name == link["from_socket"]
][0]
to_socket["links"].append(link)
from_socket["links"].append(link)
task = self.wgdata["tasks"][link["to_node"]]
if link["to_socket"] not in task["input_links"]:
task["input_links"][link["to_socket"]] = []
task["input_links"][link["to_socket"]].append(link)

def assign_zone(self) -> None:
"""Assign zone for each task."""
Expand Down Expand Up @@ -139,8 +124,8 @@ def find_zone_inputs(self, name: str) -> None:
"""Find the input and outputs tasks for the zone."""
task = self.wgdata["tasks"][name]
input_tasks = []
for _, input in self.wgdata["tasks"][name]["inputs"].items():
for link in input["links"]:
for _, links in self.wgdata["tasks"][name]["input_links"].items():
for link in links:
input_tasks.append(link["from_node"])
# find all the input tasks
for child_task in task["children"]:
Expand All @@ -157,8 +142,8 @@ def find_zone_inputs(self, name: str) -> None:
else:
# if the child task is not a zone, get the input tasks of the child task
# find all the input tasks which outside the while zone
for _, input in self.wgdata["tasks"][child_task]["inputs"].items():
for link in input["links"]:
for _, links in self.wgdata["tasks"][child_task]["input_links"].items():
for link in links:
input_tasks.append(link["from_node"])
# find the input tasks which are not in the zone
new_input_tasks = []
Expand Down
6 changes: 3 additions & 3 deletions src/aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def run(
# set task inputs
if inputs is not None:
for name, input in inputs.items():
if name not in self.tasks.keys():
if name not in self.tasks:
raise KeyError(f"Task {name} not found in WorkGraph.")
self.tasks[name].set(input)
# One can not run again if the process is alreay created. otherwise, a new process node will
Expand Down Expand Up @@ -128,7 +128,7 @@ def submit(
# set task inputs
if inputs is not None:
for name, input in inputs.items():
if name not in self.tasks.keys():
if name not in self.tasks:
raise KeyError(f"Task {name} not found in WorkGraph.")
self.tasks[name].set(input)

Expand Down Expand Up @@ -300,7 +300,7 @@ def update(self) -> None:
)
# read results from the process outputs
elif isinstance(node, aiida.orm.Data):
self.tasks[name].outputs[0].value = node
self.tasks[name].outputs[0].socket_value = node
execution_count = getattr(self.process.outputs, "execution_count", None)
self.execution_count = execution_count if execution_count else 0
if self._widget is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_pause_play_task(wg_calcjob):
# Seems the daemon is not responding to the play signal
wg.play_tasks(["add2"])
wg.wait(interval=5)
assert wg.tasks["add2"].outputs["sum"].value == 9
assert wg.tasks["add2"].outputs["sum"].socket_value == 9


def test_pause_play_error_handler(wg_calcjob, finished_process_node):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_awaitable_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def awaitable_func(x, y):
wg.run()
report = get_workchain_report(wg.process, "REPORT")
assert "Waiting for child processes: awaitable_func1" in report
assert add1.outputs["result"].value == 4
assert add1.outputs["result"].socket_value == 4


def test_monitor_decorator():
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_time_monitor(decorated_add):
wg.run()
report = get_workchain_report(wg.process, "REPORT")
assert "Waiting for child processes: monitor1" in report
assert add1.outputs["result"].value == 3
assert add1.outputs["result"].socket_value == 3


def test_file_monitor(decorated_add, tmp_path):
Expand All @@ -84,7 +84,7 @@ async def create_test_file(filepath="/tmp/test_file_monitor.txt", t=2):
wg.run()
report = get_workchain_report(wg.process, "REPORT")
assert "Waiting for child processes: monitor1" in report
assert add1.outputs["result"].value == 3
assert add1.outputs["result"].socket_value == 3


@pytest.mark.usefixtures("started_daemon_client")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_calcfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_run(wg_calcfunction: WorkGraph) -> None:
print("state: ", wg.state)
# print("results: ", results[])
assert wg.tasks["sumdiff2"].node.outputs.sum == 9
assert wg.tasks["sumdiff2"].outputs["sum"].value == 9
assert wg.tasks["sumdiff2"].outputs["sum"].socket_value == 9


@pytest.mark.usefixtures("started_daemon_client")
Expand All @@ -27,4 +27,4 @@ def add(**kwargs):
wg = WorkGraph("test_dynamic_inputs")
wg.add_task(add, name="add1", x=orm.Int(1), y=orm.Int(2))
wg.run()
assert wg.tasks["add1"].outputs["result"].value == 3
assert wg.tasks["add1"].outputs["result"].socket_value == 3
2 changes: 1 addition & 1 deletion tests/test_calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ def test_submit(wg_calcjob: WorkGraph) -> None:
wg = wg_calcjob
wg.name = "test_submit_calcjob"
wg.submit(wait=True)
assert wg.tasks["add2"].outputs["sum"].value == 9
assert wg.tasks["add2"].outputs["sum"].socket_value == 9
4 changes: 2 additions & 2 deletions tests/test_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_workgraph_ctx(decorated_add: Callable) -> None:
get_ctx1.waiting_on.add(add1)
add2 = wg.add_task(decorated_add, "add2", x=get_ctx1.outputs["result"], y=1)
wg.run()
assert add2.outputs["result"].value == 6
assert add2.outputs["result"].socket_value == 6


@pytest.mark.usefixtures("started_daemon_client")
Expand All @@ -40,4 +40,4 @@ def test_task_set_ctx(decorated_add: Callable) -> None:
add2 = wg.add_task(decorated_add, "add2", y="{{ sum }}")
wg.add_link(add1.outputs[0], add2.inputs["x"])
wg.submit(wait=True)
assert add2.outputs["result"].value == 10
assert add2.outputs["result"].socket_value == 10
6 changes: 3 additions & 3 deletions tests/test_data_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_data_task(identifier, data) -> None:
wg = WorkGraph("test_normal_task")
task1 = wg.add_task(identifier, name="task1", value=data)
wg.run()
assert task1.outputs["result"].value == data
assert task1.outputs["result"].socket_value == data


def test_data_dict_task():
Expand All @@ -25,7 +25,7 @@ def test_data_dict_task():
wg = WorkGraph("test_data_dict_task")
task1 = wg.add_task("workgraph.aiida_dict", name="task1", value={"a": 1})
wg.run()
assert task1.outputs["result"].value == {"a": 1}
assert task1.outputs["result"].socket_value == {"a": 1}


def test_data_list_task():
Expand All @@ -34,4 +34,4 @@ def test_data_list_task():
wg = WorkGraph("test_data_list_task")
task1 = wg.add_task("workgraph.aiida_list", name="task1", value=[1, 2, 3])
wg.run()
assert task1.outputs["result"].value == [1, 2, 3]
assert task1.outputs["result"].socket_value == [1, 2, 3]
Loading

0 comments on commit d7c3745

Please sign in to comment.