Skip to content

Commit

Permalink
check if input is required
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 5, 2024
1 parent b586105 commit df0145e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"numpy~=1.21",
"scipy",
"ase",
"node-graph==0.1.4",
"node-graph==0.1.5",
"node-graph-widget",
"aiida-core>=2.3",
"cloudpickle",
Expand Down
8 changes: 7 additions & 1 deletion src/aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,13 @@ def build_shelljob_task(
nodes = {} if nodes is None else nodes
keys = list(nodes.keys())
for key in keys:
inputs.append({"identifier": "workgraph.any", "name": f"nodes.{key}"})
inputs.append(
{
"identifier": "workgraph.any",
"name": f"nodes.{key}",
"metadata": {"required": True},
}
)
# input is a output of another task, we make a link
if isinstance(nodes[key], NodeSocket):
links[f"nodes.{key}"] = nodes[key]
Expand Down
13 changes: 7 additions & 6 deletions src/aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,12 +558,13 @@ def workgraph_to_short_json(
#
for name, task in wgdata["tasks"].items():
# Add required inputs to nodes
inputs = [
{"name": name, "identifier": input["identifier"]}
for name, input in task["inputs"].items()
if name in task["args"]
or (task["identifier"].upper() == "SHELLJOB" and name.startswith("nodes."))
]
inputs = []
for input in task["inputs"].values():
metadata = input.get("metadata", {}) or {}
if metadata.get("required", False):
inputs.append(
{"name": input["name"], "identifier": input["identifier"]}
)

properties = process_properties(task)
wgdata_short["nodes"][name] = {
Expand Down

0 comments on commit df0145e

Please sign in to comment.