From d7c3745a5dbc6a245b8ae4a7f16d25990472ba82 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Tue, 10 Dec 2024 14:53:36 +0100 Subject: [PATCH] use input_links --- src/aiida_workgraph/decorator.py | 18 +++-- src/aiida_workgraph/engine/task_manager.py | 79 ++++++++-------------- src/aiida_workgraph/tasks/pythonjob.py | 11 +-- src/aiida_workgraph/utils/analysis.py | 37 +++------- src/aiida_workgraph/workgraph.py | 6 +- tests/test_action.py | 2 +- tests/test_awaitable_task.py | 6 +- tests/test_calcfunction.py | 4 +- tests/test_calcjob.py | 2 +- tests/test_ctx.py | 4 +- tests/test_data_task.py | 6 +- tests/test_decorator.py | 14 ++-- tests/test_engine.py | 2 +- tests/test_error_handler.py | 11 ++- tests/test_if.py | 2 +- tests/test_pythonjob.py | 38 ++++++----- tests/test_shell.py | 10 +-- tests/test_socket.py | 1 - tests/test_task_from_workgraph.py | 12 ++-- tests/test_while.py | 6 +- 20 files changed, 119 insertions(+), 152 deletions(-) diff --git a/src/aiida_workgraph/decorator.py b/src/aiida_workgraph/decorator.py index 7bc27bd1..ac53873e 100644 --- a/src/aiida_workgraph/decorator.py +++ b/src/aiida_workgraph/decorator.py @@ -381,10 +381,13 @@ def build_task_from_workgraph(wg: any) -> Task: } ) for socket in task.inputs: - if socket.name in builtin_input_names: + if socket.socket_name in builtin_input_names: continue inputs.append( - {"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"} + { + "identifier": socket._socket_identifier, + "name": f"{task.name}.{socket.socket_name}", + } ) # outputs outputs.append( @@ -394,15 +397,18 @@ def build_task_from_workgraph(wg: any) -> Task: } ) for socket in task.outputs: - if socket.name in builtin_output_names: + if socket.socket_name in builtin_output_names: continue outputs.append( - {"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"} + { + "identifier": socket._socket_identifier, + "name": f"{task.name}.{socket.socket_name}", + } ) group_outputs.append( { - "name": f"{task.name}.{socket.name}", - "from": f"{task.name}.{socket.name}", + "name": f"{task.name}.{socket.socket_name}", + "from": f"{task.name}.{socket.socket_name}", } ) # add built-in sockets diff --git a/src/aiida_workgraph/engine/task_manager.py b/src/aiida_workgraph/engine/task_manager.py index 2ef1f148..0d1464f2 100644 --- a/src/aiida_workgraph/engine/task_manager.py +++ b/src/aiida_workgraph/engine/task_manager.py @@ -39,10 +39,10 @@ def get_task(self, name: str): task = Task.from_dict(self.ctx._tasks[name]) # update task results for output in task.outputs: - output.value = get_nested_dict( + output.socket_value = get_nested_dict( self.ctx._tasks[name]["results"], output.socket_name, - default=output.value, + default=output.socket_value, ) return task @@ -588,21 +588,20 @@ def get_inputs( var_args = None var_kwargs = None task = self.ctx._tasks[name] - properties = task.get("properties", {}) inputs = {} + for name, prop in task.get("properties", {}).items(): + inputs[name] = self.ctx_manager.update_context_variable(prop["value"]) for name, input in task["inputs"].items(): # print(f"input: {input['name']}") - if len(input["links"]) == 0: - if input["identifier"] == "workgraph.namespace": - inputs[name] = self.ctx_manager.update_context_variable( - input["value"] - ) - else: - inputs[name] = self.ctx_manager.update_context_variable( - input["property"]["value"] - ) - elif len(input["links"]) == 1: - link = input["links"][0] + if input["identifier"] == "workgraph.namespace": + inputs[name] = self.ctx_manager.update_context_variable(input["value"]) + else: + inputs[name] = self.ctx_manager.update_context_variable( + input["property"]["value"] + ) + for name, links in task["input_links"].items(): + if len(links) == 1: + link = links[0] if self.ctx._tasks[link["from_node"]]["results"] is None: inputs[name] = None else: @@ -617,9 +616,9 @@ def get_inputs( link["from_socket"], ) # handle the case of multiple outputs - elif len(input["links"]) > 1: + elif len(links) > 1: value = {} - for link in input["links"]: + for link in links: item_name = f'{link["from_node"]}_{link["from_socket"]}' # handle the special socket _wait, _outputs if link["from_socket"] == "_wait": @@ -631,42 +630,18 @@ def get_inputs( "results" ][link["from_socket"]] inputs[name] = value - for name in task.get("args", []): - if name in inputs: - args.append(inputs[name]) - args_dict[name] = inputs[name] - else: - value = self.ctx_manager.update_context_variable( - properties[name]["value"] - ) - args.append(value) - args_dict[name] = value - for name in task.get("kwargs", []): - if name in inputs: - kwargs[name] = inputs[name] - else: - value = self.ctx_manager.update_context_variable( - properties[name]["value"] - ) - kwargs[name] = value - if task["var_args"] is not None: - name = task["var_args"] - if name in inputs: - var_args = inputs[name] - else: - value = self.ctx_manager.update_context_variable( - properties[name]["value"] - ) - var_args = value - if task["var_kwargs"] is not None: - name = task["var_kwargs"] - if name in inputs: - var_kwargs = inputs[name] - else: - value = self.ctx_manager.update_context_variable( - properties[name]["value"] - ) - var_kwargs = value + for name, input in inputs.items(): + # only need to check the top level key + key = name.split(".")[0] + if key in task["args"]: + args.append(input) + args_dict[name] = input + elif key in task["kwargs"]: + kwargs[name] = input + elif key == task["var_args"]: + var_args = input + elif key == task["var_kwargs"]: + var_kwargs = input return args, kwargs, var_args, var_kwargs, args_dict def update_task_state(self, name: str, success=True) -> None: diff --git a/src/aiida_workgraph/tasks/pythonjob.py b/src/aiida_workgraph/tasks/pythonjob.py index b86a5106..16328611 100644 --- a/src/aiida_workgraph/tasks/pythonjob.py +++ b/src/aiida_workgraph/tasks/pythonjob.py @@ -20,7 +20,8 @@ def serialize_pythonjob_data(cls, tdata: Dict[str, Any]): for input in tdata["inputs"].values(): if input["metadata"].get("is_function_input", False): - input["property"]["value"] = cls.serialize_socket_data(input) + if input.get("property", {}).get("value") is not None: + input["property"]["value"] = cls.serialize_socket_data(input) @classmethod def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None: @@ -43,7 +44,7 @@ def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None: @classmethod def serialize_socket_data(cls, data: Dict[str, Any]) -> Any: if data.get("identifier", "Any").upper() == "WORKGRAPH.NAMESPACE": - if data["property"]["value"] is None: + if data.get("property", {}).get("value") is None: return None if isinstance(data["property"]["value"], dict): serialized_result = {} @@ -53,14 +54,14 @@ def serialize_socket_data(cls, data: Dict[str, Any]) -> Any: else: raise ValueError("Namespace socket should be a dictionary.") else: - if isinstance(data["property"]["value"], orm.Data): + if isinstance(data.get("property", {}).get("value"), orm.Data): return data["property"]["value"] return general_serializer(data["property"]["value"]) @classmethod def deserialize_socket_data(cls, data: Dict[str, Any]) -> Any: if data.get("identifier", "Any").upper() == "WORKGRAPH.NAMESPACE": - if isinstance(data["property"]["value"], dict): + if isinstance(data.get("property", {}).get("value"), dict): deserialized_result = {} for key, value in data["property"]["value"].items(): if isinstance(value, orm.Data): @@ -71,6 +72,6 @@ def deserialize_socket_data(cls, data: Dict[str, Any]) -> Any: else: raise ValueError("Namespace socket should be a dictionary.") else: - if isinstance(data["property"]["value"], orm.Data): + if isinstance(data.get("property", {}).get("value"), orm.Data): return data["property"]["value"].value return data["property"]["value"] diff --git a/src/aiida_workgraph/utils/analysis.py b/src/aiida_workgraph/utils/analysis.py index ef7112ba..43a446eb 100644 --- a/src/aiida_workgraph/utils/analysis.py +++ b/src/aiida_workgraph/utils/analysis.py @@ -82,29 +82,14 @@ def build_task_link(self) -> None: 1) workgraph links """ - # reset task input links - for name, task in self.wgdata["tasks"].items(): - for _, input in task["inputs"].items(): - input["links"] = [] - for _, output in task["outputs"].items(): - output["links"] = [] + # create a `input_links` to store the input links for each task + for task in self.wgdata["tasks"].values(): + task["input_links"] = {} for link in self.wgdata["links"]: - to_socket = [ - socket - for name, socket in self.wgdata["tasks"][link["to_node"]][ - "inputs" - ].items() - if name == link["to_socket"] - ][0] - from_socket = [ - socket - for name, socket in self.wgdata["tasks"][link["from_node"]][ - "outputs" - ].items() - if name == link["from_socket"] - ][0] - to_socket["links"].append(link) - from_socket["links"].append(link) + task = self.wgdata["tasks"][link["to_node"]] + if link["to_socket"] not in task["input_links"]: + task["input_links"][link["to_socket"]] = [] + task["input_links"][link["to_socket"]].append(link) def assign_zone(self) -> None: """Assign zone for each task.""" @@ -139,8 +124,8 @@ def find_zone_inputs(self, name: str) -> None: """Find the input and outputs tasks for the zone.""" task = self.wgdata["tasks"][name] input_tasks = [] - for _, input in self.wgdata["tasks"][name]["inputs"].items(): - for link in input["links"]: + for _, links in self.wgdata["tasks"][name]["input_links"].items(): + for link in links: input_tasks.append(link["from_node"]) # find all the input tasks for child_task in task["children"]: @@ -157,8 +142,8 @@ def find_zone_inputs(self, name: str) -> None: else: # if the child task is not a zone, get the input tasks of the child task # find all the input tasks which outside the while zone - for _, input in self.wgdata["tasks"][child_task]["inputs"].items(): - for link in input["links"]: + for _, links in self.wgdata["tasks"][child_task]["input_links"].items(): + for link in links: input_tasks.append(link["from_node"]) # find the input tasks which are not in the zone new_input_tasks = [] diff --git a/src/aiida_workgraph/workgraph.py b/src/aiida_workgraph/workgraph.py index cf379d41..ebd53733 100644 --- a/src/aiida_workgraph/workgraph.py +++ b/src/aiida_workgraph/workgraph.py @@ -96,7 +96,7 @@ def run( # set task inputs if inputs is not None: for name, input in inputs.items(): - if name not in self.tasks.keys(): + if name not in self.tasks: raise KeyError(f"Task {name} not found in WorkGraph.") self.tasks[name].set(input) # One can not run again if the process is alreay created. otherwise, a new process node will @@ -128,7 +128,7 @@ def submit( # set task inputs if inputs is not None: for name, input in inputs.items(): - if name not in self.tasks.keys(): + if name not in self.tasks: raise KeyError(f"Task {name} not found in WorkGraph.") self.tasks[name].set(input) @@ -300,7 +300,7 @@ def update(self) -> None: ) # read results from the process outputs elif isinstance(node, aiida.orm.Data): - self.tasks[name].outputs[0].value = node + self.tasks[name].outputs[0].socket_value = node execution_count = getattr(self.process.outputs, "execution_count", None) self.execution_count = execution_count if execution_count else 0 if self._widget is not None: diff --git a/tests/test_action.py b/tests/test_action.py index 82ad8954..0a07c850 100644 --- a/tests/test_action.py +++ b/tests/test_action.py @@ -36,7 +36,7 @@ def test_pause_play_task(wg_calcjob): # Seems the daemon is not responding to the play signal wg.play_tasks(["add2"]) wg.wait(interval=5) - assert wg.tasks["add2"].outputs["sum"].value == 9 + assert wg.tasks["add2"].outputs["sum"].socket_value == 9 def test_pause_play_error_handler(wg_calcjob, finished_process_node): diff --git a/tests/test_awaitable_task.py b/tests/test_awaitable_task.py index 3bb552a7..7827fc99 100644 --- a/tests/test_awaitable_task.py +++ b/tests/test_awaitable_task.py @@ -21,7 +21,7 @@ async def awaitable_func(x, y): wg.run() report = get_workchain_report(wg.process, "REPORT") assert "Waiting for child processes: awaitable_func1" in report - assert add1.outputs["result"].value == 4 + assert add1.outputs["result"].socket_value == 4 def test_monitor_decorator(): @@ -60,7 +60,7 @@ def test_time_monitor(decorated_add): wg.run() report = get_workchain_report(wg.process, "REPORT") assert "Waiting for child processes: monitor1" in report - assert add1.outputs["result"].value == 3 + assert add1.outputs["result"].socket_value == 3 def test_file_monitor(decorated_add, tmp_path): @@ -84,7 +84,7 @@ async def create_test_file(filepath="/tmp/test_file_monitor.txt", t=2): wg.run() report = get_workchain_report(wg.process, "REPORT") assert "Waiting for child processes: monitor1" in report - assert add1.outputs["result"].value == 3 + assert add1.outputs["result"].socket_value == 3 @pytest.mark.usefixtures("started_daemon_client") diff --git a/tests/test_calcfunction.py b/tests/test_calcfunction.py index d6c86d00..ab93018f 100644 --- a/tests/test_calcfunction.py +++ b/tests/test_calcfunction.py @@ -11,7 +11,7 @@ def test_run(wg_calcfunction: WorkGraph) -> None: print("state: ", wg.state) # print("results: ", results[]) assert wg.tasks["sumdiff2"].node.outputs.sum == 9 - assert wg.tasks["sumdiff2"].outputs["sum"].value == 9 + assert wg.tasks["sumdiff2"].outputs["sum"].socket_value == 9 @pytest.mark.usefixtures("started_daemon_client") @@ -27,4 +27,4 @@ def add(**kwargs): wg = WorkGraph("test_dynamic_inputs") wg.add_task(add, name="add1", x=orm.Int(1), y=orm.Int(2)) wg.run() - assert wg.tasks["add1"].outputs["result"].value == 3 + assert wg.tasks["add1"].outputs["result"].socket_value == 3 diff --git a/tests/test_calcjob.py b/tests/test_calcjob.py index 74768f5f..d22c40b7 100644 --- a/tests/test_calcjob.py +++ b/tests/test_calcjob.py @@ -8,4 +8,4 @@ def test_submit(wg_calcjob: WorkGraph) -> None: wg = wg_calcjob wg.name = "test_submit_calcjob" wg.submit(wait=True) - assert wg.tasks["add2"].outputs["sum"].value == 9 + assert wg.tasks["add2"].outputs["sum"].socket_value == 9 diff --git a/tests/test_ctx.py b/tests/test_ctx.py index bffbc872..3cdd1cbf 100644 --- a/tests/test_ctx.py +++ b/tests/test_ctx.py @@ -23,7 +23,7 @@ def test_workgraph_ctx(decorated_add: Callable) -> None: get_ctx1.waiting_on.add(add1) add2 = wg.add_task(decorated_add, "add2", x=get_ctx1.outputs["result"], y=1) wg.run() - assert add2.outputs["result"].value == 6 + assert add2.outputs["result"].socket_value == 6 @pytest.mark.usefixtures("started_daemon_client") @@ -40,4 +40,4 @@ def test_task_set_ctx(decorated_add: Callable) -> None: add2 = wg.add_task(decorated_add, "add2", y="{{ sum }}") wg.add_link(add1.outputs[0], add2.inputs["x"]) wg.submit(wait=True) - assert add2.outputs["result"].value == 10 + assert add2.outputs["result"].socket_value == 10 diff --git a/tests/test_data_task.py b/tests/test_data_task.py index e2939d65..8668b1e8 100644 --- a/tests/test_data_task.py +++ b/tests/test_data_task.py @@ -16,7 +16,7 @@ def test_data_task(identifier, data) -> None: wg = WorkGraph("test_normal_task") task1 = wg.add_task(identifier, name="task1", value=data) wg.run() - assert task1.outputs["result"].value == data + assert task1.outputs["result"].socket_value == data def test_data_dict_task(): @@ -25,7 +25,7 @@ def test_data_dict_task(): wg = WorkGraph("test_data_dict_task") task1 = wg.add_task("workgraph.aiida_dict", name="task1", value={"a": 1}) wg.run() - assert task1.outputs["result"].value == {"a": 1} + assert task1.outputs["result"].socket_value == {"a": 1} def test_data_list_task(): @@ -34,4 +34,4 @@ def test_data_list_task(): wg = WorkGraph("test_data_list_task") task1 = wg.add_task("workgraph.aiida_list", name="task1", value=[1, 2, 3]) wg.run() - assert task1.outputs["result"].value == [1, 2, 3] + assert task1.outputs["result"].socket_value == [1, 2, 3] diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 5bbb995a..a84f4e32 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -51,7 +51,7 @@ def test(a, b=1, **c): assert tdata["var_kwargs"] == "c" assert set(n.get_output_names()) == set(["result", "_outputs", "_wait"]) assert isinstance(n.inputs.c, TaskSocketNamespace) - assert set(n.inputs.metadata._keys()) == metadata_kwargs + assert set(n.inputs.metadata._get_keys()) == metadata_kwargs @pytest.fixture(params=["decorator_factory", "decorator"]) @@ -121,7 +121,7 @@ def test_decorators_workfunction_args(task_workfunction) -> None: assert tdata["var_args"] is None assert tdata["var_kwargs"] == "c" assert set(n.get_output_names()) == set(["result", "_outputs", "_wait"]) - assert set(n.inputs.metadata._keys()) == metadata_kwargs + assert set(n.inputs.metadata._get_keys()) == metadata_kwargs def test_decorators_parameters() -> None: @@ -181,7 +181,7 @@ def test_inputs_outputs_workchain() -> None: wg = WorkGraph() task = wg.add_task(MultiplyAddWorkChain) assert "metadata" in task.get_input_names() - assert "call_link_label" in task.inputs.metadata._keys() + assert "call_link_label" in task.inputs.metadata._get_keys() assert "result" in task.get_output_names() @@ -192,7 +192,7 @@ def test_decorator_calcfunction(decorated_add: Callable) -> None: wg = WorkGraph(name="test_decorator_calcfunction") wg.add_task(decorated_add, "add1", x=2, y=3) wg.submit(wait=True, timeout=100) - assert wg.tasks["add1"].outputs["result"].value == 5 + assert wg.tasks["add1"].outputs["result"].socket_value == 5 def test_decorator_workfunction(decorated_add_multiply: Callable) -> None: @@ -201,7 +201,7 @@ def test_decorator_workfunction(decorated_add_multiply: Callable) -> None: wg = WorkGraph(name="test_decorator_workfunction") wg.add_task(decorated_add_multiply, "add_multiply1", x=2, y=3, z=4) wg.submit(wait=True, timeout=100) - assert wg.tasks["add_multiply1"].outputs["result"].value == 20 + assert wg.tasks["add_multiply1"].outputs["result"].socket_value == 20 @pytest.mark.usefixtures("started_daemon_client") @@ -216,5 +216,5 @@ def test_decorator_graph_builder(decorated_add_multiply_group: Callable) -> None # use run to check if graph builder workgraph can be submit inside the engine wg.run() assert wg.tasks["add_multiply1"].process.outputs.result.value == 32 - assert wg.tasks["add_multiply1"].outputs["result"].value == 32 - assert wg.tasks["sum_diff1"].outputs["sum"].value == 32 + assert wg.tasks["add_multiply1"].outputs["result"].socket_value == 32 + assert wg.tasks["sum_diff1"].outputs["sum"].socket_value == 32 diff --git a/tests/test_engine.py b/tests/test_engine.py index 04c4fea5..72cb21be 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -50,4 +50,4 @@ def test_max_number_jobs(add_code) -> None: wg.submit(wait=True, timeout=40) report = get_workchain_report(wg.process, "REPORT") assert "tasks ready to run: add2" in report - wg.tasks["add2"].outputs["sum"].value == 2 + wg.tasks["add2"].outputs["sum"].socket_value == 2 diff --git a/tests/test_error_handler.py b/tests/test_error_handler.py index d1eff81f..5b3982c3 100644 --- a/tests/test_error_handler.py +++ b/tests/test_error_handler.py @@ -1,10 +1,8 @@ -import pytest from aiida_workgraph import WorkGraph, Task from aiida import orm from aiida.calculations.arithmetic.add import ArithmeticAddCalculation -@pytest.mark.usefixtures("started_daemon_client") def test_error_handlers(add_code): """Test error handlers.""" from aiida.cmdline.utils.common import get_workchain_report @@ -16,8 +14,8 @@ def handle_negative_sum(task: Task): # modify task inputs task.set( { - "x": orm.Int(abs(task.inputs["x"].value)), - "y": orm.Int(abs(task.inputs["y"].value)), + "x": orm.Int(abs(task.inputs["x"].socket_value)), + "y": orm.Int(abs(task.inputs["y"].socket_value)), } ) msg = "Run error handler: handle_negative_sum." @@ -39,12 +37,11 @@ def handle_negative_sum(task: Task): }, ) assert len(wg.error_handlers) == 1 - wg.submit( + wg.run( inputs={ "add1": {"code": add_code, "x": orm.Int(1), "y": orm.Int(-2)}, }, - wait=True, ) report = get_workchain_report(wg.process, "REPORT") assert "Run error handler: handle_negative_sum." in report - assert wg.tasks["add1"].outputs["sum"].value == 3 + assert wg.tasks["add1"].outputs["sum"].socket_value == 3 diff --git a/tests/test_if.py b/tests/test_if.py index 458810bd..8e9fbf56 100644 --- a/tests/test_if.py +++ b/tests/test_if.py @@ -30,7 +30,7 @@ def test_if_task(decorated_add, decorated_multiply, decorated_compare): ) add3 = wg.add_task(decorated_add, name="add3", x=select1.outputs["result"], y=1) wg.run() - assert add3.outputs["result"].value == 5 + assert add3.outputs["result"].socket_value == 5 def test_empty_if_task(): diff --git a/tests/test_pythonjob.py b/tests/test_pythonjob.py index 8fe7369b..5928eff3 100644 --- a/tests/test_pythonjob.py +++ b/tests/test_pythonjob.py @@ -40,9 +40,9 @@ def multiply(x: Any, y: Any) -> Any: ) # wg.submit(wait=True) wg.run() - assert wg.tasks["add1"].outputs["sum"].value.value == 3 - assert wg.tasks["add1"].outputs["diff"].value.value == -1 - assert wg.tasks["multiply1"].outputs["result"].value.value == 9 + assert wg.tasks["add1"].outputs["sum"].socket_value.value == 3 + assert wg.tasks["add1"].outputs["diff"].socket_value.value == -1 + assert wg.tasks["multiply1"].outputs["result"].socket_value.value == 9 # process_label and label assert wg.tasks["add1"].node.process_label == "PythonJob" assert wg.tasks["add1"].node.label == "add1" @@ -72,10 +72,10 @@ def add(x, y=1, **kwargs): ) # data inside the kwargs should be serialized separately wg.process.inputs.wg.tasks.add.inputs.kwargs.socket_property.value.m.value == 2 - assert wg.tasks["add"].outputs["result"].value.value == 8 + assert wg.tasks["add"].outputs["result"].socket_value.value == 8 # load the workgraph wg = WorkGraph.load(wg.pk) - assert wg.tasks["add"].inputs["kwargs"].value == {"m": 2, "n": 3} + assert wg.tasks["add"].inputs["kwargs"].socket_value == {"m": 2, "n": 3} def test_PythonJob_namespace_output_input(fixture_localhost, python_executable_path): @@ -141,10 +141,10 @@ def myfunc3(x, y): }, } wg.run(inputs=inputs) - assert wg.tasks["myfunc"].outputs["add_multiply"].value.add.value == 3 - assert wg.tasks["myfunc"].outputs["add_multiply"].value.multiply.value == 2 - assert wg.tasks["myfunc2"].outputs["result"].value.value == 8 - assert wg.tasks["myfunc3"].outputs["result"].value.value == 7 + assert wg.tasks.myfunc.outputs.add_multiply.add.socket_value == 3 + assert wg.tasks.myfunc.outputs.add_multiply.multiply.socket_value == 2 + assert wg.tasks.myfunc2.outputs.result.socket_value == 8 + assert wg.tasks.myfunc3.outputs.result.socket_value == 7 def test_PythonJob_copy_files(fixture_localhost, python_executable_path): @@ -184,7 +184,7 @@ def multiply(x_folder_name, y_folder_name): wg.tasks["multiply"].inputs["copy_files"], ) # ------------------------- Submit the calculation ------------------- - wg.submit( + wg.run( inputs={ "add1": { "x": 2, @@ -205,9 +205,8 @@ def multiply(x_folder_name, y_folder_name): "command_info": {"label": python_executable_path}, }, }, - wait=True, ) - assert wg.tasks["multiply"].outputs["result"].value.value == 25 + assert wg.tasks["multiply"].outputs["result"].socket_value.value == 25 def test_load_pythonjob(fixture_localhost, python_executable_path): @@ -231,10 +230,10 @@ def add(x: str, y: str) -> str: }, # wait=True, ) - assert wg.tasks["add"].outputs["result"].value.value == "Hello, World!" + assert wg.tasks["add"].outputs["result"].socket_value.value == "Hello, World!" wg = WorkGraph.load(wg.pk) - wg.tasks["add"].inputs["x"].value = "Hello, " - wg.tasks["add"].inputs["y"].value = "World!" + wg.tasks["add"].inputs["x"].socket_value = "Hello, " + wg.tasks["add"].inputs["y"].socket_value = "World!" def test_exit_code(fixture_localhost, python_executable_path): @@ -246,7 +245,12 @@ def handle_negative_sum(task: Task): Simply make the inputs positive by taking the absolute value. """ - task.set({"x": abs(task.inputs["x"].value), "y": abs(task.inputs["y"].value)}) + task.set( + { + "x": abs(task.inputs["x"].socket_value), + "y": abs(task.inputs["y"].socket_value), + } + ) return "Run error handler: handle_negative_sum." @@ -281,4 +285,4 @@ def add(x: array, y: array) -> array: ) # the final task should have exit status 0 assert wg.tasks["add1"].node.exit_status == 0 - assert (wg.tasks["add1"].outputs["sum"].value.value == array([2, 3])).all() + assert (wg.tasks["add1"].outputs["sum"].socket_value.value == array([2, 3])).all() diff --git a/tests/test_shell.py b/tests/test_shell.py index cdd7e778..4192a5fe 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -72,11 +72,11 @@ def test_dynamic_port(): ) wg.add_link(echo_task.outputs["copied_file"], cat_task.inputs["nodes.input1"]) # task will create input for each item in the dynamic port (nodes) - assert "nodes.input1" in cat_task.get_input_names() - assert "nodes.input2" in cat_task.get_input_names() + assert "nodes.input1" in cat_task.inputs + assert "nodes.input2" in cat_task.inputs # if the value of the item is a Socket, then it will create a link, and pop the item - assert "nodes.input3" in cat_task.get_input_names() - assert cat_task.inputs["nodes"].value == {"input1": None, "input2": Int(2)} + assert "nodes.input3" in cat_task.inputs + assert cat_task.inputs["nodes"].socket_value == {"input2": Int(2)} @pytest.mark.usefixtures("started_daemon_client") @@ -123,4 +123,4 @@ def parser(dirpath): wg = WorkGraph(name="test_shell_graph_builder") add_multiply1 = wg.add_task(add_multiply, x=Int(2), y=Int(3)) wg.submit(wait=True) - assert add_multiply1.outputs["result"].value.value == 5 + assert add_multiply1.outputs["result"].socket_value.value == 5 diff --git a/tests/test_socket.py b/tests/test_socket.py index d96ed634..07439e13 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -118,7 +118,6 @@ def test_numpy_array(decorated_normal_add): wg.submit(wait=True) # wg.run() assert wg.state.upper() == "FINISHED" - # assert (wg.tasks["add1"].outputs["result"].value == np.array([5, 7, 9])).all() def test_kwargs() -> None: diff --git a/tests/test_task_from_workgraph.py b/tests/test_task_from_workgraph.py index d7afbdf4..e5e8f773 100644 --- a/tests/test_task_from_workgraph.py +++ b/tests/test_task_from_workgraph.py @@ -14,10 +14,10 @@ def test_inputs_outptus(wg_calcfunction: WorkGraph) -> None: noutput = 0 for sub_task in wg_calcfunction.tasks: noutput += len(sub_task.outputs) - 2 + 1 - assert len(task1.inputs) == ninput + 1 - assert len(task1.outputs) == noutput + 2 - assert "sumdiff1.x" in task1.get_input_names() - assert "sumdiff1.sum" in task1.get_output_names() + assert len(task1.inputs) == len(wg_calcfunction.tasks) + 1 + assert len(task1.outputs) == len(wg_calcfunction.tasks) + 2 + assert "sumdiff1.x" in task1.inputs + assert "sumdiff1.sum" in task1.outputs @pytest.mark.usefixtures("started_daemon_client") @@ -33,7 +33,7 @@ def test_build_task_from_workgraph(decorated_add: Callable) -> None: add1_task = wg.add_task(decorated_add, name="add1", x=1, y=3) wg_task = wg.add_task(sub_wg, name="sub_wg") # the default value of the namespace is None - assert wg_task.inputs["add1"].value is None + assert wg_task.inputs["add1"].socket_value == {"metadata": {}} wg.add_task(decorated_add, name="add2", y=3) wg.add_link(add1_task.outputs["result"], wg_task.inputs["add1.x"]) wg.add_link(wg_task.outputs["add2.result"], wg.tasks["add2"].inputs["x"]) @@ -41,4 +41,4 @@ def test_build_task_from_workgraph(decorated_add: Callable) -> None: assert len(wg_task.outputs) == 6 wg.submit(wait=True) # wg.run() - assert wg.tasks["add2"].outputs["result"].value.value == 12 + assert wg.tasks["add2"].outputs["result"].socket_value.value == 12 diff --git a/tests/test_while.py b/tests/test_while.py index 6c677e85..841cfc9b 100644 --- a/tests/test_while.py +++ b/tests/test_while.py @@ -85,7 +85,7 @@ def raw_python_code(): for link in wg.process.base.links.get_outgoing().all(): if isinstance(link.node, orm.ProcessNode): print(link.node.label, link.node.outputs.result) - assert add2.outputs["result"].value.value == raw_python_code().value + assert add2.outputs["result"].socket_value.value == raw_python_code().value @pytest.mark.usefixtures("started_daemon_client") @@ -105,7 +105,7 @@ def test_while_workgraph(decorated_add, decorated_multiply, decorated_compare): wg.add_link(multiply1.outputs["result"], add1.inputs["x"]) wg.submit(wait=True, timeout=100) assert wg.execution_count == 4 - assert wg.tasks["add1"].outputs["result"].value == 29 + assert wg.tasks["add1"].outputs["result"].socket_value == 29 @pytest.mark.usefixtures("started_daemon_client") @@ -137,5 +137,5 @@ def my_while(n=0, limit=100): wg.add_link(add1.outputs["result"], my_while1.inputs["limit"]) wg.add_link(my_while1.outputs["result"], add2.inputs["x"]) wg.submit(wait=True, timeout=100) - assert add2.outputs["result"].value < 31 + assert add2.outputs["result"].socket_value < 31 assert my_while1.node.outputs.execution_count == 2