Skip to content

Commit

Permalink
Fix KeyError when using cylc remove after a reload changed the graph
Browse files Browse the repository at this point in the history
  • Loading branch information
MetRonnie committed Dec 5, 2024
1 parent 5b2b180 commit 35819a5
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 72 deletions.
2 changes: 1 addition & 1 deletion cylc/flow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,7 +1543,7 @@ def configure_workflow_state_polling_tasks(self):
"script cannot be defined for automatic" +
" workflow polling task '%s':\n%s" % (l_task, cs))
# Generate the automatic scripting.
for name, tdef in list(self.taskdefs.items()):
for name, tdef in self.taskdefs.items():
if name not in self.workflow_polling_tasks:
continue
rtc = tdef.rtconfig
Expand Down
56 changes: 23 additions & 33 deletions cylc/flow/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
from cylc.flow.broadcast_mgr import BroadcastMgr
from cylc.flow.cfgspec.glbl_cfg import glbl_cfg
from cylc.flow.config import WorkflowConfig
from cylc.flow.cycling.loader import get_point
from cylc.flow.data_store_mgr import DataStoreMgr
from cylc.flow.exceptions import (
CommandFailedError,
Expand All @@ -93,9 +92,7 @@
get_user,
is_remote_platform,
)
from cylc.flow.id import (
Tokens,
)
from cylc.flow.id import Tokens, quick_relative_id
from cylc.flow.log_level import (
verbosity_to_env,
verbosity_to_opts,
Expand Down Expand Up @@ -1102,17 +1099,20 @@ def remove_tasks(

if flow_nums is None:
flow_nums = set()
# Mapping of *relative* task IDs to removed flow numbers:
removed: Dict[Tokens, FlowNums] = {}
not_removed: Set[Tokens] = set()
# Mapping of relative task IDs to removed flow numbers:
removed: Dict[str, FlowNums] = {}
not_removed: Set[str] = set()
# All the matched tasks (will add applicable active tasks below):
matched_tasks = inactive.copy()
to_kill: List[TaskProxy] = []

for itask in active:
fnums_to_remove = itask.match_flows(flow_nums)
if not fnums_to_remove:
not_removed.add(itask.tokens.task)
not_removed.add(itask.identity)
continue
removed[itask.tokens.task] = fnums_to_remove
removed[itask.identity] = fnums_to_remove
matched_tasks.add((itask.tdef, itask.point))
if fnums_to_remove == itask.flow_nums:
# Need to remove the task from the pool.
# Spawn next occurrence of xtrigger sequential task (otherwise
Expand All @@ -1123,21 +1123,13 @@ def remove_tasks(
itask.removed = True
itask.flow_nums.difference_update(fnums_to_remove)

# All the matched tasks (including inactive & applicable active tasks):
matched_tasks = {
*removed.keys(),
*(Tokens(cycle=str(cycle), task=task) for task, cycle in inactive),
}

for tokens in matched_tasks:
tdef = self.config.taskdefs[tokens['task']]
for tdef, point in matched_tasks:
task_id = quick_relative_id(point, tdef.name)

# Go through any tasks downstream of this matched task to see if
# any need to stand down as a result of this task being removed:
for child in set(itertools.chain.from_iterable(
generate_graph_children(
tdef, get_point(tokens['cycle'])
).values()
generate_graph_children(tdef, point).values()
)):
child_itask = self.pool.get_task(child.point, child.name)
if not child_itask:
Expand All @@ -1152,9 +1144,9 @@ def remove_tasks(
):
# Unset any prereqs naturally satisfied by these tasks
# (do not unset those satisfied by `cylc set --pre`):
if prereq.unset_naturally_satisfied(tokens.relative_id):
if prereq.unset_naturally_satisfied(task_id):
prereqs_changed = True
removed.setdefault(tokens, set()).update(
removed.setdefault(task_id, set()).update(
fnums_to_remove
)
if not prereqs_changed:
Expand All @@ -1173,7 +1165,7 @@ def remove_tasks(
# Check if downstream task should remain spawned:
if (
# Ignoring tasks we are already dealing with:
child_itask.tokens.task in matched_tasks
(child_itask.tdef, child_itask.point) in matched_tasks
or child_itask.state.any_satisfied_prerequisite_outputs()
):
continue
Expand All @@ -1187,33 +1179,31 @@ def remove_tasks(

# Remove the matched tasks from the flows in the DB tables:
db_removed_fnums = self.workflow_db_mgr.remove_task_from_flows(
tokens['cycle'], tokens['task'], flow_nums,
str(point), tdef.name, flow_nums,
)
if db_removed_fnums:
removed.setdefault(tokens, set()).update(db_removed_fnums)
removed.setdefault(task_id, set()).update(db_removed_fnums)

if task_id not in removed:
not_removed.add(task_id)

if to_kill:
self.kill_tasks(to_kill, warn=False)

if removed:
tasks_str_list = []
for task, fnums in removed.items():
self.data_store_mgr.delta_remove_task_flow_nums(
task.relative_id, fnums
)
self.data_store_mgr.delta_remove_task_flow_nums(task, fnums)
tasks_str_list.append(
f"{task.relative_id} {repr_flow_nums(fnums, full=True)}"
f"{task} {repr_flow_nums(fnums, full=True)}"
)
LOG.info(f"Removed task(s): {', '.join(sorted(tasks_str_list))}")

not_removed.update(matched_tasks.difference(removed))
if not_removed:
fnums_str = (
repr_flow_nums(flow_nums, full=True) if flow_nums else ''
)
tasks_str = ', '.join(
sorted(tokens.relative_id for tokens in not_removed)
)
tasks_str = ', '.join(sorted(not_removed))
LOG.warning(f"Task(s) not removable: {tasks_str} {fnums_str}")

if removed and self.pool.compute_runahead():
Expand Down
40 changes: 20 additions & 20 deletions cylc/flow/task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Tuple,
Type,
Union,
cast,
)

from cylc.flow import LOG
Expand Down Expand Up @@ -1335,9 +1336,9 @@ def hold_tasks(self, items: Iterable[str]) -> int:
for itask in itasks:
self.hold_active_task(itask)
# Set inactive tasks to be held:
for name, cycle in inactive_tasks:
self.data_store_mgr.delta_task_held(name, cycle, True)
self.tasks_to_hold.update(inactive_tasks)
for tdef, cycle in inactive_tasks:
self.data_store_mgr.delta_task_held(tdef.name, cycle, True)
self.tasks_to_hold.add((tdef.name, cycle))
self.workflow_db_mgr.put_tasks_to_hold(self.tasks_to_hold)
LOG.debug(f"Tasks to hold: {self.tasks_to_hold}")
return len(unmatched)
Expand All @@ -1353,9 +1354,9 @@ def release_held_tasks(self, items: Iterable[str]) -> int:
for itask in itasks:
self.release_held_active_task(itask)
# Unhold inactive tasks:
for name, cycle in inactive_tasks:
self.data_store_mgr.delta_task_held(name, cycle, False)
self.tasks_to_hold.difference_update(inactive_tasks)
for tdef, cycle in inactive_tasks:
self.data_store_mgr.delta_task_held(tdef.name, cycle, False)
self.tasks_to_hold.discard((tdef.name, cycle))
self.workflow_db_mgr.put_tasks_to_hold(self.tasks_to_hold)
LOG.debug(f"Tasks to hold: {self.tasks_to_hold}")
return len(unmatched)
Expand Down Expand Up @@ -1979,8 +1980,7 @@ def set_prereqs_and_outputs(
if not flow:
# default: assign to all active flows
flow_nums = self._get_active_flow_nums()
for name, point in inactive_tasks:
tdef = self.config.get_taskdef(name)
for tdef, point in inactive_tasks:
if prereqs:
self._set_prereqs_tdef(
point, tdef, prereqs, flow_nums, flow_wait)
Expand Down Expand Up @@ -2175,7 +2175,7 @@ def force_trigger_tasks(
"""
# Get matching tasks proxies, and matching inactive task IDs.
existing_tasks, inactive_ids, unmatched = self.filter_task_proxies(
existing_tasks, inactive, unmatched = self.filter_task_proxies(
items, inactive=True, warn_no_active=False,
)

Expand All @@ -2199,15 +2199,15 @@ def force_trigger_tasks(
if not flow:
# default: assign to all active flows
flow_nums = self._get_active_flow_nums()
for name, point in inactive_ids:
if not self.can_be_spawned(name, point):
for tdef, point in inactive:
if not self.can_be_spawned(tdef.name, point):
continue
submit_num, _, prev_fwait = (
self._get_task_history(name, point, flow_nums)
self._get_task_history(tdef.name, point, flow_nums)
)
itask = TaskProxy(
self.tokens,
self.config.get_taskdef(name),
tdef,
point,
flow_nums,
flow_wait=flow_wait,
Expand Down Expand Up @@ -2327,7 +2327,7 @@ def filter_task_proxies(
ids: Iterable[str],
warn_no_active: bool = True,
inactive: bool = False,
) -> 'Tuple[List[TaskProxy], Set[Tuple[str, PointBase]], List[str]]':
) -> 'Tuple[List[TaskProxy], Set[Tuple[TaskDef, PointBase]], List[str]]':
"""Return task proxies that match names, points, states in items.
Args:
Expand All @@ -2353,7 +2353,7 @@ def filter_task_proxies(
ids,
warn=warn_no_active,
)
inactive_matched: 'Set[Tuple[str, PointBase]]' = set()
inactive_matched: 'Set[Tuple[TaskDef, PointBase]]' = set()
if inactive and unmatched:
inactive_matched, unmatched = self.match_inactive_tasks(
unmatched
Expand All @@ -2364,7 +2364,7 @@ def filter_task_proxies(
def match_inactive_tasks(
self,
ids: Iterable[str],
) -> Tuple[Set[Tuple[str, 'PointBase']], List[str]]:
) -> 'Tuple[Set[Tuple[TaskDef, PointBase]], List[str]]':
"""Match task IDs against task definitions (rather than the task pool).
IDs will be matched providing the ID:
Expand All @@ -2377,7 +2377,7 @@ def match_inactive_tasks(
(matched_tasks, unmatched_tasks)
"""
matched_tasks: 'Set[Tuple[str, PointBase]]' = set()
matched_tasks: 'Set[Tuple[TaskDef, PointBase]]' = set()
unmatched_tasks: 'List[str]' = []
for id_ in ids:
try:
Expand All @@ -2404,8 +2404,8 @@ def match_inactive_tasks(
unmatched_tasks.append(id_)
continue

point_str = tokens['cycle']
name_str = tokens['task']
point_str = cast('str', tokens['cycle'])
name_str = cast('str', tokens['task'])
if name_str not in self.config.taskdefs:
if self.config.find_taskdefs(name_str):
# It's a family name; was not matched by active tasks
Expand All @@ -2427,7 +2427,7 @@ def match_inactive_tasks(
point = get_point(point_str)
taskdef = self.config.taskdefs[name_str]
if taskdef.is_valid_point(point):
matched_tasks.add((taskdef.name, point))
matched_tasks.add((taskdef, point))
else:
LOG.warning(
self.ERR_PREFIX_TASK_NOT_ON_SEQUENCE.format(
Expand Down
3 changes: 1 addition & 2 deletions cylc/flow/taskdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class TaskDef:
def __init__(self, name, rtcfg, start_point, initial_point):
if not TaskID.is_valid_name(name):
raise TaskDefError("Illegal task name: %s" % name)

self.name: str = name
self.rtconfig = rtcfg
self.start_point = start_point
self.initial_point = initial_point
Expand All @@ -192,7 +192,6 @@ def __init__(self, name, rtcfg, start_point, initial_point):
self.external_triggers = []
self.xtrig_labels = {} # {sequence: [labels]}

self.name = name
self.elapsed_times = deque(maxlen=self.MAX_LEN_ELAPSED_TIMES)
self._add_std_outputs()
self.has_abs_triggers = False
Expand Down
2 changes: 1 addition & 1 deletion cylc/flow/workflow_db_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def _put_update_task_x(
self, table_name: str, itask: 'TaskProxy', set_args: 'DbArgDict'
) -> None:
"""Put UPDATE statement for a task_* table."""
where_args = {
where_args: Dict[str, Any] = {
"cycle": str(itask.point),
"name": itask.tdef.name,
}
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/test_optional_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
get_completion_expression,
)
from cylc.flow.task_state import (
TASK_STATUSES_ACTIVE,
TASK_STATUS_EXPIRED,
TASK_STATUS_PREPARING,
TASK_STATUS_RUNNING,
Expand Down Expand Up @@ -484,7 +483,7 @@ async def test_removed_taskdef(
'R1': 'a'
}
}
}, id_=id_)
}, workflow_id=id_)

# restart the workflow
schd: 'Scheduler' = scheduler(id_)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def test_reload_failure(
async with start(schd):
# corrupt the config by removing the scheduling section
two_conf = {**one_conf, 'scheduling': {}}
flow(two_conf, id_=id_)
flow(two_conf, workflow_id=id_)

# reload the workflow
await commands.run_cmd(commands.reload_workflow(schd))
Expand Down
35 changes: 35 additions & 0 deletions tests/integration/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from cylc.flow.commands import (
force_trigger_tasks,
reload_workflow,
remove_tasks,
run_cmd,
)
Expand Down Expand Up @@ -478,3 +479,37 @@ async def test_kill_running(flow, scheduler, run, complete, reflog):
('1/c', ('1/b',)),
# The a:failed output should not cause 1/q to run
}


async def test_reload_changed_config(flow, scheduler, run, complete):
"""Test that a task is removed from the pool if its configuration changes
to make it no longer match the graph."""
wid = flow({
'scheduling': {
'graph': {
'R1': '''
a => b
a:started => s & b
''',
},
},
'runtime': {
'a': {
'simulation': {
# Ensure 1/a still in pool during reload
'fail cycle points': 'all',
},
},
},
})
schd: Scheduler = scheduler(wid, paused_start=False)
async with run(schd):
await complete(schd, '1/s')
# Change graph then reload
flow('b', workflow_id=wid)
await run_cmd(reload_workflow(schd))
assert schd.config.cfg['scheduling']['graph']['R1'] == 'b'
assert schd.pool.get_task_ids() == {'1/a', '1/b'}

await run_cmd(remove_tasks(schd, ['1/a'], [FLOW_ALL]))
await complete(schd, '1/b')
2 changes: 1 addition & 1 deletion tests/integration/test_stop_after_cycle_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_db_value(schd) -> Optional[str]:

# change the configured cycle point to "2"
config['scheduling']['stop after cycle point'] = '2'
id_ = flow(config, id_=id_)
id_ = flow(config, workflow_id=id_)
schd = scheduler(id_, paused_start=False)
async with run(schd):
# the cycle point should be reloaded from the workflow configuration
Expand Down
Loading

0 comments on commit 35819a5

Please sign in to comment.