Skip to content

Commit

Permalink
To provide a better user experience, we raise an exception explicitly…
Browse files Browse the repository at this point in the history
… when the timeout is exceeded in the wait method.
  • Loading branch information
superstar54 committed Dec 13, 2024
1 parent 575106e commit fbd90ac
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions src/aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,14 @@ def submit(
self,
inputs: Optional[Dict[str, Any]] = None,
wait: bool = False,
timeout: int = 60,
timeout: int = 600,
interval: int = 5,
metadata: Optional[Dict[str, Any]] = None,
) -> aiida.orm.ProcessNode:
"""Submit the AiiDA workgraph process and optionally wait for it to finish.
Args:
wait (bool): Wait for the process to finish.
timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 60.
timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 600.
restart (bool): Restart the process, and reset the modified tasks, then only re-run the modified tasks.
new (bool): Submit a new process.
"""
Expand Down Expand Up @@ -228,11 +228,17 @@ def get_error_handlers(self) -> Dict[str, Any]:
task["exit_codes"] = exit_codes
return error_handlers

def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 5) -> None:
def wait(self, timeout: int = 600, tasks: dict = None, interval: int = 5) -> None:
"""
Periodically checks and waits for the AiiDA workgraph process to finish until a given timeout.
Args:
timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 50.
timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 600.
tasks (dict): Optional; specifies task states to wait for in the format {task_name: [acceptable_states]}.
interval (int): The time interval in seconds between checks. Defaults to 5.
Raises:
TimeoutError: If the process does not finish within the given timeout.
"""
terminating_states = (
"KILLED",
Expand All @@ -245,8 +251,10 @@ def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 5) -> None
start = time.time()
self.update()
finished = False

while not finished:
self.update()

if tasks is not None:
states = []
for name, value in tasks.items():
Expand All @@ -255,9 +263,17 @@ def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 5) -> None
finished = all(states)
else:
finished = self.state in terminating_states

if finished:
print(f"Process {self.process.pk} finished with state: {self.state}")
return

time.sleep(interval)

if time.time() - start > timeout:
break
raise TimeoutError(
f"Timeout reached after {timeout} seconds while waiting for the WorkGraph: {self.process.pk}. "
)

def update(self) -> None:
"""
Expand Down

0 comments on commit fbd90ac

Please sign in to comment.