diff --git a/cylc/flow/scripts/workflow_state.py b/cylc/flow/scripts/workflow_state.py index a10a4efd40e..427f992f32d 100755 --- a/cylc/flow/scripts/workflow_state.py +++ b/cylc/flow/scripts/workflow_state.py @@ -84,7 +84,7 @@ import asyncio import sqlite3 import sys -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional from cylc.flow.pathutil import get_cylc_run_dir from cylc.flow.id import Tokens @@ -117,9 +117,16 @@ class WorkflowPoller(Poller): """An object that polls for task states or outputs in a workflow DB.""" def __init__( - self, id_, offset, flow_num, alt_cylc_run_dir, default_status, - is_output, is_message, old_format, - *args, **kwargs + self, + id_: str, + offset: Optional[str], + flow_num: Optional[int], + alt_cylc_run_dir: Optional[str], + default_status: Optional[str], + is_output: bool, + is_message: bool, + old_format: bool, + **kwargs ): self.id_ = id_ self.offset = offset @@ -133,16 +140,15 @@ def __init__( self.cycle_raw = tokens["cycle"] self.task = tokens["task"] - self.workflow_id = None - self.cycle = None - self.results = None - self.db_checker = None + self.workflow_id: Optional[str] = None + self.cycle: Optional[str] = None + self.result: Optional[List[List[str]]] = None + self._db_checker: Optional[CylcWorkflowDBChecker] = None + self.is_message = is_message if is_message: - self.is_message = is_message self.is_output = False else: - self.is_message = False self.is_output = ( is_output or ( @@ -151,7 +157,7 @@ def __init__( ) ) - super().__init__(*args, **kwargs) + super().__init__(**kwargs) def _find_workflow(self) -> bool: """Find workflow and infer run directory, return True if found.""" @@ -169,18 +175,23 @@ def _find_workflow(self) -> bool: return True - def _db_connect(self) -> bool: - """Connect to workflow DB, return True if connected.""" - try: - self.db_checker = CylcWorkflowDBChecker( - get_cylc_run_dir(self.alt_cylc_run_dir), - self.workflow_id - ) - except (OSError, sqlite3.Error): - LOG.debug("DB not connected") - return False + @property + def db_checker(self) -> Optional[CylcWorkflowDBChecker]: + """Connect to workflow DB if not already connected. - return True + Returns DB checker if connected. + """ + if not self._db_checker: + try: + self._db_checker = CylcWorkflowDBChecker( + get_cylc_run_dir(self.alt_cylc_run_dir), + self.workflow_id + ) + except (OSError, sqlite3.Error): + LOG.debug("DB not connected") + return None + + return self._db_checker async def check(self) -> bool: """Return True if requested state achieved, else False. @@ -193,7 +204,7 @@ async def check(self) -> bool: if self.workflow_id is None and not self._find_workflow(): return False - if self.db_checker is None and not self._db_connect(): + if self.db_checker is None: return False if self.cycle is None: @@ -275,13 +286,13 @@ def main(parser: COP, options: 'Values', *ids: str) -> None: options.offset, options.flow_num, options.alt_cylc_run_dir, - None, # default status - options.is_output, - options.is_message, - options.old_format, - f'"{id_}"', - options.interval, - options.max_polls, + default_status=None, + is_output=options.is_output, + is_message=options.is_message, + old_format=options.old_format, + condition=id_, + interval=options.interval, + max_polls=options.max_polls, args=None ) diff --git a/cylc/flow/xtriggers/workflow_state.py b/cylc/flow/xtriggers/workflow_state.py index b71c336c52d..4087f869596 100644 --- a/cylc/flow/xtriggers/workflow_state.py +++ b/cylc/flow/xtriggers/workflow_state.py @@ -28,10 +28,10 @@ def workflow_state( workflow_task_id: str, offset: Optional[str] = None, flow_num: Optional[int] = 1, - is_output: Optional[bool] = False, - is_message: Optional[bool] = False, + is_output: bool = False, + is_message: bool = False, alt_cylc_run_dir: Optional[str] = None, -) -> Tuple[bool, Dict[str, Optional[str]]]: +) -> Tuple[bool, Dict[str, Any]]: """Connect to a workflow DB and check a task status or output. If the status or output has been achieved, return {True, result}. @@ -64,10 +64,11 @@ def workflow_state( workflow_task_id, offset, flow_num, alt_cylc_run_dir, TASK_STATUS_SUCCEEDED, is_output, is_message, - f'"{id}"', - '10', # interval (irrelevant, for a single poll) - 1, # max polls (for xtriggers the scheduler does the polling) - [], {} + old_format=False, + condition=workflow_task_id, + max_polls=1, # (for xtriggers the scheduler does the polling) + interval=0, # irrelevant for 1 poll + args=[] ) if asyncio.run(poller.poll()): return (