From df0145e07bebf5acc7b27aaeeaa5bcbc69b92bed Mon Sep 17 00:00:00 2001 From: superstar54 Date: Thu, 5 Dec 2024 01:00:24 +0100 Subject: [PATCH] check if input is required --- pyproject.toml | 2 +- src/aiida_workgraph/decorator.py | 8 +++++++- src/aiida_workgraph/utils/__init__.py | 13 +++++++------ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 493fbca6..cd11f0e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "numpy~=1.21", "scipy", "ase", - "node-graph==0.1.4", + "node-graph==0.1.5", "node-graph-widget", "aiida-core>=2.3", "cloudpickle", diff --git a/src/aiida_workgraph/decorator.py b/src/aiida_workgraph/decorator.py index 8cf193c8..ea7dda73 100644 --- a/src/aiida_workgraph/decorator.py +++ b/src/aiida_workgraph/decorator.py @@ -356,7 +356,13 @@ def build_shelljob_task( nodes = {} if nodes is None else nodes keys = list(nodes.keys()) for key in keys: - inputs.append({"identifier": "workgraph.any", "name": f"nodes.{key}"}) + inputs.append( + { + "identifier": "workgraph.any", + "name": f"nodes.{key}", + "metadata": {"required": True}, + } + ) # input is a output of another task, we make a link if isinstance(nodes[key], NodeSocket): links[f"nodes.{key}"] = nodes[key] diff --git a/src/aiida_workgraph/utils/__init__.py b/src/aiida_workgraph/utils/__init__.py index 49e5ddbf..d457c857 100644 --- a/src/aiida_workgraph/utils/__init__.py +++ b/src/aiida_workgraph/utils/__init__.py @@ -558,12 +558,13 @@ def workgraph_to_short_json( # for name, task in wgdata["tasks"].items(): # Add required inputs to nodes - inputs = [ - {"name": name, "identifier": input["identifier"]} - for name, input in task["inputs"].items() - if name in task["args"] - or (task["identifier"].upper() == "SHELLJOB" and name.startswith("nodes.")) - ] + inputs = [] + for input in task["inputs"].values(): + metadata = input.get("metadata", {}) or {} + if metadata.get("required", False): + inputs.append( + {"name": input["name"], "identifier": input["identifier"]} + ) properties = process_properties(task) wgdata_short["nodes"][name] = {