Skip to content

Commit

Permalink
fix task from workgraph
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 10, 2024
1 parent d7c3745 commit 9dbbe2f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
8 changes: 5 additions & 3 deletions src/aiida_workgraph/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple:
# because kwargs is updated using update_nested_dict_with_special_keys
# which means the data is grouped by the task name
for socket_name, value in data.items():
wgdata["tasks"][task_name]["inputs"][socket_name]["property"][
"value"
] = value
input = wgdata["tasks"][task_name]["inputs"][socket_name]
if input["identifier"] == "workgraph.namespace":
input["value"] = value
else:
input["property"]["value"] = value
# merge the properties
# organize_nested_inputs(wgdata)
# serialize_workgraph_inputs(wgdata)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_task_from_workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def test_build_task_from_workgraph(decorated_add: Callable) -> None:
wg.add_task(decorated_add, name="add2", y=3)
wg.add_link(add1_task.outputs["result"], wg_task.inputs["add1.x"])
wg.add_link(wg_task.outputs["add2.result"], wg.tasks["add2"].inputs["x"])
assert len(wg_task.inputs) == 21
assert len(wg_task.outputs) == 6
wg.submit(wait=True)
# wg.run()
assert wg.tasks["add2"].outputs["result"].socket_value.value == 12
assert len(wg_task.inputs) == 3
assert len(wg_task.outputs) == 4
# wg.submit(wait=True)
wg.run()
assert wg.tasks.add2.outputs.result.socket_value.value == 12

0 comments on commit 9dbbe2f

Please sign in to comment.