From 6bf5c50381243c6320707b667e9754160e4894fe Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Wed, 11 Dec 2024 13:58:29 +0100 Subject: [PATCH] Namespace socket and tab completion (#388) AiiDA has a namespace port and supports nested ports, e.g., the "base.pw.metadata" in the `PwRelaxWorkChain.` Previously, WorkGraph flat the nested port, and make every port on the top-level. There are many disadvantages: - one needs to use the dict-style code to access the socket, e.g., `outputs["relax.output_structure"]`. - tab-completion is not possible, because "." is in the key. This PR uses the `NodeSocketNamespace` from the latest `node-graph` to support the nested sockets, i.e., the namespace in AiiDA. The top-level `inputs` and `outputs` are also namespace sockets now. The auto-completion is also supported. --- README.md | 8 +- docs/gallery/autogen/quick_start.py | 184 +---------------- docs/gallery/built-in/autogen/shelljob.py | 8 +- docs/gallery/concept/autogen/socket.py | 13 +- docs/gallery/concept/autogen/task.py | 30 ++- docs/gallery/concept/autogen/workgraph.py | 2 +- docs/gallery/howto/autogen/aggregate.py | 18 +- docs/gallery/howto/autogen/graph_builder.py | 20 +- docs/gallery/howto/autogen/parallel.py | 2 +- docs/gallery/tutorial/autogen/eos.py | 16 +- docs/gallery/tutorial/autogen/qe.py | 16 +- docs/gallery/tutorial/autogen/zero_to_hero.py | 44 ++-- docs/source/built-in/monitor.ipynb | 6 +- docs/source/built-in/pythonjob.ipynb | 36 ++-- docs/source/built-in/shelljob.ipynb | 6 +- docs/source/howto/context.ipynb | 4 +- pyproject.toml | 4 +- src/aiida_workgraph/collection.py | 48 +---- src/aiida_workgraph/config.py | 4 +- src/aiida_workgraph/decorator.py | 61 +++--- src/aiida_workgraph/engine/task_manager.py | 77 +++---- src/aiida_workgraph/engine/utils.py | 8 +- src/aiida_workgraph/engine/workgraph.py | 9 +- src/aiida_workgraph/orm/mapping.py | 18 ++ src/aiida_workgraph/socket.py | 19 +- src/aiida_workgraph/sockets/builtins.py | 63 +++--- src/aiida_workgraph/task.py | 57 +----- src/aiida_workgraph/tasks/builtins.py | 190 ++++++++++-------- src/aiida_workgraph/tasks/monitors.py | 64 +++--- src/aiida_workgraph/tasks/pythonjob.py | 70 +++---- src/aiida_workgraph/tasks/test.py | 64 +++--- src/aiida_workgraph/utils/__init__.py | 69 +------ src/aiida_workgraph/utils/analysis.py | 48 ++--- src/aiida_workgraph/utils/graph.py | 8 +- src/aiida_workgraph/workgraph.py | 45 +++-- tests/conftest.py | 12 +- tests/test_action.py | 10 +- tests/test_awaitable_task.py | 10 +- tests/test_build_task.py | 8 +- tests/test_calcfunction.py | 4 +- tests/test_calcjob.py | 2 +- tests/test_ctx.py | 10 +- tests/test_data_task.py | 6 +- tests/test_decorator.py | 88 ++++---- tests/test_engine.py | 2 +- tests/test_error_handler.py | 9 +- tests/test_failed_node.py | 2 +- tests/test_for.py | 4 +- tests/test_if.py | 18 +- tests/test_link.py | 8 +- tests/test_normal_function.py | 8 +- tests/test_pythonjob.py | 56 +++--- tests/test_shell.py | 12 +- tests/test_socket.py | 18 +- tests/test_task_from_workgraph.py | 28 ++- tests/test_tasks.py | 66 ++++-- tests/test_while.py | 34 ++-- tests/test_workchain.py | 2 +- tests/test_workgraph.py | 24 ++- tests/test_yaml.py | 10 +- tests/test_zone.py | 6 +- tests/widget/test_widget.py | 2 +- 62 files changed, 767 insertions(+), 1031 deletions(-) create mode 100644 src/aiida_workgraph/orm/mapping.py diff --git a/README.md b/README.md index fe918e9b..0d008d57 100644 --- a/README.md +++ b/README.md @@ -49,19 +49,19 @@ def multiply(x, y): wg = WorkGraph("test_add_multiply") wg.add_task(add, name="add1") wg.add_task(multiply, name="multiply1") -wg.add_link(wg.tasks["add1"].outputs["result"], wg.tasks["multiply1"].inputs["x"]) +wg.add_link(wg.tasks.add1.outputs.result, wg.tasks.multiply1.inputs.x) ``` -Prepare inputs and submit the workflow: +Prepare inputs and run the workflow: ```python from aiida import load_profile load_profile() -wg.submit(inputs = {"add1": {"x": 2, "y": 3}, "multiply1": {"y": 4}}, wait=True) -print("Result of multiply1 is", wg.tasks["multiply1"].outputs[0].value) +wg.run(inputs = {"add1": {"x": 2, "y": 3}, "multiply1": {"y": 4}}) +print("Result of multiply1 is", wg.tasks.multiply1.outputs.result.value) ``` ## Web ui To use the web ui, first install the web ui package: diff --git a/docs/gallery/autogen/quick_start.py b/docs/gallery/autogen/quick_start.py index d7f0b7a8..3fef1f7a 100644 --- a/docs/gallery/autogen/quick_start.py +++ b/docs/gallery/autogen/quick_start.py @@ -107,7 +107,7 @@ def multiply(x, y): wg = WorkGraph("add_multiply_workflow") add_task = wg.add_task(add, name="add1") # link the output of the `add` task to one of the `x` input of the `multiply` task. -wg.add_task(multiply, name="multiply1", x=add_task.outputs["result"]) +wg.add_task(multiply, name="multiply1", x=add_task.outputs.result) # export the workgraph to html file so that it can be visualized in a browser wg.to_html() @@ -128,8 +128,8 @@ def multiply(x, y): ) print("State of WorkGraph: {}".format(wg.state)) -print("Result of add : {}".format(wg.tasks["add1"].outputs[0].value)) -print("Result of multiply : {}".format(wg.tasks["multiply1"].outputs[0].value)) +print("Result of add : {}".format(wg.tasks.add1.outputs.result.value)) +print("Result of multiply : {}".format(wg.tasks.multiply1.outputs.result.value)) ###################################################################### @@ -141,172 +141,6 @@ def multiply(x, y): generate_node_graph(wg.pk) -###################################################################### -# Remote job -# ---------- -# -# The ``PythonJob`` is a built-in task that allows users to run Python -# functions on a remote computer. -# -# In this case, we use define the task using normal function instead of -# ``calcfunction``. Thus, user does not need to install AiiDA on the -# remote computer. -# - -from aiida_workgraph import WorkGraph, task - -# define add task -@task() -def add(x, y): - return x + y - - -# define multiply task -@task() -def multiply(x, y): - return x * y - - -wg = WorkGraph("second_workflow") -# You might need to adapt the label to python3 if you use this as your default python -wg.add_task("PythonJob", function=add, name="add", command_info={"label": "python"}) -wg.add_task( - "PythonJob", - function=multiply, - name="multiply", - x=wg.tasks["add"].outputs[0], - command_info={"label": "python"}, -) - -# export the workgraph to html file so that it can be visualized in a browser -wg.to_html() -# visualize the workgraph in jupyter-notebook -# wg - - -###################################################################### -# Submit the workgraph -# ~~~~~~~~~~~~~~~~~~~~ -# -# **Code**: We can set the ``computer`` to the remote computer where we -# want to run the job. This will create a code ``python@computer`` if not -# exists. Of course, you can also set the ``code`` directly if you have -# already created the code. -# -# **Data**: Users can (and is recoomaneded) use normal Python data as -# input. The workgraph will transfer the data to AiiDA data -# (``PickledData``) using pickle. -# -# **Python Version**: since pickle is used to store and load data, the -# Python version on the remote computer should match the one used in the -# localhost. One can use conda to create a virtual environment with the -# same Python version. Then activate the environment before running the -# script. -# -# .. code:: python -# -# # For real applications, one can pass metadata to the scheduler to activate the conda environment -# metadata = { -# "options": { -# 'custom_scheduler_commands' : 'module load anaconda\nconda activate py3.11\n', -# } -# } -# - -from aiida_workgraph.utils import generate_node_graph - -# ------------------------- Submit the calculation ------------------- -# For real applications, one can pass metadata to the scheduler to activate the conda environment -metadata = { - "options": { - # 'custom_scheduler_commands' : 'module load anaconda\nconda activate py3.11\n', - "custom_scheduler_commands": "", - } -} - -wg.submit( - inputs={ - "add": {"x": 2, "y": 3, "computer": "localhost", "metadata": metadata}, - "multiply": {"y": 4, "computer": "localhost", "metadata": metadata}, - }, - wait=True, -) -# ------------------------- Print the output ------------------------- -print( - "\nResult of multiply is {} \n\n".format( - wg.tasks["multiply"].outputs["result"].value - ) -) -# ------------------------- Generate node graph ------------------- -generate_node_graph(wg.pk) - - -###################################################################### -# Use parent folder -# ~~~~~~~~~~~~~~~~~ -# -# The parent_folder parameter allows a task to access the output files of -# a parent task. This feature is particularly useful when you want to -# reuse data generated by a previous computation in subsequent -# computations. In the following example, the multiply task uses the -# ``result.txt`` file created by the add task. -# -# By default, the content of the parent folder is symlinked to the working -# directory. In the function, you can access the parent folder using the -# relative path. For example, ``./parent_folder/result.txt``. -# - -from aiida_workgraph import WorkGraph, task - -# define add task -@task() -def add(x, y): - z = x + y - with open("result.txt", "w") as f: - f.write(str(z)) - - -# define multiply task -@task() -def multiply(x, y): - with open("parent_folder/result.txt", "r") as f: - z = int(f.read()) - return x * y + z - - -wg = WorkGraph("third_workflow") -# You might need to adapt the label to python3 if you use this as your default python -wg.add_task("PythonJob", function=add, name="add", command_info={"label": "python"}) -wg.add_task( - "PythonJob", - function=multiply, - name="multiply", - parent_folder=wg.tasks["add"].outputs["remote_folder"], - command_info={"label": "python"}, -) - -wg.to_html() - - -###################################################################### -# Submit the calculation -# - -# ------------------------- Submit the calculation ------------------- -wg.submit( - inputs={ - "add": {"x": 2, "y": 3, "computer": "localhost"}, - "multiply": {"x": 3, "y": 4, "computer": "localhost"}, - }, - wait=True, -) -print( - "\nResult of multiply is {} \n\n".format( - wg.tasks["multiply"].outputs["result"].value - ) -) - - ###################################################################### # CalcJob and WorkChain # --------------------- @@ -341,7 +175,7 @@ def multiply(x, y): wg = WorkGraph("test_add_multiply") add1 = wg.add_task(ArithmeticAddCalculation, name="add1", x=Int(2), y=Int(3), code=code) add2 = wg.add_task(ArithmeticAddCalculation, name="add2", y=Int(3), code=code) -wg.add_link(wg.tasks["add1"].outputs["sum"], wg.tasks["add2"].inputs["x"]) +wg.add_link(wg.tasks.add1.outputs.sum, wg.tasks.add2.inputs.x) wg.to_html() @@ -350,7 +184,7 @@ def multiply(x, y): # wg.submit(wait=True) -print("Result of task add1: {}".format(wg.tasks["add2"].outputs["sum"].value)) +print("Result of task add1: {}".format(wg.tasks.add2.outputs.sum.value)) from aiida_workgraph.utils import generate_node_graph @@ -413,7 +247,7 @@ def add_multiply(x, y, z): wg = WorkGraph() wg.add_task(add, name="add", x=x, y=y) wg.add_task(multiply, name="multiply", x=z) - wg.add_link(wg.tasks["add"].outputs["result"], wg.tasks["multiply"].inputs["y"]) + wg.add_link(wg.tasks.add.outputs.result, wg.tasks.multiply.inputs.y) # don't forget to return the `wg` return wg @@ -431,7 +265,7 @@ def add_multiply(x, y, z): add_multiply1 = wg.add_task(add_multiply, x=Int(2), y=Int(3), z=Int(4)) add_multiply2 = wg.add_task(add_multiply, x=Int(2), y=Int(3)) # link the output of a task to the input of another task -wg.add_link(add_multiply1.outputs["multiply"], add_multiply2.inputs["z"]) +wg.add_link(add_multiply1.outputs.multiply, add_multiply2.inputs.z) wg.submit(wait=True) print("WorkGraph state: ", wg.state) @@ -440,9 +274,7 @@ def add_multiply(x, y, z): # Get the result of the tasks: # -print( - "Result of task add_multiply1: {}".format(add_multiply1.outputs["multiply"].value) -) +print("Result of task add_multiply1: {}".format(add_multiply1.outputs.multiply.value)) generate_node_graph(wg.pk) diff --git a/docs/gallery/built-in/autogen/shelljob.py b/docs/gallery/built-in/autogen/shelljob.py index 7632ef71..6bbca441 100644 --- a/docs/gallery/built-in/autogen/shelljob.py +++ b/docs/gallery/built-in/autogen/shelljob.py @@ -31,7 +31,7 @@ wg.submit(wait=True) # Print out the result: -print("\nResult: ", date_task.outputs["stdout"].value.get_content()) +print("\nResult: ", date_task.outputs.stdout.value.get_content()) # %% # Under the hood, an AiiDA ``Code`` instance ``date`` will be created on the ``localhost`` computer. In addition, it is also @@ -93,7 +93,7 @@ wg.submit(wait=True) # Print out the result: -print("\nResult: ", date_task.outputs["stdout"].value.get_content()) +print("\nResult: ", date_task.outputs.stdout.value.get_content()) # %% # Running a shell command with files as arguments @@ -117,7 +117,7 @@ wg.submit(wait=True) # Print out the result: -print("\nResult: ", cat_task.outputs["stdout"].value.get_content()) +print("\nResult: ", cat_task.outputs.stdout.value.get_content()) # %% # Create a workflow @@ -157,7 +157,7 @@ def parser(dirpath): name="expr_2", command="expr", arguments=["{result}", "*", "{z}"], - nodes={"z": Int(4), "result": expr_1.outputs["result"]}, + nodes={"z": Int(4), "result": expr_1.outputs.result}, parser=PickledData(parser), parser_outputs=[{"name": "result"}], ) diff --git a/docs/gallery/concept/autogen/socket.py b/docs/gallery/concept/autogen/socket.py index a769a6b7..09e41c00 100644 --- a/docs/gallery/concept/autogen/socket.py +++ b/docs/gallery/concept/autogen/socket.py @@ -25,8 +25,8 @@ def multiply(x, y): return x * y -print("Input ports: ", multiply.task().inputs.keys()) -print("Output ports: ", multiply.task().outputs.keys()) +print("Input ports: ", multiply.task().get_input_names()) +print("Output ports: ", multiply.task().get_output_names()) multiply.task().to_html() @@ -49,8 +49,8 @@ def add_minus(x, y): return {"sum": x + y, "difference": x - y} -print("Input ports: ", add_minus.task().inputs.keys()) -print("Ouput ports: ", add_minus.task().outputs.keys()) +print("Input ports: ", add_minus.task().get_input_names()) +print("Ouput ports: ", add_minus.task().get_output_names()) add_minus.task().to_html() @@ -85,8 +85,7 @@ def add(x: int, y: float) -> float: return x + y -for input in add.task().inputs: - print("{:30s}: {:20s}".format(input.name, input.identifier)) +print("inputs: ", add.task().inputs) ###################################################################### @@ -140,7 +139,7 @@ def add(x, y): def create_sockets(self): # create a General port. - inp = self.inputs.new("workgraph.Any", "symbols") + inp = self.add_input("workgraph.Any", "symbols") # add a string property to the port with default value "H". inp.add_property("String", "default", default="H") diff --git a/docs/gallery/concept/autogen/task.py b/docs/gallery/concept/autogen/task.py index 23872e97..d18724f9 100644 --- a/docs/gallery/concept/autogen/task.py +++ b/docs/gallery/concept/autogen/task.py @@ -47,8 +47,8 @@ def multiply(x, y): # add1 = add.task() -print("Inputs:", add1.inputs.keys()) -print("Outputs:", add1.outputs.keys()) +print("Inputs:", add1.get_input_names()) +print("Outputs:", add1.get_output_names()) ###################################################################### @@ -62,8 +62,8 @@ def add_minus(x, y): return {"sum": x + y, "difference": x - y} -print("Inputs:", add_minus.task().inputs.keys()) -print("Outputs:", add_minus.task().outputs.keys()) +print("Inputs:", add_minus.task().get_input_names()) +print("Outputs:", add_minus.task().get_output_names()) ###################################################################### # One can also add an ``identifier`` to indicates the data type. The data @@ -94,7 +94,7 @@ def add_minus(x, y): wg = WorkGraph() add_minus1 = wg.add_task(add_minus, name="add_minus1") multiply1 = wg.add_task(multiply, name="multiply1") -wg.add_link(add_minus1.outputs["sum"], multiply1.inputs["x"]) +wg.add_link(add_minus1.outputs.sum, multiply1.inputs.x) ###################################################################### @@ -124,14 +124,8 @@ def add_minus(x, y): wg = WorkGraph() norm_task = wg.add_task(NormTask, name="norm1") -print("Inputs:") -for input in norm_task.inputs: - if "." not in input.name: - print(f" - {input.name}") -print("Outputs:") -for output in norm_task.outputs: - if "." not in output.name: - print(f" - {output.name}") +print("Inputs: ", norm_task.inputs) +print("Outputs: ", norm_task.outputs) ###################################################################### # For specifying the outputs, the most explicit way is to provide a list of dictionaries, as shown above. In addition, @@ -194,11 +188,11 @@ class MyAdd(Task): } def create_sockets(self): - self.inputs.clear() - self.outputs.clear() - inp = self.inputs.new("workgraph.Any", "x") - inp = self.inputs.new("workgraph.Any", "y") - self.outputs.new("workgraph.Any", "sum") + self.inputs._clear() + self.outputs._clear() + inp = self.add_input("workgraph.Any", "x") + inp = self.add_input("workgraph.Any", "y") + self.add_output("workgraph.Any", "sum") ###################################################################### diff --git a/docs/gallery/concept/autogen/workgraph.py b/docs/gallery/concept/autogen/workgraph.py index 118357f2..12d01da9 100644 --- a/docs/gallery/concept/autogen/workgraph.py +++ b/docs/gallery/concept/autogen/workgraph.py @@ -33,7 +33,7 @@ def add(x, y): # %% # Add a link between tasks: -wg.add_link(add1.outputs["result"], add2.inputs["x"]) +wg.add_link(add1.outputs.result, add2.inputs.x) wg.to_html() # %% diff --git a/docs/gallery/howto/autogen/aggregate.py b/docs/gallery/howto/autogen/aggregate.py index fc3138fc..41d01d7c 100644 --- a/docs/gallery/howto/autogen/aggregate.py +++ b/docs/gallery/howto/autogen/aggregate.py @@ -109,12 +109,12 @@ def aggregate( aggregate_task = wg.add_task(aggregate, name="aggregate_task") # we have to increase the link limit because by default workgraph only supports one link per input socket -aggregate_task.inputs["collected_values"].link_limit = 50 +aggregate_task.inputs["collected_values"]._link_limit = 50 for i in range(2): # this can be chosen as wanted generator_task = wg.add_task(generator, name=f"generator{i}", seed=Int(i)) wg.add_link( - generator_task.outputs["result"], + generator_task.outputs.result, aggregate_task.inputs["collected_values"], ) @@ -128,7 +128,7 @@ def aggregate( # %% # Print the output -print("aggregate_task result", aggregate_task.outputs["sum"].value) +print("aggregate_task result", aggregate_task.outputs.sum.value) # %% @@ -188,8 +188,8 @@ def aggregate(**collected_values): # we have to increase the link limit because by default workgraph only supports # one link per input socket. -aggregate_task.inputs["collected_ints"].link_limit = 50 -aggregate_task.inputs["collected_floats"].link_limit = 50 +aggregate_task.inputs["collected_ints"]._link_limit = 50 +aggregate_task.inputs["collected_floats"]._link_limit = 50 for i in range(2): # this can be chosen as wanted @@ -200,12 +200,12 @@ def aggregate(**collected_values): ) wg.add_link( - generator_int_task.outputs["result"], + generator_int_task.outputs.result, aggregate_task.inputs["collected_ints"], ) wg.add_link( - generator_float_task.outputs["result"], + generator_float_task.outputs.result, aggregate_task.inputs["collected_floats"], ) wg.to_html() @@ -269,7 +269,7 @@ def generator_loop(nb_iterations: Int): aggregate_task = wg.add_task( aggregate, name="aggregate_task", - collected_values=generator_loop_task.outputs["result"], + collected_values=generator_loop_task.outputs.result, ) wg.to_html() @@ -283,7 +283,7 @@ def generator_loop(nb_iterations: Int): # %% # Print the output -print("aggregate_task result", aggregate_task.outputs["result"].value) +print("aggregate_task result", aggregate_task.outputs.result.value) # %% diff --git a/docs/gallery/howto/autogen/graph_builder.py b/docs/gallery/howto/autogen/graph_builder.py index f67d1a04..f5b074f0 100644 --- a/docs/gallery/howto/autogen/graph_builder.py +++ b/docs/gallery/howto/autogen/graph_builder.py @@ -52,7 +52,7 @@ def add_multiply(x=None, y=None, z=None): wg = WorkGraph() wg.add_task(add, name="add", x=x, y=y) wg.add_task(multiply, name="multiply", x=z) - wg.add_link(wg.tasks["add"].outputs["result"], wg.tasks["multiply"].inputs["y"]) + wg.add_link(wg.tasks.add.outputs.result, wg.tasks.multiply.inputs.y) return wg @@ -61,9 +61,7 @@ def add_multiply(x=None, y=None, z=None): add_multiply1 = wg.add_task(add_multiply(x=Int(2), y=Int(3), z=Int(4))) add_multiply2 = wg.add_task(add_multiply(x=Int(2), y=Int(3))) # link the output of a task to the input of another task -wg.add_link( - add_multiply1.outputs["multiply.result"], add_multiply2.inputs["multiply.x"] -) +wg.add_link(add_multiply1.outputs.multiply.result, add_multiply2.inputs.multiply.x) wg.to_html() # %% @@ -72,7 +70,7 @@ def add_multiply(x=None, y=None, z=None): wg.submit(wait=True) # (2+3)*4 = 20 # (2+3)*20 = 100 -assert add_multiply2.outputs["multiply.result"].value == 100 +assert add_multiply2.outputs.multiply.result.value == 100 # %% # Generate node graph from the AiiDA process @@ -115,7 +113,7 @@ def add_multiply(x, y, z): wg = WorkGraph() wg.add_task(add, name="add", x=x, y=y) wg.add_task(multiply, name="multiply", x=z) - wg.add_link(wg.tasks["add"].outputs[0], wg.tasks["multiply"].inputs["y"]) + wg.add_link(wg.tasks.add.outputs[0], wg.tasks.multiply.inputs.y) # Don't forget to return the `wg` return wg @@ -133,7 +131,7 @@ def add_multiply(x, y, z): add_multiply1 = wg.add_task(add_multiply, x=Int(2), y=Int(3), z=Int(4)) add_multiply2 = wg.add_task(add_multiply, x=Int(2), y=Int(3)) # link the output of a task to the input of another task -wg.add_link(add_multiply1.outputs[0], add_multiply2.inputs["z"]) +wg.add_link(add_multiply1.outputs[0], add_multiply2.inputs.z) wg.submit(wait=True) assert add_multiply2.outputs[0].value == 100 wg.to_html() @@ -211,7 +209,7 @@ def for_loop(nb_iterations: Int): # Running the workgraph. wg.submit(wait=True) -print("Output of last task", task.outputs["result"].value) # 1 + 1 result +print("Output of last task", task.outputs.result.value) # 1 + 1 result # %% # Plotting provenance @@ -250,15 +248,15 @@ def if_then_else(i: Int): wg = WorkGraph("Nested workflow: If") task1 = wg.add_task(if_then_else, i=Int(1)) -task2 = wg.add_task(if_then_else, i=task1.outputs["result"]) +task2 = wg.add_task(if_then_else, i=task1.outputs.result) wg.to_html() # %% # Running the workgraph. wg.submit(wait=True) -print("Output of first task", task1.outputs["result"].value) # 1 + 1 result -print("Output of second task", task2.outputs["result"].value) # 2 % 2 result +print("Output of first task", task1.outputs.result.value) # 1 + 1 result +print("Output of second task", task2.outputs.result.value) # 2 % 2 result # %% # Plotting provenance diff --git a/docs/gallery/howto/autogen/parallel.py b/docs/gallery/howto/autogen/parallel.py index 240c22b4..b8866fd8 100644 --- a/docs/gallery/howto/autogen/parallel.py +++ b/docs/gallery/howto/autogen/parallel.py @@ -132,7 +132,7 @@ def sum(**datas): # print("State of WorkGraph: {}".format(wg.state)) -print("Result of task add1: {}".format(wg.tasks["sum1"].outputs["result"].value)) +print("Result of task add1: {}".format(wg.tasks.sum1.outputs.result.value)) # %% diff --git a/docs/gallery/tutorial/autogen/eos.py b/docs/gallery/tutorial/autogen/eos.py index 51a1c09e..c57c3729 100644 --- a/docs/gallery/tutorial/autogen/eos.py +++ b/docs/gallery/tutorial/autogen/eos.py @@ -102,8 +102,8 @@ def eos(**datas): scale_structure1 = wg.add_task(scale_structure, name="scale_structure1") all_scf1 = wg.add_task(all_scf, name="all_scf1") eos1 = wg.add_task(eos, name="eos1") -wg.add_link(scale_structure1.outputs["structures"], all_scf1.inputs["structures"]) -wg.add_link(all_scf1.outputs["result"], eos1.inputs["datas"]) +wg.add_link(scale_structure1.outputs.structures, all_scf1.inputs.structures) +wg.add_link(all_scf1.outputs.result, eos1.inputs.datas) wg.to_html() # visualize the workgraph in jupyter-notebook # wg @@ -183,8 +183,8 @@ def eos(**datas): } # ------------------------------------------------------- # set the input parameters for each task -wg.tasks["scale_structure1"].set({"structure": si, "scales": [0.95, 1.0, 1.05]}) -wg.tasks["all_scf1"].set({"scf_inputs": scf_inputs}) +wg.tasks.scale_structure1.set({"structure": si, "scales": [0.95, 1.0, 1.05]}) +wg.tasks.all_scf1.set({"scf_inputs": scf_inputs}) print("Waiting for the workgraph to finish...") wg.submit(wait=True, timeout=300) # one can also run the workgraph directly @@ -196,7 +196,7 @@ def eos(**datas): # -data = wg.tasks["eos1"].outputs["result"].value.get_dict() +data = wg.tasks["eos1"].outputs.result.value.get_dict() print("B: {B}\nv0: {v0}\ne0: {e0}\nv0: {v0}".format(**data)) # %% @@ -216,8 +216,8 @@ def eos_workgraph(structure=None, scales=None, scf_inputs=None): ) all_scf1 = wg.add_task(all_scf, name="all_scf1", scf_inputs=scf_inputs) eos1 = wg.add_task(eos, name="eos1") - wg.add_link(scale_structure1.outputs["structures"], all_scf1.inputs["structures"]) - wg.add_link(all_scf1.outputs["result"], eos1.inputs["datas"]) + wg.add_link(scale_structure1.outputs.structures, all_scf1.inputs.structures) + wg.add_link(all_scf1.outputs.result, eos1.inputs.datas) return wg @@ -268,7 +268,7 @@ def eos_workgraph(structure=None, scales=None, scf_inputs=None): eos_wg_task = wg.add_task( eos_workgraph, name="eos1", scales=[0.95, 1.0, 1.05], scf_inputs=scf_inputs ) -wg.add_link(relax_task.outputs["output_structure"], eos_wg_task.inputs["structure"]) +wg.add_link(relax_task.outputs.output_structure, eos_wg_task.inputs["structure"]) # ------------------------------------------------------- # One can submit the workgraph directly # wg.submit(wait=True, timeout=300) diff --git a/docs/gallery/tutorial/autogen/qe.py b/docs/gallery/tutorial/autogen/qe.py index aaab06bb..fe3323ce 100644 --- a/docs/gallery/tutorial/autogen/qe.py +++ b/docs/gallery/tutorial/autogen/qe.py @@ -147,7 +147,7 @@ # ------------------------- Print the output ------------------------- print( "Energy of an un-relaxed N2 molecule: {:0.3f}".format( - pw1.outputs["output_parameters"].value.get_dict()["energy"] + pw1.outputs.output_parameters.value.get_dict()["energy"] ) ) # @@ -238,8 +238,8 @@ def atomization_energy(output_atom, output_mol): ) # create the task to calculate the atomization energy atomization = wg.add_task(atomization_energy, name="atomization_energy") -wg.add_link(pw_n.outputs["output_parameters"], atomization.inputs["output_atom"]) -wg.add_link(pw_n2.outputs["output_parameters"], atomization.inputs["output_mol"]) +wg.add_link(pw_n.outputs.output_parameters, atomization.inputs.output_atom) +wg.add_link(pw_n2.outputs.output_parameters, atomization.inputs.output_mol) wg.to_html() @@ -249,9 +249,7 @@ def atomization_energy(output_atom, output_mol): wg.submit(wait=True, timeout=300) -print( - "Atomization energy: {:0.3f} eV".format(atomization.outputs["result"].value.value) -) +print("Atomization energy: {:0.3f} eV".format(atomization.outputs.result.value.value)) # %% @@ -305,7 +303,7 @@ def pw_parameters(paras, relax_type): }, ) paras_task = wg.add_task(pw_parameters, "parameters", paras=paras, relax_type="relax") -wg.add_link(paras_task.outputs[0], pw_relax1.inputs["base.pw.parameters"]) +wg.add_link(paras_task.outputs[0], pw_relax1.inputs.base.pw.parameters) # One can submit the workgraph directly # wg.submit(wait=True, timeout=200) # print( @@ -341,10 +339,10 @@ def pw_parameters(paras, relax_type): # we can now inspect the inputs of the workchain print("The inputs for the PwBaseWorkchain are:") print("-" * 80) -pprint(pw_relax1.inputs["base"].value) +pprint(pw_relax1.inputs.base._value) print("\nThe input parameters for pw are:") print("-" * 80) -pprint(pw_relax1.inputs["base"].value["pw"]["parameters"].get_dict()) +pprint(pw_relax1.inputs.base.pw.parameters.value.get_dict()) # %% diff --git a/docs/gallery/tutorial/autogen/zero_to_hero.py b/docs/gallery/tutorial/autogen/zero_to_hero.py index 617e2e8e..3c93ba3f 100644 --- a/docs/gallery/tutorial/autogen/zero_to_hero.py +++ b/docs/gallery/tutorial/autogen/zero_to_hero.py @@ -70,7 +70,7 @@ def multiply(x, y): # wg = WorkGraph("add_multiply_workflow") wg.add_task(add, name="add1") -wg.add_task(multiply, name="multiply1", x=wg.tasks["add1"].outputs["result"]) +wg.add_task(multiply, name="multiply1", x=wg.tasks.add1.outputs.result) # export the workgraph to html file so that it can be visualized in a browser wg.to_html() # visualize the workgraph in jupyter-notebook @@ -92,11 +92,9 @@ def multiply(x, y): inputs={"add1": {"x": Int(2), "y": Int(3)}, "multiply1": {"y": Int(4)}}, wait=True ) # ------------------------- Print the output ------------------------- -assert wg.tasks["multiply1"].outputs["result"].value == 20 +assert wg.tasks.multiply1.outputs.result.value == 20 print( - "\nResult of multiply1 is {} \n\n".format( - wg.tasks["multiply1"].outputs["result"].value - ) + "\nResult of multiply1 is {} \n\n".format(wg.tasks.multiply1.outputs.result.value) ) # ------------------------- Generate node graph ------------------- generate_node_graph(wg.pk) @@ -118,7 +116,7 @@ def multiply(x, y): wg = WorkGraph("test_calcjob") new = wg.add_task new(ArithmeticAddCalculation, name="add1") -wg.add_task(ArithmeticAddCalculation, name="add2", x=wg.tasks["add1"].outputs["sum"]) +wg.add_task(ArithmeticAddCalculation, name="add2", x=wg.tasks.add1.outputs.sum) wg.to_html() # @@ -134,7 +132,7 @@ def multiply(x, y): # visualize the task -wg.tasks["add1"].to_html() +wg.tasks.add1.to_html() # # %% @@ -183,8 +181,8 @@ def atomization_energy(output_atom, output_mol): wg.add_task( atomization_energy, name="atomization_energy", - output_atom=pw_atom.outputs["output_parameters"], - output_mol=pw_mol.outputs["output_parameters"], + output_atom=pw_atom.outputs.output_parameters, + output_mol=pw_mol.outputs.output_parameters, ) # export the workgraph to html file so that it can be visualized in a browser wg.to_html() @@ -265,7 +263,7 @@ def atomization_energy(output_atom, output_mol): } # # ------------------------- Set the inputs ------------------------- -wg.tasks["pw_atom"].set( +wg.tasks.pw_atom.set( { "code": pw_code, "structure": structure_n, @@ -275,7 +273,7 @@ def atomization_energy(output_atom, output_mol): "metadata": metadata, } ) -wg.tasks["pw_mol"].set( +wg.tasks.pw_mol.set( { "code": pw_code, "structure": structure_n2, @@ -290,17 +288,17 @@ def atomization_energy(output_atom, output_mol): # ------------------------- Print the output ------------------------- print( "Energy of a N atom: {:0.3f}".format( - wg.tasks["pw_atom"].outputs["output_parameters"].value.get_dict()["energy"] + wg.tasks.pw_atom.outputs.output_parameters.value.get_dict()["energy"] ) ) print( "Energy of an un-relaxed N2 molecule: {:0.3f}".format( - wg.tasks["pw_mol"].outputs["output_parameters"].value.get_dict()["energy"] + wg.tasks.pw_mol.outputs.output_parameters.value.get_dict()["energy"] ) ) print( "Atomization energy: {:0.3f} eV".format( - wg.tasks["atomization_energy"].outputs["result"].value.value + wg.tasks.atomization_energy.outputs.result.value.value ) ) # @@ -364,9 +362,9 @@ def add_multiply_if_generator(x, y): wg.add_task( add_multiply_if_generator, name="add_multiply_if1", - x=wg.tasks["add1"].outputs["result"], + x=wg.tasks.add1.outputs.result, ) -wg.add_task(add, name="add2", x=wg.tasks["add_multiply_if1"].outputs["result"]) +wg.add_task(add, name="add2", x=wg.tasks.add_multiply_if1.outputs.result) wg.to_html() # @@ -383,8 +381,8 @@ def add_multiply_if_generator(x, y): wait=True, ) # ------------------------- Print the output ------------------------- -assert wg.tasks["add2"].outputs["result"].value == 7 -print("\nResult of add2 is {} \n\n".format(wg.tasks["add2"].outputs["result"].value)) +assert wg.tasks.add2.outputs.result.value == 7 +print("\nResult of add2 is {} \n\n".format(wg.tasks.add2.outputs.result.value)) # # %% # Note: one can not see the detail of the `add_multiply_if1` before you running it. @@ -465,9 +463,9 @@ def eos(**datas): wg = WorkGraph("eos") scale_structure1 = wg.add_task(scale_structure, name="scale_structure1") all_scf1 = wg.add_task( - all_scf, name="all_scf1", structures=scale_structure1.outputs["structures"] + all_scf, name="all_scf1", structures=scale_structure1.outputs.structures ) -eos1 = wg.add_task(eos, name="eos1", datas=all_scf1.outputs["result"]) +eos1 = wg.add_task(eos, name="eos1", datas=all_scf1.outputs.result) wg.to_html() # @@ -489,8 +487,8 @@ def eos_workgraph(structure=None, scales=None, scf_inputs=None): ) all_scf1 = wg.add_task(all_scf, name="all_scf1", scf_inputs=scf_inputs) eos1 = wg.add_task(eos, name="eos1") - wg.add_link(scale_structure1.outputs["structures"], all_scf1.inputs["structures"]) - wg.add_link(all_scf1.outputs["result"], eos1.inputs["datas"]) + wg.add_link(scale_structure1.outputs.structures, all_scf1.inputs.structures) + wg.add_link(all_scf1.outputs.result, eos1.inputs.datas) return wg @@ -500,7 +498,7 @@ def eos_workgraph(structure=None, scales=None, scf_inputs=None): wg = WorkGraph("relax_eos") relax_task = wg.add_task(PwCalculation, name="relax1") eos_wg_task = wg.add_task( - eos_workgraph, name="eos1", structure=relax_task.outputs["output_structure"] + eos_workgraph, name="eos1", structure=relax_task.outputs.output_structure ) wg.to_html() diff --git a/docs/source/built-in/monitor.ipynb b/docs/source/built-in/monitor.ipynb index 25be2343..293b9b9e 100644 --- a/docs/source/built-in/monitor.ipynb +++ b/docs/source/built-in/monitor.ipynb @@ -67,10 +67,10 @@ "wg2.add_task(add, \"add2\",x=1, y=2, t=5)\n", "wg2.submit(wait=True)\n", "wg1.wait()\n", - "print(\"ctime of add1: \", wg1.tasks[\"add1\"].node.ctime)\n", - "print(\"citme of add2: \", wg2.tasks[\"add2\"].node.ctime)\n", + "print(\"ctime of add1: \", wg1.tasks.add1.node.ctime)\n", + "print(\"citme of add2: \", wg2.tasks.add2.node.ctime)\n", "# calculate the time difference between the creation of the two task nodes\n", - "time_difference = wg1.tasks[\"add1\"].node.ctime - wg2.tasks[\"add2\"].node.ctime\n", + "time_difference = wg1.tasks.add1.node.ctime - wg2.tasks.add2.node.ctime\n", "assert time_difference.total_seconds() > 5" ] }, diff --git a/docs/source/built-in/pythonjob.ipynb b/docs/source/built-in/pythonjob.ipynb index a7e286ad..bcb76f71 100644 --- a/docs/source/built-in/pythonjob.ipynb +++ b/docs/source/built-in/pythonjob.ipynb @@ -106,7 +106,7 @@ "wg = WorkGraph(\"first_workflow\")\n", "wg.add_task(add, name=\"add\")\n", "# we can also use a normal python function directly, but provide the \"PythonJob\" as the first argument\n", - "wg.add_task(\"PythonJob\", function=multiply, name=\"multiply\", x=wg.tasks[\"add\"].outputs[0])\n", + "wg.add_task(\"PythonJob\", function=multiply, name=\"multiply\", x=wg.tasks.add.outputs[0])\n", "\n", "# visualize the workgraph\n", "wg.to_html()\n", @@ -362,7 +362,7 @@ " \"metadata\": metadata}},\n", " wait=True)\n", "#------------------------- Print the output -------------------------\n", - "print(\"\\nResult of multiply is {} \\n\\n\".format(wg.tasks[\"multiply\"].outputs['result'].value))\n", + "print(\"\\nResult of multiply is {} \\n\\n\".format(wg.tasks.multiply.outputs.result.value))\n", "#------------------------- Generate node graph -------------------\n", "generate_node_graph(wg.pk)" ] @@ -421,7 +421,7 @@ "wg = WorkGraph(\"PythonJob_parent_folder\")\n", "wg.add_task(\"PythonJob\", function=add, name=\"add\")\n", "wg.add_task(\"PythonJob\", function=multiply, name=\"multiply\",\n", - " parent_folder=wg.tasks[\"add\"].outputs[\"remote_folder\"],\n", + " parent_folder=wg.tasks.add.outputs.remote_folder,\n", " )\n", "\n", "wg.to_html()" @@ -457,7 +457,7 @@ "wg.submit(inputs = {\"add\": {\"x\": 2, \"y\": 3, \"computer\": \"localhost\"},\n", " \"multiply\": {\"x\": 3, \"y\": 4, \"computer\": \"localhost\"}},\n", " wait=True)\n", - "print(\"\\nResult of multiply is {} \\n\\n\".format(wg.tasks[\"multiply\"].outputs['result'].value))" + "print(\"\\nResult of multiply is {} \\n\\n\".format(wg.tasks.multiply.outputs.result.value))" ] }, { @@ -526,7 +526,7 @@ " },\n", " },\n", " wait=True)\n", - "print(\"\\nResult of add is {} \\n\\n\".format(wg.tasks[\"add\"].outputs['result'].value))" + "print(\"\\nResult of add is {} \\n\\n\".format(wg.tasks.add.outputs['result'].value))" ] }, { @@ -615,8 +615,8 @@ "pw_atom = wg.add_task(\"PythonJob\", function=emt, name=\"emt_atom\")\n", "pw_mol = wg.add_task(\"PythonJob\", function=emt, name=\"emt_mol\")\n", "wg.add_task(\"PythonJob\", function=atomization_energy, name=\"atomization_energy\",\n", - " energy_atom=pw_atom.outputs[\"result\"],\n", - " energy_molecule=pw_mol.outputs[\"result\"])\n", + " energy_atom=pw_atom.outputs.result,\n", + " energy_molecule=pw_mol.outputs.result)\n", "wg.to_html()" ] }, @@ -904,9 +904,9 @@ "#------------------------- Submit the calculation -------------------\n", "wg.submit(wait=True, timeout=200)\n", "#------------------------- Print the output -------------------------\n", - "print('Energy of a N atom: {:0.3f}'.format(wg.tasks['emt_atom'].outputs[\"result\"].value.value))\n", - "print('Energy of an un-relaxed N2 molecule: {:0.3f}'.format(wg.tasks['emt_mol'].outputs[\"result\"].value.value))\n", - "print('Atomization energy: {:0.3f} eV'.format(wg.tasks['atomization_energy'].outputs[\"result\"].value.value))\n", + "print('Energy of a N atom: {:0.3f}'.format(wg.tasks['emt_atom'].outputs.result.value.value))\n", + "print('Energy of an un-relaxed N2 molecule: {:0.3f}'.format(wg.tasks['emt_mol'].outputs.result.value.value))\n", + "print('Atomization energy: {:0.3f} eV'.format(wg.tasks['atomization_energy'].outputs.result.value.value))\n", "#------------------------- Generate node graph -------------------\n", "generate_node_graph(wg.pk)\n" ] @@ -981,7 +981,7 @@ "\n", "wg = WorkGraph(\"PythonJob_shell_command\")\n", "wg.add_task(\"PythonJob\", function=add, name=\"add\")\n", - "wg.add_task(\"PythonJob\", function=multiply, name=\"multiply\", x=wg.tasks[\"add\"].outputs[0])\n", + "wg.add_task(\"PythonJob\", function=multiply, name=\"multiply\", x=wg.tasks.add.outputs[0])\n", "\n", "# visualize the workgraph\n", "wg.to_html()\n" @@ -1191,7 +1191,7 @@ " \"multiply\": {\"y\": 4, \"computer\": \"localhost\"}},\n", " wait=True)\n", "#------------------------- Print the output -------------------------\n", - "print(\"\\nResult of multiply is {} \\n\\n\".format(wg.tasks[\"multiply\"].outputs['result'].value))\n", + "print(\"\\nResult of multiply is {} \\n\\n\".format(wg.tasks.multiply.outputs.result.value))\n", "#------------------------- Generate node graph -------------------\n", "generate_node_graph(wg.pk)" ] @@ -1546,14 +1546,14 @@ " # -------- calculate_enegies -----------\n", "calculate_enegies_task = wg.add_task(calculate_enegies,\n", " name=\"calculate_enegies\",\n", - " scaled_atoms=scale_atoms_task.outputs[\"scaled_atoms\"],\n", + " scaled_atoms=scale_atoms_task.outputs.scaled_atoms,\n", " )\n", " # -------- fit_eos -----------\n", "wg.add_task(\"PythonJob\",\n", " function=fit_eos,\n", " name=\"fit_eos\",\n", - " volumes=scale_atoms_task.outputs[\"volumes\"],\n", - " emt_results=calculate_enegies_task.outputs[\"results\"],\n", + " volumes=scale_atoms_task.outputs.volumes,\n", + " emt_results=calculate_enegies_task.outputs.results,\n", " )\n", "wg.to_html()" ] @@ -1598,7 +1598,7 @@ " )\n", "\n", "print(\"The fitted EOS parameters are:\")\n", - "wg.tasks[\"fit_eos\"].outputs[\"result\"].value.value\n" + "wg.tasks[\"fit_eos\"].outputs.result.value.value\n" ] }, { @@ -2257,8 +2257,8 @@ "wg.add_task(add, name=\"add\", x=1, y=-2)\n", "wg.submit(wait=True)\n", "\n", - "print(\"exit status: \", wg.tasks[\"add\"].node.exit_status)\n", - "print(\"exit message: \", wg.tasks[\"add\"].node.exit_message)" + "print(\"exit status: \", wg.tasks.add.node.exit_status)\n", + "print(\"exit message: \", wg.tasks.add.node.exit_message)" ] }, { diff --git a/docs/source/built-in/shelljob.ipynb b/docs/source/built-in/shelljob.ipynb index 142de3de..079801e4 100644 --- a/docs/source/built-in/shelljob.ipynb +++ b/docs/source/built-in/shelljob.ipynb @@ -249,19 +249,19 @@ "# bc command to calculate the expression\n", "bc_task_1 = wg.add_task(\"ShellJob\", name=\"bc_task_1\", command=\"bc\", arguments=[\"{expression}\"],\n", " parser=PickledData(parser),\n", - " nodes={'expression': echo_task_1.outputs[\"stdout\"]},\n", + " nodes={'expression': echo_task_1.outputs.stdout},\n", " parser_outputs=[{\"name\": \"result\"}],\n", " )\n", "# echo result + z expression\n", "echo_task_2 = wg.add_task(\"ShellJob\", name=\"echo_task_2\", command=\"echo\",\n", " arguments=[\"{result}\", \"*\", \"{z}\"],\n", " nodes={'z': Int(4),\n", - " \"result\": bc_task_1.outputs[\"result\"]},\n", + " \"result\": bc_task_1.outputs.result},\n", " )\n", "# bc command to calculate the expression\n", "bc_task_2 = wg.add_task(\"ShellJob\", name=\"bc_task_2\", command=\"bc\", arguments=[\"{expression}\"],\n", " parser=PickledData(parser),\n", - " nodes={'expression': echo_task_2.outputs[\"stdout\"]},\n", + " nodes={'expression': echo_task_2.outputs.stdout},\n", " parser_outputs=[{\"name\": \"result\"}],\n", " )\n", "display(wg.to_html())\n", diff --git a/docs/source/howto/context.ipynb b/docs/source/howto/context.ipynb index 9ccbca59..eea0b1cf 100644 --- a/docs/source/howto/context.ipynb +++ b/docs/source/howto/context.ipynb @@ -137,10 +137,10 @@ "# Set the context of the workgraph\n", "wg.context = {\"x\": 2, \"data.y\": 3}\n", "get_ctx1 = wg.add_task(\"workgraph.get_context\", name=\"get_ctx1\", key=\"x\")\n", - "add1 = wg.add_task(add, \"add1\", x=get_ctx1.outputs[\"result\"],\n", + "add1 = wg.add_task(add, \"add1\", x=get_ctx1.outputs.result,\n", " y=\"{{data.y}}\")\n", "set_ctx1 = wg.add_task(\"workgraph.set_context\", name=\"set_ctx1\", key=\"x\",\n", - " value=add1.outputs[\"result\"])\n", + " value=add1.outputs.result)\n", "wg.to_html()\n", "# wg" ] diff --git a/pyproject.toml b/pyproject.toml index 3679dde5..bb27084f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "numpy~=1.21", "scipy", "ase", - "node-graph==0.1.7", + "node-graph==0.1.9", "node-graph-widget", "aiida-core>=2.3", "cloudpickle", @@ -127,7 +127,7 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" [project.entry-points."aiida_workgraph.socket"] "workgraph.any" = "aiida_workgraph.sockets.builtins:SocketAny" -"workgraph.namespace" = "aiida_workgraph.sockets.builtins:SocketNamespace" +"workgraph.namespace" = "aiida_workgraph.socket:TaskSocketNamespace" "workgraph.int" = "aiida_workgraph.sockets.builtins:SocketInt" "workgraph.float" = "aiida_workgraph.sockets.builtins:SocketFloat" "workgraph.string" = "aiida_workgraph.sockets.builtins:SocketString" diff --git a/src/aiida_workgraph/collection.py b/src/aiida_workgraph/collection.py index ac8f3337..e73dbf59 100644 --- a/src/aiida_workgraph/collection.py +++ b/src/aiida_workgraph/collection.py @@ -1,14 +1,12 @@ from node_graph.collection import ( NodeCollection, PropertyCollection, - InputSocketCollection, - OutputSocketCollection, ) from typing import Any, Callable, Optional, Union class TaskCollection(NodeCollection): - def new( + def _new( self, identifier: Union[Callable, str], name: Optional[str] = None, @@ -33,21 +31,21 @@ def new( outputs=kwargs.get("outputs", None), parser_outputs=kwargs.pop("parser_outputs", None), ) - task = super().new(identifier, name, uuid, **kwargs) + task = super()._new(identifier, name, uuid, **kwargs) return task if isinstance(identifier, str) and identifier.upper() == "WHILE": - task = super().new("workgraph.while", name, uuid, **kwargs) + task = super()._new("workgraph.while", name, uuid, **kwargs) return task if isinstance(identifier, str) and identifier.upper() == "IF": - task = super().new("workgraph.if", name, uuid, **kwargs) + task = super()._new("workgraph.if", name, uuid, **kwargs) return task if isinstance(identifier, WorkGraph): identifier = build_task_from_workgraph(identifier) - return super().new(identifier, name, uuid, **kwargs) + return super()._new(identifier, name, uuid, **kwargs) class WorkGraphPropertyCollection(PropertyCollection): - def new( + def _new( self, identifier: Union[Callable, str], name: Optional[str] = None, @@ -59,36 +57,4 @@ def new( if callable(identifier): identifier = build_property_from_AiiDA(identifier) # Call the original new method - return super().new(identifier, name, **kwargs) - - -class WorkGraphInputSocketCollection(InputSocketCollection): - def new( - self, - identifier: Union[Callable, str], - name: Optional[str] = None, - **kwargs: Any - ) -> Any: - from aiida_workgraph.socket import build_socket_from_AiiDA - - # build the socket on the fly if the identifier is a callable - if callable(identifier): - identifier = build_socket_from_AiiDA(identifier) - # Call the original new method - return super().new(identifier, name, **kwargs) - - -class WorkGraphOutputSocketCollection(OutputSocketCollection): - def new( - self, - identifier: Union[Callable, str], - name: Optional[str] = None, - **kwargs: Any - ) -> Any: - from aiida_workgraph.socket import build_socket_from_AiiDA - - # build the socket on the fly if the identifier is a callable - if callable(identifier): - identifier = build_socket_from_AiiDA(identifier) - # Call the original new method - return super().new(identifier, name, **kwargs) + return super()._new(identifier, name, **kwargs) diff --git a/src/aiida_workgraph/config.py b/src/aiida_workgraph/config.py index 4fce3c41..e3cd4d67 100644 --- a/src/aiida_workgraph/config.py +++ b/src/aiida_workgraph/config.py @@ -6,7 +6,9 @@ WORKGRAPH_SHORT_EXTRA_KEY = "_workgraph_short" -builtin_inputs = [{"name": "_wait", "link_limit": 1e6, "arg_type": "none"}] +builtin_inputs = [ + {"name": "_wait", "link_limit": 1e6, "metadata": {"arg_type": "none"}} +] builtin_outputs = [{"name": "_wait"}, {"name": "_outputs"}] diff --git a/src/aiida_workgraph/decorator.py b/src/aiida_workgraph/decorator.py index 0ab2d2df..b125991d 100644 --- a/src/aiida_workgraph/decorator.py +++ b/src/aiida_workgraph/decorator.py @@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, List, Optional, Union, Tuple from aiida_workgraph.utils import get_executor from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain -from aiida import orm from aiida.orm.nodes.process.calculation.calcfunction import CalcFunctionNode from aiida.orm.nodes.process.workflow.workfunction import WorkFunctionNode from aiida.engine.processes.ports import PortNamespace @@ -11,6 +10,8 @@ from aiida_workgraph.utils import build_callable, validate_task_inout import inspect from aiida_workgraph.config import builtin_inputs, builtin_outputs +from aiida_workgraph.orm.mapping import type_mapping + task_types = { CalcFunctionNode: "CALCFUNCTION", @@ -19,29 +20,18 @@ WorkChain: "WORKCHAIN", } -type_mapping = { - "default": "workgraph.any", - "namespace": "workgraph.namespace", - int: "workgraph.int", - float: "workgraph.float", - str: "workgraph.string", - bool: "workgraph.bool", - orm.Int: "workgraph.aiida_int", - orm.Float: "workgraph.aiida_float", - orm.Str: "workgraph.aiida_string", - orm.Bool: "workgraph.aiida_bool", - orm.List: "workgraph.aiida_list", - orm.Dict: "workgraph.aiida_dict", - orm.StructureData: "workgraph.aiida_structuredata", -} - def create_task(tdata): """Wrap create_node from node_graph to create a Task.""" from node_graph.decorator import create_node + from node_graph.utils import list_to_dict tdata["type_mapping"] = type_mapping tdata["metadata"]["node_type"] = tdata["metadata"].pop("task_type") + tdata["properties"] = list_to_dict(tdata.get("properties", {})) + tdata["inputs"] = list_to_dict(tdata.get("inputs", {})) + tdata["outputs"] = list_to_dict(tdata.get("outputs", {})) + return create_node(tdata) @@ -67,9 +57,11 @@ def add_input_recursive( { "identifier": "workgraph.namespace", "name": port_name, - "arg_type": "kwargs", - "metadata": {"required": required, "dynamic": port.dynamic}, - "property": {"identifier": "workgraph.any", "default": None}, + "metadata": { + "arg_type": "kwargs", + "required": required, + "dynamic": port.dynamic, + }, } ) for value in port.values(): @@ -88,8 +80,7 @@ def add_input_recursive( { "identifier": socket_type, "name": port_name, - "arg_type": "kwargs", - "metadata": {"required": required}, + "metadata": {"arg_type": "kwargs", "required": required}, } ) return inputs @@ -222,7 +213,7 @@ def build_task_from_AiiDA( tdata: Dict[str, Any], inputs: Optional[List[str]] = None, outputs: Optional[List[str]] = None, -) -> Task: +) -> Tuple[Task, Dict[str, Any]]: """Register a task from a AiiDA component. For example: CalcJob, WorkChain, CalcFunction, WorkFunction.""" @@ -248,11 +239,9 @@ def build_task_from_AiiDA( if name not in [input["name"] for input in inputs]: inputs.append( { - "identifier": "workgraph.any", + "identifier": "workgraph.namespace", "name": name, - "arg_type": "var_kwargs", - "metadata": {"dynamic": True}, - "property": {"identifier": "workgraph.any", "default": {}}, + "metadata": {"arg_type": "var_kwargs", "dynamic": True}, } ) @@ -392,10 +381,13 @@ def build_task_from_workgraph(wg: any) -> Task: } ) for socket in task.inputs: - if socket.name in builtin_input_names: + if socket._name in builtin_input_names: continue inputs.append( - {"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"} + { + "identifier": socket._identifier, + "name": f"{task.name}.{socket._name}", + } ) # outputs outputs.append( @@ -405,15 +397,18 @@ def build_task_from_workgraph(wg: any) -> Task: } ) for socket in task.outputs: - if socket.name in builtin_output_names: + if socket._name in builtin_output_names: continue outputs.append( - {"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"} + { + "identifier": socket._identifier, + "name": f"{task.name}.{socket._name}", + } ) group_outputs.append( { - "name": f"{task.name}.{socket.name}", - "from": f"{task.name}.{socket.name}", + "name": f"{task.name}.{socket._name}", + "from": f"{task.name}.{socket._name}", } ) # add built-in sockets diff --git a/src/aiida_workgraph/engine/task_manager.py b/src/aiida_workgraph/engine/task_manager.py index d0175dfe..30dc9334 100644 --- a/src/aiida_workgraph/engine/task_manager.py +++ b/src/aiida_workgraph/engine/task_manager.py @@ -41,7 +41,7 @@ def get_task(self, name: str): for output in task.outputs: output.value = get_nested_dict( self.ctx._tasks[name]["results"], - output.name, + output._name, default=output.value, ) return task @@ -249,6 +249,7 @@ def run_tasks(self, names: List[str], continue_workgraph: bool = True) -> None: kwargs[key] = args[i] # update the port namespace kwargs = update_nested_dict_with_special_keys(kwargs) + print("kwargs: ", kwargs) # kwargs["meta.label"] = name # output must be a Data type or a mapping of {string: Data} task["results"] = {} @@ -580,6 +581,7 @@ def get_inputs( Dict[str, Any], ]: """Get input based on the links.""" + from node_graph.utils import collect_values_inside_namespace args = [] args_dict = {} @@ -587,16 +589,21 @@ def get_inputs( var_args = None var_kwargs = None task = self.ctx._tasks[name] - properties = task.get("properties", {}) inputs = {} + for name, prop in task.get("properties", {}).items(): + inputs[name] = self.ctx_manager.update_context_variable(prop["value"]) for name, input in task["inputs"].items(): # print(f"input: {input['name']}") - if len(input["links"]) == 0: + if input["identifier"] == "workgraph.namespace": + # inputs[name] = self.ctx_manager.update_context_variable(input["value"]) + inputs[name] = collect_values_inside_namespace(input) + else: inputs[name] = self.ctx_manager.update_context_variable( input["property"]["value"] ) - elif len(input["links"]) == 1: - link = input["links"][0] + for name, links in task["input_links"].items(): + if len(links) == 1: + link = links[0] if self.ctx._tasks[link["from_node"]]["results"] is None: inputs[name] = None else: @@ -611,9 +618,9 @@ def get_inputs( link["from_socket"], ) # handle the case of multiple outputs - elif len(input["links"]) > 1: + elif len(links) > 1: value = {} - for link in input["links"]: + for link in links: item_name = f'{link["from_node"]}_{link["from_socket"]}' # handle the special socket _wait, _outputs if link["from_socket"] == "_wait": @@ -625,42 +632,18 @@ def get_inputs( "results" ][link["from_socket"]] inputs[name] = value - for name in task.get("args", []): - if name in inputs: - args.append(inputs[name]) - args_dict[name] = inputs[name] - else: - value = self.ctx_manager.update_context_variable( - properties[name]["value"] - ) - args.append(value) - args_dict[name] = value - for name in task.get("kwargs", []): - if name in inputs: - kwargs[name] = inputs[name] - else: - value = self.ctx_manager.update_context_variable( - properties[name]["value"] - ) - kwargs[name] = value - if task["var_args"] is not None: - name = task["var_args"] - if name in inputs: - var_args = inputs[name] - else: - value = self.ctx_manager.update_context_variable( - properties[name]["value"] - ) - var_args = value - if task["var_kwargs"] is not None: - name = task["var_kwargs"] - if name in inputs: - var_kwargs = inputs[name] - else: - value = self.ctx_manager.update_context_variable( - properties[name]["value"] - ) - var_kwargs = value + for name, input in inputs.items(): + # only need to check the top level key + key = name.split(".")[0] + if key in task["args"]: + args.append(input) + args_dict[name] = input + elif key in task["kwargs"]: + kwargs[name] = input + elif key == task["var_args"]: + var_args = input + elif key == task["var_kwargs"]: + var_kwargs = input return args, kwargs, var_args, var_kwargs, args_dict def update_task_state(self, name: str, success=True) -> None: @@ -728,7 +711,9 @@ def update_normal_task_state(self, name, results, success=True): """Set the results of a normal task. A normal task is created by decorating a function with @task(). """ - from aiida_workgraph.utils import get_sorted_names + from aiida_workgraph.config import builtin_outputs + + builtin_output_names = [output["name"] for output in builtin_outputs] if success: task = self.ctx._tasks[name] @@ -737,7 +722,9 @@ def update_normal_task_state(self, name, results, success=True): if len(task["outputs"]) - 2 != len(results): self.on_task_failed(name) return self.process.exit_codes.OUTPUS_NOT_MATCH_RESULTS - output_names = get_sorted_names(task["outputs"])[0:-2] + output_names = [ + name for name in task["outputs"] if name not in builtin_output_names + ] for i, output_name in enumerate(output_names): task["results"][output_name] = results[i] elif isinstance(results, dict): diff --git a/src/aiida_workgraph/engine/utils.py b/src/aiida_workgraph/engine/utils.py index ebfdcba3..7dadfdaa 100644 --- a/src/aiida_workgraph/engine/utils.py +++ b/src/aiida_workgraph/engine/utils.py @@ -13,9 +13,11 @@ def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple: # because kwargs is updated using update_nested_dict_with_special_keys # which means the data is grouped by the task name for socket_name, value in data.items(): - wgdata["tasks"][task_name]["inputs"][socket_name]["property"][ - "value" - ] = value + input = wgdata["tasks"][task_name]["inputs"][socket_name] + if input["identifier"] == "workgraph.namespace": + input["value"] = value + else: + input["property"]["value"] = value # merge the properties # organize_nested_inputs(wgdata) # serialize_workgraph_inputs(wgdata) diff --git a/src/aiida_workgraph/engine/workgraph.py b/src/aiida_workgraph/engine/workgraph.py index 582cad94..4590ecfb 100644 --- a/src/aiida_workgraph/engine/workgraph.py +++ b/src/aiida_workgraph/engine/workgraph.py @@ -314,11 +314,10 @@ def read_wgdata_from_base(self) -> t.Dict[str, t.Any]: for name, task in wgdata["tasks"].items(): wgdata["tasks"][name] = deserialize_unsafe(task) for _, input in wgdata["tasks"][name]["inputs"].items(): - if input["property"] is None: - continue - prop = input["property"] - if isinstance(prop["value"], PickledLocalFunction): - prop["value"] = prop["value"].value + if input.get("property"): + prop = input["property"] + if isinstance(prop["value"], PickledLocalFunction): + prop["value"] = prop["value"].value wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"]) wgdata["context"] = deserialize_unsafe(wgdata["context"]) return wgdata diff --git a/src/aiida_workgraph/orm/mapping.py b/src/aiida_workgraph/orm/mapping.py new file mode 100644 index 00000000..c6396447 --- /dev/null +++ b/src/aiida_workgraph/orm/mapping.py @@ -0,0 +1,18 @@ +from aiida import orm + + +type_mapping = { + "default": "workgraph.any", + "namespace": "workgraph.namespace", + int: "workgraph.int", + float: "workgraph.float", + str: "workgraph.string", + bool: "workgraph.bool", + orm.Int: "workgraph.aiida_int", + orm.Float: "workgraph.aiida_float", + orm.Str: "workgraph.aiida_string", + orm.Bool: "workgraph.aiida_bool", + orm.List: "workgraph.aiida_list", + orm.Dict: "workgraph.aiida_dict", + orm.StructureData: "workgraph.aiida_structuredata", +} diff --git a/src/aiida_workgraph/socket.py b/src/aiida_workgraph/socket.py index ff9f677a..a2e270f4 100644 --- a/src/aiida_workgraph/socket.py +++ b/src/aiida_workgraph/socket.py @@ -1,9 +1,13 @@ from typing import Any, Type from aiida import orm -from node_graph.socket import NodeSocket +from node_graph.socket import ( + NodeSocket, + NodeSocketNamespace, +) from aiida_workgraph.property import TaskProperty +from aiida_workgraph.orm.mapping import type_mapping class TaskSocket(NodeSocket): @@ -11,7 +15,7 @@ class TaskSocket(NodeSocket): # use TaskProperty from aiida_workgraph.property # to override the default NodeProperty from node_graph - node_property = TaskProperty + _socket_property_class = TaskProperty @property def node_value(self): @@ -30,6 +34,17 @@ def get_node_value(self): return self.value +class TaskSocketNamespace(NodeSocketNamespace): + """Represent a namespace of a Task in the AiiDA WorkGraph.""" + + _identifier = "workgraph.namespace" + _socket_property_class = TaskProperty + _type_mapping: dict = type_mapping + + def __init__(self, *args, **kwargs): + super().__init__(*args, entry_point="aiida_workgraph.socket", **kwargs) + + def build_socket_from_AiiDA(DataClass: Type[Any]) -> Type[TaskSocket]: """Create a socket class from AiiDA DataClass.""" diff --git a/src/aiida_workgraph/sockets/builtins.py b/src/aiida_workgraph/sockets/builtins.py index 3d31e5a2..57d05a29 100644 --- a/src/aiida_workgraph/sockets/builtins.py +++ b/src/aiida_workgraph/sockets/builtins.py @@ -4,103 +4,96 @@ class SocketAny(TaskSocket): """Any socket.""" - identifier: str = "workgraph.any" - property_identifier: str = "workgraph.any" - - -class SocketNamespace(TaskSocket): - """Namespace socket.""" - - identifier: str = "workgraph.namespace" - property_identifier: str = "workgraph.any" + _identifier: str = "workgraph.any" + _socket_property_identifier: str = "workgraph.any" class SocketFloat(TaskSocket): """Float socket.""" - identifier: str = "workgraph.float" - property_identifier: str = "workgraph.float" + _identifier: str = "workgraph.float" + _socket_property_identifier: str = "workgraph.float" class SocketInt(TaskSocket): """Int socket.""" - identifier: str = "workgraph.int" - property_identifier: str = "workgraph.int" + _identifier: str = "workgraph.int" + _socket_property_identifier: str = "workgraph.int" class SocketString(TaskSocket): """String socket.""" - identifier: str = "workgraph.string" - property_identifier: str = "workgraph.string" + _identifier: str = "workgraph.string" + _socket_property_identifier: str = "workgraph.string" class SocketBool(TaskSocket): """Bool socket.""" - identifier: str = "workgraph.bool" - property_identifier: str = "workgraph.bool" + _identifier: str = "workgraph.bool" + _socket_property_identifier: str = "workgraph.bool" class SocketAiiDAFloat(TaskSocket): """AiiDAFloat socket.""" - identifier: str = "workgraph.aiida_float" - property_identifier: str = "workgraph.aiida_float" + _identifier: str = "workgraph.aiida_float" + _socket_property_identifier: str = "workgraph.aiida_float" class SocketAiiDAInt(TaskSocket): """AiiDAInt socket.""" - identifier: str = "workgraph.aiida_int" - property_identifier: str = "workgraph.aiida_int" + _identifier: str = "workgraph.aiida_int" + _socket_property_identifier: str = "workgraph.aiida_int" class SocketAiiDAString(TaskSocket): """AiiDAString socket.""" - identifier: str = "workgraph.aiida_string" - property_identifier: str = "workgraph.aiida_string" + _identifier: str = "workgraph.aiida_string" + _socket_property_identifier: str = "workgraph.aiida_string" class SocketAiiDABool(TaskSocket): """AiiDABool socket.""" - identifier: str = "workgraph.aiida_bool" - property_identifier: str = "workgraph.aiida_bool" + _identifier: str = "workgraph.aiida_bool" + _socket_property_identifier: str = "workgraph.aiida_bool" class SocketAiiDAList(TaskSocket): """AiiDAList socket.""" - identifier: str = "workgraph.aiida_list" - property_identifier: str = "workgraph.aiida_list" + _identifier: str = "workgraph.aiida_list" + _socket_property_identifier: str = "workgraph.aiida_list" class SocketAiiDADict(TaskSocket): """AiiDADict socket.""" - identifier: str = "workgraph.aiida_dict" - property_identifier: str = "workgraph.aiida_dict" + _identifier: str = "workgraph.aiida_dict" + _socket_property_identifier: str = "workgraph.aiida_dict" class SocketAiiDAIntVector(TaskSocket): """Socket with a AiiDAIntVector property.""" - identifier: str = "workgraph.aiida_int_vector" - property_identifier: str = "workgraph.aiida_int_vector" + _identifier: str = "workgraph.aiida_int_vector" + _socket_property_identifier: str = "workgraph.aiida_int_vector" class SocketAiiDAFloatVector(TaskSocket): """Socket with a FloatVector property.""" - identifier: str = "workgraph.aiida_float_vector" - property_identifier: str = "workgraph.aiida_float_vector" + _identifier: str = "workgraph.aiida_float_vector" + _socket_property_identifier: str = "workgraph.aiida_float_vector" class SocketStructureData(TaskSocket): """Any socket.""" - identifier: str = "workgraph.aiida_structuredata" - property_identifier: str = "workgraph.aiida_structuredata" + _identifier: str = "workgraph.aiida_structuredata" + _socket_property_identifier: str = "workgraph.aiida_structuredata" diff --git a/src/aiida_workgraph/task.py b/src/aiida_workgraph/task.py index 882be22c..58fb2a5a 100644 --- a/src/aiida_workgraph/task.py +++ b/src/aiida_workgraph/task.py @@ -4,11 +4,10 @@ from aiida_workgraph.properties import property_pool from aiida_workgraph.sockets import socket_pool +from aiida_workgraph.socket import NodeSocketNamespace from node_graph_widget import NodeGraphWidget from aiida_workgraph.collection import ( WorkGraphPropertyCollection, - WorkGraphInputSocketCollection, - WorkGraphOutputSocketCollection, ) import aiida from typing import Any, Dict, Optional, Union, Callable, List, Set, Iterable @@ -38,8 +37,8 @@ def __init__( """ super().__init__( property_collection_class=WorkGraphPropertyCollection, - input_collection_class=WorkGraphInputSocketCollection, - output_collection_class=WorkGraphOutputSocketCollection, + input_collection_class=NodeSocketNamespace, + output_collection_class=NodeSocketNamespace, **kwargs, ) self.context_mapping = {} if context_mapping is None else context_mapping @@ -78,54 +77,12 @@ def set_context(self, context: Dict[str, Any]) -> None: key is the context key, value is the output key. """ # all values should belong to the outputs.keys() - remain_keys = set(context.values()).difference(self.outputs.keys()) + remain_keys = set(context.values()).difference(self.get_output_names()) if remain_keys: msg = f"Keys {remain_keys} are not in the outputs of this task." raise ValueError(msg) self.context_mapping.update(context) - def set(self, data: Dict[str, Any]) -> None: - from node_graph.socket import NodeSocket - - super().set(data) - - def process_nested_inputs( - base_key: str, value: Any, dynamic: bool = False - ) -> None: - """Recursive function to process nested inputs. - Creates sockets and links dynamically for nested values. - """ - if isinstance(value, dict): - keys = list(value.keys()) - for sub_key in keys: - sub_value = value[sub_key] - # Form the full key for the current nested level - full_key = f"{base_key}.{sub_key}" if base_key else sub_key - - # Create a new input socket if it does not exist - if full_key not in self.inputs.keys() and dynamic: - self.inputs.new( - "workgraph.any", - name=full_key, - metadata={"required": True}, - ) - if isinstance(sub_value, NodeSocket): - self.parent.links.new(sub_value, self.inputs[full_key]) - value.pop(sub_key) - else: - # Recursively process nested dictionaries - process_nested_inputs(full_key, sub_value, dynamic) - - # create input sockets and links for items inside a dynamic socket - # TODO the input value could be nested, but we only support one level for now - for key in data: - if self.inputs[key].identifier == "workgraph.namespace": - process_nested_inputs( - key, - self.inputs[key].value, - dynamic=self.inputs[key].metadata.get("dynamic", False), - ) - def set_from_builder(self, builder: Any) -> None: """Set the task inputs from a AiiDA ProcessBuilder.""" from aiida_workgraph.utils import get_dict_from_builder @@ -231,7 +188,7 @@ def to_widget_value(self): for key in ("properties", "executor", "node_class", "process"): tdata.pop(key, None) for input in tdata["inputs"].values(): - input.pop("property") + input.pop("property", None) tdata["label"] = tdata["identifier"] @@ -290,9 +247,9 @@ def _normalize_tasks( task_objects = [] for task in tasks: if isinstance(task, str): - if task not in self.graph.tasks.keys(): + if task not in self.graph.tasks: raise ValueError( - f"Task '{task}' is not in the graph. Available tasks: {self.graph.tasks.keys()}" + f"Task '{task}' is not in the graph. Available tasks: {self.graph.tasks}" ) task_objects.append(self.graph.tasks[task]) elif isinstance(task, Task): diff --git a/src/aiida_workgraph/tasks/builtins.py b/src/aiida_workgraph/tasks/builtins.py index 892f1859..03d83aca 100644 --- a/src/aiida_workgraph/tasks/builtins.py +++ b/src/aiida_workgraph/tasks/builtins.py @@ -17,10 +17,12 @@ def __init__(self, *args, **kwargs): self.children = TaskCollection(parent=self) def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.outputs.new("workgraph.any", "_wait") + self.inputs._clear() + self.outputs._clear() + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) + self.add_output("workgraph.any", "_wait") def to_dict(self, short: bool = False) -> Dict[str, Any]: tdata = super().to_dict(short=short) @@ -41,14 +43,16 @@ class While(Zone): catalog = "Control" def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.inputs.new( + self.inputs._clear() + self.outputs._clear() + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) + self.add_input( "node_graph.int", "max_iterations", property_data={"default": 10000} ) - self.inputs.new("workgraph.any", "conditions", link_limit=100000) - self.outputs.new("workgraph.any", "_wait") + self.add_input("workgraph.any", "conditions", link_limit=100000) + self.add_output("workgraph.any", "_wait") class If(Zone): @@ -60,12 +64,14 @@ class If(Zone): catalog = "Control" def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.inputs.new("workgraph.any", "conditions") - self.inputs.new("workgraph.any", "invert_condition") - self.outputs.new("workgraph.any", "_wait") + self.inputs._clear() + self.outputs._clear() + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) + self.add_input("workgraph.any", "conditions") + self.add_input("workgraph.any", "invert_condition") + self.add_output("workgraph.any", "_wait") class SetContext(Task): @@ -77,12 +83,14 @@ class SetContext(Task): catalog = "Control" def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.any", "key") - self.inputs.new("workgraph.any", "value") - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.outputs.new("workgraph.any", "_wait") + self.inputs._clear() + self.outputs._clear() + self.add_input("workgraph.any", "key") + self.add_input("workgraph.any", "value") + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) + self.add_output("workgraph.any", "_wait") class GetContext(Task): @@ -94,12 +102,14 @@ class GetContext(Task): catalog = "Control" def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.any", "key") - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.outputs.new("workgraph.any", "result") - self.outputs.new("workgraph.any", "_wait") + self.inputs._clear() + self.outputs._clear() + self.add_input("workgraph.any", "key") + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) + self.add_output("workgraph.any", "result") + self.add_output("workgraph.any", "_wait") class AiiDAInt(Task): @@ -114,10 +124,12 @@ class AiiDAInt(Task): } def create_sockets(self) -> None: - self.inputs.new("workgraph.any", "value", property_data={"default": 0.0}) - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.outputs.new("workgraph.aiida_int", "result") - self.outputs.new("workgraph.any", "_wait") + self.add_input("workgraph.any", "value", property_data={"default": 0.0}) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) + self.add_output("workgraph.aiida_int", "result") + self.add_output("workgraph.any", "_wait") class AiiDAFloat(Task): @@ -132,12 +144,14 @@ class AiiDAFloat(Task): } def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.float", "value", property_data={"default": 0.0}) - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.outputs.new("workgraph.aiida_float", "result") - self.outputs.new("workgraph.any", "_wait") + self.inputs._clear() + self.outputs._clear() + self.add_input("workgraph.float", "value", property_data={"default": 0.0}) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) + self.add_output("workgraph.aiida_float", "result") + self.add_output("workgraph.any", "_wait") class AiiDAString(Task): @@ -152,12 +166,14 @@ class AiiDAString(Task): } def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.string", "value", property_data={"default": ""}) - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.outputs.new("workgraph.aiida_string", "result") - self.outputs.new("workgraph.any", "_wait") + self.inputs._clear() + self.outputs._clear() + self.add_input("workgraph.string", "value", property_data={"default": ""}) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) + self.add_output("workgraph.aiida_string", "result") + self.add_output("workgraph.any", "_wait") class AiiDAList(Task): @@ -172,12 +188,14 @@ class AiiDAList(Task): } def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.any", "value", property_data={"default": []}) - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.outputs.new("workgraph.aiida_list", "result") - self.outputs.new("workgraph.any", "_wait") + self.inputs._clear() + self.outputs._clear() + self.add_input("workgraph.any", "value", property_data={"default": []}) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) + self.add_output("workgraph.aiida_list", "result") + self.add_output("workgraph.any", "_wait") class AiiDADict(Task): @@ -192,12 +210,14 @@ class AiiDADict(Task): } def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.any", "value", property_data={"default": {}}) - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.outputs.new("workgraph.aiida_dict", "result") - self.outputs.new("workgraph.any", "_wait") + self.inputs._clear() + self.outputs._clear() + self.add_input("workgraph.any", "value", property_data={"default": {}}) + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) + self.add_output("workgraph.aiida_dict", "result") + self.add_output("workgraph.any", "_wait") class AiiDANode(Task): @@ -217,15 +237,17 @@ def create_properties(self) -> None: pass def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.any", "identifier") - self.inputs.new("workgraph.any", "pk") - self.inputs.new("workgraph.any", "uuid") - self.inputs.new("workgraph.any", "label") - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.outputs.new("workgraph.any", "node") - self.outputs.new("workgraph.any", "_wait") + self.inputs._clear() + self.outputs._clear() + self.add_input("workgraph.any", "identifier") + self.add_input("workgraph.any", "pk") + self.add_input("workgraph.any", "uuid") + self.add_input("workgraph.any", "label") + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) + self.add_output("workgraph.any", "node") + self.add_output("workgraph.any", "_wait") class AiiDACode(Task): @@ -242,15 +264,17 @@ class AiiDACode(Task): } def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.any", "identifier") - self.inputs.new("workgraph.any", "pk") - self.inputs.new("workgraph.any", "uuid") - self.inputs.new("workgraph.any", "label") - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.outputs.new("workgraph.any", "Code") - self.outputs.new("workgraph.any", "_wait") + self.inputs._clear() + self.outputs._clear() + self.add_input("workgraph.any", "identifier") + self.add_input("workgraph.any", "pk") + self.add_input("workgraph.any", "uuid") + self.add_input("workgraph.any", "label") + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) + self.add_output("workgraph.any", "Code") + self.add_output("workgraph.any", "_wait") class Select(Task): @@ -267,11 +291,13 @@ class Select(Task): } def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.any", "condition") - self.inputs.new("workgraph.any", "true") - self.inputs.new("workgraph.any", "false") - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.outputs.new("workgraph.any", "result") - self.outputs.new("workgraph.any", "_wait") + self.inputs._clear() + self.outputs._clear() + self.add_input("workgraph.any", "condition") + self.add_input("workgraph.any", "true") + self.add_input("workgraph.any", "false") + self.add_input( + "workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"} + ) + self.add_output("workgraph.any", "result") + self.add_output("workgraph.any", "_wait") diff --git a/src/aiida_workgraph/tasks/monitors.py b/src/aiida_workgraph/tasks/monitors.py index 44d4f1be..fe54a1ef 100644 --- a/src/aiida_workgraph/tasks/monitors.py +++ b/src/aiida_workgraph/tasks/monitors.py @@ -15,17 +15,19 @@ class TimeMonitor(Task): } def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.any", "time") - inp = self.inputs.new("workgraph.any", "interval") + self.inputs._clear() + self.outputs._clear() + self.add_input("workgraph.any", "time") + inp = self.add_input("workgraph.any", "interval") inp.add_property("workgraph.any", default=1.0) - inp = self.inputs.new("workgraph.any", "timeout") + inp = self.add_input("workgraph.any", "timeout") inp.add_property("workgraph.any", default=86400.0) - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - inp.link_limit = 100000 - self.outputs.new("workgraph.any", "result") - self.outputs.new("workgraph.any", "_wait") + self.add_input( + "workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000 + ) + inp._link_limit = 100000 + self.add_output("workgraph.any", "result") + self.add_output("workgraph.any", "_wait") class FileMonitor(Task): @@ -42,17 +44,19 @@ class FileMonitor(Task): } def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.any", "filepath") - inp = self.inputs.new("workgraph.any", "interval") + self.inputs._clear() + self.outputs._clear() + self.add_input("workgraph.any", "filepath") + inp = self.add_input("workgraph.any", "interval") inp.add_property("workgraph.any", default=1.0) - inp = self.inputs.new("workgraph.any", "timeout") + inp = self.add_input("workgraph.any", "timeout") inp.add_property("workgraph.any", default=86400.0) - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - inp.link_limit = 100000 - self.outputs.new("workgraph.any", "result") - self.outputs.new("workgraph.any", "_wait") + self.add_input( + "workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000 + ) + inp._link_limit = 100000 + self.add_output("workgraph.any", "result") + self.add_output("workgraph.any", "_wait") class TaskMonitor(Task): @@ -69,16 +73,18 @@ class TaskMonitor(Task): } def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.any", "workgraph_pk") - self.inputs.new("workgraph.any", "workgraph_name") - self.inputs.new("workgraph.any", "task_name") - inp = self.inputs.new("workgraph.any", "interval") + self.inputs._clear() + self.outputs._clear() + self.add_input("workgraph.any", "workgraph_pk") + self.add_input("workgraph.any", "workgraph_name") + self.add_input("workgraph.any", "task_name") + inp = self.add_input("workgraph.any", "interval") inp.add_property("workgraph.any", default=1.0) - inp = self.inputs.new("workgraph.any", "timeout") + inp = self.add_input("workgraph.any", "timeout") inp.add_property("workgraph.any", default=86400.0) - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - inp.link_limit = 100000 - self.outputs.new("workgraph.any", "result") - self.outputs.new("workgraph.any", "_wait") + self.add_input( + "workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000 + ) + inp._link_limit = 100000 + self.add_output("workgraph.any", "result") + self.add_output("workgraph.any", "_wait") diff --git a/src/aiida_workgraph/tasks/pythonjob.py b/src/aiida_workgraph/tasks/pythonjob.py index b86a5106..ec92fc51 100644 --- a/src/aiida_workgraph/tasks/pythonjob.py +++ b/src/aiida_workgraph/tasks/pythonjob.py @@ -11,19 +11,28 @@ class PythonJob(Task): def update_from_dict(self, data: Dict[str, Any], **kwargs) -> "PythonJob": """Overwrite the update_from_dict method to handle the PythonJob data.""" - self.deserialize_pythonjob_data(data) + self.deserialize_pythonjob_data(data["inputs"]) super().update_from_dict(data) @classmethod - def serialize_pythonjob_data(cls, tdata: Dict[str, Any]): + def serialize_pythonjob_data( + cls, input_data: Dict[str, Any], is_function_input: bool = False + ) -> None: """Serialize the properties for PythonJob.""" - for input in tdata["inputs"].values(): - if input["metadata"].get("is_function_input", False): - input["property"]["value"] = cls.serialize_socket_data(input) + for input in input_data.values(): + if is_function_input or input["metadata"].get("is_function_input", False): + if input["identifier"] == "workgraph.namespace": + cls.serialize_pythonjob_data( + input["sockets"], is_function_input=True + ) + elif input.get("property", {}).get("value") is not None: + cls.serialize_socket_data(input) @classmethod - def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None: + def deserialize_pythonjob_data( + cls, input_data: Dict[str, Any], is_function_input: bool = False + ) -> None: """ Process the task data dictionary for a PythonJob. It load the orignal Python data from the AiiDA Data node for the @@ -36,41 +45,26 @@ def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None: Dict[str, Any]: The processed data dictionary. """ - for input in tdata["inputs"].values(): - if input["metadata"].get("is_function_input", False): - input["property"]["value"] = cls.deserialize_socket_data(input) + for input in input_data.values(): + if is_function_input or input["metadata"].get("is_function_input", False): + if input["identifier"] == "workgraph.namespace": + print("deserialize namespace: ", input["name"]) + cls.deserialize_pythonjob_data( + input["sockets"], is_function_input=True + ) + else: + print("deserialize socket: ", input["name"]) + cls.deserialize_socket_data(input) @classmethod def serialize_socket_data(cls, data: Dict[str, Any]) -> Any: - if data.get("identifier", "Any").upper() == "WORKGRAPH.NAMESPACE": - if data["property"]["value"] is None: - return None - if isinstance(data["property"]["value"], dict): - serialized_result = {} - for key, value in data["property"]["value"].items(): - serialized_result[key] = general_serializer(value) - return serialized_result - else: - raise ValueError("Namespace socket should be a dictionary.") - else: - if isinstance(data["property"]["value"], orm.Data): - return data["property"]["value"] - return general_serializer(data["property"]["value"]) + value = data.get("property", {}).get("value") + if value is None or isinstance(value, orm.Data): + return + data["property"]["value"] = general_serializer(value) @classmethod def deserialize_socket_data(cls, data: Dict[str, Any]) -> Any: - if data.get("identifier", "Any").upper() == "WORKGRAPH.NAMESPACE": - if isinstance(data["property"]["value"], dict): - deserialized_result = {} - for key, value in data["property"]["value"].items(): - if isinstance(value, orm.Data): - deserialized_result[key] = value.value - else: - deserialized_result[key] = value - return deserialized_result - else: - raise ValueError("Namespace socket should be a dictionary.") - else: - if isinstance(data["property"]["value"], orm.Data): - return data["property"]["value"].value - return data["property"]["value"] + value = data.get("property", {}).get("value") + if isinstance(value, orm.Data): + data["property"]["value"] = value.value diff --git a/src/aiida_workgraph/tasks/test.py b/src/aiida_workgraph/tasks/test.py index e4d0bd77..77adeff9 100644 --- a/src/aiida_workgraph/tasks/test.py +++ b/src/aiida_workgraph/tasks/test.py @@ -14,19 +14,21 @@ class TestAdd(Task): } def create_properties(self) -> None: - self.properties.new("workgraph.aiida_float", "t", default=1.0) + self.add_property("workgraph.aiida_float", "t", default=1.0) def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - inp = self.inputs.new("workgraph.aiida_float", "x") + self.inputs._clear() + self.outputs._clear() + inp = self.add_input("workgraph.aiida_float", "x") inp.add_property("workgraph.aiida_float", "x", default=0.0) - inp = self.inputs.new("workgraph.aiida_float", "y") + inp = self.add_input("workgraph.aiida_float", "y") inp.add_property("workgraph.aiida_float", "y", default=0.0) - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.outputs.new("workgraph.aiida_float", "sum") - self.outputs.new("workgraph.any", "_wait") - self.outputs.new("workgraph.any", "_outputs") + self.add_input( + "workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000 + ) + self.add_output("workgraph.aiida_float", "sum") + self.add_output("workgraph.any", "_wait") + self.add_output("workgraph.any", "_outputs") class TestSumDiff(Task): @@ -42,20 +44,22 @@ class TestSumDiff(Task): } def create_properties(self) -> None: - self.properties.new("workgraph.aiida_float", "t", default=1.0) + self.properties._new("workgraph.aiida_float", "t", default=1.0) def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - inp = self.inputs.new("workgraph.aiida_float", "x") + self.inputs._clear() + self.outputs._clear() + inp = self.add_input("workgraph.aiida_float", "x") inp.add_property("workgraph.aiida_float", "x", default=0.0) - inp = self.inputs.new("workgraph.aiida_float", "y") + inp = self.add_input("workgraph.aiida_float", "y") inp.add_property("workgraph.aiida_float", "y", default=0.0) - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.outputs.new("workgraph.aiida_float", "sum") - self.outputs.new("workgraph.aiida_float", "diff") - self.outputs.new("workgraph.any", "_wait") - self.outputs.new("workgraph.any", "_outputs") + self.add_input( + "workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000 + ) + self.add_output("workgraph.aiida_float", "sum") + self.add_output("workgraph.aiida_float", "diff") + self.add_output("workgraph.any", "_wait") + self.add_output("workgraph.any", "_outputs") class TestArithmeticMultiplyAdd(Task): @@ -74,16 +78,18 @@ def create_properties(self) -> None: pass def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.any", "code") - inp = self.inputs.new("workgraph.aiida_int", "x") + self.inputs._clear() + self.outputs._clear() + self.add_input("workgraph.any", "code") + inp = self.add_input("workgraph.aiida_int", "x") inp.add_property("workgraph.aiida_int", "x", default=0.0) - inp = self.inputs.new("workgraph.aiida_int", "y") + inp = self.add_input("workgraph.aiida_int", "y") inp.add_property("workgraph.aiida_int", "y", default=0.0) - inp = self.inputs.new("workgraph.aiida_int", "z") + inp = self.add_input("workgraph.aiida_int", "z") inp.add_property("workgraph.aiida_int", "z", default=0.0) - self.inputs.new("workgraph.any", "_wait", arg_type="none", link_limit=100000) - self.outputs.new("workgraph.aiida_int", "result") - self.outputs.new("workgraph.any", "_wait") - self.outputs.new("workgraph.any", "_outputs") + self.add_input( + "workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000 + ) + self.add_output("workgraph.aiida_int", "result") + self.add_output("workgraph.any", "_wait") + self.add_output("workgraph.any", "_outputs") diff --git a/src/aiida_workgraph/utils/__init__.py b/src/aiida_workgraph/utils/__init__.py index 70c94a72..494a9b21 100644 --- a/src/aiida_workgraph/utils/__init__.py +++ b/src/aiida_workgraph/utils/__init__.py @@ -42,18 +42,6 @@ def build_callable(obj: Callable) -> Dict[str, Any]: return executor -def get_sorted_names(data: dict) -> list[str]: - """Get the sorted names from a dictionary.""" - sorted_names = [ - name - for name, _ in sorted( - ((name, item["list_index"]) for name, item in data.items()), - key=lambda x: x[1], - ) - ] - return sorted_names - - def store_nodes_recursely(data: Any) -> None: """Recurse through a data structure and store any unstored nodes that are found along the way :param data: a data structure potentially containing unstored nodes @@ -223,48 +211,6 @@ def update_nested_dict_with_special_keys(data: Dict[str, Any]) -> Dict[str, Any] return data -def organize_nested_inputs(wgdata: Dict[str, Any]) -> None: - """Merge sub properties to the root properties. - The sub properties will be se - For example: - task["inputs"]["base"]["property"]["value"] = None - task["inputs"]["base.pw.parameters"]["property"]["value"] = 2 - task["inputs"]["base.pw.code"]["property"]["value"] = 1 - task["inputs"]["metadata"]["property"]["value"] = {"options": {"resources": {"num_cpus": 1}} - task["inputs"]["metadata.options"]["property"]["value"] = {"resources": {"num_machine": 1}} - After organizing: - task["inputs"]["base"]["property"]["value"] = {"base": {"pw": {"parameters": 2, - "code": 1}, - "metadata": {"options": - {"resources": {"num_cpus": 1, - "num_machine": 1}}}}, - } - task["inputs"]["base.pw.parameters"]["property"]["value"] = None - task["inputs"]["base.pw.code"]["property"]["value"] = None - task["inputs"]["metadata"]["property"]["value"] = None - task["inputs"]["metadata.options"]["property"]["value"] = None - """ - for _, task in wgdata["tasks"].items(): - for key, prop in task["properties"].items(): - if "." in key and prop["value"] not in [None, {}]: - root, key = key.split(".", 1) - root_prop = task["properties"][root] - update_nested_dict(root_prop["value"], key, prop["value"]) - prop["value"] = None - for key, input in task["inputs"].items(): - if input["property"] is None: - continue - prop = input["property"] - if "." in key and prop["value"] not in [None, {}]: - root, key = key.split(".", 1) - root_prop = task["inputs"][root]["property"] - # update the root property - root_prop["value"] = update_nested_dict( - root_prop["value"], key, prop["value"] - ) - prop["value"] = None - - def generate_node_graph( pk: int, output: str = None, width: str = "100%", height: str = "600px" ) -> Any: @@ -464,13 +410,12 @@ def serialize_workgraph_inputs(wgdata): if not data["handler"]["use_module_path"]: pickle_callable(data["handler"]) if task["metadata"]["node_type"].upper() == "PYTHONJOB": - PythonJob.serialize_pythonjob_data(task) + PythonJob.serialize_pythonjob_data(task["inputs"]) for _, input in task["inputs"].items(): - if input["property"] is None: - continue - prop = input["property"] - if inspect.isfunction(prop["value"]): - prop["value"] = PickledLocalFunction(prop["value"]).store() + if input.get("property"): + prop = input["property"] + if inspect.isfunction(prop["value"]): + prop["value"] = PickledLocalFunction(prop["value"]).store() # error_handlers of the workgraph for _, data in wgdata["error_handlers"].items(): if not data["handler"]["use_module_path"]: @@ -548,7 +493,7 @@ def process_properties(task: Dict) -> Dict: } # for name, input in task["inputs"].items(): - if input["property"] is not None: + if input.get("property"): prop = input["property"] identifier = prop["identifier"] value = prop.get("value") @@ -631,6 +576,8 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di elif isinstance(item, dict): processed_inout_list.append(item) + processed_inout_list = processed_inout_list + return processed_inout_list diff --git a/src/aiida_workgraph/utils/analysis.py b/src/aiida_workgraph/utils/analysis.py index e2dc1af3..43a446eb 100644 --- a/src/aiida_workgraph/utils/analysis.py +++ b/src/aiida_workgraph/utils/analysis.py @@ -82,29 +82,14 @@ def build_task_link(self) -> None: 1) workgraph links """ - # reset task input links - for name, task in self.wgdata["tasks"].items(): - for _, input in task["inputs"].items(): - input["links"] = [] - for _, output in task["outputs"].items(): - output["links"] = [] + # create a `input_links` to store the input links for each task + for task in self.wgdata["tasks"].values(): + task["input_links"] = {} for link in self.wgdata["links"]: - to_socket = [ - socket - for name, socket in self.wgdata["tasks"][link["to_node"]][ - "inputs" - ].items() - if name == link["to_socket"] - ][0] - from_socket = [ - socket - for name, socket in self.wgdata["tasks"][link["from_node"]][ - "outputs" - ].items() - if name == link["from_socket"] - ][0] - to_socket["links"].append(link) - from_socket["links"].append(link) + task = self.wgdata["tasks"][link["to_node"]] + if link["to_socket"] not in task["input_links"]: + task["input_links"][link["to_socket"]] = [] + task["input_links"][link["to_socket"]].append(link) def assign_zone(self) -> None: """Assign zone for each task.""" @@ -139,8 +124,8 @@ def find_zone_inputs(self, name: str) -> None: """Find the input and outputs tasks for the zone.""" task = self.wgdata["tasks"][name] input_tasks = [] - for _, input in self.wgdata["tasks"][name]["inputs"].items(): - for link in input["links"]: + for _, links in self.wgdata["tasks"][name]["input_links"].items(): + for link in links: input_tasks.append(link["from_node"]) # find all the input tasks for child_task in task["children"]: @@ -157,8 +142,8 @@ def find_zone_inputs(self, name: str) -> None: else: # if the child task is not a zone, get the input tasks of the child task # find all the input tasks which outside the while zone - for _, input in self.wgdata["tasks"][child_task]["inputs"].items(): - for link in input["links"]: + for _, links in self.wgdata["tasks"][child_task]["input_links"].items(): + for link in links: input_tasks.append(link["from_node"]) # find the input tasks which are not in the zone new_input_tasks = [] @@ -215,11 +200,10 @@ def insert_workgraph_to_db(self) -> None: self.save_task_states() for name, task in self.wgdata["tasks"].items(): for _, input in task["inputs"].items(): - if input["property"] is None: - continue - prop = input["property"] - if inspect.isfunction(prop["value"]): - prop["value"] = PickledLocalFunction(prop["value"]).store() + if input.get("property"): + prop = input["property"] + if inspect.isfunction(prop["value"]): + prop["value"] = PickledLocalFunction(prop["value"]).store() self.wgdata["tasks"][name] = serialize(task) # nodes is a copy of tasks, so we need to pop it out self.wgdata["error_handlers"] = serialize(self.wgdata["error_handlers"]) @@ -302,7 +286,7 @@ def check_diff( # change tasks to nodes for DifferenceAnalysis wg1["nodes"] = wg1.pop("tasks") self.wgdata["nodes"] = self.wgdata.pop("tasks") - dc = DifferenceAnalysis(nt1=wg1, nt2=self.wgdata) + dc = DifferenceAnalysis(ng1=wg1, ng2=self.wgdata) ( new_tasks, modified_tasks, diff --git a/src/aiida_workgraph/utils/graph.py b/src/aiida_workgraph/utils/graph.py index 2e3b668f..969b3f72 100644 --- a/src/aiida_workgraph/utils/graph.py +++ b/src/aiida_workgraph/utils/graph.py @@ -45,9 +45,9 @@ def link_creation_hook(self, link: Any) -> None: "type": "add_link", "data": { "from_node": link.from_node.name, - "from_socket": link.from_socket.name, + "from_socket": link.from_socket._name, "to_node": link.to_node.name, - "to_socket": link.to_socket.name, + "to_socket": link.to_socket._name, }, } ) @@ -65,9 +65,9 @@ def link_deletion_hook(self, link: Any) -> None: "type": "delete_link", "data": { "from_node": link.from_node.name, - "from_socket": link.from_socket.name, + "from_socket": link.from_socket._name, "to_node": link.to_node.name, - "to_socket": link.to_socket.name, + "to_socket": link.to_socket._name, }, } ) diff --git a/src/aiida_workgraph/workgraph.py b/src/aiida_workgraph/workgraph.py index cb74b312..1181ed17 100644 --- a/src/aiida_workgraph/workgraph.py +++ b/src/aiida_workgraph/workgraph.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import aiida.orm import node_graph import aiida -import node_graph.link -from aiida_workgraph.socket import NodeSocket +from node_graph.link import NodeLink +from aiida_workgraph.socket import TaskSocket from aiida_workgraph.tasks import task_pool from aiida_workgraph.task import Task import time @@ -68,13 +70,9 @@ def tasks(self) -> TaskCollection: def prepare_inputs( self, metadata: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: - from aiida_workgraph.utils import ( - organize_nested_inputs, - serialize_workgraph_inputs, - ) + from aiida_workgraph.utils import serialize_workgraph_inputs wgdata = self.to_dict() - organize_nested_inputs(wgdata) serialize_workgraph_inputs(wgdata) metadata = metadata or {} inputs = {"wg": wgdata, "metadata": metadata} @@ -94,7 +92,7 @@ def run( # set task inputs if inputs is not None: for name, input in inputs.items(): - if name not in self.tasks.keys(): + if name not in self.tasks: raise KeyError(f"Task {name} not found in WorkGraph.") self.tasks[name].set(input) # One can not run again if the process is alreay created. otherwise, a new process node will @@ -126,7 +124,7 @@ def submit( # set task inputs if inputs is not None: for name, input in inputs.items(): - if name not in self.tasks.keys(): + if name not in self.tasks: raise KeyError(f"Task {name} not found in WorkGraph.") self.tasks[name].set(input) @@ -292,12 +290,15 @@ def update(self) -> None: # even if the node.is_finished_ok is True if node.is_finished_ok: # update the output sockets - i = 0 for socket in self.tasks[name].outputs: - socket.value = get_nested_dict( - node.outputs, socket.name, default=None - ) - i += 1 + if socket._identifier == "workgraph.namespace": + socket._value = get_nested_dict( + node.outputs, socket._name, default=None + ) + else: + socket.value = get_nested_dict( + node.outputs, socket._name, default=None + ) # read results from the process outputs elif isinstance(node, aiida.orm.Data): self.tasks[name].outputs[0].value = node @@ -471,13 +472,13 @@ def extend(self, wg: "WorkGraph", prefix: str = "") -> None: for task in wg.tasks: task.name = prefix + task.name task.parent = self - self.tasks.append(task) + self.tasks._append(task) # self.sequence.extend([prefix + task for task in wg.sequence]) # self.conditions.extend(wg.conditions) self.context.update(wg.context) # links for link in wg.links: - self.links.append(link) + self.links._append(link) @property def error_handlers(self) -> Dict[str, Any]: @@ -492,14 +493,14 @@ def add_task( self, identifier: Union[str, callable], name: str = None, **kwargs ) -> Task: """Add a task to the workgraph.""" - node = self.tasks.new(identifier, name, **kwargs) + node = self.tasks._new(identifier, name, **kwargs) return node - def add_link( - self, source: NodeSocket, target: NodeSocket - ) -> node_graph.link.NodeLink: - """Add a link between two nodes.""" - link = self.links.new(source, target) + def add_link(self, source: TaskSocket | Task, target: TaskSocket) -> NodeLink: + """Add a link between two tasks.""" + if isinstance(source, Task): + source = source.outputs["_outputs"] + link = self.links._new(source, target) return link def to_widget_value(self) -> Dict[str, Any]: diff --git a/tests/conftest.py b/tests/conftest.py index 7facdb73..e553dc25 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -88,7 +88,7 @@ def wg_calcjob(add_code) -> WorkGraph: wg = WorkGraph(name="test_debug_math") add1 = wg.add_task(ArithmeticAddCalculation, "add1", x=2, y=3, code=add_code) add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=4, code=add_code) - wg.add_link(add1.outputs["sum"], add2.inputs["y"]) + wg.add_link(add1.outputs.sum, add2.inputs["y"]) return wg @@ -233,11 +233,11 @@ def wg_engine(decorated_add, add_code) -> WorkGraph: add3 = wg.add_task(decorated_add, "add3", x=2, y=3) add4 = wg.add_task(ArithmeticAddCalculation, "add4", x=2, y=4, code=code) add5 = wg.add_task(decorated_add, "add5", x=2, y=5) - wg.add_link(add0.outputs["sum"], add2.inputs["x"]) - wg.add_link(add1.outputs[0], add3.inputs["x"]) - wg.add_link(add3.outputs[0], add4.inputs["x"]) - wg.add_link(add2.outputs["sum"], add5.inputs["x"]) - wg.add_link(add4.outputs["sum"], add5.inputs["y"]) + wg.add_link(add0.outputs.sum, add2.inputs.x) + wg.add_link(add1.outputs[0], add3.inputs.x) + wg.add_link(add3.outputs[0], add4.inputs.x) + wg.add_link(add2.outputs.sum, add5.inputs.x) + wg.add_link(add4.outputs.sum, add5.inputs["y"]) return wg diff --git a/tests/test_action.py b/tests/test_action.py index 82ad8954..bca6dcdd 100644 --- a/tests/test_action.py +++ b/tests/test_action.py @@ -23,20 +23,20 @@ def test_pause_play_task(wg_calcjob): wg.submit() # wait for the workgraph to launch add1 wg.wait(tasks={"add1": ["CREATED"]}, timeout=40, interval=5) - assert wg.tasks["add1"].node.process_state.value.upper() == "CREATED" - assert wg.tasks["add1"].node.process_status == "Paused through WorkGraph" + assert wg.tasks.add1.node.process_state.value.upper() == "CREATED" + assert wg.tasks.add1.node.process_status == "Paused through WorkGraph" # pause add2 after submit wg.pause_tasks(["add2"]) wg.play_tasks(["add1"]) # wait for the workgraph to launch add2 wg.wait(tasks={"add2": ["CREATED"]}, timeout=40, interval=5) - assert wg.tasks["add2"].node.process_state.value.upper() == "CREATED" - assert wg.tasks["add2"].node.process_status == "Paused through WorkGraph" + assert wg.tasks.add2.node.process_state.value.upper() == "CREATED" + assert wg.tasks.add2.node.process_status == "Paused through WorkGraph" # I disabled the following lines because the test is not stable # Seems the daemon is not responding to the play signal wg.play_tasks(["add2"]) wg.wait(interval=5) - assert wg.tasks["add2"].outputs["sum"].value == 9 + assert wg.tasks.add2.outputs.sum.value == 9 def test_pause_play_error_handler(wg_calcjob, finished_process_node): diff --git a/tests/test_awaitable_task.py b/tests/test_awaitable_task.py index 3bb552a7..f10b5808 100644 --- a/tests/test_awaitable_task.py +++ b/tests/test_awaitable_task.py @@ -17,11 +17,11 @@ async def awaitable_func(x, y): wg = WorkGraph(name="test_awaitable_decorator") awaitable_func1 = wg.add_task(awaitable_func, "awaitable_func1", x=1, y=2) - add1 = wg.add_task(decorated_add, "add1", x=1, y=awaitable_func1.outputs["result"]) + add1 = wg.add_task(decorated_add, "add1", x=1, y=awaitable_func1.outputs.result) wg.run() report = get_workchain_report(wg.process, "REPORT") assert "Waiting for child processes: awaitable_func1" in report - assert add1.outputs["result"].value == 4 + assert add1.outputs.result.value == 4 def test_monitor_decorator(): @@ -60,7 +60,7 @@ def test_time_monitor(decorated_add): wg.run() report = get_workchain_report(wg.process, "REPORT") assert "Waiting for child processes: monitor1" in report - assert add1.outputs["result"].value == 3 + assert add1.outputs.result.value == 3 def test_file_monitor(decorated_add, tmp_path): @@ -84,7 +84,7 @@ async def create_test_file(filepath="/tmp/test_file_monitor.txt", t=2): wg.run() report = get_workchain_report(wg.process, "REPORT") assert "Waiting for child processes: monitor1" in report - assert add1.outputs["result"].value == 3 + assert add1.outputs.result.value == 3 @pytest.mark.usefixtures("started_daemon_client") @@ -105,7 +105,7 @@ def test_task_monitor(decorated_add): wg1.add_task(decorated_add, "add1", x=1, y=2, t=5) wg1.submit(wait=True) wg2.wait() - assert wg2.tasks["add1"].node.ctime > wg1.tasks["add1"].node.ctime + assert wg2.tasks.add1.node.ctime > wg1.tasks.add1.node.ctime @pytest.mark.usefixtures("started_daemon_client") diff --git a/tests/test_build_task.py b/tests/test_build_task.py index bdffebd9..2b3abfae 100644 --- a/tests/test_build_task.py +++ b/tests/test_build_task.py @@ -57,17 +57,17 @@ def add_minus(x, y): ], ) assert issubclass(AddTask, Task) - assert "sum" in AddTask().outputs.keys() + assert "sum" in AddTask().get_output_names() # use the class directly wg = WorkGraph() add1 = wg.add_task(add, name="add1") - assert "result" in add1.outputs.keys() + assert "result" in add1.get_output_names() assert add1.name == "add1" AddTask_outputs_list = build_task(add_minus, outputs=["sum", "difference"]) assert issubclass(AddTask_outputs_list, Task) - assert "sum" in AddTask_outputs_list().outputs.keys() - assert "difference" in AddTask_outputs_list().outputs.keys() + assert "sum" in AddTask_outputs_list().get_output_names() + assert "difference" in AddTask_outputs_list().get_output_names() def test_function(): diff --git a/tests/test_calcfunction.py b/tests/test_calcfunction.py index d6c86d00..add763b9 100644 --- a/tests/test_calcfunction.py +++ b/tests/test_calcfunction.py @@ -11,7 +11,7 @@ def test_run(wg_calcfunction: WorkGraph) -> None: print("state: ", wg.state) # print("results: ", results[]) assert wg.tasks["sumdiff2"].node.outputs.sum == 9 - assert wg.tasks["sumdiff2"].outputs["sum"].value == 9 + assert wg.tasks["sumdiff2"].outputs.sum.value == 9 @pytest.mark.usefixtures("started_daemon_client") @@ -27,4 +27,4 @@ def add(**kwargs): wg = WorkGraph("test_dynamic_inputs") wg.add_task(add, name="add1", x=orm.Int(1), y=orm.Int(2)) wg.run() - assert wg.tasks["add1"].outputs["result"].value == 3 + assert wg.tasks.add1.outputs.result.value == 3 diff --git a/tests/test_calcjob.py b/tests/test_calcjob.py index 74768f5f..ff2eb327 100644 --- a/tests/test_calcjob.py +++ b/tests/test_calcjob.py @@ -8,4 +8,4 @@ def test_submit(wg_calcjob: WorkGraph) -> None: wg = wg_calcjob wg.name = "test_submit_calcjob" wg.submit(wait=True) - assert wg.tasks["add2"].outputs["sum"].value == 9 + assert wg.tasks.add2.outputs.sum.value == 9 diff --git a/tests/test_ctx.py b/tests/test_ctx.py index bffbc872..5e44cae7 100644 --- a/tests/test_ctx.py +++ b/tests/test_ctx.py @@ -16,14 +16,14 @@ def test_workgraph_ctx(decorated_add: Callable) -> None: wg.context = {"x": Float(2), "data.y": Float(3), "array": array} add1 = wg.add_task(decorated_add, "add1", x="{{ x }}", y="{{ data.y }}") wg.add_task( - "workgraph.set_context", name="set_ctx1", key="x", value=add1.outputs["result"] + "workgraph.set_context", name="set_ctx1", key="x", value=add1.outputs.result ) get_ctx1 = wg.add_task("workgraph.get_context", name="get_ctx1", key="x") # test the task can wait for another task get_ctx1.waiting_on.add(add1) - add2 = wg.add_task(decorated_add, "add2", x=get_ctx1.outputs["result"], y=1) + add2 = wg.add_task(decorated_add, "add2", x=get_ctx1.outputs.result, y=1) wg.run() - assert add2.outputs["result"].value == 6 + assert add2.outputs.result.value == 6 @pytest.mark.usefixtures("started_daemon_client") @@ -38,6 +38,6 @@ def test_task_set_ctx(decorated_add: Callable) -> None: assert str(e) == "Keys {'resul'} are not in the outputs of this task." add1.set_context({"sum": "result"}) add2 = wg.add_task(decorated_add, "add2", y="{{ sum }}") - wg.add_link(add1.outputs[0], add2.inputs["x"]) + wg.add_link(add1.outputs[0], add2.inputs.x) wg.submit(wait=True) - assert add2.outputs["result"].value == 10 + assert add2.outputs.result.value == 10 diff --git a/tests/test_data_task.py b/tests/test_data_task.py index e2939d65..74304edf 100644 --- a/tests/test_data_task.py +++ b/tests/test_data_task.py @@ -16,7 +16,7 @@ def test_data_task(identifier, data) -> None: wg = WorkGraph("test_normal_task") task1 = wg.add_task(identifier, name="task1", value=data) wg.run() - assert task1.outputs["result"].value == data + assert task1.outputs.result.value == data def test_data_dict_task(): @@ -25,7 +25,7 @@ def test_data_dict_task(): wg = WorkGraph("test_data_dict_task") task1 = wg.add_task("workgraph.aiida_dict", name="task1", value={"a": 1}) wg.run() - assert task1.outputs["result"].value == {"a": 1} + assert task1.outputs.result.value == {"a": 1} def test_data_list_task(): @@ -34,4 +34,4 @@ def test_data_list_task(): wg = WorkGraph("test_data_list_task") task1 = wg.add_task("workgraph.aiida_list", name="task1", value=[1, 2, 3]) wg.run() - assert task1.outputs["result"].value == [1, 2, 3] + assert task1.outputs.result.value == [1, 2, 3] diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 5fd44293..2ef736ba 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -1,5 +1,6 @@ import pytest from aiida_workgraph import WorkGraph, task +from aiida_workgraph.socket import TaskSocketNamespace from typing import Callable @@ -11,48 +12,46 @@ def add_multiply(x, y): return {"sum": x + y, "product": x * y} n = add_multiply.task() - assert "sum" in n.outputs.keys() - assert "product" in n.outputs.keys() + assert "sum" in n.outputs + assert "product" in n.outputs -@pytest.fixture(params=["decorator_factory", "decorator"]) -def task_calcfunction(request): - if request.param == "decorator_factory": - - @task.calcfunction() - def test(a, b=1, **c): - print(a, b, c) - - elif request.param == "decorator": +def test_decorators_args() -> None: + @task() + def test(a, b=1, **c): + print(a, b, c) - @task.calcfunction - def test(a, b=1, **c): - print(a, b, c) + n = test.task() + tdata = n.to_dict() + assert tdata["args"] == [] + assert set(tdata["kwargs"]) == set(["a", "b"]) + assert tdata["var_args"] is None + assert tdata["var_kwargs"] == "c" + assert set(n.get_output_names()) == set(["result", "_outputs", "_wait"]) + assert isinstance(n.inputs.c, TaskSocketNamespace) - else: - raise ValueError(f"{request.param} not supported.") - return test +def test_decorators_calcfunction_args() -> None: + @task.calcfunction() + def test(a, b=1, **c): + print(a, b, c) -def test_decorators_calcfunction_args(task_calcfunction) -> None: metadata_kwargs = set( [ - f"metadata.{key}" - for key in task_calcfunction.process_class.spec() - .inputs.ports["metadata"] - .ports.keys() + f"{key}" + for key in test.process_class.spec().inputs.ports["metadata"].ports.keys() ] ) - kwargs = set(task_calcfunction.process_class.spec().inputs.ports.keys()).union( - metadata_kwargs - ) - n = task_calcfunction.task() + kwargs = set(test.process_class.spec().inputs.ports.keys()) + n = test.task() tdata = n.to_dict() assert tdata["args"] == [] assert set(tdata["kwargs"]) == set(kwargs) assert tdata["var_args"] is None assert tdata["var_kwargs"] == "c" - assert set(n.outputs.keys()) == set(["result", "_outputs", "_wait"]) + assert set(n.get_output_names()) == set(["result", "_outputs", "_wait"]) + assert isinstance(n.inputs.c, TaskSocketNamespace) + assert set(n.inputs.metadata._get_keys()) == metadata_kwargs @pytest.fixture(params=["decorator_factory", "decorator"]) @@ -107,15 +106,13 @@ def test(a, b=1, **c): def test_decorators_workfunction_args(task_workfunction) -> None: metadata_kwargs = set( [ - f"metadata.{key}" + f"{key}" for key in task_workfunction.process_class.spec() .inputs.ports["metadata"] .ports.keys() ] ) - kwargs = set(task_workfunction.process_class.spec().inputs.ports.keys()).union( - metadata_kwargs - ) + kwargs = set(task_workfunction.process_class.spec().inputs.ports.keys()) # n = task_workfunction.task() tdata = n.to_dict() @@ -123,7 +120,8 @@ def test_decorators_workfunction_args(task_workfunction) -> None: assert set(tdata["kwargs"]) == set(kwargs) assert tdata["var_args"] is None assert tdata["var_kwargs"] == "c" - assert set(n.outputs.keys()) == set(["result", "_outputs", "_wait"]) + assert set(n.get_output_names()) == set(["result", "_outputs", "_wait"]) + assert set(n.inputs.metadata._get_keys()) == metadata_kwargs def test_decorators_parameters() -> None: @@ -137,9 +135,9 @@ def test(a, b=1, **c): return {"sum": a + b, "product": a * b} test1 = test.task() - assert test1.inputs["c"].link_limit == 1000 - assert "sum" in test1.outputs.keys() - assert "product" in test1.outputs.keys() + assert test1.inputs["c"]._link_limit == 1000 + assert "sum" in test1.get_output_names() + assert "product" in test1.get_output_names() @pytest.fixture(params=["decorator_factory", "decorator"]) @@ -174,7 +172,7 @@ def test_decorators_graph_builder_args(task_graph_builder) -> None: assert tdata["kwargs"] == ["a", "b"] assert tdata["var_args"] is None assert tdata["var_kwargs"] == "c" - assert set(n.outputs.keys()) == set(["_outputs", "_wait"]) + assert set(n.get_output_names()) == set(["_outputs", "_wait"]) def test_inputs_outputs_workchain() -> None: @@ -182,9 +180,9 @@ def test_inputs_outputs_workchain() -> None: wg = WorkGraph() task = wg.add_task(MultiplyAddWorkChain) - assert "metadata" in task.inputs.keys() - assert "metadata.call_link_label" in task.inputs.keys() - assert "result" in task.outputs.keys() + assert "metadata" in task.get_input_names() + assert "call_link_label" in task.inputs.metadata._get_keys() + assert "result" in task.get_output_names() @pytest.mark.usefixtures("started_daemon_client") @@ -194,7 +192,7 @@ def test_decorator_calcfunction(decorated_add: Callable) -> None: wg = WorkGraph(name="test_decorator_calcfunction") wg.add_task(decorated_add, "add1", x=2, y=3) wg.submit(wait=True, timeout=100) - assert wg.tasks["add1"].outputs["result"].value == 5 + assert wg.tasks.add1.outputs.result.value == 5 def test_decorator_workfunction(decorated_add_multiply: Callable) -> None: @@ -203,7 +201,7 @@ def test_decorator_workfunction(decorated_add_multiply: Callable) -> None: wg = WorkGraph(name="test_decorator_workfunction") wg.add_task(decorated_add_multiply, "add_multiply1", x=2, y=3, z=4) wg.submit(wait=True, timeout=100) - assert wg.tasks["add_multiply1"].outputs["result"].value == 20 + assert wg.tasks["add_multiply1"].outputs.result.value == 20 @pytest.mark.usefixtures("started_daemon_client") @@ -213,10 +211,10 @@ def test_decorator_graph_builder(decorated_add_multiply_group: Callable) -> None add1 = wg.add_task("workgraph.test_add", "add1", x=2, y=3) add_multiply1 = wg.add_task(decorated_add_multiply_group, "add_multiply1", y=3, z=4) sum_diff1 = wg.add_task("workgraph.test_sum_diff", "sum_diff1") - wg.add_link(add1.outputs[0], add_multiply1.inputs["x"]) - wg.add_link(add_multiply1.outputs["result"], sum_diff1.inputs["x"]) + wg.add_link(add1.outputs[0], add_multiply1.inputs.x) + wg.add_link(add_multiply1.outputs.result, sum_diff1.inputs.x) # use run to check if graph builder workgraph can be submit inside the engine wg.run() assert wg.tasks["add_multiply1"].process.outputs.result.value == 32 - assert wg.tasks["add_multiply1"].outputs["result"].value == 32 - assert wg.tasks["sum_diff1"].outputs["sum"].value == 32 + assert wg.tasks["add_multiply1"].outputs.result.value == 32 + assert wg.tasks["sum_diff1"].outputs.sum.value == 32 diff --git a/tests/test_engine.py b/tests/test_engine.py index 04c4fea5..fc99a95c 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -50,4 +50,4 @@ def test_max_number_jobs(add_code) -> None: wg.submit(wait=True, timeout=40) report = get_workchain_report(wg.process, "REPORT") assert "tasks ready to run: add2" in report - wg.tasks["add2"].outputs["sum"].value == 2 + wg.tasks.add2.outputs.sum.value == 2 diff --git a/tests/test_error_handler.py b/tests/test_error_handler.py index d1eff81f..d2400811 100644 --- a/tests/test_error_handler.py +++ b/tests/test_error_handler.py @@ -1,10 +1,8 @@ -import pytest from aiida_workgraph import WorkGraph, Task from aiida import orm from aiida.calculations.arithmetic.add import ArithmeticAddCalculation -@pytest.mark.usefixtures("started_daemon_client") def test_error_handlers(add_code): """Test error handlers.""" from aiida.cmdline.utils.common import get_workchain_report @@ -16,7 +14,7 @@ def handle_negative_sum(task: Task): # modify task inputs task.set( { - "x": orm.Int(abs(task.inputs["x"].value)), + "x": orm.Int(abs(task.inputs.x.value)), "y": orm.Int(abs(task.inputs["y"].value)), } ) @@ -39,12 +37,11 @@ def handle_negative_sum(task: Task): }, ) assert len(wg.error_handlers) == 1 - wg.submit( + wg.run( inputs={ "add1": {"code": add_code, "x": orm.Int(1), "y": orm.Int(-2)}, }, - wait=True, ) report = get_workchain_report(wg.process, "REPORT") assert "Run error handler: handle_negative_sum." in report - assert wg.tasks["add1"].outputs["sum"].value == 3 + assert wg.tasks.add1.outputs.sum.value == 3 diff --git a/tests/test_failed_node.py b/tests/test_failed_node.py index 9849c42a..edc5be26 100644 --- a/tests/test_failed_node.py +++ b/tests/test_failed_node.py @@ -11,7 +11,7 @@ def test_failed_node(decorated_sqrt: Callable, decorated_add: Callable) -> None: wg = WorkGraph(name="test_failed_node") wg.add_task(decorated_add, "add1", x=Float(1), y=Float(2)) sqrt1 = wg.add_task(decorated_sqrt, "sqrt1", x=Float(-1)) - wg.add_task(decorated_sqrt, "sqrt2", x=sqrt1.outputs["result"]) + wg.add_task(decorated_sqrt, "sqrt2", x=sqrt1.outputs.result) wg.submit(wait=True) # print("results: ", results[]) assert wg.process.exit_status == 302 diff --git a/tests/test_for.py b/tests/test_for.py index fdd6fa47..d70adaab 100644 --- a/tests/test_for.py +++ b/tests/test_for.py @@ -22,7 +22,7 @@ def add_multiply_for(sequence): add1 = wg.add_task(decorated_add, name="add1", x="{{ total }}") # update the context variable add1.set_context({"total": "result"}) - wg.add_link(multiply1.outputs["result"], add1.inputs["y"]) + wg.add_link(multiply1.outputs.result, add1.inputs["y"]) # don't forget to return the workgraph return wg @@ -30,6 +30,6 @@ def add_multiply_for(sequence): wg = WorkGraph("test_for") for1 = wg.add_task(add_multiply_for, sequence=range(5)) add1 = wg.add_task(decorated_add, name="add1", y=orm.Int(1)) - wg.add_link(for1.outputs["result"], add1.inputs["x"]) + wg.add_link(for1.outputs.result, add1.inputs.x) wg.submit(wait=True, timeout=200) assert add1.node.outputs.result.value == 21 diff --git a/tests/test_if.py b/tests/test_if.py index 458810bd..c63b03b8 100644 --- a/tests/test_if.py +++ b/tests/test_if.py @@ -7,16 +7,16 @@ def test_if_task(decorated_add, decorated_multiply, decorated_compare): wg = WorkGraph("test_if") add1 = wg.add_task(decorated_add, name="add1", x=1, y=1) condition1 = wg.add_task(decorated_compare, name="condition1", x=1, y=0) - add2 = wg.add_task(decorated_add, name="add2", x=add1.outputs["result"], y=2) - if1 = wg.add_task("If", name="if_true", conditions=condition1.outputs["result"]) + add2 = wg.add_task(decorated_add, name="add2", x=add1.outputs.result, y=2) + if1 = wg.add_task("If", name="if_true", conditions=condition1.outputs.result) if1.children.add("add2") multiply1 = wg.add_task( - decorated_multiply, name="multiply1", x=add1.outputs["result"], y=2 + decorated_multiply, name="multiply1", x=add1.outputs.result, y=2 ) if2 = wg.add_task( "If", name="if_false", - conditions=condition1.outputs["result"], + conditions=condition1.outputs.result, invert_condition=True, ) if2.children.add("multiply1") @@ -24,13 +24,13 @@ def test_if_task(decorated_add, decorated_multiply, decorated_compare): select1 = wg.add_task( "workgraph.select", name="select1", - true=add2.outputs["result"], - false=multiply1.outputs["result"], - condition=condition1.outputs["result"], + true=add2.outputs.result, + false=multiply1.outputs.result, + condition=condition1.outputs.result, ) - add3 = wg.add_task(decorated_add, name="add3", x=select1.outputs["result"], y=1) + add3 = wg.add_task(decorated_add, name="add3", x=select1.outputs.result, y=1) wg.run() - assert add3.outputs["result"].value == 5 + assert add3.outputs.result.value == 5 def test_empty_if_task(): diff --git a/tests/test_link.py b/tests/test_link.py index e7f2a3ba..87f86592 100644 --- a/tests/test_link.py +++ b/tests/test_link.py @@ -20,10 +20,10 @@ def sum(**datas): float2 = wg.add_task("workgraph.aiida_node", pk=Float(2.0).store().pk) float3 = wg.add_task("workgraph.aiida_node", pk=Float(3.0).store().pk) sum1 = wg.add_task(sum, "sum1") - sum1.inputs["datas"].link_limit = 100 - wg.add_link(float1.outputs[0], sum1.inputs["datas"]) - wg.add_link(float2.outputs[0], sum1.inputs["datas"]) - wg.add_link(float3.outputs[0], sum1.inputs["datas"]) + sum1.inputs.datas._link_limit = 100 + wg.add_link(float1.outputs[0], sum1.inputs.datas) + wg.add_link(float2.outputs[0], sum1.inputs.datas) + wg.add_link(float3.outputs[0], sum1.inputs.datas) # wg.submit(wait=True) wg.run() assert sum1.node.outputs.result.value == 6 diff --git a/tests/test_normal_function.py b/tests/test_normal_function.py index 0b2cb5c8..21d06188 100644 --- a/tests/test_normal_function.py +++ b/tests/test_normal_function.py @@ -10,9 +10,9 @@ def test_normal_function_run( wg = WorkGraph(name="test_normal_function_run") add1 = wg.add_task(decorated_normal_add, "add1", x=2, y=3) add2 = wg.add_task(decorated_add, "add2", x=6) - wg.add_link(add1.outputs["result"], add2.inputs["y"]) + wg.add_link(add1.outputs.result, add2.inputs["y"]) wg.run() - assert wg.tasks["add2"].node.outputs.result == 11 + assert wg.tasks.add2.node.outputs.result == 11 @pytest.mark.usefixtures("started_daemon_client") @@ -23,6 +23,6 @@ def test_normal_function_submit( wg = WorkGraph(name="test_normal_function_submit") add1 = wg.add_task(decorated_normal_add, "add1", x=2, y=3) add2 = wg.add_task(decorated_add, "add2", x=6) - wg.add_link(add1.outputs["result"], add2.inputs["y"]) + wg.add_link(add1.outputs.result, add2.inputs["y"]) wg.submit(wait=True) - assert wg.tasks["add2"].node.outputs.result == 11 + assert wg.tasks.add2.node.outputs.result == 11 diff --git a/tests/test_pythonjob.py b/tests/test_pythonjob.py index d7793e22..50645178 100644 --- a/tests/test_pythonjob.py +++ b/tests/test_pythonjob.py @@ -33,19 +33,19 @@ def multiply(x: Any, y: Any) -> Any: wg.add_task( decorted_multiply, name="multiply1", - x=wg.tasks["add1"].outputs["sum"], + x=wg.tasks.add1.outputs.sum, y=3, computer="localhost", command_info={"label": python_executable_path}, ) # wg.submit(wait=True) wg.run() - assert wg.tasks["add1"].outputs["sum"].value.value == 3 - assert wg.tasks["add1"].outputs["diff"].value.value == -1 - assert wg.tasks["multiply1"].outputs["result"].value.value == 9 + assert wg.tasks.add1.outputs.sum.value.value == 3 + assert wg.tasks.add1.outputs["diff"].value.value == -1 + assert wg.tasks.multiply1.outputs.result.value.value == 9 # process_label and label - assert wg.tasks["add1"].node.process_label == "PythonJob" - assert wg.tasks["add1"].node.label == "add1" + assert wg.tasks.add1.node.process_label == "PythonJob" + assert wg.tasks.add1.node.label == "add1" def test_PythonJob_kwargs(fixture_localhost, python_executable_path): @@ -58,10 +58,10 @@ def add(x, y=1, **kwargs): return x wg = WorkGraph("test_PythonJob") - wg.add_task("PythonJob", function=add, name="add") + wg.add_task("PythonJob", function=add, name="add1") wg.run( inputs={ - "add": { + "add1": { "x": 1, "y": 2, "kwargs": {"m": 2, "n": 3}, @@ -71,11 +71,11 @@ def add(x, y=1, **kwargs): }, ) # data inside the kwargs should be serialized separately - wg.process.inputs.wg.tasks.add.inputs.kwargs.property.value.m.value == 2 - assert wg.tasks["add"].outputs["result"].value.value == 8 + wg.process.inputs.wg.tasks.add1.inputs.kwargs.sockets.m.property.value == 2 + assert wg.tasks.add1.outputs.result.value.value == 8 # load the workgraph wg = WorkGraph.load(wg.pk) - assert wg.tasks["add"].inputs["kwargs"].value == {"m": 2, "n": 3} + assert wg.tasks.add1.inputs["kwargs"]._value == {"m": 2, "n": 3} def test_PythonJob_namespace_output_input(fixture_localhost, python_executable_path): @@ -141,10 +141,10 @@ def myfunc3(x, y): }, } wg.run(inputs=inputs) - assert wg.tasks["myfunc"].outputs["add_multiply"].value.add.value == 3 - assert wg.tasks["myfunc"].outputs["add_multiply"].value.multiply.value == 2 - assert wg.tasks["myfunc2"].outputs["result"].value.value == 8 - assert wg.tasks["myfunc3"].outputs["result"].value.value == 7 + assert wg.tasks.myfunc.outputs.add_multiply.add.value == 3 + assert wg.tasks.myfunc.outputs.add_multiply.multiply.value == 2 + assert wg.tasks.myfunc2.outputs.result.value == 8 + assert wg.tasks.myfunc3.outputs.result.value == 7 def test_PythonJob_copy_files(fixture_localhost, python_executable_path): @@ -176,15 +176,15 @@ def multiply(x_folder_name, y_folder_name): name="multiply", ) wg.add_link( - wg.tasks["add1"].outputs["remote_folder"], + wg.tasks.add1.outputs.remote_folder, wg.tasks["multiply"].inputs["copy_files"], ) wg.add_link( - wg.tasks["add2"].outputs["remote_folder"], + wg.tasks.add2.outputs.remote_folder, wg.tasks["multiply"].inputs["copy_files"], ) # ------------------------- Submit the calculation ------------------- - wg.submit( + wg.run( inputs={ "add1": { "x": 2, @@ -205,9 +205,8 @@ def multiply(x_folder_name, y_folder_name): "command_info": {"label": python_executable_path}, }, }, - wait=True, ) - assert wg.tasks["multiply"].outputs["result"].value.value == 25 + assert wg.tasks["multiply"].outputs.result.value.value == 25 def test_load_pythonjob(fixture_localhost, python_executable_path): @@ -231,10 +230,10 @@ def add(x: str, y: str) -> str: }, # wait=True, ) - assert wg.tasks["add"].outputs["result"].value.value == "Hello, World!" + assert wg.tasks.add.outputs.result.value.value == "Hello, World!" wg = WorkGraph.load(wg.pk) - wg.tasks["add"].inputs["x"].value = "Hello, " - wg.tasks["add"].inputs["y"].value = "World!" + wg.tasks.add.inputs.x.value = "Hello, " + wg.tasks.add.inputs["y"].value = "World!" def test_exit_code(fixture_localhost, python_executable_path): @@ -246,7 +245,12 @@ def handle_negative_sum(task: Task): Simply make the inputs positive by taking the absolute value. """ - task.set({"x": abs(task.inputs["x"].value), "y": abs(task.inputs["y"].value)}) + task.set( + { + "x": abs(task.inputs.x.value), + "y": abs(task.inputs["y"].value), + } + ) return "Run error handler: handle_negative_sum." @@ -280,5 +284,5 @@ def add(x: array, y: array) -> array: == "Some elements are negative" ) # the final task should have exit status 0 - assert wg.tasks["add1"].node.exit_status == 0 - assert (wg.tasks["add1"].outputs["sum"].value.value == array([2, 3])).all() + assert wg.tasks.add1.node.exit_status == 0 + assert (wg.tasks.add1.outputs.sum.value.value == array([2, 3])).all() diff --git a/tests/test_shell.py b/tests/test_shell.py index 88045b42..ce15ca7c 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -72,11 +72,11 @@ def test_dynamic_port(): ) wg.add_link(echo_task.outputs["copied_file"], cat_task.inputs["nodes.input1"]) # task will create input for each item in the dynamic port (nodes) - assert "nodes.input1" in cat_task.inputs.keys() - assert "nodes.input2" in cat_task.inputs.keys() + assert "nodes.input1" in cat_task.inputs + assert "nodes.input2" in cat_task.inputs # if the value of the item is a Socket, then it will create a link, and pop the item - assert "nodes.input3" in cat_task.inputs.keys() - assert cat_task.inputs["nodes"].value == {"input1": None, "input2": Int(2)} + assert "nodes.input3" in cat_task.inputs + assert cat_task.inputs["nodes"]._value == {"input2": Int(2)} @pytest.mark.usefixtures("started_daemon_client") @@ -112,7 +112,7 @@ def parser(dirpath): name="job2", command="bc", arguments=["{expression}"], - nodes={"expression": job1.outputs["stdout"]}, + nodes={"expression": job1.outputs.stdout}, parser=parser, parser_outputs=[ {"identifier": "workgraph.any", "name": "result"} @@ -123,4 +123,4 @@ def parser(dirpath): wg = WorkGraph(name="test_shell_graph_builder") add_multiply1 = wg.add_task(add_multiply, x=Int(2), y=Int(3)) wg.submit(wait=True) - assert add_multiply1.outputs["result"].value.value == 5 + assert add_multiply1.outputs.result.value.value == 5 diff --git a/tests/test_socket.py b/tests/test_socket.py index 835ec99b..ac6ba4e2 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -30,8 +30,8 @@ def test_type_mapping(data_type, data, identifier) -> None: def add(x: data_type): pass - assert add.task().inputs["x"].identifier == identifier - assert add.task().inputs["x"].property.identifier == identifier + assert add.task().inputs.x._identifier == identifier + assert add.task().inputs.x.property.identifier == identifier add_task = add.task() add_task.set({"x": data}) # test set data from context @@ -43,7 +43,7 @@ def test_vector_socket() -> None: from aiida_workgraph import Task t = Task() - t.inputs.new( + t.add_input( "workgraph.aiida_int_vector", "vector2d", property_data={"size": 2, "default": [1, 2]}, @@ -71,8 +71,8 @@ def test_aiida_data_socket() -> None: def add(x: data_type): pass - assert add.task().inputs["x"].identifier == identifier - assert add.task().inputs["x"].property.identifier == identifier + assert add.task().inputs.x._identifier == identifier + assert add.task().inputs.x.property.identifier == identifier add_task = add.task() add_task.set({"x": data}) # test set data from context @@ -118,7 +118,6 @@ def test_numpy_array(decorated_normal_add): wg.submit(wait=True) # wg.run() assert wg.state.upper() == "FINISHED" - # assert (wg.tasks["add1"].outputs["result"].value == np.array([5, 7, 9])).all() def test_kwargs() -> None: @@ -129,9 +128,8 @@ def test(a, b=1, **kwargs): return {"sum": a + b, "product": a * b} test1 = test.node() - assert test1.inputs["kwargs"].link_limit == 1e6 - assert test1.inputs["kwargs"].identifier == "workgraph.namespace" - assert test1.inputs["kwargs"].property.value is None + assert test1.inputs["kwargs"]._link_limit == 1e6 + assert test1.inputs["kwargs"]._identifier == "workgraph.namespace" @pytest.mark.parametrize( @@ -158,7 +156,7 @@ def my_task(x: data_type): pass my_task1 = wg.add_task(my_task, name="my_task", x=socket_value) - socket = my_task1.inputs["x"] + socket = my_task1.inputs.x socket_node_value = socket.get_node_value() assert isinstance(socket_node_value, type(node_value)) diff --git a/tests/test_task_from_workgraph.py b/tests/test_task_from_workgraph.py index c6af4b6b..d93d24a8 100644 --- a/tests/test_task_from_workgraph.py +++ b/tests/test_task_from_workgraph.py @@ -14,10 +14,10 @@ def test_inputs_outptus(wg_calcfunction: WorkGraph) -> None: noutput = 0 for sub_task in wg_calcfunction.tasks: noutput += len(sub_task.outputs) - 2 + 1 - assert len(task1.inputs) == ninput + 1 - assert len(task1.outputs) == noutput + 2 - assert "sumdiff1.x" in task1.inputs.keys() - assert "sumdiff1.sum" in task1.outputs.keys() + assert len(task1.inputs) == len(wg_calcfunction.tasks) + 1 + assert len(task1.outputs) == len(wg_calcfunction.tasks) + 2 + assert "sumdiff1.x" in task1.inputs + assert "sumdiff1.sum" in task1.outputs @pytest.mark.usefixtures("started_daemon_client") @@ -25,20 +25,18 @@ def test_build_task_from_workgraph(decorated_add: Callable) -> None: # create a sub workgraph sub_wg = WorkGraph("build_task_from_workgraph") sub_wg.add_task(decorated_add, name="add1", x=1, y=3) - sub_wg.add_task( - decorated_add, name="add2", x=2, y=sub_wg.tasks["add1"].outputs["result"] - ) + sub_wg.add_task(decorated_add, name="add2", x=2, y=sub_wg.tasks.add1.outputs.result) # wg = WorkGraph("build_task_from_workgraph") add1_task = wg.add_task(decorated_add, name="add1", x=1, y=3) wg_task = wg.add_task(sub_wg, name="sub_wg") # the default value of the namespace is None - assert wg_task.inputs["add1"].value is None + assert wg_task.inputs["add1"]._value == {} wg.add_task(decorated_add, name="add2", y=3) - wg.add_link(add1_task.outputs["result"], wg_task.inputs["add1.x"]) - wg.add_link(wg_task.outputs["add2.result"], wg.tasks["add2"].inputs["x"]) - assert len(wg_task.inputs) == 21 - assert len(wg_task.outputs) == 6 - wg.submit(wait=True) - # wg.run() - assert wg.tasks["add2"].outputs["result"].value.value == 12 + wg.add_link(add1_task.outputs.result, wg_task.inputs["add1.x"]) + wg.add_link(wg_task.outputs["add2.result"], wg.tasks.add2.inputs.x) + assert len(wg_task.inputs) == 3 + assert len(wg_task.outputs) == 4 + # wg.submit(wait=True) + wg.run() + assert wg.tasks.add2.outputs.result.value.value == 12 diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 4044ecee..9c7cf378 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -15,10 +15,12 @@ def sum_diff(x, y): wg = WorkGraph("test_normal_task") task1 = wg.add_task(sum_diff, name="sum_diff", x=2, y=3) task2 = wg.add_task( - decorated_add, name="add", x=task1.outputs["sum"], y=task1.outputs["diff"] + decorated_add, name="add", x=task1.outputs.sum, y=task1.outputs["diff"] ) wg.run() - assert task2.outputs["result"].value == 4 + print("node: ", task2.node.outputs.result) + wg.update() + assert task2.outputs.result.value == 4 def test_task_collection(decorated_add: Callable) -> None: @@ -65,11 +67,11 @@ def test_set_non_dynamic_namespace_socket(decorated_add) -> None: task2 = wg.add_task(WorkChainWithNestNamespace) task2.set( { - "non_dynamic_port": {"a": task1.outputs["result"], "b": orm.Int(2)}, + "non_dynamic_port": {"a": task1.outputs.result, "b": orm.Int(2)}, } ) - assert len(task2.inputs["non_dynamic_port.a"].links) == 1 - assert task2.inputs["non_dynamic_port"].value == {"b": orm.Int(2)} + assert len(task2.inputs["non_dynamic_port.a"]._links) == 1 + assert task2.inputs["non_dynamic_port"]._value == {"b": orm.Int(2)} assert len(wg.links) == 1 @@ -82,11 +84,13 @@ def test_set_namespace_socket(decorated_add) -> None: task2 = wg.add_task(WorkChainWithNestNamespace) task2.set( { - "add": {"x": task1.outputs["result"], "y": orm.Int(2)}, + "add": {"x": task1.outputs.result, "y": orm.Int(2)}, } ) - assert len(task2.inputs["add.x"].links) == 1 - assert task2.inputs["add"].value == {"y": orm.Int(2)} + assert len(task2.inputs["add.x"]._links) == 1 + assert task2.inputs["add"]._value == { + "y": orm.Int(2), + } assert len(wg.links) == 1 @@ -104,20 +108,19 @@ def test_set_dynamic_port_input(decorated_add) -> None: dynamic_port={ "input1": None, "input2": orm.Int(2), - "input3": task1.outputs["result"], - "nested": {"input4": orm.Int(4), "input5": task1.outputs["result"]}, + "input3": task1.outputs.result, + "nested": {"input4": orm.Int(4), "input5": task1.outputs.result}, }, ) wg.add_link(task1.outputs["_wait"], task2.inputs["dynamic_port.input1"]) # task will create input for each item in the dynamic port (nodes) - assert "dynamic_port.input1" in task2.inputs.keys() - assert "dynamic_port.input2" in task2.inputs.keys() + assert "dynamic_port.input1" in task2.inputs + assert "dynamic_port.input2" in task2.inputs # if the value of the item is a Socket, then it will create a link, and pop the item - assert "dynamic_port.input3" in task2.inputs.keys() - assert "dynamic_port.nested.input4" in task2.inputs.keys() - assert "dynamic_port.nested.input5" in task2.inputs.keys() - assert task2.inputs["dynamic_port"].value == { - "input1": None, + assert "dynamic_port.input3" in task2.inputs + assert "dynamic_port.nested.input4" in task2.inputs + assert "dynamic_port.nested.input5" in task2.inputs + assert task2.inputs.dynamic_port._value == { "input2": orm.Int(2), "nested": {"input4": orm.Int(4)}, } @@ -133,9 +136,9 @@ def test_set_inputs(decorated_add: Callable) -> None: data = wg.prepare_inputs(metadata=None) assert data["wg"]["tasks"]["add1"]["inputs"]["y"]["property"]["value"] == 2 assert ( - data["wg"]["tasks"]["add1"]["inputs"]["metadata"]["property"]["value"][ + data["wg"]["tasks"]["add1"]["inputs"]["metadata"]["sockets"][ "store_provenance" - ] + ]["property"]["value"] is False ) @@ -151,7 +154,7 @@ def test_set_inputs_from_builder(add_code) -> None: builder.x = 1 builder.y = 2 add1.set_from_builder(builder) - assert add1.inputs["x"].value == 1 + assert add1.inputs.x.value == 1 assert add1.inputs["y"].value == 2 assert add1.inputs["code"].value == add_code with pytest.raises( @@ -159,3 +162,26 @@ def test_set_inputs_from_builder(add_code) -> None: match=f"Executor {ArithmeticAddCalculation.__name__} does not have the get_builder_from_protocol method.", ): add1.set_from_protocol(code=add_code, protocol="fast") + + +def test_namespace_outputs(): + @task.calcfunction( + outputs=[ + {"identifier": "workgraph.namespace", "name": "add_multiply"}, + {"name": "add_multiply.add"}, + {"name": "add_multiply.multiply"}, + {"name": "minus"}, + ] + ) + def myfunc(x, y): + return { + "add_multiply": {"add": orm.Float(x + y), "multiply": orm.Float(x * y)}, + "minus": orm.Float(x - y), + } + + wg = WorkGraph("test_namespace_outputs") + wg.add_task(myfunc, name="myfunc", x=1.0, y=2.0) + wg.run() + assert wg.tasks.myfunc.outputs.minus.value == -1 + assert wg.tasks.myfunc.outputs.add_multiply.add.value == 3 + assert wg.tasks.myfunc.outputs.add_multiply.multiply.value == 2 diff --git a/tests/test_while.py b/tests/test_while.py index 6c677e85..6644ae79 100644 --- a/tests/test_while.py +++ b/tests/test_while.py @@ -47,37 +47,33 @@ def raw_python_code(): # --------------------------------------------------------------------- # the `result` of compare1 taskis used as condition compare1 = wg.add_task(decorated_compare, name="compare1", x="{{m}}", y=10) - while1 = wg.add_task("While", name="while1", conditions=compare1.outputs["result"]) + while1 = wg.add_task("While", name="while1", conditions=compare1.outputs.result) add11 = wg.add_task(decorated_add, name="add11", x=1, y=1) # --------------------------------------------------------------------- compare2 = wg.add_task(decorated_compare, name="compare2", x="{{n}}", y=5) - while2 = wg.add_task("While", name="while2", conditions=compare2.outputs["result"]) - add21 = wg.add_task( - decorated_add, name="add21", x="{{n}}", y=add11.outputs["result"] - ) + while2 = wg.add_task("While", name="while2", conditions=compare2.outputs.result) + add21 = wg.add_task(decorated_add, name="add21", x="{{n}}", y=add11.outputs.result) add21.waiting_on.add("add1") - add22 = wg.add_task(decorated_add, name="add22", x=add21.outputs["result"], y=1) + add22 = wg.add_task(decorated_add, name="add22", x=add21.outputs.result, y=1) add22.set_context({"n": "result"}) while2.children.add(["add21", "add22"]) # --------------------------------------------------------------------- compare3 = wg.add_task(decorated_compare, name="compare3", x="{{l}}", y=5) while3 = wg.add_task( - "While", name="while3", max_iterations=1, conditions=compare3.outputs["result"] + "While", name="while3", max_iterations=1, conditions=compare3.outputs.result ) add31 = wg.add_task(decorated_add, name="add31", x="{{l}}", y=1) add31.waiting_on.add("add22") - add32 = wg.add_task(decorated_add, name="add32", x=add31.outputs["result"], y=1) + add32 = wg.add_task(decorated_add, name="add32", x=add31.outputs.result, y=1) add32.set_context({"l": "result"}) while3.children.add(["add31", "add32"]) # --------------------------------------------------------------------- - add12 = wg.add_task( - decorated_add, name="add12", x="{{m}}", y=add32.outputs["result"] - ) + add12 = wg.add_task(decorated_add, name="add12", x="{{m}}", y=add32.outputs.result) add12.set_context({"m": "result"}) while1.children.add(["add11", "while2", "while3", "add12", "compare2", "compare3"]) # --------------------------------------------------------------------- add2 = wg.add_task( - decorated_add, name="add2", x=add12.outputs["result"], y=add31.outputs["result"] + decorated_add, name="add2", x=add12.outputs.result, y=add31.outputs.result ) # wg.submit(wait=True, timeout=100) wg.run() @@ -85,7 +81,7 @@ def raw_python_code(): for link in wg.process.base.links.get_outgoing().all(): if isinstance(link.node, orm.ProcessNode): print(link.node.label, link.node.outputs.result) - assert add2.outputs["result"].value.value == raw_python_code().value + assert add2.outputs.result.value.value == raw_python_code().value @pytest.mark.usefixtures("started_daemon_client") @@ -102,10 +98,10 @@ def test_while_workgraph(decorated_add, decorated_multiply, decorated_compare): ) add1 = wg.add_task(decorated_add, name="add1", y=3) add1.set_context({"n": "result"}) - wg.add_link(multiply1.outputs["result"], add1.inputs["x"]) + wg.add_link(multiply1.outputs.result, add1.inputs.x) wg.submit(wait=True, timeout=100) assert wg.execution_count == 4 - assert wg.tasks["add1"].outputs["result"].value == 29 + assert wg.tasks.add1.outputs.result.value == 29 @pytest.mark.usefixtures("started_daemon_client") @@ -126,7 +122,7 @@ def my_while(n=0, limit=100): ) add1 = wg.add_task(decorated_add, name="add1", y=3) add1.set_context({"n": "result"}) - wg.add_link(multiply1.outputs["result"], add1.inputs["x"]) + wg.add_link(multiply1.outputs.result, add1.inputs.x) return wg # ----------------------------------------- @@ -134,8 +130,8 @@ def my_while(n=0, limit=100): add1 = wg.add_task(decorated_add, name="add1", x=orm.Int(25), y=orm.Int(25)) my_while1 = wg.add_task(my_while, n=orm.Int(1)) add2 = wg.add_task(decorated_add, name="add2", y=orm.Int(2)) - wg.add_link(add1.outputs["result"], my_while1.inputs["limit"]) - wg.add_link(my_while1.outputs["result"], add2.inputs["x"]) + wg.add_link(add1.outputs.result, my_while1.inputs["limit"]) + wg.add_link(my_while1.outputs.result, add2.inputs.x) wg.submit(wait=True, timeout=100) - assert add2.outputs["result"].value < 31 + assert add2.outputs.result.value < 31 assert my_while1.node.outputs.execution_count == 2 diff --git a/tests/test_workchain.py b/tests/test_workchain.py index 85f59b1f..164cddb3 100644 --- a/tests/test_workchain.py +++ b/tests/test_workchain.py @@ -19,7 +19,7 @@ def test_build_workchain_inputs_outputs(): node = build_task(MultiplyAddWorkChain)() inputs = MultiplyAddWorkChain.spec().inputs # inputs + metadata + _wait - ninput = len(inputs.ports) + len(inputs.ports["metadata"].ports) + 1 + ninput = len(inputs.ports) + 1 assert len(node.inputs) == ninput assert len(node.outputs) == 3 diff --git a/tests/test_workgraph.py b/tests/test_workgraph.py index e355aa33..1badd386 100644 --- a/tests/test_workgraph.py +++ b/tests/test_workgraph.py @@ -9,9 +9,7 @@ def test_from_dict(decorated_add): """Export NodeGraph to dict.""" wg = WorkGraph("test_from_dict") task1 = wg.add_task(decorated_add, x=2, y=3) - wg.add_task( - "workgraph.test_sum_diff", name="sumdiff2", x=4, y=task1.outputs["result"] - ) + wg.add_task("workgraph.test_sum_diff", name="sumdiff2", x=4, y=task1.outputs.result) wgdata = wg.to_dict() wg1 = WorkGraph.from_dict(wgdata) assert len(wg.tasks) == len(wg1.tasks) @@ -23,7 +21,7 @@ def test_add_task(): wg = WorkGraph("test_add_task") add1 = wg.add_task(ArithmeticAddCalculation, name="add1") add2 = wg.add_task(ArithmeticAddCalculation, name="add2") - wg.add_link(add1.outputs["sum"], add2.inputs["x"]) + wg.add_link(add1.outputs.sum, add2.inputs.x) assert len(wg.tasks) == 2 assert len(wg.links) == 1 @@ -62,12 +60,12 @@ def test_save_load(wg_calcfunction, decorated_add): wg2 = WorkGraph.load(wg.process.pk) assert len(wg.tasks) == len(wg2.tasks) # check the executor of the decorated task - callable = wg2.tasks["add1"].get_executor()["callable"] + callable = wg2.tasks.add1.get_executor()["callable"] assert isinstance(callable, PickledFunction) assert callable.value(1, 2) == 3 # TODO, the following code is not working # wg2.save() - # assert wg2.tasks["add1"].executor == decorated_add + # assert wg2.tasks.add1.executor == decorated_add # remove the extra, will raise an error wg.process.base.extras.delete(WORKGRAPH_EXTRA_KEY) with pytest.raises( @@ -79,6 +77,7 @@ def test_save_load(wg_calcfunction, decorated_add): def test_organize_nested_inputs(): """Merge sub properties to the root properties.""" from .utils.test_workchain import WorkChainWithNestNamespace + from node_graph.utils import collect_values_inside_namespace wg = WorkGraph("test_organize_nested_inputs") task1 = wg.add_task(WorkChainWithNestNamespace, name="task1") @@ -96,11 +95,14 @@ def test_organize_nested_inputs(): data = { "metadata": { "call_link_label": "nest", - "options": {"resources": {"num_cpus": 1, "num_machines": 1}}, + "options": {"resources": {"num_machines": 1}}, }, "x": "1", } - assert inputs["wg"]["tasks"]["task1"]["inputs"]["add"]["property"]["value"] == data + collected_data = collect_values_inside_namespace( + inputs["wg"]["tasks"]["task1"]["inputs"]["add"] + ) + assert collected_data == data @pytest.mark.usefixtures("started_daemon_client") @@ -114,7 +116,7 @@ def test_reset_message(wg_calcjob): # wait for the daemon to start the workgraph time.sleep(3) wg = WorkGraph.load(wg.process.pk) - wg.tasks["add2"].set({"y": orm.Int(10).store()}) + wg.tasks.add2.set({"y": orm.Int(10).store()}) wg.save() wg.wait() report = get_workchain_report(wg.process, "REPORT") @@ -131,7 +133,7 @@ def test_restart_and_reset(wg_calcfunction): "workgraph.test_sum_diff", "sumdiff3", x=4, - y=wg.tasks["sumdiff2"].outputs["sum"], + y=wg.tasks["sumdiff2"].outputs.sum, ) wg.name = "test_restart_0" wg.submit(wait=True) @@ -164,7 +166,7 @@ def test_extend_workgraph(decorated_add_multiply_group): assert "group_add1" in [ task.name for task in wg.tasks["group_multiply1"].waiting_on ] - wg.add_link(add1.outputs[0], wg.tasks["group_add1"].inputs["x"]) + wg.add_link(add1.outputs[0], wg.tasks["group_add1"].inputs.x) wg.run() assert wg.tasks["group_multiply1"].node.outputs.result == 45 diff --git a/tests/test_yaml.py b/tests/test_yaml.py index 8a9bb520..37111443 100644 --- a/tests/test_yaml.py +++ b/tests/test_yaml.py @@ -7,11 +7,11 @@ def test_calcfunction(): wg = WorkGraph.from_yaml(os.path.join(cwd, "datas/test_calcfunction.yaml")) - assert wg.tasks["float1"].inputs["value"].value == 3.0 - assert wg.tasks["sumdiff1"].inputs["x"].value == 2.0 - assert wg.tasks["sumdiff2"].inputs["x"].value == 4.0 + assert wg.tasks.float1.inputs.value.value == 3.0 + assert wg.tasks.sumdiff1.inputs.x.value == 2.0 + assert wg.tasks.sumdiff2.inputs.x.value == 4.0 wg.run() - assert wg.tasks["sumdiff2"].node.outputs.sum == 9 + assert wg.tasks.sumdiff2.node.outputs.sum == 9 # skip this test for now @@ -19,4 +19,4 @@ def test_calcfunction(): def test_calcjob(): wg = WorkGraph.from_yaml(os.path.join(cwd, "datas/test_calcjob.yaml")) wg.submit(wait=True) - assert wg.tasks["add2"].node.outputs.sum == 9 + assert wg.tasks.add2.node.outputs.sum == 9 diff --git a/tests/test_zone.py b/tests/test_zone.py index 007e286e..b46d575d 100644 --- a/tests/test_zone.py +++ b/tests/test_zone.py @@ -9,9 +9,9 @@ def test_zone_task(decorated_add): wg.context = {} add1 = wg.add_task(decorated_add, name="add1", x=1, y=1) wg.add_task(decorated_add, name="add2", x=1, y=1) - add3 = wg.add_task(decorated_add, name="add3", x=1, y=add1.outputs["result"]) - wg.add_task(decorated_add, name="add4", x=1, y=add3.outputs["result"]) - wg.add_task(decorated_add, name="add5", x=1, y=add3.outputs["result"]) + add3 = wg.add_task(decorated_add, name="add3", x=1, y=add1.outputs.result) + wg.add_task(decorated_add, name="add4", x=1, y=add3.outputs.result) + wg.add_task(decorated_add, name="add5", x=1, y=add3.outputs.result) zone1 = wg.add_task("workgraph.zone", name="Zone1") zone1.children.add(["add2", "add3"]) wg.run() diff --git a/tests/widget/test_widget.py b/tests/widget/test_widget.py index c7cb34a8..44307330 100644 --- a/tests/widget/test_widget.py +++ b/tests/widget/test_widget.py @@ -14,7 +14,7 @@ def test_workgraph_widget(wg_calcfunction, decorated_add): assert len(value["links"]) == 2 # check required sockets # there are more than 2 inputs, but only 2 are required - assert len(wg.tasks["add1"].inputs) > 2 + assert len(wg.tasks.add1.inputs) > 2 assert len(value["nodes"]["add1"]["inputs"]) == 2 # to_html data = wg.to_html()