From 35819a5adf335171ded0fe17c815b99effb29237 Mon Sep 17 00:00:00 2001 From: Ronnie Dutta <61982285+MetRonnie@users.noreply.github.com> Date: Thu, 5 Dec 2024 12:36:37 +0000 Subject: [PATCH] Fix KeyError when using `cylc remove` after a reload changed the graph --- cylc/flow/config.py | 2 +- cylc/flow/scheduler.py | 56 ++++++++----------- cylc/flow/task_pool.py | 40 ++++++------- cylc/flow/taskdef.py | 3 +- cylc/flow/workflow_db_mgr.py | 2 +- tests/integration/test_optional_outputs.py | 3 +- tests/integration/test_reload.py | 2 +- tests/integration/test_remove.py | 35 ++++++++++++ .../test_stop_after_cycle_point.py | 2 +- tests/integration/test_task_pool.py | 10 ++-- tests/integration/utils/flow_tools.py | 11 ++-- 11 files changed, 94 insertions(+), 72 deletions(-) diff --git a/cylc/flow/config.py b/cylc/flow/config.py index ea31695d2e5..c438cd9105a 100644 --- a/cylc/flow/config.py +++ b/cylc/flow/config.py @@ -1543,7 +1543,7 @@ def configure_workflow_state_polling_tasks(self): "script cannot be defined for automatic" + " workflow polling task '%s':\n%s" % (l_task, cs)) # Generate the automatic scripting. - for name, tdef in list(self.taskdefs.items()): + for name, tdef in self.taskdefs.items(): if name not in self.workflow_polling_tasks: continue rtc = tdef.rtconfig diff --git a/cylc/flow/scheduler.py b/cylc/flow/scheduler.py index 3339423640f..2a66ade7dcb 100644 --- a/cylc/flow/scheduler.py +++ b/cylc/flow/scheduler.py @@ -70,7 +70,6 @@ from cylc.flow.broadcast_mgr import BroadcastMgr from cylc.flow.cfgspec.glbl_cfg import glbl_cfg from cylc.flow.config import WorkflowConfig -from cylc.flow.cycling.loader import get_point from cylc.flow.data_store_mgr import DataStoreMgr from cylc.flow.exceptions import ( CommandFailedError, @@ -93,9 +92,7 @@ get_user, is_remote_platform, ) -from cylc.flow.id import ( - Tokens, -) +from cylc.flow.id import Tokens, quick_relative_id from cylc.flow.log_level import ( verbosity_to_env, verbosity_to_opts, @@ -1102,17 +1099,20 @@ def remove_tasks( if flow_nums is None: flow_nums = set() - # Mapping of *relative* task IDs to removed flow numbers: - removed: Dict[Tokens, FlowNums] = {} - not_removed: Set[Tokens] = set() + # Mapping of relative task IDs to removed flow numbers: + removed: Dict[str, FlowNums] = {} + not_removed: Set[str] = set() + # All the matched tasks (will add applicable active tasks below): + matched_tasks = inactive.copy() to_kill: List[TaskProxy] = [] for itask in active: fnums_to_remove = itask.match_flows(flow_nums) if not fnums_to_remove: - not_removed.add(itask.tokens.task) + not_removed.add(itask.identity) continue - removed[itask.tokens.task] = fnums_to_remove + removed[itask.identity] = fnums_to_remove + matched_tasks.add((itask.tdef, itask.point)) if fnums_to_remove == itask.flow_nums: # Need to remove the task from the pool. # Spawn next occurrence of xtrigger sequential task (otherwise @@ -1123,21 +1123,13 @@ def remove_tasks( itask.removed = True itask.flow_nums.difference_update(fnums_to_remove) - # All the matched tasks (including inactive & applicable active tasks): - matched_tasks = { - *removed.keys(), - *(Tokens(cycle=str(cycle), task=task) for task, cycle in inactive), - } - - for tokens in matched_tasks: - tdef = self.config.taskdefs[tokens['task']] + for tdef, point in matched_tasks: + task_id = quick_relative_id(point, tdef.name) # Go through any tasks downstream of this matched task to see if # any need to stand down as a result of this task being removed: for child in set(itertools.chain.from_iterable( - generate_graph_children( - tdef, get_point(tokens['cycle']) - ).values() + generate_graph_children(tdef, point).values() )): child_itask = self.pool.get_task(child.point, child.name) if not child_itask: @@ -1152,9 +1144,9 @@ def remove_tasks( ): # Unset any prereqs naturally satisfied by these tasks # (do not unset those satisfied by `cylc set --pre`): - if prereq.unset_naturally_satisfied(tokens.relative_id): + if prereq.unset_naturally_satisfied(task_id): prereqs_changed = True - removed.setdefault(tokens, set()).update( + removed.setdefault(task_id, set()).update( fnums_to_remove ) if not prereqs_changed: @@ -1173,7 +1165,7 @@ def remove_tasks( # Check if downstream task should remain spawned: if ( # Ignoring tasks we are already dealing with: - child_itask.tokens.task in matched_tasks + (child_itask.tdef, child_itask.point) in matched_tasks or child_itask.state.any_satisfied_prerequisite_outputs() ): continue @@ -1187,10 +1179,13 @@ def remove_tasks( # Remove the matched tasks from the flows in the DB tables: db_removed_fnums = self.workflow_db_mgr.remove_task_from_flows( - tokens['cycle'], tokens['task'], flow_nums, + str(point), tdef.name, flow_nums, ) if db_removed_fnums: - removed.setdefault(tokens, set()).update(db_removed_fnums) + removed.setdefault(task_id, set()).update(db_removed_fnums) + + if task_id not in removed: + not_removed.add(task_id) if to_kill: self.kill_tasks(to_kill, warn=False) @@ -1198,22 +1193,17 @@ def remove_tasks( if removed: tasks_str_list = [] for task, fnums in removed.items(): - self.data_store_mgr.delta_remove_task_flow_nums( - task.relative_id, fnums - ) + self.data_store_mgr.delta_remove_task_flow_nums(task, fnums) tasks_str_list.append( - f"{task.relative_id} {repr_flow_nums(fnums, full=True)}" + f"{task} {repr_flow_nums(fnums, full=True)}" ) LOG.info(f"Removed task(s): {', '.join(sorted(tasks_str_list))}") - not_removed.update(matched_tasks.difference(removed)) if not_removed: fnums_str = ( repr_flow_nums(flow_nums, full=True) if flow_nums else '' ) - tasks_str = ', '.join( - sorted(tokens.relative_id for tokens in not_removed) - ) + tasks_str = ', '.join(sorted(not_removed)) LOG.warning(f"Task(s) not removable: {tasks_str} {fnums_str}") if removed and self.pool.compute_runahead(): diff --git a/cylc/flow/task_pool.py b/cylc/flow/task_pool.py index c771af215cd..92a337e9a2c 100644 --- a/cylc/flow/task_pool.py +++ b/cylc/flow/task_pool.py @@ -32,6 +32,7 @@ Tuple, Type, Union, + cast, ) from cylc.flow import LOG @@ -1335,9 +1336,9 @@ def hold_tasks(self, items: Iterable[str]) -> int: for itask in itasks: self.hold_active_task(itask) # Set inactive tasks to be held: - for name, cycle in inactive_tasks: - self.data_store_mgr.delta_task_held(name, cycle, True) - self.tasks_to_hold.update(inactive_tasks) + for tdef, cycle in inactive_tasks: + self.data_store_mgr.delta_task_held(tdef.name, cycle, True) + self.tasks_to_hold.add((tdef.name, cycle)) self.workflow_db_mgr.put_tasks_to_hold(self.tasks_to_hold) LOG.debug(f"Tasks to hold: {self.tasks_to_hold}") return len(unmatched) @@ -1353,9 +1354,9 @@ def release_held_tasks(self, items: Iterable[str]) -> int: for itask in itasks: self.release_held_active_task(itask) # Unhold inactive tasks: - for name, cycle in inactive_tasks: - self.data_store_mgr.delta_task_held(name, cycle, False) - self.tasks_to_hold.difference_update(inactive_tasks) + for tdef, cycle in inactive_tasks: + self.data_store_mgr.delta_task_held(tdef.name, cycle, False) + self.tasks_to_hold.discard((tdef.name, cycle)) self.workflow_db_mgr.put_tasks_to_hold(self.tasks_to_hold) LOG.debug(f"Tasks to hold: {self.tasks_to_hold}") return len(unmatched) @@ -1979,8 +1980,7 @@ def set_prereqs_and_outputs( if not flow: # default: assign to all active flows flow_nums = self._get_active_flow_nums() - for name, point in inactive_tasks: - tdef = self.config.get_taskdef(name) + for tdef, point in inactive_tasks: if prereqs: self._set_prereqs_tdef( point, tdef, prereqs, flow_nums, flow_wait) @@ -2175,7 +2175,7 @@ def force_trigger_tasks( """ # Get matching tasks proxies, and matching inactive task IDs. - existing_tasks, inactive_ids, unmatched = self.filter_task_proxies( + existing_tasks, inactive, unmatched = self.filter_task_proxies( items, inactive=True, warn_no_active=False, ) @@ -2199,15 +2199,15 @@ def force_trigger_tasks( if not flow: # default: assign to all active flows flow_nums = self._get_active_flow_nums() - for name, point in inactive_ids: - if not self.can_be_spawned(name, point): + for tdef, point in inactive: + if not self.can_be_spawned(tdef.name, point): continue submit_num, _, prev_fwait = ( - self._get_task_history(name, point, flow_nums) + self._get_task_history(tdef.name, point, flow_nums) ) itask = TaskProxy( self.tokens, - self.config.get_taskdef(name), + tdef, point, flow_nums, flow_wait=flow_wait, @@ -2327,7 +2327,7 @@ def filter_task_proxies( ids: Iterable[str], warn_no_active: bool = True, inactive: bool = False, - ) -> 'Tuple[List[TaskProxy], Set[Tuple[str, PointBase]], List[str]]': + ) -> 'Tuple[List[TaskProxy], Set[Tuple[TaskDef, PointBase]], List[str]]': """Return task proxies that match names, points, states in items. Args: @@ -2353,7 +2353,7 @@ def filter_task_proxies( ids, warn=warn_no_active, ) - inactive_matched: 'Set[Tuple[str, PointBase]]' = set() + inactive_matched: 'Set[Tuple[TaskDef, PointBase]]' = set() if inactive and unmatched: inactive_matched, unmatched = self.match_inactive_tasks( unmatched @@ -2364,7 +2364,7 @@ def filter_task_proxies( def match_inactive_tasks( self, ids: Iterable[str], - ) -> Tuple[Set[Tuple[str, 'PointBase']], List[str]]: + ) -> 'Tuple[Set[Tuple[TaskDef, PointBase]], List[str]]': """Match task IDs against task definitions (rather than the task pool). IDs will be matched providing the ID: @@ -2377,7 +2377,7 @@ def match_inactive_tasks( (matched_tasks, unmatched_tasks) """ - matched_tasks: 'Set[Tuple[str, PointBase]]' = set() + matched_tasks: 'Set[Tuple[TaskDef, PointBase]]' = set() unmatched_tasks: 'List[str]' = [] for id_ in ids: try: @@ -2404,8 +2404,8 @@ def match_inactive_tasks( unmatched_tasks.append(id_) continue - point_str = tokens['cycle'] - name_str = tokens['task'] + point_str = cast('str', tokens['cycle']) + name_str = cast('str', tokens['task']) if name_str not in self.config.taskdefs: if self.config.find_taskdefs(name_str): # It's a family name; was not matched by active tasks @@ -2427,7 +2427,7 @@ def match_inactive_tasks( point = get_point(point_str) taskdef = self.config.taskdefs[name_str] if taskdef.is_valid_point(point): - matched_tasks.add((taskdef.name, point)) + matched_tasks.add((taskdef, point)) else: LOG.warning( self.ERR_PREFIX_TASK_NOT_ON_SEQUENCE.format( diff --git a/cylc/flow/taskdef.py b/cylc/flow/taskdef.py index 34461ca0d6f..49784822127 100644 --- a/cylc/flow/taskdef.py +++ b/cylc/flow/taskdef.py @@ -165,7 +165,7 @@ class TaskDef: def __init__(self, name, rtcfg, start_point, initial_point): if not TaskID.is_valid_name(name): raise TaskDefError("Illegal task name: %s" % name) - + self.name: str = name self.rtconfig = rtcfg self.start_point = start_point self.initial_point = initial_point @@ -192,7 +192,6 @@ def __init__(self, name, rtcfg, start_point, initial_point): self.external_triggers = [] self.xtrig_labels = {} # {sequence: [labels]} - self.name = name self.elapsed_times = deque(maxlen=self.MAX_LEN_ELAPSED_TIMES) self._add_std_outputs() self.has_abs_triggers = False diff --git a/cylc/flow/workflow_db_mgr.py b/cylc/flow/workflow_db_mgr.py index d9ae87150d8..70c01beed8c 100644 --- a/cylc/flow/workflow_db_mgr.py +++ b/cylc/flow/workflow_db_mgr.py @@ -698,7 +698,7 @@ def _put_update_task_x( self, table_name: str, itask: 'TaskProxy', set_args: 'DbArgDict' ) -> None: """Put UPDATE statement for a task_* table.""" - where_args = { + where_args: Dict[str, Any] = { "cycle": str(itask.point), "name": itask.tdef.name, } diff --git a/tests/integration/test_optional_outputs.py b/tests/integration/test_optional_outputs.py index d5c4e41ce81..fd0e6281be8 100644 --- a/tests/integration/test_optional_outputs.py +++ b/tests/integration/test_optional_outputs.py @@ -41,7 +41,6 @@ get_completion_expression, ) from cylc.flow.task_state import ( - TASK_STATUSES_ACTIVE, TASK_STATUS_EXPIRED, TASK_STATUS_PREPARING, TASK_STATUS_RUNNING, @@ -484,7 +483,7 @@ async def test_removed_taskdef( 'R1': 'a' } } - }, id_=id_) + }, workflow_id=id_) # restart the workflow schd: 'Scheduler' = scheduler(id_) diff --git a/tests/integration/test_reload.py b/tests/integration/test_reload.py index ad96b187722..6043d0931d0 100644 --- a/tests/integration/test_reload.py +++ b/tests/integration/test_reload.py @@ -130,7 +130,7 @@ async def test_reload_failure( async with start(schd): # corrupt the config by removing the scheduling section two_conf = {**one_conf, 'scheduling': {}} - flow(two_conf, id_=id_) + flow(two_conf, workflow_id=id_) # reload the workflow await commands.run_cmd(commands.reload_workflow(schd)) diff --git a/tests/integration/test_remove.py b/tests/integration/test_remove.py index aab060a788f..539bc33e782 100644 --- a/tests/integration/test_remove.py +++ b/tests/integration/test_remove.py @@ -20,6 +20,7 @@ from cylc.flow.commands import ( force_trigger_tasks, + reload_workflow, remove_tasks, run_cmd, ) @@ -478,3 +479,37 @@ async def test_kill_running(flow, scheduler, run, complete, reflog): ('1/c', ('1/b',)), # The a:failed output should not cause 1/q to run } + + +async def test_reload_changed_config(flow, scheduler, run, complete): + """Test that a task is removed from the pool if its configuration changes + to make it no longer match the graph.""" + wid = flow({ + 'scheduling': { + 'graph': { + 'R1': ''' + a => b + a:started => s & b + ''', + }, + }, + 'runtime': { + 'a': { + 'simulation': { + # Ensure 1/a still in pool during reload + 'fail cycle points': 'all', + }, + }, + }, + }) + schd: Scheduler = scheduler(wid, paused_start=False) + async with run(schd): + await complete(schd, '1/s') + # Change graph then reload + flow('b', workflow_id=wid) + await run_cmd(reload_workflow(schd)) + assert schd.config.cfg['scheduling']['graph']['R1'] == 'b' + assert schd.pool.get_task_ids() == {'1/a', '1/b'} + + await run_cmd(remove_tasks(schd, ['1/a'], [FLOW_ALL])) + await complete(schd, '1/b') diff --git a/tests/integration/test_stop_after_cycle_point.py b/tests/integration/test_stop_after_cycle_point.py index 90bab288515..6105835c111 100644 --- a/tests/integration/test_stop_after_cycle_point.py +++ b/tests/integration/test_stop_after_cycle_point.py @@ -85,7 +85,7 @@ def get_db_value(schd) -> Optional[str]: # change the configured cycle point to "2" config['scheduling']['stop after cycle point'] = '2' - id_ = flow(config, id_=id_) + id_ = flow(config, workflow_id=id_) schd = scheduler(id_, paused_start=False) async with run(schd): # the cycle point should be reloaded from the workflow configuration diff --git a/tests/integration/test_task_pool.py b/tests/integration/test_task_pool.py index 404ec8da87f..3fb0becadd6 100644 --- a/tests/integration/test_task_pool.py +++ b/tests/integration/test_task_pool.py @@ -714,7 +714,7 @@ async def test_restart_prereqs( # Edit the workflow to add a new dependency on "z" conf['scheduling']['graph']['R1'] = graph_2 - id_ = flow(conf, id_=id_) + id_ = flow(conf, workflow_id=id_) # Restart it schd = scheduler(id_, run_mode='simulation', paused_start=False) @@ -834,7 +834,7 @@ async def test_reload_prereqs( # Modify flow.cylc to add a new dependency on "z" conf['scheduling']['graph']['R1'] = graph_2 - flow(conf, id_=id_) + flow(conf, workflow_id=id_) # Reload the workflow config await commands.run_cmd(commands.reload_workflow(schd)) @@ -953,7 +953,7 @@ async def test_graph_change_prereq_satisfaction( # shutdown and change the workflow definiton conf['scheduling']['graph']['R1'] += '\nb => c' - flow(conf, id_=id_) + flow(conf, workflow_id=id_) schd = scheduler(id_, run_mode='simulation', paused_start=False) async with start(schd): @@ -966,7 +966,7 @@ async def test_graph_change_prereq_satisfaction( # Modify flow.cylc to add a new dependency on "b" conf['scheduling']['graph']['R1'] += '\nb => c' - flow(conf, id_=id_) + flow(conf, workflow_id=id_) # Reload the workflow config await commands.run_cmd(commands.reload_workflow(schd)) @@ -2158,7 +2158,7 @@ async def list_data_store(): ].replace('@a', '@c') # reload - flow(config, id_=id_) + flow(config, workflow_id=id_) await commands.run_cmd(commands.reload_workflow(schd)) # check xtrigs post-reload diff --git a/tests/integration/utils/flow_tools.py b/tests/integration/utils/flow_tools.py index 86377bfaf50..ba5e3d0470d 100644 --- a/tests/integration/utils/flow_tools.py +++ b/tests/integration/utils/flow_tools.py @@ -31,7 +31,6 @@ from secrets import token_hex from cylc.flow import CYLC_LOG -from cylc.flow.run_modes import RunMode from cylc.flow.workflow_files import WorkflowFiles from cylc.flow.scheduler import Scheduler, SchedulerStop from cylc.flow.scheduler_cli import RunOptions @@ -56,7 +55,7 @@ def _make_flow( test_dir: Path, conf: Union[dict, str], name: Optional[str] = None, - id_: Optional[str] = None, + workflow_id: Optional[str] = None, defaults: Optional[bool] = True, filename: str = WorkflowFiles.FLOW_FILE, ) -> str: @@ -70,14 +69,14 @@ def _make_flow( Set false for Cylc 7 upgrader tests. """ - if id_: - flow_run_dir = (cylc_run_dir / id_) + if workflow_id: + flow_run_dir = (cylc_run_dir / workflow_id) else: if name is None: name = token_hex(4) flow_run_dir = (test_dir / name) flow_run_dir.mkdir(parents=True, exist_ok=True) - id_ = str(flow_run_dir.relative_to(cylc_run_dir)) + workflow_id = str(flow_run_dir.relative_to(cylc_run_dir)) if isinstance(conf, str): conf = { 'scheduling': { @@ -100,7 +99,7 @@ def _make_flow( with open((flow_run_dir / filename), 'w+') as flow_file: flow_file.write(flow_config_str(conf)) - return id_ + return workflow_id @contextmanager