diff --git a/changes.d/6472.feat.md b/changes.d/6472.feat.md new file mode 100644 index 00000000000..522ea3eec13 --- /dev/null +++ b/changes.d/6472.feat.md @@ -0,0 +1,5 @@ +`cylc remove` improvements: +- It can now remove tasks that are no longer active, making it look like they never ran. +- Removing a submitted/running task will kill it. +- Added the `--flow` option. +- Removed tasks are now demoted to `flow=none` but retained in the workflow database for provenance. diff --git a/cylc/flow/command_validation.py b/cylc/flow/command_validation.py index eb5e45c4734..bd570e2aa20 100644 --- a/cylc/flow/command_validation.py +++ b/cylc/flow/command_validation.py @@ -24,19 +24,34 @@ ) from cylc.flow.exceptions import InputError -from cylc.flow.id import IDTokens, Tokens +from cylc.flow.flow_mgr import ( + FLOW_ALL, + FLOW_NEW, + FLOW_NONE, +) +from cylc.flow.id import ( + IDTokens, + Tokens, +) from cylc.flow.task_outputs import TASK_OUTPUT_SUCCEEDED -from cylc.flow.flow_mgr import FLOW_ALL, FLOW_NEW, FLOW_NONE -ERR_OPT_FLOW_VAL = "Flow values must be an integer, or 'all', 'new', or 'none'" +ERR_OPT_FLOW_VAL = ( + f"Flow values must be integers, or '{FLOW_ALL}', '{FLOW_NEW}', " + f"or '{FLOW_NONE}'" +) +ERR_OPT_FLOW_VAL_2 = f"Flow values must be integers, or '{FLOW_ALL}'" ERR_OPT_FLOW_COMBINE = "Cannot combine --flow={0} with other flow values" ERR_OPT_FLOW_WAIT = ( f"--wait is not compatible with --flow={FLOW_NEW} or --flow={FLOW_NONE}" ) -def flow_opts(flows: List[str], flow_wait: bool) -> None: +def flow_opts( + flows: List[str], + flow_wait: bool, + allow_new_or_none: bool = True +) -> None: """Check validity of flow-related CLI options. Note the schema defaults flows to []. @@ -63,6 +78,10 @@ def flow_opts(flows: List[str], flow_wait: bool) -> None: cylc.flow.exceptions.InputError: --wait is not compatible with --flow=new or --flow=none + >>> flow_opts(["new"], False, allow_new_or_none=False) + Traceback (most recent call last): + cylc.flow.exceptions.InputError: ... must be integers, or 'all' + """ if not flows: return @@ -70,9 +89,12 @@ def flow_opts(flows: List[str], flow_wait: bool) -> None: flows = [val.strip() for val in flows] for val in flows: + val = val.strip() if val in {FLOW_NONE, FLOW_NEW, FLOW_ALL}: if len(flows) != 1: raise InputError(ERR_OPT_FLOW_COMBINE.format(val)) + if not allow_new_or_none and val in {FLOW_NEW, FLOW_NONE}: + raise InputError(ERR_OPT_FLOW_VAL_2) else: try: int(val) diff --git a/cylc/flow/commands.py b/cylc/flow/commands.py index 8ca1ecc7cd1..c45194752ab 100644 --- a/cylc/flow/commands.py +++ b/cylc/flow/commands.py @@ -53,18 +53,24 @@ """ from contextlib import suppress -from time import sleep, time +from time import ( + sleep, + time, +) from typing import ( + TYPE_CHECKING, AsyncGenerator, Callable, Dict, Iterable, List, Optional, - TYPE_CHECKING, + TypeVar, Union, ) +from metomi.isodatetime.parsers import TimePointParser + from cylc.flow import LOG import cylc.flow.command_validation as validate from cylc.flow.exceptions import ( @@ -73,6 +79,7 @@ CylcConfigError, ) import cylc.flow.flags +from cylc.flow.flow_mgr import get_flow_nums_set from cylc.flow.log_level import log_level_to_verbosity from cylc.flow.network.schema import WorkflowStopMode from cylc.flow.parsec.exceptions import ParsecError @@ -80,16 +87,14 @@ from cylc.flow.task_id import TaskID from cylc.flow.workflow_status import StopMode -from metomi.isodatetime.parsers import TimePointParser if TYPE_CHECKING: from cylc.flow.scheduler import Scheduler # define a type for command implementations - Command = Callable[ - ..., - AsyncGenerator, - ] + Command = Callable[..., AsyncGenerator] + # define a generic type needed for the @_command decorator + _TCommand = TypeVar('_TCommand', bound=Command) # a directory of registered commands (populated on module import) COMMANDS: 'Dict[str, Command]' = {} @@ -97,15 +102,15 @@ def _command(name: str): """Decorator to register a command.""" - def _command(fcn: 'Command'): + def _command(fcn: '_TCommand') -> '_TCommand': nonlocal name COMMANDS[name] = fcn - fcn.command_name = name # type: ignore + fcn.command_name = name # type: ignore[attr-defined] return fcn return _command -async def run_cmd(fcn, *args, **kwargs): +async def run_cmd(bound_fcn: AsyncGenerator): """Run a command outside of the scheduler's main loop. Normally commands are run via the Scheduler's command_queue (which is @@ -120,10 +125,9 @@ async def run_cmd(fcn, *args, **kwargs): For these purposes use "run_cmd", otherwise, queue commands via the scheduler as normal. """ - cmd = fcn(*args, **kwargs) - await cmd.__anext__() # validate + await bound_fcn.__anext__() # validate with suppress(StopAsyncIteration): - return await cmd.__anext__() # run + return await bound_fcn.__anext__() # run @_command('set') @@ -311,11 +315,15 @@ async def set_verbosity(schd: 'Scheduler', level: Union[int, str]): @_command('remove_tasks') -async def remove_tasks(schd: 'Scheduler', tasks: Iterable[str]): +async def remove_tasks( + schd: 'Scheduler', tasks: Iterable[str], flow: List[str] +): """Remove tasks.""" validate.is_tasks(tasks) + validate.flow_opts(flow, flow_wait=False, allow_new_or_none=False) yield - yield schd.pool.remove_tasks(tasks) + flow_nums = get_flow_nums_set(flow) + schd.remove_tasks(tasks, flow_nums) @_command('reload_workflow') diff --git a/cylc/flow/data_store_mgr.py b/cylc/flow/data_store_mgr.py index 0cc55580a8a..47c2a4c9efa 100644 --- a/cylc/flow/data_store_mgr.py +++ b/cylc/flow/data_store_mgr.py @@ -2357,22 +2357,42 @@ def delta_task_held( self.updates_pending = True def delta_task_flow_nums(self, itask: TaskProxy) -> None: - """Create delta for change in task proxy flow_nums. + """Create delta for change in task proxy flow numbers. Args: - itask (cylc.flow.task_proxy.TaskProxy): - Update task-node from corresponding task proxy - objects from the workflow task pool. - + itask: TaskProxy with updated flow numbers. """ tproxy: Optional[PbTaskProxy] tp_id, tproxy = self.store_node_fetcher(itask.tokens) if not tproxy: return - tp_delta = self.updated[TASK_PROXIES].setdefault( - tp_id, PbTaskProxy(id=tp_id)) + self._delta_task_flow_nums(tp_id, itask.flow_nums) + + def delta_remove_task_flow_nums( + self, task: str, removed: 'FlowNums' + ) -> None: + """Create delta for removal of flow numbers from a task proxy. + + Args: + task: Relative ID of task. + removed: Flow numbers to remove from the task proxy in the + data store. + """ + tproxy: Optional[PbTaskProxy] + tp_id, tproxy = self.store_node_fetcher( + Tokens(task, relative=True).duplicate(**self.id_) + ) + if not tproxy: + return + new_flow_nums = deserialise_set(tproxy.flow_nums).difference(removed) + self._delta_task_flow_nums(tp_id, new_flow_nums) + + def _delta_task_flow_nums(self, tp_id: str, flow_nums: 'FlowNums') -> None: + tp_delta: PbTaskProxy = self.updated[TASK_PROXIES].setdefault( + tp_id, PbTaskProxy(id=tp_id) + ) tp_delta.stamp = f'{tp_id}@{time()}' - tp_delta.flow_nums = serialise_set(itask.flow_nums) + tp_delta.flow_nums = serialise_set(flow_nums) self.updates_pending = True def delta_task_output( diff --git a/cylc/flow/dbstatecheck.py b/cylc/flow/dbstatecheck.py index fc2d9cf0da3..3fbad5c6723 100644 --- a/cylc/flow/dbstatecheck.py +++ b/cylc/flow/dbstatecheck.py @@ -28,7 +28,7 @@ IntegerPoint, IntegerInterval ) -from cylc.flow.flow_mgr import stringify_flow_nums +from cylc.flow.flow_mgr import repr_flow_nums from cylc.flow.pathutil import expand_path from cylc.flow.rundb import CylcWorkflowDAO from cylc.flow.task_outputs import ( @@ -318,7 +318,7 @@ def workflow_state_query( if flow_num is not None and flow_num not in flow_nums: # skip result, wrong flow continue - fstr = stringify_flow_nums(flow_nums) + fstr = repr_flow_nums(flow_nums) if fstr: res.append(fstr) db_res.append(res) diff --git a/cylc/flow/flow_mgr.py b/cylc/flow/flow_mgr.py index 1cd1c1e8c70..67f816982ec 100644 --- a/cylc/flow/flow_mgr.py +++ b/cylc/flow/flow_mgr.py @@ -16,8 +16,15 @@ """Manage flow counter and flow metadata.""" -from typing import Dict, Set, Optional, TYPE_CHECKING import datetime +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + List, + Optional, + Set, +) from cylc.flow import LOG @@ -55,36 +62,62 @@ def add_flow_opts(parser): ) -def stringify_flow_nums(flow_nums: Set[int], full: bool = False) -> str: - """Return a string representation of a set of flow numbers +def get_flow_nums_set(flow: List[str]) -> FlowNums: + """Return set of integer flow numbers from list of strings. - Return: - - "none" for no flow - - "" for the original flow (flows only matter if there are several) - - otherwise e.g. "(flow=1,2,3)" + Returns an empty set if the input is empty or contains only "all". + + >>> get_flow_nums_set(["1", "2", "3"]) + {1, 2, 3} + >>> get_flow_nums_set([]) + set() + >>> get_flow_nums_set(["all"]) + set() + """ + if flow == [FLOW_ALL]: + return set() + return {int(val.strip()) for val in flow} + + +def stringify_flow_nums(flow_nums: Iterable[int]) -> str: + """Return the canonical string for a set of flow numbers. Examples: + >>> stringify_flow_nums({1}) + '1' + + >>> stringify_flow_nums({3, 1, 2}) + '1,2,3' + >>> stringify_flow_nums({}) + '' + + """ + return ','.join(str(i) for i in sorted(flow_nums)) + + +def repr_flow_nums(flow_nums: FlowNums, full: bool = False) -> str: + """Return a representation of a set of flow numbers + + If `full` is False, return an empty string for flows=1. + + Examples: + >>> repr_flow_nums({}) '(flows=none)' - >>> stringify_flow_nums({1}) + >>> repr_flow_nums({1}) '' - >>> stringify_flow_nums({1}, True) + >>> repr_flow_nums({1}, full=True) '(flows=1)' - >>> stringify_flow_nums({1,2,3}) + >>> repr_flow_nums({1,2,3}) '(flows=1,2,3)' """ if not full and flow_nums == {1}: return "" - else: - return ( - "(flows=" - f"{','.join(str(i) for i in flow_nums) or 'none'}" - ")" - ) + return f"(flows={stringify_flow_nums(flow_nums) or 'none'})" class FlowMgr: diff --git a/cylc/flow/id.py b/cylc/flow/id.py index f2c8b05b4a1..68c62e3a118 100644 --- a/cylc/flow/id.py +++ b/cylc/flow/id.py @@ -22,6 +22,7 @@ from enum import Enum import re from typing import ( + TYPE_CHECKING, Iterable, List, Optional, @@ -33,6 +34,10 @@ from cylc.flow import LOG +if TYPE_CHECKING: + from cylc.flow.cycling import PointBase + + class IDTokens(Enum): """Cylc object identifier tokens.""" @@ -524,14 +529,14 @@ def duplicate( ) -def quick_relative_detokenise(cycle, task): +def quick_relative_id(cycle: Union[str, int, 'PointBase'], task: str) -> str: """Generate a relative ID for a task. This is a more efficient solution to `Tokens` for cases where you only want the ID string and don't have any use for a Tokens object. Example: - >>> q = quick_relative_detokenise + >>> q = quick_relative_id >>> q('1', 'a') == Tokens(cycle='1', task='a').relative_id True diff --git a/cylc/flow/network/schema.py b/cylc/flow/network/schema.py index 6388b7b7e87..84e019b17e0 100644 --- a/cylc/flow/network/schema.py +++ b/cylc/flow/network/schema.py @@ -22,8 +22,8 @@ from operator import attrgetter from typing import ( TYPE_CHECKING, - AsyncGenerator, Any, + AsyncGenerator, Dict, List, Optional, @@ -34,47 +34,71 @@ import graphene from graphene import ( - Boolean, Field, Float, ID, InputObjectType, Int, - Mutation, ObjectType, Schema, String, Argument, Interface + ID, + Argument, + Boolean, + Field, + Float, + InputObjectType, + Int, + Interface, + Mutation, + ObjectType, + Schema, + String, ) from graphene.types.generic import GenericScalar from graphene.utils.str_converters import to_snake_case from graphql.type.definition import get_named_type from cylc.flow import LOG_LEVELS -from cylc.flow.broadcast_mgr import ALL_CYCLE_POINTS_STRS, addict +from cylc.flow.broadcast_mgr import ( + ALL_CYCLE_POINTS_STRS, + addict, +) from cylc.flow.data_store_mgr import ( - FAMILIES, FAMILY_PROXIES, JOBS, TASKS, TASK_PROXIES, - DELTA_ADDED, DELTA_UPDATED + DELTA_ADDED, + DELTA_UPDATED, + FAMILIES, + FAMILY_PROXIES, + JOBS, + TASK_PROXIES, + TASKS, +) +from cylc.flow.flow_mgr import ( + FLOW_ALL, + FLOW_NEW, + FLOW_NONE, ) -from cylc.flow.flow_mgr import FLOW_ALL, FLOW_NEW, FLOW_NONE from cylc.flow.id import Tokens from cylc.flow.run_modes import ( TASK_CONFIG_RUN_MODES, WORKFLOW_RUN_MODES, RunMode) from cylc.flow.task_outputs import SORT_ORDERS from cylc.flow.task_state import ( - TASK_STATUSES_ORDERED, TASK_STATUS_DESC, - TASK_STATUS_WAITING, TASK_STATUS_EXPIRED, + TASK_STATUS_FAILED, TASK_STATUS_PREPARING, + TASK_STATUS_RUNNING, TASK_STATUS_SUBMIT_FAILED, TASK_STATUS_SUBMITTED, - TASK_STATUS_RUNNING, - TASK_STATUS_FAILED, - TASK_STATUS_SUCCEEDED + TASK_STATUS_SUCCEEDED, + TASK_STATUS_WAITING, + TASK_STATUSES_ORDERED, ) from cylc.flow.util import sstrip from cylc.flow.workflow_status import StopMode + if TYPE_CHECKING: from enum import Enum from graphql import ResolveInfo from graphql.type.definition import ( - GraphQLNamedType, GraphQLList, + GraphQLNamedType, GraphQLNonNull, ) + from cylc.flow.network.resolvers import BaseResolvers @@ -2141,6 +2165,19 @@ class Meta: ''') resolver = partial(mutator, command='remove_tasks') + class Arguments(TaskMutation.Arguments): + flow = graphene.List( + graphene.NonNull(Flow), + default_value=[FLOW_ALL], + description=sstrip(f''' + "Remove the task(s) from the specified flows. " + + This should be a list of flow numbers, or '{FLOW_ALL}' + to remove the task(s) from all flows they belong to + (which is the default). + ''') + ) + class SetPrereqsAndOutputs(Mutation, TaskMutation): class Meta: diff --git a/cylc/flow/platforms.py b/cylc/flow/platforms.py index 110d135edeb..c75faa5283b 100644 --- a/cylc/flow/platforms.py +++ b/cylc/flow/platforms.py @@ -268,7 +268,9 @@ def platform_from_name( # If platform name in run mode and not otherwise defined: if platform_name in JOBLESS_MODES: - return platforms['localhost'] + platform_data = deepcopy(platforms['localhost']) + platform_data['name'] = 'localhost' + return platform_data raise PlatformLookupError( f"No matching platform \"{platform_name}\" found") @@ -652,7 +654,7 @@ def get_install_target_to_platforms_map( Return {install_target_1: [platform_1_dict, platform_2_dict, ...], ...} """ ret: Dict[str, List[Dict[str, Any]]] = {} - for p_name in set(platform_names) - set(JOBLESS_MODES): + for p_name in set(platform_names): try: platform = platform_from_name(p_name) except PlatformLookupError as exc: @@ -662,10 +664,6 @@ def get_install_target_to_platforms_map( install_target = get_install_target_from_platform(platform) ret.setdefault(install_target, []).append(platform) - # Map jobless modes to localhost. - ret.setdefault('localhost', []).extend( - {'name': mode} for mode in JOBLESS_MODES - ) return ret diff --git a/cylc/flow/prerequisite.py b/cylc/flow/prerequisite.py index cd1b17442e2..8c077248711 100644 --- a/cylc/flow/prerequisite.py +++ b/cylc/flow/prerequisite.py @@ -23,7 +23,6 @@ ItemsView, Iterable, Iterator, - List, NamedTuple, Optional, Set, @@ -39,7 +38,7 @@ from cylc.flow.cycling.loader import get_point from cylc.flow.data_messages_pb2 import PbCondition, PbPrerequisite from cylc.flow.exceptions import TriggerExpressionError -from cylc.flow.id import quick_relative_detokenise +from cylc.flow.id import quick_relative_id from cylc.flow.run_modes import RunMode @@ -48,26 +47,26 @@ from cylc.flow.id import Tokens -AnyPrereqMessage = Tuple[Union['PointBase', str, int], str, str] +AnyPrereqTuple = Tuple[Union['PointBase', str, int], str, str] -class PrereqMessage(NamedTuple): - """A message pertaining to a Prerequisite.""" +class PrereqTuple(NamedTuple): + """A task output in a Prerequisite.""" point: str task: str output: str def get_id(self) -> str: - """Get the relative ID of the task in this prereq message.""" - return quick_relative_detokenise(self.point, self.task) + """Get the relative ID of the task in this prereq output.""" + return quick_relative_id(self.point, self.task) @staticmethod - def coerce(tuple_: AnyPrereqMessage) -> 'PrereqMessage': - """Coerce a tuple to a PrereqMessage.""" - if isinstance(tuple_, PrereqMessage): + def coerce(tuple_: AnyPrereqTuple) -> 'PrereqTuple': + """Coerce a tuple to a PrereqTuple.""" + if isinstance(tuple_, PrereqTuple): return tuple_ point, task, output = tuple_ - return PrereqMessage(point=str(point), task=task, output=output) + return PrereqTuple(point=str(point), task=task, output=output) SatisfiedState = Literal[ @@ -82,19 +81,20 @@ def coerce(tuple_: AnyPrereqMessage) -> 'PrereqMessage': class Prerequisite: """The concrete result of an abstract logical trigger expression. + A Prerequisite object represents the left-hand side of a single graph + arrow. + A single TaskProxy can have multiple Prerequisites, all of which require - satisfying. This corresponds to multiple tasks being dependencies of a task - in Cylc graphs (e.g. `a => c`, `b => c`). But a single Prerequisite can - also have multiple 'messages' (basically, subcomponents of a Prerequisite) - corresponding to parenthesised expressions in Cylc graphs (e.g. - `(a & b) => c` or `(a | b) => c`). For the OR operator (`|`), only one - message has to be satisfied for the Prerequisite to be satisfied. + satisfying. This corresponds to multiple graph arrow dependencies + (e.g. `a => c`, `b => c`). But a single Prerequisite object + can also have multiple dependencies from operator-joined left-hand side + expressions in the graph (e.g. `a & (b | c) => d`). """ # Memory optimization - constrain possible attributes to this list. __slots__ = ( "_satisfied", - "_all_satisfied", + "_cached_satisfied", "conditional_expression", "point", ) @@ -108,19 +108,20 @@ def __init__(self, point: 'PointBase'): # cylc.flow.cycling.PointBase self.point = point - # Dictionary of messages pertaining to this prerequisite. + # Dictionary of task outputs pertaining to this prerequisite + # (i.e. all the outputs on the LHS of the graph arrow). # {('point string', 'task name', 'output'): DEP_STATE_X, ...} - self._satisfied: Dict[PrereqMessage, SatisfiedState] = {} + self._satisfied: Dict[PrereqTuple, SatisfiedState] = {} - # Expression present only when conditions are used. - # '1/foo failed & 1/bar succeeded' + # Expression present only when the OR operator is used. + # '1/foo failed | 1/bar succeeded' self.conditional_expression: Optional[str] = None # The cached state of this prerequisite: # * `None` (no cached state) # * `True` (prerequisite satisfied) # * `False` (prerequisite unsatisfied). - self._all_satisfied: Optional[bool] = None + self._cached_satisfied: Optional[bool] = None def instantaneous_hash(self) -> int: """Generate a hash of this prerequisite in its current state. @@ -136,17 +137,17 @@ def instantaneous_hash(self) -> int: tuple(self._satisfied.keys()), )) - def __getitem__(self, key: AnyPrereqMessage) -> SatisfiedState: + def __getitem__(self, key: AnyPrereqTuple) -> SatisfiedState: """Return the satisfaction state of a dependency. Args: key: Tuple of (point, name, output) for a task. """ - return self._satisfied[PrereqMessage.coerce(key)] + return self._satisfied[PrereqTuple.coerce(key)] def __setitem__( self, - key: AnyPrereqMessage, + key: AnyPrereqTuple, value: Union[SatisfiedState, bool] = False, ) -> None: """Register an output with this prerequisite. @@ -157,37 +158,37 @@ def __setitem__( this should be True). """ - key = PrereqMessage.coerce(key) + key = PrereqTuple.coerce(key) if value is True: value = 'satisfied naturally' self._satisfied[key] = value - if not (self._all_satisfied and value): + if not (self._cached_satisfied and value): # Force later recalculation of cached satisfaction state: - self._all_satisfied = None + self._cached_satisfied = None - def __iter__(self) -> Iterator[PrereqMessage]: + def __iter__(self) -> Iterator[PrereqTuple]: return iter(self._satisfied) - def items(self) -> ItemsView[PrereqMessage, SatisfiedState]: + def items(self) -> ItemsView[PrereqTuple, SatisfiedState]: return self._satisfied.items() def get_raw_conditional_expression(self): """Return a representation of this prereq as a string. - Returns None if this prerequisite is not a conditional one. + Returns None if this prerequisite does not involve an OR operator. """ expr = self.conditional_expression if not expr: return None - for message in self._satisfied: - expr = expr.replace(self.SATISFIED_TEMPLATE % message, - self.MESSAGE_TEMPLATE % message) + for task_output in self._satisfied: + expr = expr.replace(self.SATISFIED_TEMPLATE % task_output, + self.MESSAGE_TEMPLATE % task_output) return expr - def set_condition(self, expr): + def set_conditional_expr(self, expr): """Set the conditional expression for this prerequisite. - Resets the cached state (self._all_satisfied). + Resets the cached state (self._cached_satisfied). Examples: # GH #3644 construct conditional expression when one task name @@ -196,22 +197,22 @@ def set_condition(self, expr): >>> preq = Prerequisite(1) >>> preq[(1, 'foo', 'succeeded')] = False >>> preq[(1, 'xfoo', 'succeeded')] = False - >>> preq.set_condition("1/foo succeeded|1/xfoo succeeded") + >>> preq.set_conditional_expr("1/foo succeeded|1/xfoo succeeded") >>> expr = preq.conditional_expression >>> expr.split('|') # doctest: +NORMALIZE_WHITESPACE ['bool(self._satisfied[("1", "foo", "succeeded")])', 'bool(self._satisfied[("1", "xfoo", "succeeded")])'] """ - self._all_satisfied = None + self._cached_satisfied = None if '|' in expr: # Make a Python expression so we can eval() the logic. - for message in self._satisfied: + for t_output in self._satisfied: # Use '\b' in case one task name is a substring of another # and escape special chars ('.', timezone '+') in task IDs. expr = re.sub( - fr"\b{re.escape(self.MESSAGE_TEMPLATE % message)}\b", - self.SATISFIED_TEMPLATE % message, + fr"\b{re.escape(self.MESSAGE_TEMPLATE % t_output)}\b", + self.SATISFIED_TEMPLATE % t_output, expr ) @@ -223,14 +224,14 @@ def is_satisfied(self): Return cached state if present, else evaluate the prerequisite. """ - if self._all_satisfied is not None: + if self._cached_satisfied is not None: # Cached value. - return self._all_satisfied + return self._cached_satisfied if self._satisfied == {}: # No prerequisites left after pre-initial simplification. return True - self._all_satisfied = self._eval_satisfied() - return self._all_satisfied + self._cached_satisfied = self._eval_satisfied() + return self._cached_satisfied def _eval_satisfied(self) -> bool: """Evaluate the prerequisite's condition expression. @@ -259,25 +260,32 @@ def satisfy_me( self, outputs: Iterable['Tokens'], mode: Optional[RunMode] = None, + forced: bool = False, ) -> 'Set[Tokens]': - """Attempt to satisfy me with given outputs. + """Set the given outputs as satisfied. - Updates cache with the result. Return outputs that match. + Args: + outputs: List of outputs to satisfy. + mode: Task run mode. + forced: If True, records that this should not be undone by + `cylc remove`. """ valid = set() for output in outputs: - prereq = PrereqMessage( + output_tuple = PrereqTuple( output['cycle'], output['task'], output['task_sel'] ) - if prereq not in self._satisfied: + if output_tuple not in self._satisfied: continue valid.add(output) - self[prereq] = ( + msg: SatisfiedState = ( 'satisfied by skip mode' if mode == RunMode.SKIP else 'satisfied naturally' ) + if self._satisfied[output_tuple] != msg: + self[output_tuple] = 'force satisfied' if forced else msg return valid def api_dump(self) -> Optional[PbPrerequisite]: @@ -290,21 +298,21 @@ def api_dump(self) -> Optional[PbPrerequisite]: ).replace('|', ' | ').replace('&', ' & ') else: expr = ' & '.join( - self.MESSAGE_TEMPLATE % s_msg - for s_msg in self._satisfied + self.MESSAGE_TEMPLATE % task_output + for task_output in self._satisfied ) conds = [] num_length = len(str(len(self._satisfied))) - for ind, message_tuple in enumerate(sorted(self._satisfied)): - t_id = message_tuple.get_id() + for ind, output_tuple in enumerate(sorted(self._satisfied)): + t_id = output_tuple.get_id() char = str(ind).zfill(num_length) - c_msg = self.MESSAGE_TEMPLATE % message_tuple - c_val = self._satisfied[message_tuple] + c_msg = self.MESSAGE_TEMPLATE % output_tuple + c_val = self._satisfied[output_tuple] conds.append( PbCondition( task_proxy=t_id, expr_alias=char, - req_state=message_tuple.output, + req_state=output_tuple.output, satisfied=bool(c_val), message=(c_val or 'unsatisfied'), ) @@ -323,17 +331,17 @@ def set_satisfied(self) -> None: State can be overridden by calling `self.satisfy_me`. """ - for message in self._satisfied: - if not self._satisfied[message]: - self._satisfied[message] = 'force satisfied' + for task_output in self._satisfied: + if not self._satisfied[task_output]: + self._satisfied[task_output] = 'force satisfied' if self.conditional_expression: - self._all_satisfied = self._eval_satisfied() + self._cached_satisfied = self._eval_satisfied() else: - self._all_satisfied = True + self._cached_satisfied = True def iter_target_point_strings(self): yield from { - message.point for message in self._satisfied + task_output.point for task_output in self._satisfied } def get_target_points(self): @@ -343,14 +351,15 @@ def get_target_points(self): get_point(p) for p in self.iter_target_point_strings() ] - def get_resolved_dependencies(self) -> List[str]: - """Return a list of satisfied dependencies. - - E.G: ['1/foo', '2/bar'] + def unset_naturally_satisfied(self, id_: str) -> bool: + """Set the dependencies with matching task IDs to unsatisfied only if + they were naturally satisfied. + Returns True if any dependencies were changed. """ - return [ - msg.get_id() - for msg, satisfied in self._satisfied.items() - if satisfied - ] + changed = False + for t_output, sat in self._satisfied.items(): + if t_output.get_id() == id_ and sat and sat != 'force satisfied': + self[t_output] = False + changed = True + return changed diff --git a/cylc/flow/rundb.py b/cylc/flow/rundb.py index 216895fd996..809e0a275d7 100644 --- a/cylc/flow/rundb.py +++ b/cylc/flow/rundb.py @@ -15,6 +15,7 @@ # along with this program. If not, see . """Provide data access object for the workflow runtime database.""" +from collections import defaultdict from contextlib import suppress from dataclasses import dataclass from os.path import expandvars @@ -23,6 +24,8 @@ import traceback from typing import ( TYPE_CHECKING, + Any, + DefaultDict, Dict, Iterable, List, @@ -30,11 +33,13 @@ Set, Tuple, Union, + cast, ) from cylc.flow import LOG from cylc.flow.exceptions import PlatformLookupError import cylc.flow.flags +from cylc.flow.flow_mgr import stringify_flow_nums from cylc.flow.util import ( deserialise_set, serialise_set, @@ -47,6 +52,13 @@ from cylc.flow.flow_mgr import FlowNums +DbArgDict = Dict[str, Any] +DbUpdateTuple = Union[ + Tuple[DbArgDict, DbArgDict], + Tuple[str, list] +] + + @dataclass class CylcWorkflowDAOTableColumn: """Represent a column in a table.""" @@ -69,7 +81,7 @@ class CylcWorkflowDAOTable: def __init__(self, name, column_items): self.name = name - self.columns = [] + self.columns: List[CylcWorkflowDAOTableColumn] = [] for column_item in column_items: name = column_item[0] attrs = {} @@ -81,7 +93,7 @@ def __init__(self, name, column_items): attrs.get("is_primary_key", False))) self.delete_queues = {} self.insert_queue = [] - self.update_queues = {} + self.update_queues: DefaultDict[str, list] = defaultdict(list) def get_create_stmt(self): """Return an SQL statement to create this table.""" @@ -150,14 +162,23 @@ def add_insert_item(self, args): args.get(column.name, None) for column in self.columns] self.insert_queue.append(stmt_args) - def add_update_item(self, set_args, where_args): + def add_update_item(self, item: DbUpdateTuple) -> None: """Queue an UPDATE item. + If stmt is not a string, it should be a tuple (set_args, where_args) - set_args should be a dict, with column keys and values to be set. where_args should be a dict, update will only apply to rows matching all these items. """ + if isinstance(item[0], str): + stmt = item[0] + params = cast('list', item[1]) + self.update_queues[stmt].extend(params) + return + + set_args = item[0] + where_args = cast('DbArgDict', item[1]) set_strs = [] stmt_args = [] for column in self.columns: @@ -177,9 +198,8 @@ def add_update_item(self, set_args, where_args): stmt = self.FMT_UPDATE % { "name": self.name, "set_str": set_str, - "where_str": where_str} - if stmt not in self.update_queues: - self.update_queues[stmt] = [] + "where_str": where_str + } self.update_queues[stmt].append(stmt_args) @@ -407,15 +427,18 @@ def add_insert_item(self, table_name, args): """ self.tables[table_name].add_insert_item(args) - def add_update_item(self, table_name, set_args, where_args=None): + def add_update_item( + self, table_name: str, item: DbUpdateTuple + ) -> None: """Queue an UPDATE item for a given table. + If stmt is not a string, it should be a tuple (set_args, where_args) - set_args should be a dict, with column keys and values to be set. where_args should be a dict, update will only apply to rows matching all these items. """ - self.tables[table_name].add_update_item(set_args, where_args) + self.tables[table_name].add_update_item(item) def close(self) -> None: """Explicitly close the connection.""" @@ -585,10 +608,10 @@ def select_workflow_params(self) -> Iterable[Tuple[str, Optional[str]]]: key, value FROM {self.TABLE_WORKFLOW_PARAMS} - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) return self.connect().execute(stmt) - def select_workflow_flows(self, flow_nums): + def select_workflow_flows(self, flow_nums: Iterable[int]): """Return flow data for selected flows.""" stmt = rf''' SELECT @@ -596,8 +619,8 @@ def select_workflow_flows(self, flow_nums): FROM {self.TABLE_WORKFLOW_FLOWS} WHERE - flow_num in ({','.join(str(f) for f in flow_nums)}) - ''' # nosec (table name is code constant, flow_nums just integers) + flow_num in ({stringify_flow_nums(flow_nums)}) + ''' # nosec B608 (table name is code constant, flow_nums just ints) flows = {} for flow_num, start_time, descr in self.connect().execute(stmt): flows[flow_num] = { @@ -613,7 +636,7 @@ def select_workflow_flows_max_flow_num(self): MAX(flow_num) FROM {self.TABLE_WORKFLOW_FLOWS} - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) return self.connect().execute(stmt).fetchone()[0] def select_workflow_params_restart_count(self): @@ -625,7 +648,7 @@ def select_workflow_params_restart_count(self): {self.TABLE_WORKFLOW_PARAMS} WHERE key == 'n_restart' - """ # nosec (table name is code constant) + """ # nosec B608 (table name is code constant) result = self.connect().execute(stmt).fetchone() return int(result[0]) if result else 0 @@ -641,7 +664,7 @@ def select_workflow_template_vars(self, callback): key, value FROM {self.TABLE_WORKFLOW_TEMPLATE_VARS} - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) )): callback(row_idx, list(row)) @@ -658,7 +681,7 @@ def select_task_action_timers(self, callback): {",".join(attrs)} FROM {self.TABLE_TASK_ACTION_TIMERS} - ''' # nosec + ''' # nosec B608 # * table name is code constant # * attrs are code constants for row_idx, row in enumerate(self.connect().execute(stmt)): @@ -684,7 +707,7 @@ def select_task_job(self, cycle, name, submit_num=None): AND name==? ORDER BY submit_num DESC LIMIT 1 - ''' # nosec + ''' # nosec B608 # * table name is code constant # * keys are code constants stmt_args = [cycle, name] @@ -698,7 +721,7 @@ def select_task_job(self, cycle, name, submit_num=None): cycle==? AND name==? AND submit_num==? - ''' # nosec + ''' # nosec B608 # * table name is code constant # * keys are code constants stmt_args = [cycle, name, submit_num] @@ -781,7 +804,7 @@ def select_task_job_platforms(self): platform_name FROM {self.TABLE_TASK_JOBS} - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) return {i[0] for i in self.connect().execute(stmt)} def select_prev_instances( @@ -794,7 +817,7 @@ def select_prev_instances( # Ignore bandit false positive: B608: hardcoded_sql_expressions # Not an injection, simply putting the table name in the SQL query # expression as a string constant local to this module. - stmt = ( # nosec + stmt = ( # nosec B608 r"SELECT flow_nums,submit_num,flow_wait,status FROM %(name)s" r" WHERE name==? AND cycle==?" ) % {"name": self.TABLE_TASK_STATES} @@ -842,7 +865,7 @@ def select_task_outputs( {self.TABLE_TASK_OUTPUTS} WHERE name==? AND cycle==? - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) return { outputs: deserialise_set(flow_nums) for flow_nums, outputs in self.connect().execute( @@ -856,7 +879,7 @@ def select_xtriggers_for_restart(self, callback): signature, results FROM {self.TABLE_XTRIGGERS} - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) for row_idx, row in enumerate(self.connect().execute(stmt, [])): callback(row_idx, list(row)) @@ -866,7 +889,7 @@ def select_abs_outputs_for_restart(self, callback): cycle, name, output FROM {self.TABLE_ABS_OUTPUTS} - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) for row_idx, row in enumerate(self.connect().execute(stmt, [])): callback(row_idx, list(row)) @@ -979,7 +1002,7 @@ def select_task_prerequisites( cycle == ? AND name == ? AND flow_nums == ? - """ # nosec (table name is code constant) + """ # nosec B608 (table name is code constant) stmt_args = [cycle, name, flow_nums] return list(self.connect().execute(stmt, stmt_args)) @@ -990,7 +1013,7 @@ def select_tasks_to_hold(self) -> List[Tuple[str, str]]: name, cycle FROM {self.TABLE_TASKS_TO_HOLD} - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) return list(self.connect().execute(stmt)) def select_task_times(self): @@ -1012,7 +1035,7 @@ def select_task_times(self): {self.TABLE_TASK_JOBS} WHERE run_status = 0 - """ # nosec (table name is code constant) + """ # nosec B608 (table name is code constant) columns = ( 'name', 'cycle', 'host', 'job_runner', 'submit_time', 'start_time', 'succeed_time' diff --git a/cylc/flow/scheduler.py b/cylc/flow/scheduler.py index 040a10af8bb..0a2bb04450d 100644 --- a/cylc/flow/scheduler.py +++ b/cylc/flow/scheduler.py @@ -18,28 +18,41 @@ import asyncio from collections import deque from contextlib import suppress +import itertools import os from pathlib import Path -from queue import Empty, Queue +from queue import ( + Empty, + Queue, +) from shlex import quote import signal from socket import gaierror -from subprocess import DEVNULL, PIPE, Popen +from subprocess import ( + DEVNULL, + PIPE, + Popen, +) import sys -from threading import Barrier, Thread -from time import sleep, time +from threading import ( + Barrier, + Thread, +) +from time import ( + sleep, + time, +) import traceback from typing import ( + TYPE_CHECKING, Any, AsyncGenerator, - Callable, Dict, Iterable, List, NoReturn, Optional, Set, - TYPE_CHECKING, Tuple, Union, ) @@ -50,13 +63,14 @@ from cylc.flow import ( LOG, __version__ as CYLC_VERSION, + commands, main_loop, + workflow_files, ) -from cylc.flow import workflow_files from cylc.flow.broadcast_mgr import BroadcastMgr from cylc.flow.cfgspec.glbl_cfg import glbl_cfg from cylc.flow.config import WorkflowConfig -from cylc.flow import commands +from cylc.flow.cycling.loader import get_point from cylc.flow.data_store_mgr import DataStoreMgr from cylc.flow.exceptions import ( CommandFailedError, @@ -64,7 +78,12 @@ InputError, ) import cylc.flow.flags -from cylc.flow.flow_mgr import FLOW_NEW, FLOW_NONE, FlowMgr +from cylc.flow.flow_mgr import ( + FLOW_NEW, + FLOW_NONE, + FlowMgr, + repr_flow_nums, +) from cylc.flow.host_select import ( HostSelectException, select_workflow_host, @@ -74,8 +93,14 @@ get_user, is_remote_platform, ) -from cylc.flow.id import Tokens -from cylc.flow.log_level import verbosity_to_env, verbosity_to_opts +from cylc.flow.id import ( + Tokens, + quick_relative_id, +) +from cylc.flow.log_level import ( + verbosity_to_env, + verbosity_to_opts, +) from cylc.flow.loggingutil import ( ReferenceLogFileHandler, RotatingLogFileHandler, @@ -108,14 +133,9 @@ ) from cylc.flow.profiler import Profiler from cylc.flow.resources import get_resources +from cylc.flow.run_modes import RunMode from cylc.flow.run_modes.simulation import sim_time_check from cylc.flow.subprocpool import SubProcPool -from cylc.flow.templatevars import eval_var -from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager -from cylc.flow.workflow_events import WorkflowEventHandler -from cylc.flow.workflow_status import StopMode, AutoRestartMode -from cylc.flow.run_modes import RunMode -from cylc.flow.taskdef import TaskDef from cylc.flow.task_events_mgr import TaskEventsManager from cylc.flow.task_job_mgr import TaskJobManager from cylc.flow.task_pool import TaskPool @@ -136,7 +156,14 @@ TASK_STATUSES_ACTIVE, TASK_STATUSES_NEVER_ACTIVE, ) -from cylc.flow.templatevars import get_template_vars +from cylc.flow.taskdef import ( + TaskDef, + generate_graph_children, +) +from cylc.flow.templatevars import ( + eval_var, + get_template_vars, +) from cylc.flow.timer import Timer from cylc.flow.util import cli_format from cylc.flow.wallclock import ( @@ -144,8 +171,15 @@ get_time_string_from_unix_time as time2str, get_utc_mode, ) +from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager +from cylc.flow.workflow_events import WorkflowEventHandler +from cylc.flow.workflow_status import ( + AutoRestartMode, + StopMode, +) from cylc.flow.xtrigger_mgr import XtriggerManager + if TYPE_CHECKING: from optparse import Values @@ -154,6 +188,7 @@ # TO: Python 3.8 from typing_extensions import Literal + from cylc.flow.flow_mgr import FlowNums from cylc.flow.network.resolvers import TaskMsg from cylc.flow.task_proxy import TaskProxy @@ -551,7 +586,7 @@ async def configure(self, params): elif self.config.cfg['scheduling']['hold after cycle point']: holdcp = self.config.cfg['scheduling']['hold after cycle point'] if holdcp is not None: - await commands.run_cmd(commands.set_hold_point, self, holdcp) + await commands.run_cmd(commands.set_hold_point(self, holdcp)) if self.options.paused_start: self.pause_workflow('Paused on start up') @@ -641,7 +676,7 @@ async def run_scheduler(self) -> None: if self.pool.get_tasks(): # (If we're not restarting a finished workflow) self.restart_remote_init() - await commands.run_cmd(commands.poll_tasks, self, ['*/*']) + await commands.run_cmd(commands.poll_tasks(self, ['*/*'])) self.run_event_handlers(self.EVENT_STARTUP, 'workflow starting') await asyncio.gather( @@ -956,10 +991,6 @@ def process_queued_task_messages(self) -> None: warn += f'\n {msg.job_id}: {msg.severity} - "{msg.message}"' LOG.warning(warn) - def get_command_method(self, command_name: str) -> Callable: - """Return a command processing method or raise AttributeError.""" - return getattr(self, f'command_{command_name}') - async def process_command_queue(self) -> None: """Process queued commands.""" qsize = self.command_queue.qsize() @@ -1033,14 +1064,15 @@ def kill_tasks( unkillable: List[TaskProxy] = [] for itask in itasks: if itask.state(*TASK_STATUSES_ACTIVE): - itask.state_reset( - # directly reset to failed in sim mode, else let - # task_job_mgr handle it - status=(TASK_STATUS_FAILED if jobless else None), - is_held=True, - ) - self.data_store_mgr.delta_task_state(itask) + if itask.state_reset(is_held=True): + self.data_store_mgr.delta_task_state(itask) to_kill.append(itask) + if jobless: + # Directly set failed in sim mode: + self.task_events_mgr.process_message( + itask, 'CRITICAL', TASK_STATUS_FAILED, + flag=self.task_events_mgr.FLAG_RECEIVED + ) else: unkillable.append(itask) if warn and unkillable: @@ -1053,6 +1085,136 @@ def kill_tasks( return len(unkillable) + def remove_tasks( + self, items: Iterable[str], flow_nums: Optional['FlowNums'] = None + ) -> None: + """Remove tasks (`cylc remove` command). + + Args: + items: Relative IDs or globs. + flow_nums: Flows to remove the tasks from. If empty or None, it + means 'all'. + """ + active, inactive, _unmatched = self.pool.filter_task_proxies( + items, warn_no_active=False, inactive=True + ) + if not (active or inactive): + return + + if flow_nums is None: + flow_nums = set() + # Mapping of task IDs to removed flow numbers: + removed: Dict[str, FlowNums] = {} + not_removed: Set[str] = set() + to_kill: List[TaskProxy] = [] + + for itask in active: + fnums_to_remove = itask.match_flows(flow_nums) + if not fnums_to_remove: + not_removed.add(itask.identity) + continue + removed[itask.identity] = fnums_to_remove + if fnums_to_remove == itask.flow_nums: + # Need to remove the task from the pool. + # Spawn next occurrence of xtrigger sequential task (otherwise + # this would not happen after removing this occurrence): + self.pool.check_spawn_psx_task(itask) + self.pool.remove(itask, 'request') + to_kill.append(itask) + itask.removed = True + itask.flow_nums.difference_update(fnums_to_remove) + + # All the matched tasks (including inactive & applicable active tasks): + matched_task_ids = { + *removed.keys(), + *(quick_relative_id(cycle, task) for task, cycle in inactive), + } + + for id_ in matched_task_ids: + point_str, name = id_.split('/', 1) + tdef = self.config.taskdefs[name] + + # Go through any tasks downstream of this matched task to see if + # any need to stand down as a result of this task being removed: + for child in set(itertools.chain.from_iterable( + generate_graph_children(tdef, get_point(point_str)).values() + )): + child_itask = self.pool.get_task(child.point, child.name) + if not child_itask: + continue + fnums_to_remove = child_itask.match_flows(flow_nums) + if not fnums_to_remove: + continue + prereqs_changed = False + for prereq in ( + *child_itask.state.prerequisites, + *child_itask.state.suicide_prerequisites, + ): + # Unset any prereqs naturally satisfied by these tasks + # (do not unset those satisfied by `cylc set --pre`): + if prereq.unset_naturally_satisfied(id_): + prereqs_changed = True + removed.setdefault(id_, set()).update(fnums_to_remove) + if not prereqs_changed: + continue + self.data_store_mgr.delta_task_prerequisite(child_itask) + # Check if downstream task is still ready to run: + if ( + child_itask.state.is_gte(TASK_STATUS_PREPARING) + # Still ready if the task exists in other flows: + or child_itask.flow_nums != fnums_to_remove + or child_itask.state.prerequisites_all_satisfied() + ): + continue + # No longer ready to run + self.pool.unqueue_task(child_itask) + # Check if downstream task should remain spawned: + if ( + # Ignoring tasks we are already dealing with: + child_itask.identity in matched_task_ids + or child_itask.state.any_satisfied_prerequisite_outputs() + ): + continue + # No longer has reason to be in pool: + self.pool.remove(child_itask, self.pool.REMOVED_BY_PREREQ) + # Remove this downstream task from flows in DB tables to ensure + # it is not skipped if it respawns in future: + self.workflow_db_mgr.remove_task_from_flows( + str(child.point), child.name, fnums_to_remove + ) + + # Remove the matched tasks from the flows in the DB tables: + db_removed_fnums = self.workflow_db_mgr.remove_task_from_flows( + point_str, name, flow_nums + ) + if db_removed_fnums: + removed.setdefault(id_, set()).update(db_removed_fnums) + + if to_kill: + self.kill_tasks(to_kill, warn=False) + + if removed: + tasks_str_list = [] + for task, fnums in removed.items(): + self.data_store_mgr.delta_remove_task_flow_nums(task, fnums) + tasks_str_list.append( + f"{task} {repr_flow_nums(fnums, full=True)}" + ) + LOG.info(f"Removed task(s): {', '.join(sorted(tasks_str_list))}") + + not_removed.update(matched_task_ids.difference(removed)) + if not_removed: + fnums_str = ( + repr_flow_nums(flow_nums, full=True) if flow_nums else '' + ) + LOG.warning( + "Task(s) not removable: " + f"{', '.join(sorted(not_removed))} {fnums_str}" + ) + + if removed and self.pool.compute_runahead(): + self.pool.release_runahead_tasks() + def get_restart_num(self) -> int: """Return the number of the restart, else 0 if not a restart. @@ -1434,8 +1596,8 @@ async def workflow_shutdown(self): self.time_next_kill is not None and time() > self.time_next_kill ): - await commands.run_cmd(commands.poll_tasks, self, ['*/*']) - await commands.run_cmd(commands.kill_tasks, self, ['*/*']) + await commands.run_cmd(commands.poll_tasks(self, ['*/*'])) + await commands.run_cmd(commands.kill_tasks(self, ['*/*'])) self.time_next_kill = time() + self.INTERVAL_STOP_KILL # Is the workflow set to auto stop [+restart] now ... @@ -1577,7 +1739,7 @@ async def _main_loop(self) -> None: self.broadcast_mgr.check_ext_triggers( itask, self.ext_trigger_queue) - if all(itask.is_ready_to_run()): + if itask.is_ready_to_run(): self.pool.queue_task(itask) if self.xtrigger_mgr.sequential_spawn_next: diff --git a/cylc/flow/scripts/clean.py b/cylc/flow/scripts/clean.py index fffd27752d7..b3afe679429 100644 --- a/cylc/flow/scripts/clean.py +++ b/cylc/flow/scripts/clean.py @@ -227,17 +227,17 @@ async def run(*ids: str, opts: 'Values') -> None: if multi_mode and not opts.skip_interactive: prompt(workflows) # prompt for approval or exit - failed = {} + failed = False for workflow in workflows: try: init_clean(workflow, opts) except Exception as exc: - failed[workflow] = exc + failed = True + LOG.error(f"Failed to clean {workflow}\nError: {exc}") + if cylc.flow.flags.verbosity > 0: + LOG.exception(exc) if failed: - msg = "Clean failed:" - for workflow, exc_message in failed.items(): - msg += f"\nWorkflow: {workflow}\nError: {exc_message}" - raise CylcError(msg) + raise CylcError("Clean failed") @cli_function(get_option_parser) diff --git a/cylc/flow/scripts/remove.py b/cylc/flow/scripts/remove.py index ef4c74d02c8..ce7a9c0113c 100755 --- a/cylc/flow/scripts/remove.py +++ b/cylc/flow/scripts/remove.py @@ -18,7 +18,36 @@ """cylc remove [OPTIONS] ARGS -Remove one or more task instances from a running workflow. +Remove tasks in the active window, or erase the run-history of past tasks. + +Final-status incomplete tasks can be removed from the n=0 active window to +prevent/recover from a stall, if they don't need to be completed to continue +the flow. + +Erasing the run-history of past tasks allows them to be run again in the +same flow (this is an alternative to starting a new flow). + +By default, the specified task(s) will be removed from all flows. + +Tasks removed from all flows, and any waiting downstream tasks spawned by +their outputs, will be recorded in the `None` flow and will not affect +the evolution of the workflow. + +If you remove a task from some but not all of its flows, it will still exist +in the remaining flows, but it will not affect the evolution of the removed +flows. + +Removing a submitted or running task will also kill it (see "cylc kill"). + +Examples: + # Remove a task that already ran. + # (Any downstream tasks that are already running or finished will be + # left alone. The task and its outputs will be left in the None flow.) + $ cylc remove + + # Remove a task from a specified flow. + # (The task may remain in other flows) + $ cylc remove --flow=1 """ from functools import partial @@ -33,6 +62,7 @@ ) from cylc.flow.terminal import cli_function + if TYPE_CHECKING: from optparse import Values @@ -41,10 +71,12 @@ mutation ( $wFlows: [WorkflowID]!, $tasks: [NamespaceIDGlob]!, + $flow: [Flow!], ) { remove ( workflows: $wFlows, tasks: $tasks, + flow: $flow ) { result } @@ -61,6 +93,18 @@ def get_option_parser() -> COP: argdoc=[FULL_ID_MULTI_ARG_DOC], ) + parser.add_option( + '--flow', + action='append', + dest='flow', + metavar='FLOW', + help=( + "Remove the task(s) from the specified flow. " + "Reuse the option to remove the task(s) from multiple flows. " + "(By default, the task(s) will be removed from all flows.)" + ), + ) + return parser @@ -75,6 +119,7 @@ async def run(options: 'Values', workflow_id: str, *tokens_list): tokens.relative_id_with_selectors for tokens in tokens_list ], + 'flow': options.flow, } } diff --git a/cylc/flow/scripts/show.py b/cylc/flow/scripts/show.py index 268a723d198..d72a3e45d32 100755 --- a/cylc/flow/scripts/show.py +++ b/cylc/flow/scripts/show.py @@ -60,6 +60,7 @@ from cylc.flow.option_parsers import ( CylcOptionParser as COP, ID_MULTI_ARG_DOC, + Options, ) from cylc.flow.terminal import cli_function from cylc.flow.util import BOOL_SYMBOLS @@ -250,6 +251,9 @@ def get_option_parser(): return parser +ShowOptions = Options(get_option_parser()) + + async def workflow_meta_query(workflow_id, pclient, options, json_filter): query = WORKFLOW_META_QUERY query_kwargs = { diff --git a/cylc/flow/task_events_mgr.py b/cylc/flow/task_events_mgr.py index f1e841ece02..8e7a8b8636a 100644 --- a/cylc/flow/task_events_mgr.py +++ b/cylc/flow/task_events_mgr.py @@ -104,9 +104,13 @@ if TYPE_CHECKING: + from cylc.flow.broadcast_mgr import BroadcastMgr + from cylc.flow.data_store_mgr import DataStoreMgr from cylc.flow.id import Tokens - from cylc.flow.task_proxy import TaskProxy from cylc.flow.scheduler import Scheduler + from cylc.flow.task_proxy import TaskProxy + from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager + from cylc.flow.xtrigger_mgr import XtriggerManager class CustomTaskEventHandlerContext(NamedTuple): @@ -453,10 +457,10 @@ def __init__( ): self.workflow = workflow self.proc_pool = proc_pool - self.workflow_db_mgr = workflow_db_mgr - self.broadcast_mgr = broadcast_mgr - self.xtrigger_mgr = xtrigger_mgr - self.data_store_mgr = data_store_mgr + self.workflow_db_mgr: WorkflowDatabaseManager = workflow_db_mgr + self.broadcast_mgr: BroadcastMgr = broadcast_mgr + self.xtrigger_mgr: XtriggerManager = xtrigger_mgr + self.data_store_mgr: DataStoreMgr = data_store_mgr self.next_mail_time = None self.reset_inactivity_timer_func = reset_inactivity_timer_func # NOTE: do not mutate directly @@ -756,9 +760,7 @@ def process_message( ): # Already submit-failed return True - if self._process_message_submit_failed( - itask, event_time, submit_num, forced - ): + if self._process_message_submit_failed(itask, event_time, forced): self.spawn_children(itask, TASK_OUTPUT_SUBMIT_FAILED) elif message == self.EVENT_SUBMITTED: @@ -1090,7 +1092,7 @@ def _get_events_conf( glbl_cfg().get()["task events"], ): try: - value = getter.get(key) + value = getter.get(key) # type: ignore[union-attr] except (AttributeError, ItemNotFoundError, KeyError): pass else: @@ -1297,7 +1299,13 @@ def _retry_task(self, itask, wallclock_time, submit_retry=False): if itask.state_reset(TASK_STATUS_WAITING): self.data_store_mgr.delta_task_state(itask) - def _process_message_failed(self, itask, event_time, message, forced): + def _process_message_failed( + self, + itask: 'TaskProxy', + event_time: Optional[str], + message: str, + forced: bool, + ) -> bool: """Helper for process_message, handle a failed message. Return True if no retries (hence go to the failed state). @@ -1314,14 +1322,20 @@ def _process_message_failed(self, itask, event_time, message, forced): "time_run_exit": event_time, }) if ( - forced - or TimerFlags.EXECUTION_RETRY not in itask.try_timers - or itask.try_timers[TimerFlags.EXECUTION_RETRY].next() is None + forced + or TimerFlags.EXECUTION_RETRY not in itask.try_timers + or itask.try_timers[TimerFlags.EXECUTION_RETRY].next() is None ): # No retry lined up: definitive failure. no_retries = True if itask.state_reset(TASK_STATUS_FAILED, forced=forced): - self.setup_event_handlers(itask, self.EVENT_FAILED, message) + if itask.removed: + # Need to update DB as task not include in pool update + self.workflow_db_mgr.put_update_task_state(itask) + else: + self.setup_event_handlers( + itask, self.EVENT_FAILED, message + ) itask.state.outputs.set_message_complete(TASK_OUTPUT_FAILED) self.data_store_mgr.delta_task_output( itask, TASK_OUTPUT_FAILED) @@ -1399,7 +1413,10 @@ def _process_message_succeeded(self, itask, event_time, forced): self._reset_job_timers(itask) def _process_message_submit_failed( - self, itask, event_time, submit_num, forced + self, + itask: 'TaskProxy', + event_time: Optional[str], + forced: bool, ): """Helper for process_message, handle a submit-failed message. @@ -1415,17 +1432,22 @@ def _process_message_submit_failed( }) itask.summary['submit_method_id'] = None if ( - forced - or TimerFlags.SUBMISSION_RETRY not in itask.try_timers - or itask.try_timers[TimerFlags.SUBMISSION_RETRY].next() is None + forced + or TimerFlags.SUBMISSION_RETRY not in itask.try_timers + or itask.try_timers[TimerFlags.SUBMISSION_RETRY].next() is None ): # No submission retry lined up: definitive failure. # See github #476. no_retries = True if itask.state_reset(TASK_STATUS_SUBMIT_FAILED, forced=forced): - self.setup_event_handlers( - itask, self.EVENT_SUBMIT_FAILED, - f'job {self.EVENT_SUBMIT_FAILED}') + if itask.removed: + # Need to update DB as task not include in pool update + self.workflow_db_mgr.put_update_task_state(itask) + else: + self.setup_event_handlers( + itask, self.EVENT_SUBMIT_FAILED, + f'job {self.EVENT_SUBMIT_FAILED}' + ) itask.state.outputs.set_message_complete( TASK_OUTPUT_SUBMIT_FAILED ) diff --git a/cylc/flow/task_job_mgr.py b/cylc/flow/task_job_mgr.py index b52be9843c1..7e1cfa02cc4 100644 --- a/cylc/flow/task_job_mgr.py +++ b/cylc/flow/task_job_mgr.py @@ -26,13 +26,13 @@ from contextlib import suppress import json -import os from logging import ( CRITICAL, DEBUG, INFO, - WARNING + WARNING, ) +import os from shutil import rmtree from time import time from typing import ( @@ -47,7 +47,7 @@ ) from cylc.flow import LOG -from cylc.flow.job_runner_mgr import JobPollContext +from cylc.flow.cfgspec.globalcfg import SYSPATH from cylc.flow.exceptions import ( NoHostsError, NoPlatformsError, @@ -57,9 +57,10 @@ ) from cylc.flow.hostuserutil import ( get_host, - is_remote_platform + is_remote_platform, ) from cylc.flow.job_file import JobFileWriter +from cylc.flow.job_runner_mgr import JobPollContext from cylc.flow.pathutil import get_remote_workflow_run_job_dir from cylc.flow.platforms import ( get_host_from_platform, @@ -68,57 +69,62 @@ get_platform, ) from cylc.flow.remote import construct_ssh_cmd +from cylc.flow.run_modes import ( + WORKFLOW_ONLY_MODES, + RunMode, +) from cylc.flow.subprocctx import SubProcContext from cylc.flow.subprocpool import SubProcPool -from cylc.flow.run_modes import RunMode, WORKFLOW_ONLY_MODES from cylc.flow.task_action_timer import ( TaskActionTimer, - TimerFlags + TimerFlags, ) from cylc.flow.task_events_mgr import ( TaskEventsManager, - log_task_job_activity + log_task_job_activity, ) from cylc.flow.task_job_logs import ( JOB_LOG_JOB, NN, get_task_job_activity_log, get_task_job_job_log, - get_task_job_log + get_task_job_log, ) from cylc.flow.task_message import FAIL_MESSAGE_PREFIX from cylc.flow.task_outputs import ( TASK_OUTPUT_FAILED, TASK_OUTPUT_STARTED, TASK_OUTPUT_SUBMITTED, - TASK_OUTPUT_SUCCEEDED + TASK_OUTPUT_SUCCEEDED, ) from cylc.flow.task_remote_mgr import ( + REMOTE_FILE_INSTALL_255, REMOTE_FILE_INSTALL_DONE, REMOTE_FILE_INSTALL_FAILED, REMOTE_FILE_INSTALL_IN_PROGRESS, - REMOTE_INIT_IN_PROGRESS, REMOTE_INIT_255, - REMOTE_FILE_INSTALL_255, - REMOTE_INIT_DONE, REMOTE_INIT_FAILED, - TaskRemoteMgr + REMOTE_INIT_DONE, + REMOTE_INIT_FAILED, + REMOTE_INIT_IN_PROGRESS, + TaskRemoteMgr, ) from cylc.flow.task_state import ( TASK_STATUS_PREPARING, - TASK_STATUS_SUBMITTED, TASK_STATUS_RUNNING, + TASK_STATUS_SUBMITTED, TASK_STATUS_WAITING, ) +from cylc.flow.util import serialise_set from cylc.flow.wallclock import ( get_current_time_string, get_time_string_from_unix_time, - get_utc_mode + get_utc_mode, ) -from cylc.flow.cfgspec.globalcfg import SYSPATH -from cylc.flow.util import serialise_set + if TYPE_CHECKING: from cylc.flow.task_proxy import TaskProxy + from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager class TaskJobManager: @@ -152,7 +158,7 @@ def __init__(self, workflow, proc_pool, workflow_db_mgr, task_events_mgr, data_store_mgr, bad_hosts): self.workflow = workflow self.proc_pool = proc_pool - self.workflow_db_mgr = workflow_db_mgr + self.workflow_db_mgr: WorkflowDatabaseManager = workflow_db_mgr self.task_events_mgr = task_events_mgr self.data_store_mgr = data_store_mgr self.job_file_writer = JobFileWriter() @@ -278,7 +284,7 @@ def submit_livelike_task_jobs( """Submission for live tasks and dummy tasks. """ done_tasks: 'List[TaskProxy]' = [] - # {platform: [itask, ...], ...} + # Mapping of platforms to task proxies: auth_itasks: 'Dict[str, List[TaskProxy]]' = {} prepared_tasks, bad_tasks = self.prep_submit_task_jobs( @@ -291,9 +297,7 @@ def submit_livelike_task_jobs( return bad_tasks for itask in prepared_tasks: - platform_name = itask.platform['name'] - auth_itasks.setdefault(platform_name, []) - auth_itasks[platform_name].append(itask) + auth_itasks.setdefault(itask.platform['name'], []).append(itask) # Submit task jobs for each platform # Non-prepared tasks can be considered done for now: @@ -1074,14 +1078,7 @@ def submit_nonlive_task_jobs( self, itask, rtconfig, workflow, now ): # A submit function returns true if this is a nonlive task: - self.workflow_db_mgr.put_insert_task_states( - itask, - { - 'submit_num': itask.submit_num, - 'flow_nums': serialise_set(itask.flow_nums), - 'time_created': itask.summary['submitted_time_string'] - } - ) + self.workflow_db_mgr.put_insert_task_states(itask) nonlive_tasks.append(itask) else: lively_tasks.append(itask) diff --git a/cylc/flow/task_pool.py b/cylc/flow/task_pool.py index 740f2988902..c771af215cd 100644 --- a/cylc/flow/task_pool.py +++ b/cylc/flow/task_pool.py @@ -16,35 +16,55 @@ """Wrangle task proxies to manage the workflow.""" -from contextlib import suppress from collections import Counter +from contextlib import suppress import json +import logging from textwrap import indent from typing import ( + TYPE_CHECKING, Dict, Iterable, List, NamedTuple, Optional, Set, - TYPE_CHECKING, Tuple, Type, Union, ) -import logging -import cylc.flow.flags from cylc.flow import LOG -from cylc.flow.cycling.loader import get_point, standardise_point_string +from cylc.flow.cycling.loader import ( + get_point, + standardise_point_string, +) from cylc.flow.exceptions import ( - WorkflowConfigError, PointParsingError, PlatformLookupError) -from cylc.flow.id import Tokens, detokenise + PlatformLookupError, + PointParsingError, + WorkflowConfigError, +) +import cylc.flow.flags +from cylc.flow.flow_mgr import ( + FLOW_ALL, + FLOW_NEW, + FLOW_NONE, + repr_flow_nums, +) +from cylc.flow.id import ( + Tokens, + detokenise, + quick_relative_id, +) from cylc.flow.id_cli import contains_fnmatch from cylc.flow.id_match import filter_ids +from cylc.flow.platforms import get_platform from cylc.flow.run_modes import RunMode -from cylc.flow.workflow_status import StopMode -from cylc.flow.task_action_timer import TaskActionTimer, TimerFlags +from cylc.flow.run_modes.skip import process_outputs as get_skip_mode_outputs +from cylc.flow.task_action_timer import ( + TaskActionTimer, + TimerFlags, +) from cylc.flow.task_events_mgr import ( CustomTaskEventHandlerContext, EventKey, @@ -52,47 +72,41 @@ TaskJobLogsRetrieveContext, ) from cylc.flow.task_id import TaskID +from cylc.flow.task_outputs import ( + TASK_OUTPUT_EXPIRED, + TASK_OUTPUT_FAILED, + TASK_OUTPUT_SUBMIT_FAILED, + TASK_OUTPUT_SUCCEEDED, +) from cylc.flow.task_proxy import TaskProxy +from cylc.flow.task_queues.independent import IndepQueueManager from cylc.flow.task_state import ( - TASK_STATUSES_ACTIVE, - TASK_STATUSES_FINAL, - TASK_STATUS_WAITING, TASK_STATUS_EXPIRED, + TASK_STATUS_FAILED, TASK_STATUS_PREPARING, - TASK_STATUS_SUBMITTED, TASK_STATUS_RUNNING, + TASK_STATUS_SUBMITTED, TASK_STATUS_SUCCEEDED, - TASK_STATUS_FAILED, + TASK_STATUS_WAITING, + TASK_STATUSES_ACTIVE, + TASK_STATUSES_FINAL, ) from cylc.flow.task_trigger import TaskTrigger -from cylc.flow.util import ( - serialise_set, - deserialise_set -) -from cylc.flow.wallclock import get_current_time_string -from cylc.flow.platforms import get_platform -from cylc.flow.run_modes.skip import ( - process_outputs as get_skip_mode_outputs) -from cylc.flow.task_outputs import ( - TASK_OUTPUT_SUCCEEDED, - TASK_OUTPUT_EXPIRED, - TASK_OUTPUT_FAILED, - TASK_OUTPUT_SUBMIT_FAILED, -) -from cylc.flow.task_queues.independent import IndepQueueManager +from cylc.flow.util import deserialise_set +from cylc.flow.workflow_status import StopMode -from cylc.flow.flow_mgr import ( - stringify_flow_nums, - FLOW_ALL, - FLOW_NONE, - FLOW_NEW -) if TYPE_CHECKING: from cylc.flow.config import WorkflowConfig - from cylc.flow.cycling import IntervalBase, PointBase + from cylc.flow.cycling import ( + IntervalBase, + PointBase, + ) from cylc.flow.data_store_mgr import DataStoreMgr - from cylc.flow.flow_mgr import FlowMgr, FlowNums + from cylc.flow.flow_mgr import ( + FlowMgr, + FlowNums, + ) from cylc.flow.prerequisite import SatisfiedState from cylc.flow.task_events_mgr import TaskEventsManager from cylc.flow.taskdef import TaskDef @@ -109,6 +123,7 @@ class TaskPool: ERR_TMPL_NO_TASKID_MATCH = "No matching tasks found: {0}" ERR_PREFIX_TASK_NOT_ON_SEQUENCE = "Invalid cycle point for task: {0}, {1}" SUICIDE_MSG = "suicide trigger" + REMOVED_BY_PREREQ = "prerequisite task(s) removed" def __init__( self, @@ -206,18 +221,7 @@ def db_add_new_flow_rows(self, itask: TaskProxy) -> None: Call when a new task is spawned or a flow merge occurs. """ # Add row to task_states table. - now = get_current_time_string() - self.workflow_db_mgr.put_insert_task_states( - itask, - { - "time_created": now, - "time_updated": now, - "status": itask.state.status, - "flow_nums": serialise_set(itask.flow_nums), - "flow_wait": itask.flow_wait, - "is_manual_submit": itask.is_manual_submit - } - ) + self.workflow_db_mgr.put_insert_task_states(itask) # Add row to task_outputs table: self.workflow_db_mgr.put_insert_task_outputs(itask) @@ -441,13 +445,24 @@ def check_task_output( output_msg: str, flow_nums: 'FlowNums', ) -> 'SatisfiedState': - """Returns truthy if the specified output is satisfied in the DB.""" + """Returns truthy if the specified output is satisfied in the DB. + + Args: + cycle: Cycle point of the task whose output is being checked. + task: Name of the task whose output is being checked. + output_msg: The output message to check for. + flow_nums: Flow numbers of the task whose output is being + checked. If this is empty it means 'none'; will return False. + """ + if not flow_nums: + return False + for task_outputs, task_flow_nums in ( self.workflow_db_mgr.pri_dao.select_task_outputs(task, cycle) ).items(): # loop through matching tasks + # (if task_flow_nums is empty, it means the 'none' flow) if flow_nums.intersection(task_flow_nums): - # this task is in the right flow # BACK COMPAT: In Cylc >8.0.0,<8.3.0, only the task # messages were stored in the DB as a list. # from: 8.0.0 @@ -709,7 +724,7 @@ def rh_release_and_queue(self, itask) -> None: """ if itask.state_reset(is_runahead=False): self.data_store_mgr.delta_task_state(itask) - if all(itask.is_ready_to_run()): + if itask.is_ready_to_run(): # (otherwise waiting on xtriggers etc.) self.queue_task(itask) @@ -736,9 +751,7 @@ def get_or_spawn_task( It does not add a spawned task proxy to the pool. """ - ntask = self._get_task_by_id( - Tokens(cycle=str(point), task=tdef.name).relative_id - ) + ntask = self.get_task(point, tdef.name) is_in_pool = False is_xtrig_sequential = False if ntask is None: @@ -817,7 +830,7 @@ def spawn_if_parentless(self, tdef, point, flow_nums): if ntask is not None and not is_in_pool: self.add_to_pool(ntask) - def remove(self, itask, reason=None): + def remove(self, itask: 'TaskProxy', reason: Optional[str] = None) -> None: """Remove a task from the pool.""" if itask.state.is_runahead and itask.flow_nums: @@ -829,11 +842,7 @@ def remove(self, itask, reason=None): itask.flow_nums ) - msg = "removed from active task pool" - if reason is None: - msg += ": completed" - else: - msg += f": {reason}" + msg = f"removed from active task pool: {reason or 'completed'}" if itask.is_xtrigger_sequential: self.xtrigger_mgr.sequential_spawn_next.discard(itask.identity) @@ -869,6 +878,8 @@ def remove(self, itask, reason=None): ): level = logging.WARNING msg += " - active job orphaned" + elif reason == self.REMOVED_BY_PREREQ: + level = logging.INFO LOG.log(level, f"[{itask}] {msg}") @@ -887,40 +898,57 @@ def get_tasks(self) -> List[TaskProxy]: # Cached list only for use internally in this method. if self.active_tasks_changed: self.active_tasks_changed = False - self._active_tasks_list = [] - for itask_id_map in self.active_tasks.values(): - for itask in itask_id_map.values(): - self._active_tasks_list.append(itask) + self._active_tasks_list = [ + itask + for itask_id_map in self.active_tasks.values() + for itask in itask_id_map.values() + ] return self._active_tasks_list + def get_task_ids(self) -> Set[str]: + """Return a list of task IDs in the task pool.""" + return {itask.identity for itask in self.get_tasks()} + def get_tasks_by_point(self) -> 'Dict[PointBase, List[TaskProxy]]': """Return a map of task proxies by cycle point.""" - point_itasks = {} - for point, itask_id_map in self.active_tasks.items(): - point_itasks[point] = list(itask_id_map.values()) - return point_itasks + return { + point: list(itask_id_map.values()) + for point, itask_id_map in self.active_tasks.items() + } def get_task(self, point: 'PointBase', name: str) -> Optional[TaskProxy]: """Retrieve a task from the pool.""" rel_id = f'{point}/{name}' tasks = self.active_tasks.get(point) - if tasks and rel_id in tasks: - return tasks[rel_id] + if tasks: + return tasks.get(rel_id) return None def _get_task_by_id(self, id_: str) -> Optional[TaskProxy]: """Return pool task by ID if it exists, or None.""" for itask_ids in self.active_tasks.values(): - with suppress(KeyError): + if id_ in itask_ids: return itask_ids[id_] return None def queue_task(self, itask: TaskProxy) -> None: - """Queue a task that is ready to run.""" + """Queue a task that is ready to run. + + If it is already queued, do nothing. + """ if itask.state_reset(is_queued=True): self.data_store_mgr.delta_task_state(itask) self.task_queue_mgr.push_task(itask) + def unqueue_task(self, itask: TaskProxy) -> None: + """Un-queue a task that is no longer ready to run. + + If it is not queued, do nothing. + """ + if itask.state_reset(is_queued=False): + self.data_store_mgr.delta_task_state(itask) + self.task_queue_mgr.remove_task(itask) + def release_queued_tasks(self): """Return list of queue-released tasks awaiting job prep. @@ -1100,8 +1128,7 @@ def _reload_taskdefs(self) -> None: if itask.state.is_queued: # Already queued continue - ready_check_items = itask.is_ready_to_run() - if all(ready_check_items) and not itask.state.is_runahead: + if itask.is_ready_to_run() and not itask.state.is_runahead: self.queue_task(itask) def set_stop_point(self, stop_point: 'PointBase') -> bool: @@ -1243,7 +1270,7 @@ def log_unsatisfied_prereqs(self) -> bool: LOG.warning( "Partially satisfied prerequisites:\n" + "\n".join( - f" * {id_} is waiting on {others}" + f" * {id_} is waiting on {sorted(others)}" for id_, others in unsat.items() ) ) @@ -1266,7 +1293,7 @@ def is_stalled(self) -> bool: itask.state(TASK_STATUS_WAITING) and not itask.state.is_runahead # (avoid waiting pre-spawned absolute-triggered tasks:) - and not itask.is_task_prereqs_not_done() + and itask.prereqs_are_satisfied() ) for itask in self.get_tasks() ): return False @@ -1284,7 +1311,7 @@ def hold_active_task(self, itask: TaskProxy) -> None: def release_held_active_task(self, itask: TaskProxy) -> None: if itask.state_reset(is_held=False): self.data_store_mgr.delta_task_state(itask) - if (not itask.state.is_runahead) and all(itask.is_ready_to_run()): + if (not itask.state.is_runahead) and itask.is_ready_to_run(): self.queue_task(itask) self.tasks_to_hold.discard((itask.tdef.name, itask.point)) self.workflow_db_mgr.put_tasks_to_hold(self.tasks_to_hold) @@ -1302,8 +1329,8 @@ def hold_tasks(self, items: Iterable[str]) -> int: # Hold active tasks: itasks, inactive_tasks, unmatched = self.filter_task_proxies( items, - warn=False, - future=True, + warn_no_active=False, + inactive=True, ) for itask in itasks: self.hold_active_task(itask) @@ -1320,8 +1347,8 @@ def release_held_tasks(self, items: Iterable[str]) -> int: # Release active tasks: itasks, inactive_tasks, unmatched = self.filter_task_proxies( items, - warn=False, - future=True, + warn_no_active=False, + inactive=True, ) for itask in itasks: self.release_held_active_task(itask) @@ -1400,12 +1427,7 @@ def spawn_on_output(self, itask: TaskProxy, output: str) -> None: str(itask.point), itask.tdef.name, output) self.workflow_db_mgr.process_queued_ops() - c_taskid = Tokens( - cycle=str(c_point), - task=c_name, - ).relative_id - - c_task = self._get_task_by_id(c_taskid) + c_task = self._get_task_by_id(quick_relative_id(c_point, c_name)) in_pool = c_task is not None if c_task is not None and c_task != itask: @@ -1422,7 +1444,7 @@ def spawn_on_output(self, itask: TaskProxy, output: str) -> None: if is_abs: tasks, *_ = self.filter_task_proxies( [f'*/{c_name}'], - warn=False, + warn_no_active=False, ) if c_task not in tasks: tasks.append(c_task) @@ -1654,6 +1676,7 @@ def _load_historical_outputs(self, itask: 'TaskProxy') -> None: else: flow_seen = False for outputs_str, fnums in info.items(): + # (if fnums is empty, it means the 'none' flow) if itask.flow_nums.intersection(fnums): # DB row has overlap with itask's flows flow_seen = True @@ -1721,12 +1744,10 @@ def spawn_task( and not itask.state.outputs.get_completed_outputs() ): # If itask has any history in this flow but no completed outputs - # we can infer it was deliberately removed, so don't respawn it. + # we can infer it has just been deliberately removed (N.B. not + # by `cylc remove`), so don't immediately respawn it. # TODO (follow-up work): # - this logic fails if task removed after some outputs completed - # - this is does not conform to future "cylc remove" flow-erasure - # behaviour which would result in respawning of the removed task - # See github.com/cylc/cylc-flow/pull/6186/#discussion_r1669727292 LOG.debug(f"Not respawning {point}/{name} - task was removed") return None @@ -1741,7 +1762,7 @@ def spawn_task( msg += " incomplete" LOG.info( - f"{msg} {stringify_flow_nums(flow_nums, full=True)})" + f"{msg} {repr_flow_nums(flow_nums, full=True)})" ) if prev_flow_wait: self._spawn_after_flow_wait(itask) @@ -1826,9 +1847,6 @@ def _get_task_proxy_db_outputs( self.xtrigger_mgr.xtriggers.sequential_xtrigger_labels ), ) - if itask is None: - return None - # Update it with outputs that were already completed. self._load_historical_outputs(itask) return itask @@ -1934,8 +1952,8 @@ def set_prereqs_and_outputs( # Get matching pool tasks and inactive task definitions. itasks, inactive_tasks, unmatched = self.filter_task_proxies( items, - future=True, - warn=False, + inactive=True, + warn_no_active=False, ) flow_nums = self._get_flow_nums(flow, flow_descr) @@ -1945,7 +1963,7 @@ def set_prereqs_and_outputs( if flow == ['none'] and itask.flow_nums != set(): LOG.error( f"[{itask}] ignoring 'flow=none' set: task already has" - f" {stringify_flow_nums(itask.flow_nums, full=True)}" + f" {repr_flow_nums(itask.flow_nums, full=True)}" ) continue self.merge_flows(itask, flow_nums) @@ -2038,7 +2056,7 @@ def _set_prereqs_itask( # Attempt to set the given presrequisites. # Log any that aren't valid for the task. presus = self._standardise_prereqs(prereqs) - unmatched = itask.satisfy_me(presus.keys()) + unmatched = itask.satisfy_me(presus.keys(), forced=True) for task_msg in unmatched: LOG.warning( f"{itask.identity} does not depend on" @@ -2083,17 +2101,6 @@ def _get_active_flow_nums(self) -> 'FlowNums': or {1} ) - def remove_tasks(self, items): - """Remove tasks from the pool (forced by command).""" - itasks, _, bad_items = self.filter_task_proxies(items) - for itask in itasks: - # Spawn next occurrence of xtrigger sequential task. - self.check_spawn_psx_task(itask) - self.remove(itask, 'request') - if self.compute_runahead(): - self.release_runahead_tasks() - return len(bad_items) - def _get_flow_nums( self, flow: List[str], @@ -2168,8 +2175,8 @@ def force_trigger_tasks( """ # Get matching tasks proxies, and matching inactive task IDs. - existing_tasks, future_ids, unmatched = self.filter_task_proxies( - items, future=True, warn=False, + existing_tasks, inactive_ids, unmatched = self.filter_task_proxies( + items, inactive=True, warn_no_active=False, ) flow_nums = self._get_flow_nums(flow, flow_descr) @@ -2179,7 +2186,7 @@ def force_trigger_tasks( if flow == ['none'] and itask.flow_nums != set(): LOG.error( f"[{itask}] ignoring 'flow=none' trigger: task already has" - f" {stringify_flow_nums(itask.flow_nums, full=True)}" + f" {repr_flow_nums(itask.flow_nums, full=True)}" ) continue if itask.state(TASK_STATUS_PREPARING, *TASK_STATUSES_ACTIVE): @@ -2192,7 +2199,7 @@ def force_trigger_tasks( if not flow: # default: assign to all active flows flow_nums = self._get_active_flow_nums() - for name, point in future_ids: + for name, point in inactive_ids: if not self.can_be_spawned(name, point): continue submit_num, _, prev_fwait = ( @@ -2318,41 +2325,41 @@ def log_task_pool(self, log_lvl=logging.DEBUG): def filter_task_proxies( self, ids: Iterable[str], - warn: bool = True, - future: bool = False, + warn_no_active: bool = True, + inactive: bool = False, ) -> 'Tuple[List[TaskProxy], Set[Tuple[str, PointBase]], List[str]]': """Return task proxies that match names, points, states in items. Args: ids: ID strings. - warn: - Whether to log a warning if no matching tasks are found. - future: + warn_no_active: + Whether to log a warning if no matching active tasks are found. + inactive: If True, unmatched IDs will be checked against taskdefs - and cycle, task pairs will be provided in the future_matched - argument providing the ID + and cycle, and any matches will be returned in the second + return value, provided that the ID: * Specifies a cycle point. * Is not a pattern. (e.g. `*/foo`). * Does not contain a state selector (e.g. `:failed`). Returns: - (matched, future_matched, unmatched) + (matched, inactive_matched, unmatched) """ matched, unmatched = filter_ids( self.active_tasks, ids, - warn=warn, + warn=warn_no_active, ) - future_matched: 'Set[Tuple[str, PointBase]]' = set() - if future and unmatched: - future_matched, unmatched = self.match_inactive_tasks( + inactive_matched: 'Set[Tuple[str, PointBase]]' = set() + if inactive and unmatched: + inactive_matched, unmatched = self.match_inactive_tasks( unmatched ) - return matched, future_matched, unmatched + return matched, inactive_matched, unmatched def match_inactive_tasks( self, diff --git a/cylc/flow/task_proxy.py b/cylc/flow/task_proxy.py index 5e88b19f891..4d98609df39 100644 --- a/cylc/flow/task_proxy.py +++ b/cylc/flow/task_proxy.py @@ -17,7 +17,6 @@ """Provide a class to represent a task proxy in a running workflow.""" from collections import Counter -from copy import copy from fnmatch import fnmatchcase from time import time from typing import ( @@ -30,33 +29,36 @@ List, Optional, Set, - Tuple, ) from metomi.isodatetime.timezone import get_local_time_zone from cylc.flow import LOG -from cylc.flow.flow_mgr import stringify_flow_nums +from cylc.flow.cycling.iso8601 import ( + interval_parse, + point_parse, +) +from cylc.flow.flow_mgr import repr_flow_nums from cylc.flow.platforms import get_platform from cylc.flow.run_modes import RunMode from cylc.flow.task_action_timer import TimerFlags from cylc.flow.task_state import ( - TaskState, - TASK_STATUS_WAITING, TASK_STATUS_EXPIRED, + TASK_STATUS_WAITING, + TaskState, ) from cylc.flow.taskdef import generate_graph_children from cylc.flow.wallclock import get_unix_time_from_time_string as str2time -from cylc.flow.cycling.iso8601 import ( - point_parse, - interval_parse, -) + if TYPE_CHECKING: from cylc.flow.cycling import PointBase from cylc.flow.flow_mgr import FlowNums from cylc.flow.id import Tokens - from cylc.flow.prerequisite import PrereqMessage, SatisfiedState + from cylc.flow.prerequisite import ( + PrereqTuple, + SatisfiedState, + ) from cylc.flow.run_modes.simulation import ModeSettings from cylc.flow.task_action_timer import TaskActionTimer from cylc.flow.taskdef import TaskDef @@ -149,7 +151,7 @@ class TaskProxy: .graph_children (dict) graph children: {msg: [(name, point), ...]} .flow_nums: - flows I belong to + flows I belong to (if empty, belongs to 'none' flow) flow_wait: wait for flow merge before spawning children .waiting_on_job_prep: @@ -161,6 +163,9 @@ class TaskProxy: .is_xtrigger_sequential: A flag used to determine whether this task needs to wait for xtrigger satisfaction to spawn. + .removed: + A flag to indicate this task has been removed by command (used + e.g. to disable failed/submit-failed event handlers). Args: tdef: The definition object of this task. @@ -175,7 +180,7 @@ class TaskProxy: """ # Memory optimization - constrain possible attributes to this list. - __slots__ = [ + __slots__ = ( 'clock_trigger_times', 'expire_time', 'identity', @@ -206,14 +211,15 @@ class TaskProxy: 'mode_settings', 'transient', 'is_xtrigger_sequential', - ] + 'removed', + ) def __init__( self, scheduler_tokens: 'Tokens', tdef: 'TaskDef', start_point: 'PointBase', - flow_nums: Optional[Set[int]] = None, + flow_nums: Optional['FlowNums'] = None, status: str = TASK_STATUS_WAITING, is_held: bool = False, submit_num: int = 0, @@ -234,7 +240,7 @@ def __init__( self.flow_nums = set() else: # (don't share flow_nums ref with parent task) - self.flow_nums = copy(flow_nums) + self.flow_nums = flow_nums.copy() self.flow_wait = flow_wait self.point = start_point self.tokens = scheduler_tokens.duplicate( @@ -281,6 +287,7 @@ def __init__( self.late_time: Optional[float] = None self.is_late = is_late self.waiting_on_job_prep = False + self.removed: bool = False self.state = TaskState(tdef, self.point, status, is_held) @@ -309,7 +316,7 @@ def __init__( ) def __repr__(self) -> str: - return f"<{self.__class__.__name__} '{self.tokens}'>" + return f"<{self.__class__.__name__} {self.identity}>" def __str__(self) -> str: """Stringify with tokens, state, submit_num, and flow_nums. @@ -320,11 +327,11 @@ def __str__(self) -> str: """ id_ = self.identity if self.transient: - return f"{id_}{stringify_flow_nums(self.flow_nums)}" + return f"{id_}{repr_flow_nums(self.flow_nums)}" if not self.state(TASK_STATUS_WAITING, TASK_STATUS_EXPIRED): id_ += f"/{self.submit_num:02d}" return ( - f"{id_}{stringify_flow_nums(self.flow_nums)}:{self.state}" + f"{id_}{repr_flow_nums(self.flow_nums)}:{self.state}" ) def copy_to_reload_successor( @@ -355,7 +362,7 @@ def copy_to_reload_successor( # pre-reload state of prerequisites that still exist post-reload. # Get all prereq states, e.g. {('1', 'c', 'succeeded'): False, ...} - pre_reload: Dict[PrereqMessage, SatisfiedState] = { + pre_reload: Dict[PrereqTuple, SatisfiedState] = { k: v for pre in self.state.prerequisites for (k, v) in pre.items() @@ -457,7 +464,7 @@ def next_point(self): """Return the next cycle point.""" return self.tdef.next_point(self.point) - def is_ready_to_run(self) -> Tuple[bool, ...]: + def is_ready_to_run(self) -> bool: """Is this task ready to run? Takes account of all dependence: on other tasks, xtriggers, and @@ -466,16 +473,18 @@ def is_ready_to_run(self) -> Tuple[bool, ...]: """ if self.is_manual_submit: # Manually triggered, ignore unsatisfied prerequisites. - return (True,) + return True if self.state.is_held: # A held task is not ready to run. - return (False,) + return False if self.state.status in self.try_timers: # A try timer is still active. - return (self.try_timers[self.state.status].is_delay_done(),) + return self.try_timers[self.state.status].is_delay_done() return ( - self.state(TASK_STATUS_WAITING), - self.is_waiting_prereqs_done() + self.state(TASK_STATUS_WAITING) + and self.prereqs_are_satisfied() + and self.state.external_triggers_all_satisfied() + and self.state.xtriggers_all_satisfied() ) def set_summary_time(self, event_key, time_str=None): @@ -489,18 +498,9 @@ def set_summary_time(self, event_key, time_str=None): self.summary[event_key + '_time'] = float(str2time(time_str)) self.summary[event_key + '_time_string'] = time_str - def is_task_prereqs_not_done(self): - """Are some task prerequisites not satisfied?""" - return (not all(pre.is_satisfied() - for pre in self.state.prerequisites)) - - def is_waiting_prereqs_done(self): - """Are ALL prerequisites satisfied?""" - return ( - all(pre.is_satisfied() for pre in self.state.prerequisites) - and self.state.external_triggers_all_satisfied() - and self.state.xtriggers_all_satisfied() - ) + def prereqs_are_satisfied(self) -> bool: + """Are all task prerequisites satisfied?""" + return all(pre.is_satisfied() for pre in self.state.prerequisites) def reset_try_timers(self): # unset any retry delay timers @@ -525,6 +525,17 @@ def name_match( match_func(ns, value) for ns in self.tdef.namespace_hierarchy ) + def match_flows(self, flow_nums: 'FlowNums') -> 'FlowNums': + """Return which of the given flow numbers the task belongs to. + + NOTE: If `flow_nums` is empty, it means 'all', whereas + if `self.flow_nums` is empty, it means this task is in the 'none' flow + and will not match. + """ + if not flow_nums or not self.flow_nums: + return self.flow_nums.copy() + return self.flow_nums.intersection(flow_nums) + def merge_flows(self, flow_nums: Set) -> None: """Merge another set of flow_nums with mine.""" self.flow_nums.update(flow_nums) @@ -559,6 +570,7 @@ def satisfy_me( self, task_messages: 'Iterable[Tokens]', mode: Optional[RunMode] = RunMode.LIVE, + forced: bool = False, ) -> 'Set[Tokens]': """Try to satisfy my prerequisites with given output messages. @@ -568,8 +580,7 @@ def satisfy_me( Return a set of unmatched task messages. """ - - used = self.state.satisfy_me(task_messages, mode) + used = self.state.satisfy_me(task_messages, mode=mode, forced=forced) return set(task_messages) - used def clock_expire(self) -> bool: diff --git a/cylc/flow/task_state.py b/cylc/flow/task_state.py index 6314719a25a..5b8023d6464 100644 --- a/cylc/flow/task_state.py +++ b/cylc/flow/task_state.py @@ -19,9 +19,9 @@ from typing import ( TYPE_CHECKING, Dict, - Optional, Iterable, List, + Optional, Set, ) @@ -30,8 +30,8 @@ TASK_OUTPUT_EXPIRED, TASK_OUTPUT_FAILED, TASK_OUTPUT_STARTED, - TASK_OUTPUT_SUBMITTED, TASK_OUTPUT_SUBMIT_FAILED, + TASK_OUTPUT_SUBMITTED, TASK_OUTPUT_SUCCEEDED, TaskOutputs, ) @@ -41,7 +41,7 @@ if TYPE_CHECKING: from cylc.flow.cycling import PointBase from cylc.flow.id import Tokens - from cylc.flow.prerequisite import PrereqMessage + from cylc.flow.prerequisite import PrereqTuple from cylc.flow.run_modes import RunMode from cylc.flow.taskdef import TaskDef @@ -278,33 +278,31 @@ def __str__(self): ret += '(runahead)' return ret + def __repr__(self) -> str: + return f"<{type(self).__name__} {self}>" + def __call__( - self, *status, is_held=None, is_queued=None, is_runahead=None): + self, + *status: Optional[str], + is_held: Optional[bool] = None, + is_queued: Optional[bool] = None, + is_runahead: Optional[bool] = None, + ) -> bool: """Compare task state attributes. Args: - status (str/list/None): - ``str`` - Check if the task status is the same as the one provided - ``list`` - Check if the task status is one of the ones provided - ``None`` - Do not check the task state. - is_held (bool): - ``bool`` - Check the task is_held attribute is the same as provided - ``None`` - Do not check the is_held attribute - is_queued (bool): - ``bool`` - Check the task is_queued attribute is the same as provided - ``None`` - Do not check the is_queued attribute - is_runahead (bool): - ``bool`` - Check the task is_runahead attribute is as provided - ``None`` - Do not check the is_runahead attribute + status: + Check if the task status is one of the ones provided, or + do not check the task state if None. + is_held: + Check the task is_held attribute is the same as provided, or + do not check the is_held attribute if None. + is_queued: + Check the task is_queued attribute is the same as provided, or + do not check the is_queued attribute if None. + is_runahead: + Check the task is_runahead attribute is as provided, or + do not check the is_runahead attribute if None. """ return ( @@ -326,7 +324,8 @@ def __call__( def satisfy_me( self, outputs: Iterable['Tokens'], - mode: "Optional[RunMode]", + mode: 'Optional[RunMode]', + forced: bool = False, ) -> Set['Tokens']: """Try to satisfy my prerequisites with given outputs. @@ -335,7 +334,7 @@ def satisfy_me( valid: Set[Tokens] = set() for prereq in (*self.prerequisites, *self.suicide_prerequisites): valid.update( - prereq.satisfy_me(outputs, mode) + prereq.satisfy_me(outputs, mode=mode, forced=forced) ) return valid @@ -384,7 +383,8 @@ def set_prerequisites_all_satisfied(self): prereq.set_satisfied() def get_resolved_dependencies(self): - """Return a list of dependencies which have been met for this task. + """Return a list of dependencies which have been met for this task + (ignoring the specific output in the depedency). E.G: ['1/foo', '2/bar'] @@ -393,9 +393,10 @@ def get_resolved_dependencies(self): """ return sorted( - dep + task_output.get_id() for prereq in self.prerequisites - for dep in prereq.get_resolved_dependencies() + for task_output, satisfied in prereq._satisfied.items() + if satisfied ) def reset( @@ -498,7 +499,7 @@ def _add_prerequisites(self, point: 'PointBase', tdef: 'TaskDef'): cpre[(p_prev, tdef.name, TASK_STATUS_SUCCEEDED)] = ( p_prev < tdef.start_point ) - cpre.set_condition(tdef.name) + cpre.set_conditional_expr(tdef.name) prerequisites[cpre.instantaneous_hash()] = cpre self.suicide_prerequisites = list(suicide_prerequisites.values()) @@ -523,9 +524,18 @@ def _add_xtriggers(self, point, tdef): for xtrig_label in xtrig_labels: self.add_xtrigger(xtrig_label) - def get_unsatisfied_prerequisites(self) -> List['PrereqMessage']: + def get_unsatisfied_prerequisites(self) -> List['PrereqTuple']: return [ key for prereq in self.prerequisites if not prereq.is_satisfied() for key, satisfied in prereq.items() if not satisfied ] + + def any_satisfied_prerequisite_outputs(self) -> bool: + """Return True if any of this task's prerequisite outputs are + satisfied.""" + return any( + satisfied + for prereq in self.prerequisites + for satisfied in prereq._satisfied.values() + ) diff --git a/cylc/flow/task_trigger.py b/cylc/flow/task_trigger.py index 820b5eb8a18..16a31db566b 100644 --- a/cylc/flow/task_trigger.py +++ b/cylc/flow/task_trigger.py @@ -238,7 +238,7 @@ def get_prerequisite( task_trigger.task_name, task_trigger.output, )] = False - cpre.set_condition(self.get_expression(point)) + cpre.set_conditional_expr(self.get_expression(point)) return cpre def get_expression(self, point): diff --git a/cylc/flow/util.py b/cylc/flow/util.py index b7e1e0e0c73..b649265dbd9 100644 --- a/cylc/flow/util.py +++ b/cylc/flow/util.py @@ -17,7 +17,10 @@ import ast from contextlib import suppress -from functools import partial +from functools import ( + lru_cache, + partial, +) import json import re from textwrap import dedent @@ -31,6 +34,7 @@ Tuple, ) + BOOL_SYMBOLS: Dict[bool, str] = { # U+2A2F (vector cross product) False: '⨯', @@ -163,15 +167,23 @@ def serialise_set(flow_nums: Optional[set] = None) -> str: '[]' """ - return json.dumps(sorted(flow_nums or ())) + return _serialise_set(tuple(sorted(flow_nums or ()))) + + +@lru_cache(maxsize=100) +def _serialise_set(flow_nums: tuple) -> str: + return json.dumps(flow_nums) +@lru_cache(maxsize=100) def deserialise_set(flow_num_str: str) -> set: """Convert json string to set. Example: - >>> sorted(deserialise_set('[2, 3]')) - [2, 3] + >>> deserialise_set('[2, 3]') == {2, 3} + True + >>> deserialise_set('[]') + set() """ return set(json.loads(flow_num_str)) diff --git a/cylc/flow/workflow_db_mgr.py b/cylc/flow/workflow_db_mgr.py index b06367e37b3..d9ae87150d8 100644 --- a/cylc/flow/workflow_db_mgr.py +++ b/cylc/flow/workflow_db_mgr.py @@ -23,39 +23,66 @@ * Manage existing run database files on restart. """ +from collections import defaultdict import json import os -from shutil import copy, rmtree +from shutil import ( + copy, + rmtree, +) from sqlite3 import OperationalError from tempfile import mkstemp from typing import ( - Any, AnyStr, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union + TYPE_CHECKING, + Any, + AnyStr, + DefaultDict, + Dict, + List, + Optional, + Set, + Tuple, + Union, ) from packaging.version import parse as parse_version -from cylc.flow import LOG +from cylc.flow import ( + LOG, + __version__ as CYLC_VERSION, +) from cylc.flow.broadcast_report import get_broadcast_change_iter +from cylc.flow.exceptions import ( + CylcError, + ServiceFileError, +) from cylc.flow.rundb import CylcWorkflowDAO -from cylc.flow import __version__ as CYLC_VERSION -from cylc.flow.wallclock import get_current_time_string, get_utc_mode -from cylc.flow.exceptions import CylcError, ServiceFileError -from cylc.flow.util import serialise_set, deserialise_set +from cylc.flow.util import ( + deserialise_set, + serialise_set, +) +from cylc.flow.wallclock import ( + get_current_time_string, + get_utc_mode, +) + if TYPE_CHECKING: from pathlib import Path + + from packaging.version import Version + from cylc.flow.cycling import PointBase + from cylc.flow.flow_mgr import FlowNums + from cylc.flow.rundb import ( + DbArgDict, + DbUpdateTuple, + ) from cylc.flow.scheduler import Scheduler - from cylc.flow.task_pool import TaskPool from cylc.flow.task_events_mgr import EventKey + from cylc.flow.task_pool import TaskPool from cylc.flow.task_proxy import TaskProxy -Version = Any -# TODO: narrow down Any (should be str | int) after implementing type -# annotations in cylc.flow.task_state.TaskState -DbArgDict = Dict[str, Any] -DbUpdateTuple = Tuple[DbArgDict, DbArgDict] - PERM_PRIVATE = 0o600 # -rw------- @@ -141,7 +168,9 @@ def __init__(self, pri_d=None, pub_d=None): self.TABLE_TASKS_TO_HOLD: [], self.TABLE_XTRIGGERS: [], self.TABLE_ABS_OUTPUTS: []} - self.db_updates_map: Dict[str, List[DbUpdateTuple]] = {} + self.db_updates_map: DefaultDict[ + str, List[DbUpdateTuple] + ] = defaultdict(list) def copy_pri_to_pub(self) -> None: """Copy content of primary database file to public database file.""" @@ -232,29 +261,23 @@ def process_queued_ops(self) -> None: # Record workflow parameters and tasks in pool # Record any broadcast settings to be dumped out if any(self.db_deletes_map.values()): - for table_name, db_deletes in sorted( - self.db_deletes_map.items()): + for table_name, db_deletes in sorted(self.db_deletes_map.items()): while db_deletes: where_args = db_deletes.pop(0) self.pri_dao.add_delete_item(table_name, where_args) self.pub_dao.add_delete_item(table_name, where_args) if any(self.db_inserts_map.values()): - for table_name, db_inserts in sorted( - self.db_inserts_map.items()): + for table_name, db_inserts in sorted(self.db_inserts_map.items()): while db_inserts: db_insert = db_inserts.pop(0) self.pri_dao.add_insert_item(table_name, db_insert) self.pub_dao.add_insert_item(table_name, db_insert) - if (hasattr(self, 'db_updates_map') and - any(self.db_updates_map.values())): - for table_name, db_updates in sorted( - self.db_updates_map.items()): + if any(self.db_updates_map.values()): + for table_name, db_updates in sorted(self.db_updates_map.items()): while db_updates: - set_args, where_args = db_updates.pop(0) - self.pri_dao.add_update_item( - table_name, set_args, where_args) - self.pub_dao.add_update_item( - table_name, set_args, where_args) + db_update = db_updates.pop(0) + self.pri_dao.add_update_item(table_name, db_update) + self.pub_dao.add_update_item(table_name, db_update) # Previously, we used a separate thread for database writes. This has # now been removed. For the private database, there is no real @@ -426,7 +449,7 @@ def put_xtriggers(self, sat_xtrig): "signature": sig, "results": json.dumps(res)}) - def put_update_task_state(self, itask): + def put_update_task_state(self, itask: 'TaskProxy') -> None: """Update task_states table for current state of itask. NOTE the task_states table is normally updated along with the task pool @@ -447,9 +470,9 @@ def put_update_task_state(self, itask): "name": itask.tdef.name, "flow_nums": serialise_set(itask.flow_nums), } - self.db_updates_map.setdefault(self.TABLE_TASK_STATES, []) self.db_updates_map[self.TABLE_TASK_STATES].append( - (set_args, where_args)) + (set_args, where_args) + ) def put_update_task_flow_wait(self, itask): """Update flow_wait status of a task, in the task_states table. @@ -467,7 +490,6 @@ def put_update_task_flow_wait(self, itask): "name": itask.tdef.name, "flow_nums": serialise_set(itask.flow_nums), } - self.db_updates_map.setdefault(self.TABLE_TASK_STATES, []) self.db_updates_map[self.TABLE_TASK_STATES].append( (set_args, where_args)) @@ -495,7 +517,6 @@ def put_task_pool(self, pool: 'TaskPool') -> None: prereq.items() ): self.put_insert_task_prerequisites(itask, { - "flow_nums": serialise_set(itask.flow_nums), "prereq_name": p_name, "prereq_cycle": p_cycle, "prereq_output": p_output, @@ -551,7 +572,6 @@ def put_task_pool(self, pool: 'TaskPool') -> None: "name": itask.tdef.name, "flow_nums": serialise_set(itask.flow_nums) } - self.db_updates_map.setdefault(self.TABLE_TASK_STATES, []) self.db_updates_map[self.TABLE_TASK_STATES].append( (set_args, where_args) ) @@ -585,12 +605,25 @@ def put_insert_task_jobs(self, itask, args): """Put INSERT statement for task_jobs table.""" self._put_insert_task_x(CylcWorkflowDAO.TABLE_TASK_JOBS, itask, args) - def put_insert_task_states(self, itask, args): + def put_insert_task_states(self, itask: 'TaskProxy') -> None: """Put INSERT statement for task_states table.""" - self._put_insert_task_x(CylcWorkflowDAO.TABLE_TASK_STATES, itask, args) + now = get_current_time_string() + self._put_insert_task_x( + CylcWorkflowDAO.TABLE_TASK_STATES, + itask, + { + "time_created": now, + "time_updated": now, + "status": itask.state.status, + "flow_nums": serialise_set(itask.flow_nums), + "flow_wait": itask.flow_wait, + "is_manual_submit": itask.is_manual_submit, + }, + ) def put_insert_task_prerequisites(self, itask, args): """Put INSERT statement for task_prerequisites table.""" + args.setdefault("flow_nums", serialise_set(itask.flow_nums)) self._put_insert_task_x(self.TABLE_TASK_PREREQUISITES, itask, args) def put_insert_task_outputs(self, itask): @@ -627,20 +660,23 @@ def put_insert_workflow_flows(self, flow_num, flow_metadata): } ) - def _put_insert_task_x(self, table_name, itask, args): + def _put_insert_task_x( + self, table_name: str, itask: 'TaskProxy', args: 'DbArgDict' + ) -> None: """Put INSERT statement for a task_* table.""" args.update({ "name": itask.tdef.name, - "cycle": str(itask.point)}) - if "submit_num" not in args: - args["submit_num"] = itask.submit_num - self.db_inserts_map.setdefault(table_name, []) - self.db_inserts_map[table_name].append(args) + "cycle": str(itask.point), + }) + args.setdefault("submit_num", itask.submit_num) + self.db_inserts_map.setdefault(table_name, []).append(args) - def put_update_task_jobs(self, itask, set_args): + def put_update_task_jobs(self, itask: 'TaskProxy', set_args: dict) -> None: """Put UPDATE statement for task_jobs table.""" + set_args.setdefault('flow_nums', serialise_set(itask.flow_nums)) self._put_update_task_x( - CylcWorkflowDAO.TABLE_TASK_JOBS, itask, set_args) + CylcWorkflowDAO.TABLE_TASK_JOBS, itask, set_args + ) def put_update_task_outputs(self, itask: 'TaskProxy') -> None: """Put UPDATE statement for task_outputs table.""" @@ -654,22 +690,107 @@ def put_update_task_outputs(self, itask: 'TaskProxy') -> None: "name": itask.tdef.name, "flow_nums": serialise_set(itask.flow_nums), } - self.db_updates_map.setdefault(self.TABLE_TASK_OUTPUTS, []).append( + self.db_updates_map[self.TABLE_TASK_OUTPUTS].append( (set_args, where_args) ) - def _put_update_task_x(self, table_name, itask, set_args): + def _put_update_task_x( + self, table_name: str, itask: 'TaskProxy', set_args: 'DbArgDict' + ) -> None: """Put UPDATE statement for a task_* table.""" where_args = { "cycle": str(itask.point), - "name": itask.tdef.name} + "name": itask.tdef.name, + } if "submit_num" not in set_args: where_args["submit_num"] = itask.submit_num if "flow_nums" not in set_args: where_args["flow_nums"] = serialise_set(itask.flow_nums) - self.db_updates_map.setdefault(table_name, []) self.db_updates_map[table_name].append((set_args, where_args)) + def remove_task_from_flows( + self, point: str, name: str, flow_nums: 'FlowNums' + ) -> 'FlowNums': + """Remove flow numbers for a task in the task_states and task_outputs + tables. + + Args: + point: Cycle point of the task. + name: Name of the task. + flow_nums: Flow numbers to remove. If empty, remove all + flow numbers. + + Returns the flow numbers that were removed, if any. + + N.B. the task_prerequisites table is automatically updated separately + during the main loop. + """ + removed_flow_nums: FlowNums = set() + for table in ( + self.TABLE_TASK_STATES, + self.TABLE_TASK_OUTPUTS, + ): + fnums_select_stmt = rf''' + SELECT + flow_nums + FROM + {table} + WHERE + cycle = ? + AND name = ? + ''' # nosec B608 (table name is a code constant) + fnums_select_cursor = self.pri_dao.connect().execute( + fnums_select_stmt, (point, name) + ) + + if not flow_nums: + for db_fnums_str, *_ in fnums_select_cursor: + removed_flow_nums.update(deserialise_set(db_fnums_str)) + + stmt = rf''' + UPDATE OR REPLACE + {table} + SET + flow_nums = ? + WHERE + cycle = ? + AND name = ? + ''' # nosec B608 (table name is a code constant) + params: List[tuple] = [(serialise_set(), point, name)] + else: + # Mapping of existing flow nums to what should be left after + # removing the specified flow nums: + flow_nums_map: Dict[str, FlowNums] = {} + for db_fnums_str, *_ in fnums_select_cursor: + db_fnums: FlowNums = deserialise_set(db_fnums_str) + fnums_to_remove = db_fnums.intersection(flow_nums) + if fnums_to_remove: + flow_nums_map[db_fnums_str] = db_fnums.difference( + flow_nums + ) + removed_flow_nums.update(fnums_to_remove) + + stmt = rf''' + UPDATE OR REPLACE + {table} + SET + flow_nums = ? + WHERE + cycle = ? + AND name = ? + AND flow_nums = ? + ''' # nosec B608 (table name is a code constant) + params = [ + (serialise_set(new), point, name, old) + for old, new in flow_nums_map.items() + ] + + self.db_updates_map[table].append( + (stmt, params) + ) + + return removed_flow_nums + def recover_pub_from_pri(self): """Recover public database from private database.""" if self.pub_dao.n_tries >= self.pub_dao.MAX_TRIES: @@ -694,7 +815,7 @@ def restart_check(self) -> None: self.process_queued_ops() @classmethod - def _get_last_run_version(cls, pri_dao: CylcWorkflowDAO) -> Version: + def _get_last_run_version(cls, pri_dao: CylcWorkflowDAO) -> 'Version': """Return the version of Cylc this DB was last run with. Args: @@ -710,7 +831,7 @@ def _get_last_run_version(cls, pri_dao: CylcWorkflowDAO) -> Version: {cls.TABLE_WORKFLOW_PARAMS} WHERE key == ? - ''', # nosec (table name is a code constant) + ''', # nosec B608 (table name is a code constant) [cls.KEY_CYLC_VERSION] ).fetchone()[0] except (TypeError, OperationalError) as exc: @@ -785,7 +906,7 @@ def upgrade(cls, db_file: Union['Path', str]) -> None: cls.upgrade_pre_810(pri_dao) @classmethod - def check_db_compatibility(cls, db_file: Union['Path', str]) -> Version: + def check_db_compatibility(cls, db_file: Union['Path', str]) -> 'Version': """Check this DB is compatible with this Cylc version. Raises: diff --git a/tests/conftest.py b/tests/conftest.py index d07f788ba0b..8e6b988d5c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -99,31 +99,30 @@ def _inner(cached=False): @pytest.fixture -def log_filter(): - """Filter caplog record_tuples. +def log_filter(caplog: pytest.LogCaptureFixture): + """Filter caplog record_tuples (also discarding the log name entry). Args: - log: The caplog instance. - name: Filter out records if they don't match this logger name. level: Filter out records if they aren't at this logging level. contains: Filter out records if this string is not in the message. regex: Filter out records if the message doesn't match this regex. exact_match: Filter out records if the message does not exactly match this string. + log: A caplog instance. """ def _log_filter( - log: pytest.LogCaptureFixture, - name: Optional[str] = None, level: Optional[int] = None, contains: Optional[str] = None, regex: Optional[str] = None, exact_match: Optional[str] = None, - ) -> List[Tuple[str, int, str]]: + log: Optional[pytest.LogCaptureFixture] = None + ) -> List[Tuple[int, str]]: + if log is None: + log = caplog return [ - (log_name, log_level, log_message) - for log_name, log_level, log_message in log.record_tuples - if (name is None or name == log_name) - and (level is None or level == log_level) + (log_level, log_message) + for _, log_level, log_message in log.record_tuples + if (level is None or level == log_level) and (contains is None or contains in log_message) and (regex is None or re.search(regex, log_message)) and (exact_match is None or exact_match == log_message) diff --git a/tests/functional/cylc-remove/00-simple/flow.cylc b/tests/functional/cylc-remove/00-simple/flow.cylc index 84c740ad421..0ee53fafed6 100644 --- a/tests/functional/cylc-remove/00-simple/flow.cylc +++ b/tests/functional/cylc-remove/00-simple/flow.cylc @@ -1,13 +1,15 @@ # Abort on stall timeout unless we remove unhandled failed and waiting task. [scheduler] [[events]] - stall timeout = PT20S + stall timeout = PT30S abort on stall timeout = True expected task failures = 1/b [scheduling] [[graph]] - R1 = """a => b => c - cleaner""" + R1 = """ + a => b => c + cleaner + """ [runtime] [[a,c]] script = true @@ -15,10 +17,10 @@ script = false [[cleaner]] script = """ -cylc__job__poll_grep_workflow_log -E '1/b/01:running.* \(received\)failed' -# Remove the unhandled failed task -cylc remove "$CYLC_WORKFLOW_ID//1/b" -# Remove waiting 1/c -# (not auto-removed because parent 1/b, an unhandled fail, is not finished.) -cylc remove "$CYLC_WORKFLOW_ID//1/c:waiting" -""" + cylc__job__poll_grep_workflow_log -E '1/b/01:running.* \(received\)failed' + # Remove the unhandled failed task + cylc remove "$CYLC_WORKFLOW_ID//1/b" + # Remove waiting 1/c + # (not auto-removed because parent 1/b, an unhandled fail, is not finished.) + cylc remove "$CYLC_WORKFLOW_ID//1/c:waiting" + """ diff --git a/tests/functional/cylc-remove/02-cycling/flow.cylc b/tests/functional/cylc-remove/02-cycling/flow.cylc index 3b6c1051493..249f0676cc4 100644 --- a/tests/functional/cylc-remove/02-cycling/flow.cylc +++ b/tests/functional/cylc-remove/02-cycling/flow.cylc @@ -28,18 +28,6 @@ [[foo, waz]] script = true [[bar]] - script = """ - if [[ $CYLC_TASK_CYCLE_POINT == 2020 ]]; then - false - else - true - fi - """ + script = [[ $CYLC_TASK_CYCLE_POINT != 2020 ]] [[baz]] - script = """ - if [[ $CYLC_TASK_CYCLE_POINT == 2021 ]]; then - false - else - true - fi - """ + script = [[ $CYLC_TASK_CYCLE_POINT != 2021 ]] diff --git a/tests/functional/cylc-remove/03-flow.t b/tests/functional/cylc-remove/03-flow.t new file mode 100644 index 00000000000..9c0d84ecfd8 --- /dev/null +++ b/tests/functional/cylc-remove/03-flow.t @@ -0,0 +1,50 @@ +#!/usr/bin/env bash +# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE. +# Copyright (C) NIWA & British Crown (Met Office) & Contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +# High-level test of `cylc remove --flow` option. +# Integration tests exist for more comprehensive coverage. + +. "$(dirname "$0")/test_header" +set_test_number 6 + +init_workflow "${TEST_NAME_BASE}" <<'__EOF__' +[scheduler] + allow implicit tasks = True +[scheduling] + [[graph]] + R1 = foo +__EOF__ + +run_ok "${TEST_NAME_BASE}-validate" cylc validate "${WORKFLOW_NAME}" + +workflow_run_ok "${TEST_NAME_BASE}-run" cylc play "${WORKFLOW_NAME}" --pause + +run_ok "${TEST_NAME_BASE}-remove" cylc remove "${WORKFLOW_NAME}//1/foo" --flow 1 --flow 2 + +cylc stop "${WORKFLOW_NAME}" +poll_workflow_stopped + +grep_workflow_log_ok "${TEST_NAME_BASE}-grep" "Removed task(s): 1/foo (flows=1)" + +# Simple additional test of DB: +TEST_NAME="${TEST_NAME_BASE}-workflow-state" +run_ok "$TEST_NAME" cylc workflow-state "$WORKFLOW_NAME" +cmp_ok "${TEST_NAME}.stdout" <<__EOF__ +1/foo:waiting(flows=none) +__EOF__ + +purge diff --git a/tests/functional/cylc-remove/04-kill.t b/tests/functional/cylc-remove/04-kill.t new file mode 100644 index 00000000000..6b644529c2c --- /dev/null +++ b/tests/functional/cylc-remove/04-kill.t @@ -0,0 +1,65 @@ +#!/usr/bin/env bash +# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE. +# Copyright (C) NIWA & British Crown (Met Office) & Contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +# Test that removing submited/running tasks causes them to be killed. +# Any downstream tasks that depend on the `:submit-fail`/`:fail` outputs +# should NOT run. +# Handlers for the `submission failed`/`failed` events should not run either. + +export REQUIRE_PLATFORM='runner:at' +. "$(dirname "$0")/test_header" +set_test_number 10 + +# Create platform that ensures job b will be in submitted state for long enough +create_test_global_config '' " +[platforms] + [[old_street]] + job runner = at + job runner command template = at now + 5 minutes + hosts = localhost + install target = localhost +" + +install_and_validate +reftest_run + +grep_workflow_log_ok "${TEST_NAME_BASE}-grep-a" \ + "[1/a/01(flows=none):failed(held)] job killed" -F + +J_LOG_A="${WORKFLOW_RUN_DIR}/log/job/1/a/NN/job-activity.log" +# Failed handler should not run: +grep_fail "[(('event-handler-00', 'failed'), 1) out]" "$J_LOG_A" -F +# (Check submitted handler as a control): +grep_ok "[(('event-handler-00', 'submitted'), 1) out]" "$J_LOG_A" -F + +grep_workflow_log_ok "${TEST_NAME_BASE}-grep-b" \ + "[1/b/01(flows=none):submit-failed(held)] job killed" -F + +J_LOG_B="${WORKFLOW_RUN_DIR}/log/job/1/b/NN/job-activity.log" +grep_fail "[(('event-handler-00', 'submission failed'), 1) out]" "$J_LOG_B" -F +grep_ok "[(('event-handler-00', 'submitted'), 1) out]" "$J_LOG_B" -F + +# Check task state updated in DB despite removal from task pool: +sqlite3 "${WORKFLOW_RUN_DIR}/.service/db" \ + "SELECT status, flow_nums FROM task_states WHERE name='a';" > task_states.out +cmp_ok task_states.out - <<< "failed|[]" +# Check job updated in DB: +sqlite3 "${WORKFLOW_RUN_DIR}/.service/db" \ + "SELECT run_status, time_run_exit FROM task_jobs WHERE cycle='1' AND name='a';" > task_jobs.out +cmp_ok_re task_jobs.out - <<< "1\|[\w:+-]+" + +purge diff --git a/tests/functional/cylc-remove/04-kill/flow.cylc b/tests/functional/cylc-remove/04-kill/flow.cylc new file mode 100644 index 00000000000..88d2e6f51a7 --- /dev/null +++ b/tests/functional/cylc-remove/04-kill/flow.cylc @@ -0,0 +1,31 @@ +[scheduler] + allow implicit tasks = True + [[events]] + expected task failures = 1/a, 1/b + stall timeout = PT0S + abort on stall timeout = True +[scheduling] + [[graph]] + R1 = """ + a:started => remover + a:failed => u + + b:submitted? => remover + b:submit-failed? => v + """ + +[runtime] + [[a, b]] + script = sleep 40 + [[[events]]] + submitted handlers = echo %(event)s + failed handlers = echo %(event)s + submission failed handlers = echo %(event)s + [[b]] + platform = old_street + [[remover]] + script = """ + cylc remove "$CYLC_WORKFLOW_ID//1/a" "$CYLC_WORKFLOW_ID//1/b" + cylc__job__poll_grep_workflow_log -E '1\/a.* => failed' + cylc__job__poll_grep_workflow_log -E '1\/b.* => submit-failed' + """ diff --git a/tests/functional/cylc-remove/04-kill/reference.log b/tests/functional/cylc-remove/04-kill/reference.log new file mode 100644 index 00000000000..379728f01bc --- /dev/null +++ b/tests/functional/cylc-remove/04-kill/reference.log @@ -0,0 +1,5 @@ +Initial point: 1 +Final point: 1 +1/a -triggered off [] +1/b -triggered off [] +1/remover -triggered off ['1/a', '1/b'] diff --git a/tests/functional/cylc-set/06-parentless/flow.cylc b/tests/functional/cylc-set/06-parentless/flow.cylc index 5078b84e484..3ad945cc7bb 100644 --- a/tests/functional/cylc-set/06-parentless/flow.cylc +++ b/tests/functional/cylc-set/06-parentless/flow.cylc @@ -1,15 +1,11 @@ # Start this with --start-task=1800/a. -# It should stall because x => b is off-flow. -# The stall handler should unstall it by spawning x. +# Task a's script should spawn x. # The log should show a clock-trigger check before x runs. [scheduler] [[events]] inactivity timeout = PT30S abort on inactivity timeout = True - stall timeout = PT10S - abort on stall timeout = True - stall handlers = "cylc set --pre=all %(workflow)s//1800/x" [scheduling] initial cycle point = 1800 @@ -19,4 +15,6 @@ @wall_clock => x => b """ [runtime] - [[a, b, x]] + [[a]] + script = cylc set --pre=all "${CYLC_WORKFLOW_ID}//1800/x" + [[b, x]] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 1bd06697cd0..6300cefc74e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -18,23 +18,39 @@ import asyncio from functools import partial from pathlib import Path -import pytest +import re from shutil import rmtree from time import time -from typing import List, TYPE_CHECKING, Set, Tuple, Union +from typing import ( + TYPE_CHECKING, + List, + Set, + Tuple, + Union, +) + +import pytest from cylc.flow.config import WorkflowConfig from cylc.flow.id import Tokens +from cylc.flow.network.client import WorkflowRuntimeClient from cylc.flow.option_parsers import Options from cylc.flow.pathutil import get_cylc_run_dir -from cylc.flow.rundb import CylcWorkflowDAO from cylc.flow.run_modes import RunMode -from cylc.flow.scripts.validate import ValidateOptions +from cylc.flow.rundb import CylcWorkflowDAO from cylc.flow.scripts.install import ( + get_option_parser as install_gop, install as cylc_install, - get_option_parser as install_gop ) -from cylc.flow.task_state import TASK_STATUS_SUBMITTED, TASK_STATUS_SUCCEEDED +from cylc.flow.scripts.show import ( + ShowOptions, + prereqs_and_outputs_query, +) +from cylc.flow.scripts.validate import ValidateOptions +from cylc.flow.task_state import ( + TASK_STATUS_SUBMITTED, + TASK_STATUS_SUCCEEDED, +) from cylc.flow.util import serialise_set from cylc.flow.wallclock import get_current_time_string from cylc.flow.workflow_files import infer_latest_run_from_id @@ -43,15 +59,14 @@ from .utils import _rm_if_empty from .utils.flow_tools import ( _make_flow, - _make_src_flow, _make_scheduler, + _make_src_flow, _run_flow, _start_flow, ) if TYPE_CHECKING: - from cylc.flow.network.client import WorkflowRuntimeClient from cylc.flow.scheduler import Scheduler from cylc.flow.task_proxy import TaskProxy @@ -119,7 +134,11 @@ def ses_test_dir(request, run_dir): @pytest.fixture(scope='module') def mod_test_dir(request, ses_test_dir): """The root run dir for test flows in this test module.""" - path = Path(ses_test_dir, request.module.__name__) + path = Path( + ses_test_dir, + # Shorten path by dropping `integration.` prefix: + re.sub(r'^integration\.', '', request.module.__name__) + ) path.mkdir(exist_ok=True) yield path if _pytest_passed(request): @@ -513,6 +532,10 @@ def reflog(): Note, you'll need to call this on the scheduler *after* you have started it. + N.B. Trigger order is not stable; using a set ensures that tests check + trigger logic rather than binding to specific trigger order which could + change in the future, breaking the test. + Args: schd: The scheduler to capture triggering information for. @@ -591,6 +614,9 @@ async def _complete( async_timeout (handles shutdown logic more cleanly). """ + if schd.is_paused: + raise Exception("Cannot wait for completion of a paused scheduler") + start_time = time() tokens_list: List[Tokens] = [] @@ -625,11 +651,16 @@ def _set_stop(mode=None): # determine the completion condition def done(): if wait_tokens: - return not tokens_list + if not tokens_list: + return True + if not schd.contact_data: + raise AssertionError( + "Scheduler shut down before tasks completed: " + + ", ".join(map(str, tokens_list)) + ) + return False # otherwise wait for the scheduler to shut down - if not schd.contact_data: - return True - return stop_requested + return stop_requested or not schd.contact_data with pytest.MonkeyPatch.context() as mp: mp.setattr(schd.pool, 'remove_if_complete', _remove_if_complete) @@ -677,6 +708,26 @@ async def _reftest( return _reftest +@pytest.fixture +def cylc_show(): + """Fixture that runs `cylc show` on a scheduler, returning JSON object.""" + + async def _cylc_show(schd: 'Scheduler', *task_ids: str) -> dict: + pclient = WorkflowRuntimeClient(schd.workflow) + await schd.update_data_structure() + json_filter: dict = {} + await prereqs_and_outputs_query( + schd.id, + [Tokens(id_, relative=True) for id_ in task_ids], + pclient, + ShowOptions(json=True), + json_filter, + ) + return json_filter + + return _cylc_show + + @pytest.fixture def capture_live_submissions(capcall, monkeypatch): """Capture live submission attempts. diff --git a/tests/integration/events/test_task_events.py b/tests/integration/events/test_task_events.py index 3ae30c1fe73..81bbfd4316e 100644 --- a/tests/integration/events/test_task_events.py +++ b/tests/integration/events/test_task_events.py @@ -52,7 +52,7 @@ async def test_mail_footer_template( # start the workflow and get it to send an email ctx = SimpleNamespace(mail_to=None, mail_from=None) id_keys = [EventKey('none', 'failed', 'failed', Tokens('//1/a'))] - async with start(mod_one) as one_log: + async with start(mod_one): mod_one.task_events_mgr._process_event_email(mod_one, ctx, id_keys) # warnings should appear only when the template is invalid @@ -60,11 +60,9 @@ async def test_mail_footer_template( # check that template issues are handled correctly assert bool(log_filter( - one_log, contains='Ignoring bad mail footer template', )) == should_log assert bool(log_filter( - one_log, contains=template, )) == should_log diff --git a/tests/integration/events/test_workflow_events.py b/tests/integration/events/test_workflow_events.py index 6b742264636..4569cbaed8b 100644 --- a/tests/integration/events/test_workflow_events.py +++ b/tests/integration/events/test_workflow_events.py @@ -69,11 +69,9 @@ async def test_mail_footer_template( # check that template issues are handled correctly assert bool(log_filter( - one_log, contains='Ignoring bad mail footer template', )) == should_log assert bool(log_filter( - one_log, contains=template, )) == should_log @@ -114,10 +112,8 @@ async def test_custom_event_handler_template( # check that template issues are handled correctly assert bool(log_filter( - one_log, contains='bad template', )) == should_log assert bool(log_filter( - one_log, contains=template, )) == should_log diff --git a/tests/integration/main_loop/test_auto_restart.py b/tests/integration/main_loop/test_auto_restart.py index 20cc4ea81c6..e03c97485a5 100644 --- a/tests/integration/main_loop/test_auto_restart.py +++ b/tests/integration/main_loop/test_auto_restart.py @@ -45,6 +45,6 @@ async def test_no_detach( id_: str = flow(one_conf) schd: Scheduler = scheduler(id_, paused_start=True, no_detach=True) with pytest.raises(MainLoopPluginException) as exc: - async with run(schd) as log: + async with run(schd): await asyncio.sleep(2) - assert log_filter(log, contains=f"Workflow shutting down - {exc.value}") + assert log_filter(contains=f"Workflow shutting down - {exc.value}") diff --git a/tests/integration/run_modes/test_simulation.py b/tests/integration/run_modes/test_simulation.py index 100a30c46b1..72cbd7e10f1 100644 --- a/tests/integration/run_modes/test_simulation.py +++ b/tests/integration/run_modes/test_simulation.py @@ -235,9 +235,7 @@ def test_task_sped_up(sim_time_check_setup, monkeytime): assert result is True -async def test_settings_restart( - monkeytime, flow, scheduler, start,validate -): +async def test_settings_restart(monkeytime, flow, scheduler, start): """Check that simulation mode settings are correctly restored upon restart. @@ -312,7 +310,7 @@ async def test_settings_restart( ) is False # Check that the itask.mode_settings is now re-created - + assert itask.mode_settings.simulated_run_length == 60.0 assert itask.mode_settings.sim_task_fails is True @@ -356,7 +354,7 @@ async def test_settings_reload( conf_file.read_text().replace('False', 'True')) # Reload Workflow: - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) # Submit second psuedo-job and "run" to success: itask = run_simjob(schd, one_1066.point, 'one') diff --git a/tests/integration/run_modes/test_skip.py b/tests/integration/run_modes/test_skip.py index afb4b91707e..fb58a82d427 100644 --- a/tests/integration/run_modes/test_skip.py +++ b/tests/integration/run_modes/test_skip.py @@ -177,7 +177,7 @@ async def test_doesnt_release_held_tasks( """ one_conf['runtime'] = {'one': {'run mode': 'skip'}} schd = scheduler(flow(one_conf), run_mode='live', paused_start=False) - async with start(schd) as log: + async with start(schd): itask, = schd.pool.get_tasks() msg = 'held tasks shoudn\'t {}' @@ -187,16 +187,15 @@ async def test_doesnt_release_held_tasks( # Relinquish contol to the main loop. schd.release_queued_tasks() - assert not log_filter(log, contains='=> running'), msg.format('run') - assert not log_filter(log, contains='=> succeeded'), msg.format( - 'succeed') + assert not log_filter(contains='=> running'), msg.format('run') + assert not log_filter(contains='=> succeeded'), msg.format('succeed') # Release held task and assert that it now skips successfully: schd.pool.release_held_tasks(['1/one']) schd.release_queued_tasks() - assert log_filter(log, contains='=> running'), msg.format('run') - assert log_filter(log, contains='=> succeeded'), msg.format('succeed') + assert log_filter(contains='=> running'), msg.format('run') + assert log_filter(contains='=> succeeded'), msg.format('succeed') async def test_prereqs_marked_satisfied_by_skip_mode( diff --git a/tests/integration/scripts/test_set.py b/tests/integration/scripts/test_set.py index bff8878754e..41fcb7e3bfe 100644 --- a/tests/integration/scripts/test_set.py +++ b/tests/integration/scripts/test_set.py @@ -47,7 +47,7 @@ async def test_set_parentless_spawning( 'graph': {'P1': 'a => z'}, }, }) - schd = scheduler(id_, paused_start=False) + schd: Scheduler = scheduler(id_, paused_start=False) async with run(schd): # mark cycle 1 as succeeded schd.pool.set_prereqs_and_outputs( @@ -55,9 +55,7 @@ async def test_set_parentless_spawning( ) # the parentless task "a" should be spawned out to the runahead limit - assert [ - itask.identity for itask in schd.pool.get_tasks() - ] == ['2/a', '3/a'] + assert schd.pool.get_task_ids() == {'2/a', '3/a'} # the workflow should run on to the next cycle await complete(schd, '2/a', timeout=5) @@ -151,9 +149,9 @@ async def test_incomplete_detection( ): """It should detect and log finished tasks left with incomplete outputs.""" schd = scheduler(flow(one_conf)) - async with start(schd) as log: + async with start(schd): schd.pool.set_prereqs_and_outputs(['1/one'], ['failed'], None, ['1']) - assert log_filter(log, contains='1/one did not complete') + assert log_filter(contains='1/one did not complete') async def test_pre_all(flow, scheduler, run): diff --git a/tests/integration/scripts/test_validate_integration.py b/tests/integration/scripts/test_validate_integration.py index 093d70d899c..4adabf7995f 100644 --- a/tests/integration/scripts/test_validate_integration.py +++ b/tests/integration/scripts/test_validate_integration.py @@ -174,8 +174,8 @@ def test_graph_upgrade_msg_default(flow, validate, caplog, log_filter): }, }) validate(id_) - assert log_filter(caplog, contains='[scheduling][dependencies][X]graph') - assert log_filter(caplog, contains='for X in:\n P1Y, R1') + assert log_filter(contains='[scheduling][dependencies][X]graph') + assert log_filter(contains='for X in:\n P1Y, R1') def test_graph_upgrade_msg_graph_equals(flow, validate, caplog, log_filter): @@ -192,8 +192,8 @@ def test_graph_upgrade_msg_graph_equals(flow, validate, caplog, log_filter): }) validate(id_) assert log_filter( - caplog, - contains='[scheduling][dependencies]graph -> [scheduling][graph]R1') + contains='[scheduling][dependencies]graph -> [scheduling][graph]R1' + ) def test_graph_upgrade_msg_graph_equals2(flow, validate, caplog, log_filter): @@ -216,7 +216,7 @@ def test_graph_upgrade_msg_graph_equals2(flow, validate, caplog, log_filter): '\n P1Y, graph' '\n ([scheduling][dependencies]graph moves to [scheduling][graph]R1)' ) - assert log_filter(caplog, contains=expect) + assert log_filter(contains=expect) def test_undefined_parent(flow, validate): @@ -248,5 +248,5 @@ def test_log_parent_demoted(flow, validate, monkeypatch, caplog, log_filter): } }) validate(id_) - assert log_filter(caplog, contains='First parent(s) demoted to secondary') - assert log_filter(caplog, contains="FOO as parent of 'foo'") + assert log_filter(contains='First parent(s) demoted to secondary') + assert log_filter(contains="FOO as parent of 'foo'") diff --git a/tests/integration/test_compat_mode.py b/tests/integration/test_compat_mode.py index f54d5fee1fb..8c8c02cd15e 100644 --- a/tests/integration/test_compat_mode.py +++ b/tests/integration/test_compat_mode.py @@ -59,7 +59,7 @@ async def test_blocked_tasks_in_n0(flow, scheduler, run, complete): assert schd.is_stalled # the "blocked" recover tasks should remain in the pool - assert {t.identity for t in schd.pool.get_tasks()} == { + assert schd.pool.get_task_ids() == { '1/recover', '2/recover', '3/recover', @@ -92,7 +92,7 @@ async def test_blocked_tasks_in_n0(flow, scheduler, run, complete): for cycle in range(1, 4): itask = schd.pool.get_task(IntegerPoint(str(cycle)), 'recover') schd.pool.remove(itask, 'suicide-trigger') - assert {t.identity for t in schd.pool.get_tasks()} == { + assert schd.pool.get_task_ids() == { '4/foo', '5/foo', '6/foo', diff --git a/tests/integration/test_config.py b/tests/integration/test_config.py index e21b77e4add..660d03eb290 100644 --- a/tests/integration/test_config.py +++ b/tests/integration/test_config.py @@ -153,7 +153,6 @@ def test_validate_param_env_templ( one_conf, validate, env_val, - caplog, log_filter, ): """It should validate parameter environment templates.""" @@ -168,8 +167,8 @@ def test_validate_param_env_templ( } }) validate(id_) - assert log_filter(caplog, contains='bad parameter environment template') - assert log_filter(caplog, contains=env_val) + assert log_filter(contains='bad parameter environment template') + assert log_filter(contains=env_val) def test_no_graph(flow, validate): @@ -291,9 +290,7 @@ def test_queue_treated_as_implicit(flow, validate, caplog, log_filter): } ) validate(id_) - assert log_filter( - caplog, - contains='Queues contain tasks not defined in runtime') + assert log_filter(contains='Queues contain tasks not defined in runtime') def test_queue_treated_as_comma_separated(flow, validate): @@ -619,7 +616,7 @@ def test_nonlive_mode_validation(flow, validate, caplog, log_filter): }) validate(wid) - assert log_filter(caplog, contains=msg1) + assert log_filter(contains=msg1) def test_skip_forbidden_as_output(flow, validate): diff --git a/tests/integration/test_data_store_mgr.py b/tests/integration/test_data_store_mgr.py index 88dd79163bb..5c42f9d352a 100644 --- a/tests/integration/test_data_store_mgr.py +++ b/tests/integration/test_data_store_mgr.py @@ -324,7 +324,7 @@ def test_delta_task_prerequisite(harness): schd: Scheduler schd, data = harness schd.pool.set_prereqs_and_outputs( - [t.identity for t in schd.pool.get_tasks()], + schd.pool.get_task_ids(), [(TASK_STATUS_SUCCEEDED,)], [], flow=[] @@ -333,9 +333,8 @@ def test_delta_task_prerequisite(harness): for itask in schd.pool.get_tasks(): # set prereqs as not-satisfied for prereq in itask.state.prerequisites: - prereq._all_satisfied = False for key in prereq: - prereq._satisfied[key] = False + prereq[key] = False schd.data_store_mgr.delta_task_prerequisite(itask) assert not any(p.satisfied for p in get_pb_prereqs(schd)) diff --git a/tests/integration/test_examples.py b/tests/integration/test_examples.py index dc3495fe39f..a0d15ee7289 100644 --- a/tests/integration/test_examples.py +++ b/tests/integration/test_examples.py @@ -23,9 +23,11 @@ import asyncio import logging from pathlib import Path + import pytest from cylc.flow import __version__ +from cylc.flow.scheduler import Scheduler async def test_create_flow(flow, run_dir): @@ -62,9 +64,9 @@ async def test_logging(flow, scheduler, start, one_conf, log_filter): # Ensure that the cylc version is logged on startup. id_ = flow(one_conf) schd = scheduler(id_) - async with start(schd) as log: + async with start(schd): # this returns a list of log records containing __version__ - assert log_filter(log, contains=__version__) + assert log_filter(contains=__version__) async def test_scheduler_arguments(flow, scheduler, start, one_conf): @@ -159,16 +161,12 @@ def killer(): # make sure that this error causes the flow to shutdown with pytest.raises(MyException): - async with run(one) as log: + async with run(one): # The `run` fixture's shutdown logic waits for the main loop to run pass # make sure the exception was logged - assert len(log_filter( - log, - level=logging.CRITICAL, - contains='mess' - )) == 1 + assert len(log_filter(logging.CRITICAL, contains='mess')) == 1 # make sure the server socket has closed - a good indication of a # successful clean shutdown @@ -290,3 +288,11 @@ async def test_reftest(flow, scheduler, reftest): ('1/a', None), ('1/b', ('1/a',)), } + + +async def test_show(one: Scheduler, start, cylc_show): + """Demonstrate the `cylc_show` fixture""" + async with start(one): + out = await cylc_show(one, '1/one') + assert list(out.keys()) == ['1/one'] + assert out['1/one']['state'] == 'waiting' diff --git a/tests/integration/test_flow_assignment.py b/tests/integration/test_flow_assignment.py index 6c0c58a8758..ea729efeb7b 100644 --- a/tests/integration/test_flow_assignment.py +++ b/tests/integration/test_flow_assignment.py @@ -27,7 +27,7 @@ FLOW_ALL, FLOW_NEW, FLOW_NONE, - stringify_flow_nums + repr_flow_nums ) from cylc.flow.scheduler import Scheduler @@ -110,7 +110,7 @@ async def test_flow_assignment( } id_ = flow(conf) schd: Scheduler = scheduler(id_, run_mode='simulation', paused_start=True) - async with start(schd) as log: + async with start(schd): if command == 'set': do_command: Callable = functools.partial( schd.pool.set_prereqs_and_outputs, outputs=['x'], prereqs=[] @@ -137,10 +137,9 @@ async def test_flow_assignment( do_command([active_a.identity], flow=[FLOW_NONE]) assert active_a.flow_nums == {1, 2} assert log_filter( - log, contains=( f'[{active_a}] ignoring \'flow=none\' {command}: ' - f'task already has {stringify_flow_nums(active_a.flow_nums)}' + f'task already has {repr_flow_nums(active_a.flow_nums)}' ), level=logging.ERROR ) diff --git a/tests/integration/test_job_runner_mgr.py b/tests/integration/test_job_runner_mgr.py index 93663aec892..39db087f7b9 100644 --- a/tests/integration/test_job_runner_mgr.py +++ b/tests/integration/test_job_runner_mgr.py @@ -28,7 +28,7 @@ async def test_kill_error(one, start, test_dir, capsys, log_filter): """It should report the failure to kill a job.""" - async with start(one) as log: + async with start(one): # make it look like the task is running itask = one.pool.get_tasks()[0] itask.submit_num += 1 @@ -78,7 +78,6 @@ async def test_kill_error(one, start, test_dir, capsys, log_filter): # a warning should be logged assert log_filter( - log, regex=r'1/one/01:running.*job kill failed', level=logging.WARNING, ) diff --git a/tests/integration/test_queues.py b/tests/integration/test_queues.py index fc94c4c4a3d..7da83e1a1aa 100644 --- a/tests/integration/test_queues.py +++ b/tests/integration/test_queues.py @@ -120,7 +120,7 @@ async def test_queue_held_tasks( # hold all tasks and resume the workflow # (nothing should have run yet because the workflow started paused) - await commands.run_cmd(commands.hold, schd, ['*/*']) + await commands.run_cmd(commands.hold(schd, ['*/*'])) schd.resume_workflow() # release queued tasks @@ -129,7 +129,7 @@ async def test_queue_held_tasks( assert len(submitted_tasks) == 0 # un-hold tasks - await commands.run_cmd(commands.release, schd, ['*/*']) + await commands.run_cmd(commands.release(schd, ['*/*'])) # release queued tasks # (tasks should now be released from the queues) diff --git a/tests/integration/test_reload.py b/tests/integration/test_reload.py index 65960ffcdb7..ad96b187722 100644 --- a/tests/integration/test_reload.py +++ b/tests/integration/test_reload.py @@ -89,7 +89,7 @@ def change_state(_=0): change_state() # reload the workflow - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) # the task should end in the submitted state assert foo.state(TASK_STATUS_SUBMITTED) @@ -127,18 +127,17 @@ async def test_reload_failure( """ id_ = flow(one_conf) schd = scheduler(id_) - async with start(schd) as log: + async with start(schd): # corrupt the config by removing the scheduling section two_conf = {**one_conf, 'scheduling': {}} flow(two_conf, id_=id_) # reload the workflow - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) # the reload should have failed but the workflow should still be # running assert log_filter( - log, contains=( 'Reload failed - WorkflowConfigError:' ' missing [scheduling][[graph]] section' diff --git a/tests/integration/test_remove.py b/tests/integration/test_remove.py new file mode 100644 index 00000000000..aab060a788f --- /dev/null +++ b/tests/integration/test_remove.py @@ -0,0 +1,480 @@ +# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE. +# Copyright (C) NIWA & British Crown (Met Office) & Contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import logging + +import pytest + +from cylc.flow.commands import ( + force_trigger_tasks, + remove_tasks, + run_cmd, +) +from cylc.flow.cycling.integer import IntegerPoint +from cylc.flow.flow_mgr import FLOW_ALL +from cylc.flow.scheduler import Scheduler +from cylc.flow.task_outputs import TASK_OUTPUT_SUCCEEDED +from cylc.flow.task_proxy import TaskProxy +from cylc.flow.task_state import TASK_STATUS_FAILED + + +@pytest.fixture +async def cylc_show_prereqs(cylc_show): + """Fixture that returns the prereq info from `cylc show` in an + easy-to-use format.""" + async def inner(schd: Scheduler, task: str): + prerequisites = (await cylc_show(schd, task))[task]['prerequisites'] + return [ + ( + p['satisfied'], + {c['taskId']: c['satisfied'] for c in p['conditions']}, + ) + for p in prerequisites + ] + + return inner + + +@pytest.fixture +def example_workflow(flow): + return flow({ + 'scheduling': { + 'graph': { + # Note: test both `&` and separate arrows for combining + # dependencies + 'R1': ''' + a1 & a2 => b + a3 => b + ''', + }, + }, + }) + + +def get_data_store_flow_nums(schd: Scheduler, itask: TaskProxy): + _, ds_tproxy = schd.data_store_mgr.store_node_fetcher(itask.tokens) + if ds_tproxy: + return ds_tproxy.flow_nums + + +async def test_basic( + example_workflow, scheduler, start, db_select +): + """Test removing a task from all flows.""" + schd: Scheduler = scheduler(example_workflow) + async with start(schd): + a1 = schd.pool._get_task_by_id('1/a1') + a3 = schd.pool._get_task_by_id('1/a3') + schd.pool.spawn_on_output(a1, TASK_OUTPUT_SUCCEEDED) + schd.pool.spawn_on_output(a3, TASK_OUTPUT_SUCCEEDED) + await schd.update_data_structure() + + assert a1 in schd.pool.get_tasks() + for table in ('task_states', 'task_outputs'): + assert db_select(schd, True, table, 'flow_nums', name='a1') == [ + ('[1]',), + ] + assert db_select( + schd, True, 'task_prerequisites', 'satisfied', prereq_name='a1' + ) == [ + ('satisfied naturally',), + ] + assert get_data_store_flow_nums(schd, a1) == '[1]' + + await run_cmd(remove_tasks(schd, ['1/a1'], [FLOW_ALL])) + await schd.update_data_structure() + + assert a1 not in schd.pool.get_tasks() # removed from pool + for table in ('task_states', 'task_outputs'): + assert db_select(schd, True, table, 'flow_nums', name='a1') == [ + ('[]',), # removed from all flows + ] + assert db_select( + schd, True, 'task_prerequisites', 'satisfied', prereq_name='a1' + ) == [ + ('0',), # prereq is now unsatisfied + ] + assert get_data_store_flow_nums(schd, a1) == '[]' + + +async def test_specific_flow( + example_workflow, scheduler, start, db_select +): + """Test removing a task from a specific flow.""" + schd: Scheduler = scheduler(example_workflow) + + def select_prereqs(): + return db_select( + schd, + True, + 'task_prerequisites', + 'flow_nums', + 'satisfied', + prereq_name='a1', + ) + + async with start(schd): + a1 = schd.pool._get_task_by_id('1/a1') + schd.pool.force_trigger_tasks(['1/a1'], ['1', '2']) + schd.pool.spawn_on_output(a1, TASK_OUTPUT_SUCCEEDED) + await schd.update_data_structure() + + assert a1 in schd.pool.get_tasks() + assert a1.flow_nums == {1, 2} + for table in ('task_states', 'task_outputs'): + assert sorted( + db_select(schd, True, table, 'flow_nums', name='a1') + ) == [ + ('[1, 2]',), # triggered task + ('[1]',), # original spawned task + ] + assert select_prereqs() == [ + ('[1, 2]', 'satisfied naturally'), + ] + assert get_data_store_flow_nums(schd, a1) == '[1, 2]' + + await run_cmd(remove_tasks(schd, ['1/a1'], ['1'])) + await schd.update_data_structure() + + assert a1 in schd.pool.get_tasks() # still in pool + assert a1.flow_nums == {2} + for table in ('task_states', 'task_outputs'): + assert sorted( + db_select(schd, True, table, 'flow_nums', name='a1') + ) == [ + ('[2]',), + ('[]',), + ] + assert select_prereqs() == [ + ('[1, 2]', '0'), + ] + assert get_data_store_flow_nums(schd, a1) == '[2]' + + +async def test_unset_prereq(example_workflow, scheduler, start): + """Test removing a task unsets any prerequisites it satisfied.""" + schd: Scheduler = scheduler(example_workflow) + async with start(schd): + for task in ('a1', 'a2', 'a3'): + schd.pool.spawn_on_output( + schd.pool.get_task(IntegerPoint('1'), task), + TASK_OUTPUT_SUCCEEDED, + ) + b = schd.pool.get_task(IntegerPoint('1'), 'b') + assert b.prereqs_are_satisfied() + + await run_cmd(remove_tasks(schd, ['1/a1'], [FLOW_ALL])) + + assert not b.prereqs_are_satisfied() + + +async def test_not_unset_prereq( + example_workflow, scheduler, start, db_select +): + """Test removing a task does not unset a force-satisfied prerequisite + (one that was satisfied by `cylc set --pre`).""" + schd: Scheduler = scheduler(example_workflow) + async with start(schd): + # This set prereq should not be unset by removing a1: + schd.pool.set_prereqs_and_outputs( + ['1/b'], outputs=[], prereqs=['1/a1'], flow=[FLOW_ALL] + ) + # Whereas the prereq satisfied by this set output *should* be unset + # by removing a2: + schd.pool.set_prereqs_and_outputs( + ['1/a2'], outputs=['succeeded'], prereqs=[], flow=[FLOW_ALL] + ) + await schd.update_data_structure() + + assert sorted( + db_select( + schd, True, 'task_prerequisites', 'prereq_name', 'satisfied' + ) + ) == [ + ('a1', 'force satisfied'), + ('a2', 'satisfied naturally'), + ('a3', '0'), + ] + + await run_cmd(remove_tasks(schd, ['1/a1', '1/a2'], [FLOW_ALL])) + await schd.update_data_structure() + + assert sorted( + db_select( + schd, True, 'task_prerequisites', 'prereq_name', 'satisfied' + ) + ) == [ + ('a1', 'force satisfied'), + ('a2', '0'), + ('a3', '0'), + ] + + +async def test_nothing_to_do( + example_workflow, scheduler, start, log_filter +): + """Test removing an invalid task.""" + schd: Scheduler = scheduler(example_workflow) + async with start(schd): + await run_cmd(remove_tasks(schd, ['1/doh'], [FLOW_ALL])) + assert log_filter(logging.WARNING, "No matching tasks found: doh") + + +async def test_logging( + flow, scheduler, start, log_filter, caplog: pytest.LogCaptureFixture +): + """Test logging of a mixture of valid and invalid task removals.""" + schd: Scheduler = scheduler( + flow({ + 'scheduler': { + 'cycle point format': 'CCYY', + }, + 'scheduling': { + 'initial cycle point': '2000', + 'graph': { + 'R3//P1Y': 'b[-P1Y] => a & b', + }, + }, + }) + ) + tasks_to_remove = [ + # Active, removable tasks: + '2000/*', + # Future, non-removable tasks: + '2001/a', '2001/b', + # Glob that doesn't match any active tasks: + '2002/*', + # Invalid tasks: + '2005/a', '2000/doh', + ] + async with start(schd): + await run_cmd(remove_tasks(schd, tasks_to_remove, [FLOW_ALL])) + + assert log_filter( + logging.INFO, "Removed task(s): 2000/a (flows=1), 2000/b (flows=1)" + ) + + assert log_filter(logging.WARNING, "Task(s) not removable: 2001/a, 2001/b") + assert log_filter(logging.WARNING, "No active tasks matching: 2002/*") + assert log_filter(logging.WARNING, "Invalid cycle point for task: a, 2005") + assert log_filter(logging.WARNING, "No matching tasks found: doh") + # No tasks were submitted/running so none should have been killed: + assert "job killed" not in caplog.text + + +async def test_logging_flow_nums( + example_workflow, scheduler, start, log_filter +): + """Test logging of task removals involving flow numbers.""" + schd: Scheduler = scheduler(example_workflow) + async with start(schd): + schd.pool.force_trigger_tasks(['1/a1'], ['1', '2']) + # Removing from flow that doesn't exist doesn't work: + await run_cmd(remove_tasks(schd, ['1/a1'], ['3'])) + assert log_filter( + logging.WARNING, "Task(s) not removable: 1/a1 (flows=3)" + ) + + # But if a valid flow is included, it will be removed from that flow: + await run_cmd(remove_tasks(schd, ['1/a1'], ['2', '3'])) + assert log_filter(logging.INFO, "Removed task(s): 1/a1 (flows=2)") + assert schd.pool._get_task_by_id('1/a1').flow_nums == {1} + + +async def test_retrigger(flow, scheduler, run, reflog, complete): + """Test prereqs & re-run behaviour when removing tasks.""" + schd: Scheduler = scheduler( + flow('a => b => c'), + paused_start=False, + ) + async with run(schd): + reflog_triggers: set = reflog(schd) + await complete(schd, '1/b') + + await run_cmd(remove_tasks(schd, ['1/a', '1/b'], [FLOW_ALL])) + schd.process_workflow_db_queue() + # Removing 1/b should un-queue 1/c: + assert len(schd.pool.task_queue_mgr.queues['default'].deque) == 0 + + assert reflog_triggers == { + ('1/a', None), + ('1/b', ('1/a',)), + } + reflog_triggers.clear() + + await run_cmd(force_trigger_tasks(schd, ['1/a'], [])) + await complete(schd) + + assert reflog_triggers == { + ('1/a', None), + # 1/b should have run again after 1/a on the re-trigger in flow 1: + ('1/b', ('1/a',)), + ('1/c', ('1/b',)), + } + + +async def test_prereqs( + flow, scheduler, run, complete, cylc_show_prereqs, log_filter +): + """Test prereqs & stall behaviour when removing tasks.""" + schd: Scheduler = scheduler( + flow('(a1 | a2) & b => x'), + paused_start=False, + ) + async with run(schd): + await complete(schd, '1/a1', '1/a2', '1/b') + + await run_cmd(remove_tasks(schd, ['1/a1'], [FLOW_ALL])) + assert not schd.pool.is_stalled() + assert len(schd.pool.task_queue_mgr.queues['default'].deque) + # `cylc show` should reflect the now-unsatisfied condition: + assert await cylc_show_prereqs(schd, '1/x') == [ + (True, {'1/a1': False, '1/a2': True, '1/b': True}) + ] + + await run_cmd(remove_tasks(schd, ['1/b'], [FLOW_ALL])) + # Should cause stall now because 1/c prereq is unsatisfied: + assert len(schd.pool.task_queue_mgr.queues['default'].deque) == 0 + assert schd.pool.is_stalled() + assert log_filter( + logging.WARNING, + "1/x is waiting on ['1/a1:succeeded', '1/b:succeeded']", + ) + assert await cylc_show_prereqs(schd, '1/x') == [ + (False, {'1/a1': False, '1/a2': True, '1/b': False}) + ] + + assert schd.pool._get_task_by_id('1/x') + await run_cmd(remove_tasks(schd, ['1/a2'], [FLOW_ALL])) + # Should cause 1/x to be removed from the pool as it no longer has + # any satisfied prerequisite tasks: + assert not schd.pool._get_task_by_id('1/x') + assert log_filter( + logging.INFO, + regex=r"1/x.* removed .* prerequisite task\(s\) removed", + ) + + +async def test_downstream_preparing(flow, scheduler, start): + """Downstream dependents should not be removed if they are already + preparing.""" + schd: Scheduler = scheduler( + flow(''' + a => x + a => y + '''), + ) + async with start(schd): + a = schd.pool._get_task_by_id('1/a') + schd.pool.spawn_on_output(a, TASK_OUTPUT_SUCCEEDED) + assert schd.pool.get_task_ids() == {'1/a', '1/x', '1/y'} + + schd.pool._get_task_by_id('1/y').state_reset('preparing') + await run_cmd(remove_tasks(schd, ['1/a'], [FLOW_ALL])) + assert schd.pool.get_task_ids() == {'1/y'} + + +async def test_downstream_other_flows(flow, scheduler, run, complete): + """Downstream dependents should not be removed if they exist in other + flows.""" + schd: Scheduler = scheduler( + flow(''' + a => b => c => x + a => x + '''), + paused_start=False, + ) + async with run(schd): + await complete(schd, '1/a') + schd.pool.force_trigger_tasks(['1/c'], ['2']) + c = schd.pool._get_task_by_id('1/c') + schd.pool.spawn_on_output(c, TASK_OUTPUT_SUCCEEDED) + assert schd.pool._get_task_by_id('1/x').flow_nums == {1, 2} + + await run_cmd(remove_tasks(schd, ['1/c'], ['2'])) + assert schd.pool.get_task_ids() == {'1/b', '1/x'} + # Note: in future we might want to remove 1/x from flow 2 as well, to + # maintain flow continuity. However it is tricky at the moment because + # other prerequisite tasks could exist in flow 2 (we don't know as + # prereqs do not hold flow info other than in the DB). + assert schd.pool._get_task_by_id('1/x').flow_nums == {1, 2} + + +async def test_suicide(flow, scheduler, run, reflog, complete): + """Test that suicide prereqs are unset by `cylc remove`.""" + schd: Scheduler = scheduler( + flow(''' + a => b => c => d => x + a & c => !x + '''), + paused_start=False, + ) + async with run(schd): + reflog_triggers: set = reflog(schd) + await complete(schd, '1/b') + await run_cmd(remove_tasks(schd, ['1/a'], [FLOW_ALL])) + await complete(schd) + + assert reflog_triggers == { + ('1/a', None), + ('1/b', ('1/a',)), + ('1/c', ('1/b',)), + ('1/d', ('1/c',)), + # 1/x not suicided as 1/a was removed: + ('1/x', ('1/d',)), + } + + +async def test_kill_running(flow, scheduler, run, complete, reflog): + """Test removing a running task should kill it. + + Note this only tests simulation mode and a separate test for live mode + exists in tests/functional/cylc-remove. + """ + schd: Scheduler = scheduler( + flow({ + 'scheduling': { + 'graph': { + 'R1': ''' + a:started => b => c + a:failed => q + ''' + }, + }, + 'runtime': { + 'a': { + 'simulation': { + 'default run length': 'PT30S' + }, + }, + }, + }), + paused_start=False, + ) + async with run(schd): + reflog_triggers = reflog(schd) + await complete(schd, '1/b') + a = schd.pool._get_task_by_id('1/a') + await run_cmd(remove_tasks(schd, ['1/a'], [FLOW_ALL])) + assert a.state(TASK_STATUS_FAILED, is_held=True) + await complete(schd) + + assert reflog_triggers == { + ('1/a', None), + ('1/b', ('1/a',)), + ('1/c', ('1/b',)), + # The a:failed output should not cause 1/q to run + } diff --git a/tests/integration/test_resolvers.py b/tests/integration/test_resolvers.py index 981237a4d2a..4fa21dadbb9 100644 --- a/tests/integration/test_resolvers.py +++ b/tests/integration/test_resolvers.py @@ -234,7 +234,7 @@ async def test_command_logging(mock_flow, caplog, log_filter): {'mode': StopMode.REQUEST_CLEAN.value}, meta, ) - assert log_filter(caplog, contains='Command "stop" received') + assert log_filter(contains='Command "stop" received') # put_messages: only log for owner kwargs = { @@ -244,12 +244,11 @@ async def test_command_logging(mock_flow, caplog, log_filter): } meta["auth_user"] = mock_flow.owner await mock_flow.resolvers._mutation_mapper("put_messages", kwargs, meta) - assert not log_filter(caplog, contains='Command "put_messages" received:') + assert not log_filter(contains='Command "put_messages" received:') meta["auth_user"] = "Dr Spock" await mock_flow.resolvers._mutation_mapper("put_messages", kwargs, meta) - assert log_filter( - caplog, contains='Command "put_messages" received from Dr Spock') + assert log_filter(contains='Command "put_messages" received from Dr Spock') async def test_command_validation_failure( diff --git a/tests/integration/test_scheduler.py b/tests/integration/test_scheduler.py index 79b7327206a..6f1f581e899 100644 --- a/tests/integration/test_scheduler.py +++ b/tests/integration/test_scheduler.py @@ -174,7 +174,7 @@ async def test_holding_tasks_whilst_scheduler_paused( assert submitted_tasks == set() # hold all tasks & resume the workflow - await commands.run_cmd(commands.hold, one, ['*/*']) + await commands.run_cmd(commands.hold(one, ['*/*'])) one.resume_workflow() # release queued tasks @@ -183,7 +183,7 @@ async def test_holding_tasks_whilst_scheduler_paused( assert submitted_tasks == set() # release all tasks - await commands.run_cmd(commands.release, one, ['*/*']) + await commands.run_cmd(commands.release(one, ['*/*'])) # release queued tasks # (the task should be submitted) @@ -219,12 +219,12 @@ async def test_no_poll_waiting_tasks( polled_tasks = capture_polling(one) # Waiting tasks should not be polled. - await commands.run_cmd(commands.poll_tasks, one, ['*/*']) + await commands.run_cmd(commands.poll_tasks(one, ['*/*'])) assert polled_tasks == set() # Even if they have a submit number. task.submit_num = 1 - await commands.run_cmd(commands.poll_tasks, one, ['*/*']) + await commands.run_cmd(commands.poll_tasks(one, ['*/*'])) assert len(polled_tasks) == 0 # But these states should be: @@ -235,7 +235,7 @@ async def test_no_poll_waiting_tasks( TASK_STATUS_RUNNING ]: task.state.status = state - await commands.run_cmd(commands.poll_tasks, one, ['*/*']) + await commands.run_cmd(commands.poll_tasks(one, ['*/*'])) assert len(polled_tasks) == 1 polled_tasks.clear() @@ -267,7 +267,7 @@ def raise_ParsecError(*a, **k): pass assert log_filter( - log, level=logging.CRITICAL, + logging.CRITICAL, exact_match="Workflow shutting down - Mock error" ) assert TRACEBACK_MSG in log.text @@ -295,7 +295,7 @@ def mock_auto_restart(*a, **k): async with run(one) as log: pass - assert log_filter(log, level=logging.ERROR, contains=err_msg) + assert log_filter(logging.ERROR, err_msg) assert TRACEBACK_MSG in log.text @@ -361,20 +361,19 @@ async def test_restart_timeout( # restart the completed workflow schd = scheduler(id_) - async with run(schd) as log: + async with run(schd): # it should detect that the workflow has completed and alert the user assert log_filter( - log, contains='This workflow already ran to completion.' ) # it should activate a timeout - assert log_filter(log, contains='restart timer starts NOW') + assert log_filter(contains='restart timer starts NOW') # when we trigger tasks the timeout should be cleared schd.pool.force_trigger_tasks(['1/one'], {1}) await asyncio.sleep(0) # yield control to the main loop - assert log_filter(log, contains='restart timer stopped') + assert log_filter(contains='restart timer stopped') @pytest.mark.parametrize("signal", ((SIGHUP), (SIGINT), (SIGTERM))) @@ -387,14 +386,14 @@ async def test_signal_escallation(one, start, signal, log_filter): See https://github.com/cylc/cylc-flow/pull/6444 """ - async with start(one) as log: + async with start(one): # put the workflow in the stopping state one._set_stop(StopMode.REQUEST_CLEAN) assert one.stop_mode.name == 'REQUEST_CLEAN' # one signal should escalate this from CLEAN to NOW one._handle_signal(signal, None) - assert log_filter(log, contains=signal.name) + assert log_filter(contains=signal.name) assert one.stop_mode.name == 'REQUEST_NOW' # two signals should escalate this from NOW to NOW_NOW diff --git a/tests/integration/test_sequential_xtriggers.py b/tests/integration/test_sequential_xtriggers.py index 4ef563a1c53..3e9eea779c1 100644 --- a/tests/integration/test_sequential_xtriggers.py +++ b/tests/integration/test_sequential_xtriggers.py @@ -49,7 +49,7 @@ def sequential(flow, scheduler): return scheduler(id_) -async def test_remove(sequential, start): +async def test_remove(sequential: Scheduler, start): """It should spawn the next instance when a task is removed. Ensure that removing a task with a sequential xtrigger does not break the @@ -74,7 +74,7 @@ async def test_remove(sequential, start): ] # remove all tasks in the pool - sequential.pool.remove_tasks(['*']) + sequential.remove_tasks(['*']) # the next cycle should be automatically spawned assert list_cycles(sequential) == ['2004'] diff --git a/tests/integration/test_stop_after_cycle_point.py b/tests/integration/test_stop_after_cycle_point.py index f92e8d449f0..90bab288515 100644 --- a/tests/integration/test_stop_after_cycle_point.py +++ b/tests/integration/test_stop_after_cycle_point.py @@ -119,10 +119,11 @@ def get_db_value(schd) -> Optional[str]: # override this value whilst the workflow is running await commands.run_cmd( - commands.stop, - schd, - cycle_point=IntegerPoint('4'), - mode=StopMode.REQUEST_CLEAN, + commands.stop( + schd, + cycle_point=IntegerPoint('4'), + mode=StopMode.REQUEST_CLEAN, + ) ) assert schd.config.stop_point == IntegerPoint('4') diff --git a/tests/integration/test_subprocctx.py b/tests/integration/test_subprocctx.py index c4d0c4ca28a..871106dcd53 100644 --- a/tests/integration/test_subprocctx.py +++ b/tests/integration/test_subprocctx.py @@ -47,7 +47,7 @@ def myxtrigger(): return True, {} """)) schd = scheduler(id_) - async with start(schd, level=DEBUG) as log: + async with start(schd, level=DEBUG): # Set off check for x-trigger: task = schd.pool.get_tasks()[0] schd.xtrigger_mgr.call_xtriggers_async(task) @@ -59,4 +59,4 @@ def myxtrigger(): # Assert that both stderr and out from the print statement # in our xtrigger appear in the log. for expected in ['Hello World', 'Hello Hades']: - assert log_filter(log, contains=expected, level=DEBUG) + assert log_filter(DEBUG, expected) diff --git a/tests/integration/test_task_job_mgr.py b/tests/integration/test_task_job_mgr.py index 48a49eb30aa..b1cf1347071 100644 --- a/tests/integration/test_task_job_mgr.py +++ b/tests/integration/test_task_job_mgr.py @@ -23,7 +23,6 @@ from cylc.flow.task_state import TASK_STATUS_RUNNING - async def test_run_job_cmd_no_hosts_error( flow, scheduler, @@ -92,7 +91,6 @@ async def test_run_job_cmd_no_hosts_error( # ...but the failure should be logged assert log_filter( - log, contains='No available hosts for no-host-platform', ) log.clear() @@ -105,7 +103,6 @@ async def test_run_job_cmd_no_hosts_error( # ...but the failure should be logged assert log_filter( - log, contains='No available hosts for no-host-platform', ) @@ -217,7 +214,7 @@ async def test_broadcast_platform_change( schd = scheduler(id_, run_mode='live') - async with start(schd) as log: + async with start(schd): # Change the task platform with broadcast: schd.broadcast_mgr.put_broadcast( ['1'], ['mytask'], [{'platform': 'foo'}]) @@ -235,4 +232,4 @@ async def test_broadcast_platform_change( # Check that task platform hasn't become "localhost": assert schd.pool.get_tasks()[0].platform['name'] == 'foo' # ... and that remote init failed because all hosts bad: - assert log_filter(log, contains="(no hosts were reachable)") + assert log_filter(regex=r"platform: foo .*\(no hosts were reachable\)") diff --git a/tests/integration/test_task_pool.py b/tests/integration/test_task_pool.py index e4aa6602f4d..404ec8da87f 100644 --- a/tests/integration/test_task_pool.py +++ b/tests/integration/test_task_pool.py @@ -63,9 +63,6 @@ # immediately too, because we spawn autospawn absolute-triggered tasks as # well as parentless tasks. 3/asd does not spawn at start, however. EXAMPLE_FLOW_CFG = { - 'scheduler': { - 'allow implicit tasks': True - }, 'scheduling': { 'cycling mode': 'integer', 'initial cycle point': 1, @@ -86,7 +83,6 @@ EXAMPLE_FLOW_2_CFG = { 'scheduler': { - 'allow implicit tasks': True, 'UTC mode': True }, 'scheduling': { @@ -142,7 +138,7 @@ def assert_expected_log( @pytest.fixture(scope='module') async def mod_example_flow( mod_flow: Callable, mod_scheduler: Callable, mod_run: Callable -) -> 'Scheduler': +) -> AsyncGenerator['Scheduler', None]: """Return a scheduler for interrogating its task pool. This is module-scoped so faster than example_flow, but should only be used @@ -178,7 +174,7 @@ async def example_flow( @pytest.fixture(scope='module') async def mod_example_flow_2( mod_flow: Callable, mod_scheduler: Callable, mod_run: Callable -) -> 'Scheduler': +) -> AsyncGenerator['Scheduler', None]: """Return a scheduler for interrogating its task pool. This is module-scoped so faster than example_flow, but should only be used @@ -570,7 +566,7 @@ async def test_reload_stopcp( schd: 'Scheduler' = scheduler(flow(cfg)) async with start(schd): assert str(schd.pool.stop_point) == '2020' - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) assert str(schd.pool.stop_point) == '2020' @@ -584,11 +580,11 @@ async def test_runahead_after_remove( assert int(task_pool.runahead_limit_point) == 4 # No change after removing an intermediate cycle. - task_pool.remove_tasks(['3/*']) + example_flow.remove_tasks(['3/*']) assert int(task_pool.runahead_limit_point) == 4 # Should update after removing the first point. - task_pool.remove_tasks(['1/*']) + example_flow.remove_tasks(['1/*']) assert int(task_pool.runahead_limit_point) == 5 @@ -841,7 +837,7 @@ async def test_reload_prereqs( flow(conf, id_=id_) # Reload the workflow config - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) assert list_tasks(schd) == expected_3 # Check resulting dependencies of task z @@ -973,7 +969,7 @@ async def test_graph_change_prereq_satisfaction( flow(conf, id_=id_) # Reload the workflow config - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) await test.asend(schd) @@ -1183,9 +1179,6 @@ async def test_detect_incomplete_tasks( TASK_STATUS_SUBMIT_FAILED: TaskEventsManager.EVENT_SUBMIT_FAILED } id_ = flow({ - 'scheduler': { - 'allow implicit tasks': 'True', - }, 'scheduling': { 'graph': { # a workflow with one task for each of the final task states @@ -1206,7 +1199,6 @@ async def test_detect_incomplete_tasks( # ensure that it is correctly identified as incomplete assert not itask.state.outputs.is_complete() assert log_filter( - log, contains=( f"[{itask}] did not complete the required outputs:" ), @@ -1228,9 +1220,6 @@ async def test_future_trigger_final_point( """ id_ = flow( { - 'scheduler': { - 'allow implicit tasks': 'True', - }, 'scheduling': { 'cycling mode': 'integer', 'initial cycle point': 1, @@ -1246,7 +1235,6 @@ async def test_future_trigger_final_point( for itask in schd.pool.get_tasks(): schd.pool.spawn_on_output(itask, "succeeded") assert log_filter( - log, regex=( ".*1/baz.*not spawned: a prerequisite is beyond" r" the workflow stop point \(1\)" @@ -1271,17 +1259,17 @@ async def test_set_failed_complete( schd.pool.task_events_mgr.process_message(one, 1, "failed") assert log_filter( - log, regex="1/one.* setting implied output: submitted") + regex="1/one.* setting implied output: submitted") assert log_filter( - log, regex="1/one.* setting implied output: started") + regex="1/one.* setting implied output: started") assert log_filter( - log, regex="failed.* did not complete the required outputs") + regex="failed.* did not complete the required outputs") # Set failed task complete via default "set" args. schd.pool.set_prereqs_and_outputs([one.identity], None, None, ['all']) assert log_filter( - log, contains=f'[{one}] removed from active task pool: completed') + contains=f'[{one}] removed from active task pool: completed') db_outputs = db_select( schd, True, 'task_outputs', 'outputs', @@ -1305,9 +1293,6 @@ async def test_set_prereqs( """ id_ = flow( { - 'scheduler': { - 'allow implicit tasks': 'True', - }, 'scheduling': { 'initial cycle point': '2040', 'graph': { @@ -1339,7 +1324,8 @@ async def test_set_prereqs( schd.pool.set_prereqs_and_outputs( ["20400101T0000Z/qux"], None, ["20400101T0000Z/foo:a"], ['all']) assert log_filter( - log, contains='20400101T0000Z/qux does not depend on "20400101T0000Z/foo:a"') + contains='20400101T0000Z/qux does not depend on "20400101T0000Z/foo:a"' + ) # it should not add 20400101T0000Z/qux to the pool assert ( @@ -1390,7 +1376,6 @@ async def test_set_bad_prereqs( """ id_ = flow({ 'scheduler': { - 'allow implicit tasks': 'True', 'cycle point format': '%Y'}, 'scheduling': { 'initial cycle point': '2040', @@ -1406,11 +1391,11 @@ def set_prereqs(prereqs): async with start(schd) as log: # Invalid: task name wildcard: set_prereqs(["2040/*"]) - assert log_filter(log, contains='Invalid prerequisite task name') + assert log_filter(contains='Invalid prerequisite task name') # Invalid: cycle point wildcard. set_prereqs(["*/foo"]) - assert log_filter(log, contains='Invalid prerequisite cycle point') + assert log_filter(contains='Invalid prerequisite cycle point') async def test_set_outputs_live( @@ -1424,9 +1409,6 @@ async def test_set_outputs_live( """ id_ = flow( { - 'scheduler': { - 'allow implicit tasks': 'True', - }, 'scheduling': { 'graph': { 'R1': """ @@ -1476,15 +1458,12 @@ async def test_set_outputs_live( ) # it should complete implied outputs (submitted, started) too - assert log_filter( - log, contains="setting implied output: submitted") - assert log_filter( - log, contains="setting implied output: started") + assert log_filter(contains="setting implied output: submitted") + assert log_filter(contains="setting implied output: started") # set foo (default: all required outputs) to complete y. schd.pool.set_prereqs_and_outputs(["1/foo"], None, None, ['all']) - assert log_filter( - log, contains="output 1/foo:succeeded completed") + assert log_filter(contains="output 1/foo:succeeded completed") assert ( pool_get_task_ids(schd.pool) == ["1/bar", "1/baz"] ) @@ -1501,7 +1480,6 @@ async def test_set_outputs_live2( """ id_ = flow( { - 'scheduler': {'allow implicit tasks': 'True'}, 'scheduling': {'graph': { 'R1': """ foo:a => apple @@ -1517,7 +1495,6 @@ async def test_set_outputs_live2( async with start(schd) as log: schd.pool.set_prereqs_and_outputs(["1/foo"], None, None, ['all']) assert not log_filter( - log, contains="did not complete required outputs: ['a', 'b']" ) @@ -1533,9 +1510,6 @@ async def test_set_outputs_future( """ id_ = flow( { - 'scheduler': { - 'allow implicit tasks': 'True', - }, 'scheduling': { 'graph': { 'R1': "a:x & a:y => b => c" @@ -1571,9 +1545,9 @@ async def test_set_outputs_future( prereqs=None, flow=['all'] ) - assert log_filter(log, contains="output 1/a:cheese not found") - assert log_filter(log, contains="completed output x") - assert log_filter(log, contains="completed output y") + assert log_filter(contains="output 1/a:cheese not found") + assert log_filter(contains="completed output x") + assert log_filter(contains="completed output y") async def test_set_outputs_from_skip_settings( @@ -1617,7 +1591,7 @@ async def test_set_outputs_from_skip_settings( validate(id_) schd = scheduler(id_) - async with start(schd) as log: + async with start(schd): # it should start up with just tasks a: assert pool_get_task_ids(schd.pool) == ['1/a', '2/a'] @@ -1632,7 +1606,7 @@ async def test_set_outputs_from_skip_settings( # Check that the presence of "skip" in outputs doesn't # trigger a warning: - assert not log_filter(log, level=30) + assert not log_filter(level=30) # You should be able to set skip as part of a list of outputs: schd.pool.set_prereqs_and_outputs( @@ -1656,9 +1630,6 @@ async def test_prereq_satisfaction( """ id_ = flow( { - 'scheduler': { - 'allow implicit tasks': 'True', - }, 'scheduling': { 'graph': { 'R1': "a:x & a:y => b" @@ -1674,8 +1645,8 @@ async def test_prereq_satisfaction( } } ) - schd = scheduler(id_) - async with start(schd) as log: + schd: Scheduler = scheduler(id_) + async with start(schd): # it should start up with just 1/a assert pool_get_task_ids(schd.pool) == ["1/a"] # spawn b @@ -1686,21 +1657,19 @@ async def test_prereq_satisfaction( b = schd.pool.get_task(IntegerPoint("1"), "b") - assert not b.is_waiting_prereqs_done() + assert not b.prereqs_are_satisfied() # set valid and invalid prerequisites, by label and message. schd.pool.set_prereqs_and_outputs( prereqs=["1/a:xylophone", "1/a:y", "1/a:w", "1/a:z"], items=["1/b"], outputs=None, flow=['all'] ) - assert log_filter(log, contains="1/a:z not found") - assert log_filter(log, contains="1/a:w not found") - assert not log_filter(log, contains='1/b does not depend on "1/a:x"') - assert not log_filter( - log, contains='1/b does not depend on "1/a:xylophone"') - assert not log_filter(log, contains='1/b does not depend on "1/a:y"') + assert log_filter(contains="1/a:z not found") + assert log_filter(contains="1/a:w not found") + # FIXME: testing that something is *not* logged is extremely fragile: + assert not log_filter(regex='.*does not depend on.*') - assert b.is_waiting_prereqs_done() + assert b.prereqs_are_satisfied() @pytest.mark.parametrize('compat_mode', ['compat-mode', 'normal-mode']) @@ -1979,7 +1948,6 @@ async def test_fast_respawn( async def test_remove_active_task( example_flow: 'Scheduler', - caplog: pytest.LogCaptureFixture, log_filter: Callable, ) -> None: """Test warning on removing an active task.""" @@ -1994,7 +1962,6 @@ async def test_remove_active_task( assert foo not in task_pool.get_tasks() assert log_filter( - caplog, regex=( "1/foo.*removed from active task pool:" " request - active job orphaned" @@ -2016,7 +1983,6 @@ async def test_remove_by_suicide( * Removing a task manually (cylc remove) should work the same. """ id_ = flow({ - 'scheduler': {'allow implicit tasks': 'True'}, 'scheduling': { 'graph': { 'R1': ''' @@ -2035,7 +2001,6 @@ async def test_remove_by_suicide( # mark 1/a as failed and ensure 1/b is removed by suicide trigger schd.pool.spawn_on_output(a, TASK_OUTPUT_FAILED) assert log_filter( - log, regex="1/b.*removed from active task pool: suicide trigger" ) assert pool_get_task_ids(schd.pool) == ["1/a"] @@ -2044,14 +2009,14 @@ async def test_remove_by_suicide( log.clear() schd.pool.force_trigger_tasks(['1/b'], ['1']) assert log_filter( - log, regex='1/b.*added to active task pool', ) # remove 1/b by request (cylc remove) - await commands.run_cmd(commands.remove_tasks, schd, ['1/b']) + await commands.run_cmd( + commands.remove_tasks(schd, ['1/b'], [FLOW_ALL]) + ) assert log_filter( - log, regex='1/b.*removed from active task pool: request', ) @@ -2059,55 +2024,10 @@ async def test_remove_by_suicide( log.clear() schd.pool.force_trigger_tasks(['1/b'], ['1']) assert log_filter( - log, regex='1/b.*added to active task pool', ) -async def test_remove_no_respawn(flow, scheduler, start, log_filter): - """Ensure that removed tasks stay removed. - - If a task is removed by suicide trigger or "cylc remove", then it should - not be automatically spawned at a later time. - """ - id_ = flow({ - 'scheduling': { - 'graph': { - 'R1': 'a & b => z', - }, - }, - }) - schd: 'Scheduler' = scheduler(id_) - async with start(schd, level=logging.DEBUG) as log: - a1 = schd.pool.get_task(IntegerPoint("1"), "a") - b1 = schd.pool.get_task(IntegerPoint("1"), "b") - assert a1, '1/a should have been spawned on startup' - assert b1, '1/b should have been spawned on startup' - - # mark one of the upstream tasks as succeeded, 1/z should spawn - schd.pool.spawn_on_output(a1, TASK_OUTPUT_SUCCEEDED) - schd.workflow_db_mgr.process_queued_ops() - z1 = schd.pool.get_task(IntegerPoint("1"), "z") - assert z1, '1/z should have been spawned after 1/a succeeded' - - # manually remove 1/z, it should be removed from the pool - await commands.run_cmd(commands.remove_tasks, schd, ['1/z']) - schd.workflow_db_mgr.process_queued_ops() - z1 = schd.pool.get_task(IntegerPoint("1"), "z") - assert z1 is None, '1/z should have been removed (by request)' - - # mark the other upstream task as succeeded, 1/z should not be - # respawned as a result - schd.pool.spawn_on_output(b1, TASK_OUTPUT_SUCCEEDED) - assert log_filter( - log, contains='Not respawning 1/z - task was removed' - ) - z1 = schd.pool.get_task(IntegerPoint("1"), "z") - assert ( - z1 is None - ), '1/z should have stayed removed (but has been added back into the pool' - - async def test_set_future_flow(flow, scheduler, start, log_filter): """Manually-set outputs for new flow num must be recorded in the DB. @@ -2239,7 +2159,7 @@ async def list_data_store(): # reload flow(config, id_=id_) - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) # check xtrigs post-reload assert list_xtrig_mgr() == { diff --git a/tests/integration/test_workflow_db_mgr.py b/tests/integration/test_workflow_db_mgr.py index 774d8c21fac..b994f88377f 100644 --- a/tests/integration/test_workflow_db_mgr.py +++ b/tests/integration/test_workflow_db_mgr.py @@ -34,12 +34,12 @@ async def test(expected_restart_num: int, do_reload: bool = False): """(Re)start the workflow and check the restart number is as expected. """ schd: 'Scheduler' = scheduler(id_, paused_start=True) - async with start(schd) as log: + async with start(schd): if do_reload: - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) assert schd.workflow_db_mgr.n_restart == expected_restart_num assert log_filter( - log, contains=f"(re)start number={expected_restart_num + 1}" + contains=f"(re)start number={expected_restart_num + 1}" # (In the log, it's 1 higher than backend value) ) assert ('n_restart', f'{expected_restart_num}') in db_select( diff --git a/tests/integration/test_workflow_files.py b/tests/integration/test_workflow_files.py index 5b036ef2c11..77937ca66dd 100644 --- a/tests/integration/test_workflow_files.py +++ b/tests/integration/test_workflow_files.py @@ -151,7 +151,7 @@ def test_detect_old_contact_file_old_run(workflow, caplog, log_filter): # as a side effect the contact file should have been removed assert not workflow.contact_file.exists() - assert log_filter(caplog, contains='Removed contact file') + assert log_filter(contains='Removed contact file') def test_detect_old_contact_file_none(workflow): @@ -260,11 +260,9 @@ def _unlink(*args): # check the appropriate messages were logged assert bool(log_filter( - caplog, contains='Removed contact file', )) is remove_succeeded assert bool(log_filter( - caplog, contains=( f'Failed to remove contact file for {workflow.id_}:' '\nmocked-os-error' diff --git a/tests/integration/test_xtrigger_mgr.py b/tests/integration/test_xtrigger_mgr.py index 73dd66fc2c5..1f9073b94a7 100644 --- a/tests/integration/test_xtrigger_mgr.py +++ b/tests/integration/test_xtrigger_mgr.py @@ -18,16 +18,11 @@ import asyncio from pathlib import Path from textwrap import dedent -from typing import Set from cylc.flow.pathutil import get_workflow_run_dir from cylc.flow.scheduler import Scheduler -def get_task_ids(schd: Scheduler) -> Set[str]: - return {task.identity for task in schd.pool.get_tasks()} - - async def test_2_xtriggers(flow, start, scheduler, monkeypatch): """Test that if an itask has 4 wall_clock triggers with different offsets that xtrigger manager gets all of them. @@ -252,13 +247,13 @@ async def test_1_seq_clock_trigger_2_tasks(flow, start, scheduler): schd: Scheduler = scheduler(id_) async with start(schd): - start_task_pool = get_task_ids(schd) + start_task_pool = schd.pool.get_task_ids() assert start_task_pool == {'1990/foo', '1990/bar'} for _ in range(3): await schd._main_loop() - assert get_task_ids(schd) == start_task_pool.union( + assert schd.pool.get_task_ids() == start_task_pool.union( f'{year}/{name}' for year in range(1991, 1994) for name in ('foo', 'bar') diff --git a/tests/integration/tui/conftest.py b/tests/integration/tui/conftest.py index 55f83c143f3..86d2267da1e 100644 --- a/tests/integration/tui/conftest.py +++ b/tests/integration/tui/conftest.py @@ -4,7 +4,7 @@ from pathlib import Path import re from time import sleep -from uuid import uuid1 +from secrets import token_hex import pytest from urwid.display import html_fragment @@ -211,7 +211,7 @@ def wait_until_loaded(self, *ids, retries=20): ) if exc: msg += f'\n{exc}' - self.compare_screenshot(f'fail-{uuid1()}', msg, 1) + self.compare_screenshot(f'fail-{token_hex(4)}', msg, 1) @pytest.fixture diff --git a/tests/integration/tui/test_mutations.py b/tests/integration/tui/test_mutations.py index e9a41466d70..b87622bca5f 100644 --- a/tests/integration/tui/test_mutations.py +++ b/tests/integration/tui/test_mutations.py @@ -59,7 +59,7 @@ async def test_online_mutation( id_ = flow(one_conf, name='one') schd = scheduler(id_) with rakiura(size='80,15') as rk: - async with start(schd) as schd_log: + async with start(schd): await schd.update_data_structure() assert schd.command_queue.empty() @@ -91,7 +91,7 @@ async def test_online_mutation( # the mutation should be in the scheduler's command_queue await asyncio.sleep(0) - assert log_filter(schd_log, contains="hold(tasks=['1/one'])") + assert log_filter(contains="hold(tasks=['1/one'])") # close the dialogue and re-run the hold mutation rk.user_input('q', 'q', 'enter') @@ -127,7 +127,7 @@ def standardise_cli_cmds(monkeypatch): """This remove the variable bit of the workflow ID from CLI commands. The workflow ID changes from run to run. In order to make screenshots - stable, this + stable, this """ from cylc.flow.tui.data import extract_context def _extract_context(selection): diff --git a/tests/integration/utils/flow_tools.py b/tests/integration/utils/flow_tools.py index 7c94e2b38a4..86377bfaf50 100644 --- a/tests/integration/utils/flow_tools.py +++ b/tests/integration/utils/flow_tools.py @@ -28,7 +28,7 @@ import logging import pytest from typing import Any, Optional, Union -from uuid import uuid1 +from secrets import token_hex from cylc.flow import CYLC_LOG from cylc.flow.run_modes import RunMode @@ -42,7 +42,7 @@ def _make_src_flow(src_path, conf, filename=WorkflowFiles.FLOW_FILE): """Construct a workflow on the filesystem""" - flow_src_dir = (src_path / str(uuid1())) + flow_src_dir = (src_path / token_hex(4)) flow_src_dir.mkdir(parents=True, exist_ok=True) if isinstance(conf, dict): conf = flow_config_str(conf) @@ -54,7 +54,7 @@ def _make_src_flow(src_path, conf, filename=WorkflowFiles.FLOW_FILE): def _make_flow( cylc_run_dir: Union[Path, str], test_dir: Path, - conf: dict, + conf: Union[dict, str], name: Optional[str] = None, id_: Optional[str] = None, defaults: Optional[bool] = True, @@ -63,6 +63,8 @@ def _make_flow( """Construct a workflow on the filesystem. Args: + conf: Either a workflow config dictionary, or a graph string to be + used as the R1 graph in the workflow config. defaults: Set up a common defaults. * [scheduling]allow implicit tasks = true @@ -72,10 +74,18 @@ def _make_flow( flow_run_dir = (cylc_run_dir / id_) else: if name is None: - name = str(uuid1()) + name = token_hex(4) flow_run_dir = (test_dir / name) flow_run_dir.mkdir(parents=True, exist_ok=True) id_ = str(flow_run_dir.relative_to(cylc_run_dir)) + if isinstance(conf, str): + conf = { + 'scheduling': { + 'graph': { + 'R1': conf + } + } + } if defaults: # set the default simulation runtime to zero (can be overridden) ( diff --git a/tests/unit/cycling/test_iso8601.py b/tests/unit/cycling/test_iso8601.py index 4f9482dca28..6f5d08647d2 100644 --- a/tests/unit/cycling/test_iso8601.py +++ b/tests/unit/cycling/test_iso8601.py @@ -610,12 +610,12 @@ def test_exclusion_zero_duration_warning(set_cycling_type, caplog, log_filter): set_cycling_type(ISO8601_CYCLING_TYPE, "+05") with pytest.raises(Exception): ISO8601Sequence('3000', '2999') - assert log_filter(caplog, contains='zero-duration') + assert log_filter(contains='zero-duration') # parsing a point in an exclusion should not caplog.clear() ISO8601Sequence('P1Y ! 3000', '2999') - assert not log_filter(caplog, contains='zero-duration') + assert not log_filter(contains='zero-duration') def test_simple(set_cycling_type): diff --git a/tests/unit/post_install/test_log_vc_info.py b/tests/unit/post_install/test_log_vc_info.py index 59511db5461..67e204db747 100644 --- a/tests/unit/post_install/test_log_vc_info.py +++ b/tests/unit/post_install/test_log_vc_info.py @@ -279,13 +279,13 @@ def test_no_base_commit_git(tmp_path: Path): @require_svn def test_untracked_svn_subdir( - svn_source_repo: Tuple[str, str, str], caplog, log_filter + svn_source_repo: Tuple[str, str, str], log_filter ): repo_dir, *_ = svn_source_repo source_dir = Path(repo_dir, 'jar_jar_binks') source_dir.mkdir() assert get_vc_info(source_dir) is None - assert log_filter(caplog, level=logging.WARNING, contains="$ svn info") + assert log_filter(logging.WARNING, contains="$ svn info") def test_not_installed( @@ -306,7 +306,6 @@ def test_not_installed( caplog.set_level(logging.DEBUG) assert get_vc_info(tmp_path) is None assert log_filter( - caplog, - level=logging.DEBUG, + logging.DEBUG, contains=f"{fake_vcs} does not appear to be installed", ) diff --git a/tests/unit/test_clean.py b/tests/unit/test_clean.py index 285bfa6a23f..2308fe318f0 100644 --- a/tests/unit/test_clean.py +++ b/tests/unit/test_clean.py @@ -945,7 +945,7 @@ def mocked_remote_clean_cmd_side_effect(id_, platform, timeout, rm_dirs): id_, platform_names, timeout='irrelevant', rm_dirs=rm_dirs ) for msg in expected_err_msgs: - assert log_filter(caplog, level=logging.ERROR, contains=msg) + assert log_filter(logging.ERROR, msg) if expected_platforms: for p_name in expected_platforms: mocked_remote_clean_cmd.assert_any_call( diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index b3c97e824da..1b98d42f729 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1361,7 +1361,7 @@ def test_implicit_tasks( @pytest.mark.parametrize('workflow_meta', [True, False]) @pytest.mark.parametrize('url_type', ['good', 'bad', 'ugly', 'broken']) -def test_process_urls(caplog, log_filter, workflow_meta, url_type): +def test_process_urls(log_filter, workflow_meta, url_type): if url_type == 'good': # valid cylc 8 syntax @@ -1397,7 +1397,6 @@ def test_process_urls(caplog, log_filter, workflow_meta, url_type): elif url_type == 'ugly': WorkflowConfig.process_metadata_urls(config) assert log_filter( - caplog, contains='Detected deprecated template variables', ) @@ -1425,7 +1424,6 @@ def test_zero_interval( should_warn: bool, opts: Values, tmp_flow_config: Callable, - caplog: pytest.LogCaptureFixture, log_filter: Callable, ): """Test that a zero-duration recurrence with >1 repetition gets an @@ -1443,7 +1441,6 @@ def test_zero_interval( """) WorkflowConfig(id_, flow_file, options=opts) logged = log_filter( - caplog, level=logging.WARNING, contains="Cannot have more than 1 repetition for zero-duration" ) diff --git a/tests/unit/test_id.py b/tests/unit/test_id.py index 2d50c9a2706..e3011ee1d57 100644 --- a/tests/unit/test_id.py +++ b/tests/unit/test_id.py @@ -17,12 +17,15 @@ import pytest +from cylc.flow.cycling.integer import IntegerPoint +from cylc.flow.cycling.iso8601 import ISO8601Point from cylc.flow.id import ( LEGACY_CYCLE_SLASH_TASK, LEGACY_TASK_DOT_CYCLE, RELATIVE_ID, - Tokens, UNIVERSAL_ID, + Tokens, + quick_relative_id, ) @@ -392,3 +395,14 @@ def test_task_property(): ) == ( Tokens('//c:cs/t:ts/j:js', relative=True) ) + + +@pytest.mark.parametrize('cycle, expected', [ + ('2000', '2000/foo'), + (2001, '2001/foo'), + (IntegerPoint('3'), '3/foo'), + # NOTE: ISO8601Points are not standardised by this function: + (ISO8601Point('2002'), '2002/foo'), +]) +def test_quick_relative_id(cycle, expected): + assert quick_relative_id(cycle, 'foo') == expected diff --git a/tests/unit/test_platforms.py b/tests/unit/test_platforms.py index 4e0e0c0bbd9..106aceea068 100644 --- a/tests/unit/test_platforms.py +++ b/tests/unit/test_platforms.py @@ -34,6 +34,7 @@ PlatformLookupError, GlobalConfigError ) +from cylc.flow.run_modes import JOBLESS_MODES PLATFORMS = { @@ -470,12 +471,24 @@ def test_get_install_target_to_platforms_map( for install_target in _map: _map[install_target] = sorted(_map[install_target], key=lambda k: k['name']) - expected_map.update( - {'localhost': [{'name': 'simulation'}, {'name': 'skip'}]} - ) assert result == expected_map +@pytest.mark.parametrize('mode', sorted(JOBLESS_MODES)) +def test_platform_from_name__jobless_modes(mode): + result = platform_from_name(mode) + assert result['name'] == 'localhost' + + +@pytest.mark.parametrize('mode', sorted(JOBLESS_MODES)) +def test_get_install_target_to_platforms_map__jobless_modes(mode): + result = get_install_target_to_platforms_map([mode]) + assert list(result) == ['localhost'] + assert len(result['localhost']) == 1 + assert result['localhost'][0]['hosts'] == ['localhost'] + assert result['localhost'][0]['install target'] == 'localhost' + + @pytest.mark.parametrize( 'platform, job, remote, expect', [ diff --git a/tests/unit/test_prerequisite.py b/tests/unit/test_prerequisite.py index 105e5e85401..7e1fd19124c 100644 --- a/tests/unit/test_prerequisite.py +++ b/tests/unit/test_prerequisite.py @@ -14,12 +14,17 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from functools import partial + import pytest from cylc.flow.cycling.integer import IntegerPoint from cylc.flow.cycling.loader import ISO8601_CYCLING_TYPE, get_point -from cylc.flow.prerequisite import Prerequisite -from cylc.flow.id import Tokens +from cylc.flow.id import Tokens, detokenise +from cylc.flow.prerequisite import Prerequisite, SatisfiedState + + +detok = partial(detokenise, selectors=True, relative=True) @pytest.fixture @@ -43,10 +48,10 @@ def test_satisfied(prereq: Prerequisite): ('2001', 'd', 'custom'): False, } # No cached satisfaction state yet: - assert prereq._all_satisfied is None + assert prereq._cached_satisfied is None # Calling self.is_satisfied() should cache the result: assert not prereq.is_satisfied() - assert prereq._all_satisfied is False + assert prereq._cached_satisfied is False # mark two prerequisites as satisfied prereq.satisfy_me([ @@ -63,7 +68,7 @@ def test_satisfied(prereq: Prerequisite): ('2001', 'd', 'custom'): False, } # Should have reset cached satisfaction state: - assert prereq._all_satisfied is None + assert prereq._cached_satisfied is None assert not prereq.is_satisfied() # mark all prereqs as satisfied @@ -78,7 +83,7 @@ def test_satisfied(prereq: Prerequisite): ('2001', 'd', 'custom'): 'force satisfied', } # Should have set cached satisfaction state as must be true now: - assert prereq._all_satisfied is True + assert prereq._cached_satisfied is True assert prereq.is_satisfied() @@ -116,9 +121,9 @@ def test_items(prereq: Prerequisite): ] -def test_set_condition(prereq: Prerequisite): +def test_set_conditional_expr(prereq: Prerequisite): assert not prereq.is_satisfied() - prereq.set_condition('1999/a succeeded | 2000/b succeeded') + prereq.set_conditional_expr('1999/a succeeded | 2000/b succeeded') assert prereq.is_satisfied() @@ -138,14 +143,89 @@ def test_get_target_points(prereq): } -def test_get_resolved_dependencies(): +@pytest.fixture +def satisfied_states_prereq(): + """Fixture for testing the full range of possible satisfied states.""" prereq = Prerequisite(IntegerPoint('2')) prereq[('1', 'a', 'x')] = True prereq[('1', 'b', 'x')] = False prereq[('1', 'c', 'x')] = 'satisfied from database' prereq[('1', 'd', 'x')] = 'force satisfied' - assert prereq.get_resolved_dependencies() == [ - '1/a', - '1/c', - '1/d', - ] + return prereq + + +def test_unset_naturally_satisfied(satisfied_states_prereq: Prerequisite): + satisfied_states_prereq[('1', 'a', 'y')] = True + satisfied_states_prereq[('1', 'a', 'z')] = 'force satisfied' + for id_, expected in [ + ('1/a', True), + ('1/b', False), + ('1/c', True), + ('1/d', False), + ]: + assert ( + satisfied_states_prereq.unset_naturally_satisfied(id_) == expected + ) + assert satisfied_states_prereq._satisfied == { + ('1', 'a', 'x'): False, + ('1', 'a', 'y'): False, + ('1', 'a', 'z'): 'force satisfied', + ('1', 'b', 'x'): False, + ('1', 'c', 'x'): False, + ('1', 'd', 'x'): 'force satisfied', + } + + +def test_satisfy_me(): + prereq = Prerequisite(IntegerPoint('2')) + for task_name in ('a', 'b', 'c'): + prereq[('1', task_name, 'x')] = False + assert not prereq.is_satisfied() + assert prereq._cached_satisfied is False + + valid = prereq.satisfy_me( + [Tokens('//1/a:x'), Tokens('//1/d:x'), Tokens('//1/c:y')], + ) + assert {detok(tokens) for tokens in valid} == {'1/a:x'} + assert prereq._satisfied == { + ('1', 'a', 'x'): 'satisfied naturally', + ('1', 'b', 'x'): False, + ('1', 'c', 'x'): False, + } + # should have reset cached satisfaction state + assert prereq._cached_satisfied is None + + valid = prereq.satisfy_me( + [Tokens('//1/a:x'), Tokens('//1/b:x')], + forced=True, + ) + assert {detok(tokens) for tokens in valid} == {'1/a:x', '1/b:x'} + assert prereq._satisfied == { + # 1/a:x unaffected as already satisfied + ('1', 'a', 'x'): 'satisfied naturally', + ('1', 'b', 'x'): 'force satisfied', + ('1', 'c', 'x'): False, + } + + +@pytest.mark.parametrize('forced', [False, True]) +@pytest.mark.parametrize('existing, expected_when_forced', [ + (False, 'force satisfied'), + ('satisfied from database', 'force satisfied'), + ('force satisfied', 'force satisfied'), + ('satisfied naturally', 'satisfied naturally'), +]) +def test_satisfy_me__override( + forced: bool, + existing: SatisfiedState, + expected_when_forced: SatisfiedState, +): + """Test that satisfying a prereq with a different state works as expected + with and without the `forced` arg.""" + prereq = Prerequisite(IntegerPoint('2')) + prereq[('1', 'a', 'x')] = existing + + prereq.satisfy_me([Tokens('//1/a:x')], forced=forced) + assert prereq[('1', 'a', 'x')] == ( + expected_when_forced if forced else 'satisfied naturally' + ) diff --git a/tests/unit/test_rundb.py b/tests/unit/test_rundb.py index 06aba70699f..44db75fb2e5 100644 --- a/tests/unit/test_rundb.py +++ b/tests/unit/test_rundb.py @@ -112,7 +112,9 @@ def test_operational_error(tmp_path, caplog): # stage some stuff dao.add_delete_item(CylcWorkflowDAO.TABLE_TASK_JOBS) dao.add_insert_item(CylcWorkflowDAO.TABLE_TASK_JOBS, ['pub']) - dao.add_update_item(CylcWorkflowDAO.TABLE_TASK_JOBS, ['pub']) + dao.add_update_item( + CylcWorkflowDAO.TABLE_TASK_JOBS, ({'pub': None}, {}) + ) # connect the to DB dao.connect() diff --git a/tests/unit/test_scheduler.py b/tests/unit/test_scheduler.py index 36beb121d3b..ccd5f5dfed5 100644 --- a/tests/unit/test_scheduler.py +++ b/tests/unit/test_scheduler.py @@ -133,7 +133,7 @@ def _select_workflow_host(cached=False): ) caplog.set_level(logging.ERROR, CYLC_LOG) assert not Scheduler.workflow_auto_restart(schd, max_retries=2) - assert log_filter(caplog, contains='elephant') + assert log_filter(contains='elephant') def test_auto_restart_popen_error(monkeypatch, caplog, log_filter): @@ -166,4 +166,4 @@ def _popen(*args, **kwargs): ) caplog.set_level(logging.ERROR, CYLC_LOG) assert not Scheduler.workflow_auto_restart(schd, max_retries=2) - assert log_filter(caplog, contains='mystderr') + assert log_filter(contains='mystderr') diff --git a/tests/unit/test_task_pool.py b/tests/unit/test_task_pool.py index b32781895bc..d1ab1642442 100644 --- a/tests/unit/test_task_pool.py +++ b/tests/unit/test_task_pool.py @@ -21,6 +21,7 @@ import pytest from cylc.flow.flow_mgr import FlowNums +from cylc.flow.prerequisite import SatisfiedState from cylc.flow.task_pool import TaskPool @@ -55,3 +56,29 @@ def test_get_active_flow_nums( ) assert TaskPool._get_active_flow_nums(mock_task_pool) == expected + + +@pytest.mark.parametrize('output_msg, flow_nums, db_flow_nums, expected', [ + ('foo', set(), {1}, False), + ('foo', set(), set(), False), + ('foo', {1, 3}, {1}, 'satisfied from database'), + ('goo', {1, 3}, {1, 2}, 'satisfied from database'), + ('foo', {1, 3}, set(), False), + ('foo', {2}, {1}, False), + ('foo', {2}, {1, 2}, 'satisfied from database'), + ('f', {1}, {1}, False), +]) +def test_check_output( + output_msg: str, + flow_nums: set, + db_flow_nums: set, + expected: SatisfiedState, +): + mock_task_pool = Mock() + mock_task_pool.workflow_db_mgr.pri_dao.select_task_outputs.return_value = { + '{"f": "foo", "g": "goo"}': db_flow_nums, + } + + assert TaskPool.check_task_output( + mock_task_pool, '2000', 'haddock', output_msg, flow_nums + ) == expected diff --git a/tests/unit/test_task_proxy.py b/tests/unit/test_task_proxy.py index 98695ecd13f..e1b03132f3b 100644 --- a/tests/unit/test_task_proxy.py +++ b/tests/unit/test_task_proxy.py @@ -14,13 +14,15 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import pytest -from pytest import param from typing import Callable, Optional from unittest.mock import Mock +import pytest +from pytest import param + from cylc.flow.cycling import PointBase from cylc.flow.cycling.iso8601 import ISO8601Point +from cylc.flow.flow_mgr import FlowNums from cylc.flow.task_proxy import TaskProxy @@ -101,3 +103,28 @@ def test_status_match(status_str: Optional[str], expected: bool): mock_itask = Mock(state=Mock(status='waiting')) assert TaskProxy.status_match(mock_itask, status_str) is expected + + +@pytest.mark.parametrize('itask_flow_nums, flow_nums, expected', [ + param({1, 2}, {2}, {2}, id="subset"), + param({2}, {1, 2}, {2}, id="superset"), + param({1, 2}, {3, 4}, set(), id="disjoint"), + param({1, 2}, set(), {1, 2}, id="all-matches-num"), + param(set(), {1, 2}, set(), id="num-doesnt-match-none"), + param(set(), set(), set(), id="all-doesnt-match-none"), +]) +def test_match_flows( + itask_flow_nums: FlowNums, flow_nums: FlowNums, expected: FlowNums +): + mock_itask = Mock(flow_nums=itask_flow_nums) + assert TaskProxy.match_flows(mock_itask, flow_nums) == expected + + +def test_match_flows_copy(): + """Test that this method does not return the same reference as + itask.flow_nums, otherwise you could end up unexpectedly mutating + itask.flow_nums.""" + mock_itask = Mock(flow_nums={1, 2}) + result = TaskProxy.match_flows(mock_itask, set()) + assert result == mock_itask.flow_nums + assert result is not mock_itask.flow_nums diff --git a/tests/unit/test_task_state.py b/tests/unit/test_task_state.py index 7350a9aed74..7b33e5f5b8b 100644 --- a/tests/unit/test_task_state.py +++ b/tests/unit/test_task_state.py @@ -14,9 +14,11 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from unittest.mock import MagicMock import pytest from types import SimpleNamespace +from cylc.flow.prerequisite import Prerequisite from cylc.flow.taskdef import TaskDef from cylc.flow.cycling.integer import IntegerSequence, IntegerPoint from cylc.flow.run_modes import RunMode, disable_task_event_handlers @@ -122,6 +124,27 @@ def test_task_state_order(): assert not tstate.is_gte(TASK_STATUS_RUNNING) +def test_get_resolved_dependencies(): + prereq1 = Prerequisite(IntegerPoint('2')) + prereq1[('1', 'a', 'x')] = True + prereq1[('1', 'b', 'x')] = False + prereq1[('1', 'c', 'x')] = 'satisfied from database' + prereq1[('1', 'd', 'x')] = 'force satisfied' + prereq2 = Prerequisite(IntegerPoint('2')) + prereq2[('1', 'e', 'succeeded')] = False + prereq2[('1', 'e', 'failed')] = True + task_state = TaskState( + MagicMock(), IntegerPoint('2'), TASK_STATUS_WAITING, False + ) + task_state.prerequisites = [prereq1, prereq2] + assert task_state.get_resolved_dependencies() == [ + '1/a', + '1/c', + '1/d', + '1/e', + ] + + @pytest.mark.parametrize( 'itask_run_mode, disable_handlers, expect', ( diff --git a/tests/unit/test_workflow_db_mgr.py b/tests/unit/test_workflow_db_mgr.py new file mode 100644 index 00000000000..c749dc9eeef --- /dev/null +++ b/tests/unit/test_workflow_db_mgr.py @@ -0,0 +1,79 @@ +# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE. +# Copyright (C) NIWA & British Crown (Met Office) & Contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from pathlib import Path +from typing import ( + List, + Set, +) +from unittest.mock import Mock + +import pytest +from pytest import param + +from cylc.flow.cycling.integer import IntegerPoint +from cylc.flow.flow_mgr import FlowNums +from cylc.flow.id import Tokens +from cylc.flow.task_proxy import TaskProxy +from cylc.flow.taskdef import TaskDef +from cylc.flow.util import serialise_set +from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager + + +@pytest.mark.parametrize('flow_nums, expected_removed', [ + param(set(), {1, 2, 5}, id='all'), + param({1}, {1}, id='subset'), + param({1, 2, 5}, {1, 2, 5}, id='complete-set'), + param({1, 3, 5}, {1, 5}, id='intersect'), + param({3, 4}, set(), id='disjoint'), +]) +def test_remove_task_from_flows( + tmp_path: Path, flow_nums: FlowNums, expected_removed: FlowNums +): + db_flows: List[FlowNums] = [ + {1, 2}, + {5}, + set(), # FLOW_NONE + ] + expected_remaining = { + serialise_set(flow - expected_removed) for flow in db_flows + } + db_mgr = WorkflowDatabaseManager(tmp_path) + schd_tokens = Tokens('~asterix/gaul') + tdef = TaskDef('a', rtcfg={}, start_point=None, initial_point=None) + with db_mgr.get_pri_dao() as dao: + db_mgr.pri_dao = dao + db_mgr.pub_dao = Mock() + for flow in db_flows: + itask = TaskProxy( + schd_tokens, tdef, IntegerPoint('1'), flow_nums=flow + ) + db_mgr.put_insert_task_states(itask) + db_mgr.put_insert_task_outputs(itask) + db_mgr.process_queued_ops() + + removed_fnums = db_mgr.remove_task_from_flows('1', 'a', flow_nums) + assert removed_fnums == expected_removed + + db_mgr.process_queued_ops() + for table in ('task_states', 'task_outputs'): + remaining_fnums: Set[str] = { + fnums_str + for fnums_str, *_ in dao.connect().execute( + f'SELECT flow_nums FROM {table}' + ) + } + assert remaining_fnums == expected_remaining diff --git a/tests/unit/test_workflow_events.py b/tests/unit/test_workflow_events.py index 89449953f20..c9d08791781 100644 --- a/tests/unit/test_workflow_events.py +++ b/tests/unit/test_workflow_events.py @@ -83,13 +83,13 @@ def test_process_mail_footer(caplog, log_filter): assert process_mail_footer( '%(host)s|%(port)s|%(owner)s|%(suite)s|%(workflow)s', template_vars ) == 'myhost|42|me|my_workflow|my_workflow\n' - assert not log_filter(caplog, contains='Ignoring bad mail footer template') + assert not log_filter(contains='Ignoring bad mail footer template') # test invalid variable assert process_mail_footer('%(invalid)s', template_vars) == '' - assert log_filter(caplog, contains='Ignoring bad mail footer template') + assert log_filter(contains='Ignoring bad mail footer template') # test broken template caplog.clear() assert process_mail_footer('%(invalid)s', template_vars) == '' - assert log_filter(caplog, contains='Ignoring bad mail footer template') + assert log_filter(contains='Ignoring bad mail footer template') diff --git a/tests/unit/test_workflow_files.py b/tests/unit/test_workflow_files.py index b2b33e495aa..2e85d4180a0 100644 --- a/tests/unit/test_workflow_files.py +++ b/tests/unit/test_workflow_files.py @@ -196,7 +196,6 @@ def test_infer_latest_run( @pytest.mark.parametrize('warn_arg', [True, False]) def test_infer_latest_run_warns_for_runN( warn_arg: bool, - caplog: pytest.LogCaptureFixture, log_filter: Callable, tmp_run_dir: Callable, ): @@ -206,8 +205,7 @@ def test_infer_latest_run_warns_for_runN( runN_path.symlink_to('run1') infer_latest_run(runN_path, warn_runN=warn_arg) filtered_log = log_filter( - caplog, level=logging.WARNING, - contains="You do not need to include runN in the workflow ID" + logging.WARNING, "You do not need to include runN in the workflow ID" ) assert filtered_log if warn_arg else not filtered_log