From fbd90ac9f67e021d8b6e87d29d10e21c9d70d01e Mon Sep 17 00:00:00 2001 From: superstar54 Date: Fri, 13 Dec 2024 06:39:06 +0100 Subject: [PATCH] To provide a better user experience, we raise an exception explicitly when the timeout is exceeded in the wait method. --- src/aiida_workgraph/workgraph.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/aiida_workgraph/workgraph.py b/src/aiida_workgraph/workgraph.py index 1181ed17..dba249e0 100644 --- a/src/aiida_workgraph/workgraph.py +++ b/src/aiida_workgraph/workgraph.py @@ -110,14 +110,14 @@ def submit( self, inputs: Optional[Dict[str, Any]] = None, wait: bool = False, - timeout: int = 60, + timeout: int = 600, interval: int = 5, metadata: Optional[Dict[str, Any]] = None, ) -> aiida.orm.ProcessNode: """Submit the AiiDA workgraph process and optionally wait for it to finish. Args: wait (bool): Wait for the process to finish. - timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 60. + timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 600. restart (bool): Restart the process, and reset the modified tasks, then only re-run the modified tasks. new (bool): Submit a new process. """ @@ -228,11 +228,17 @@ def get_error_handlers(self) -> Dict[str, Any]: task["exit_codes"] = exit_codes return error_handlers - def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 5) -> None: + def wait(self, timeout: int = 600, tasks: dict = None, interval: int = 5) -> None: """ Periodically checks and waits for the AiiDA workgraph process to finish until a given timeout. + Args: - timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 50. + timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 600. + tasks (dict): Optional; specifies task states to wait for in the format {task_name: [acceptable_states]}. + interval (int): The time interval in seconds between checks. Defaults to 5. + + Raises: + TimeoutError: If the process does not finish within the given timeout. """ terminating_states = ( "KILLED", @@ -245,8 +251,10 @@ def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 5) -> None start = time.time() self.update() finished = False + while not finished: self.update() + if tasks is not None: states = [] for name, value in tasks.items(): @@ -255,9 +263,17 @@ def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 5) -> None finished = all(states) else: finished = self.state in terminating_states + + if finished: + print(f"Process {self.process.pk} finished with state: {self.state}") + return + time.sleep(interval) + if time.time() - start > timeout: - break + raise TimeoutError( + f"Timeout reached after {timeout} seconds while waiting for the WorkGraph: {self.process.pk}. " + ) def update(self) -> None: """