diff --git a/cylc/flow/command_validation.py b/cylc/flow/command_validation.py index e5f87e85ae6..d87c0711a8d 100644 --- a/cylc/flow/command_validation.py +++ b/cylc/flow/command_validation.py @@ -30,7 +30,7 @@ ERR_OPT_FLOW_VAL = "Flow values must be an integer, or 'all', 'new', or 'none'" -ERR_OPT_FLOW_INT = "Multiple flow options must all be integer valued" +ERR_OPT_FLOW_COMBINE = "Cannot combine --flow={0} with other flow values" ERR_OPT_FLOW_WAIT = ( f"--wait is not compatible with --flow={FLOW_NEW} or --flow={FLOW_NONE}" ) @@ -51,7 +51,8 @@ def flow_opts(flows: List[str], flow_wait: bool) -> None: Bad: >>> flow_opts(["none", "1"], False) Traceback (most recent call last): - cylc.flow.exceptions.InputError: ... must all be integer valued + cylc.flow.exceptions.InputError: Cannot combine --flow=none with other + flow values >>> flow_opts(["cheese", "2"], True) Traceback (most recent call last): @@ -59,24 +60,26 @@ def flow_opts(flows: List[str], flow_wait: bool) -> None: >>> flow_opts(["new"], True) Traceback (most recent call last): - cylc.flow.exceptions.InputError: ... + cylc.flow.exceptions.InputError: --wait is not compatible with + --flow=new or --flow=none """ if not flows: return + flows = [val.strip() for val in flows] + for val in flows: - val = val.strip() - if val in [FLOW_NONE, FLOW_NEW, FLOW_ALL]: + if val in {FLOW_NONE, FLOW_NEW, FLOW_ALL}: if len(flows) != 1: - raise InputError(ERR_OPT_FLOW_INT) + raise InputError(ERR_OPT_FLOW_COMBINE.format(val)) else: try: int(val) except ValueError: - raise InputError(ERR_OPT_FLOW_VAL.format(val)) + raise InputError(ERR_OPT_FLOW_VAL) - if flow_wait and flows[0] in [FLOW_NEW, FLOW_NONE]: + if flow_wait and flows[0] in {FLOW_NEW, FLOW_NONE}: raise InputError(ERR_OPT_FLOW_WAIT) diff --git a/cylc/flow/task_pool.py b/cylc/flow/task_pool.py index b279b3bf25f..38eaeb9dede 100644 --- a/cylc/flow/task_pool.py +++ b/cylc/flow/task_pool.py @@ -1915,17 +1915,9 @@ def set_prereqs_and_outputs( warn=False, ) - if flow == [FLOW_NEW]: - # Translate --flow=new to an actual flow number now to avoid - # incrementing it twice below. - flow = [ - str( - self.flow_mgr.get_flow_num(meta=flow_descr) - ) - ] + flow_nums = self._get_flow_nums(flow, flow_descr) # Set existing task proxies. - flow_nums = self._get_flow_nums(flow, flow_descr, active=True) for itask in itasks: self.merge_flows(itask, flow_nums) if prereqs: @@ -1934,7 +1926,9 @@ def set_prereqs_and_outputs( self._set_outputs_itask(itask, outputs) # Spawn and set future tasks. - flow_nums = self._get_flow_nums(flow, flow_descr, active=False) + if not flow: + # default: assign to all active flows + flow_nums = self._get_active_flow_nums() for name, point in future_tasks: tdef = self.config.get_taskdef(name) if prereqs: @@ -2070,51 +2064,30 @@ def remove_tasks(self, items): return len(bad_items) def _get_flow_nums( - self, - flow: List[str], - meta: Optional[str] = None, - active: bool = False + self, + flow: List[str], + meta: Optional[str] = None, ) -> Set[int]: """Return flow numbers corresponding to user command options. Arg should have been validated already during command validation. - Call this method separately for active (n=0) and future tasks. - - future tasks: assign the result to the new task - - active tasks: merge the result with existing flow numbers - - Note if a single command results in two calls to this method (for - active and future tasks), translate --flow=new to an actual flow - number first, to avoid incrementing the flow counter twice. - - The result is different in the default case (no --flow option): - - future tasks: return all active flows - - active tasks: stick with the existing flows (so return empty set). + In the default case (--flow option not provided), stick with the + existing flows (so return empty set) - NOTE this only applies for + active tasks. """ - if not flow: - # default (i.e. no --flow option was used) - if active: - # active tasks: stick with the existing flow - flow_nums = set() - else: - # future tasks: assign to all active flows - flow_nums = self._get_active_flow_nums() - elif flow == [FLOW_NONE]: - flow_nums = set() - elif flow == [FLOW_ALL]: - flow_nums = self._get_active_flow_nums() - elif flow == [FLOW_NEW]: - flow_nums = {self.flow_mgr.get_flow_num(meta=meta)} - else: - # specific flow numbers - flow_nums = { - self.flow_mgr.get_flow_num( - flow_num=int(n), meta=meta - ) - for n in flow - } - return flow_nums + if flow == [FLOW_NONE]: + return set() + if flow == [FLOW_ALL]: + return self._get_active_flow_nums() + if flow == [FLOW_NEW]: + return {self.flow_mgr.get_flow_num(meta=meta)} + # else specific flow numbers: + return { + self.flow_mgr.get_flow_num(flow_num=int(n), meta=meta) + for n in flow + } def _force_trigger(self, itask): """Assumes task is in the pool""" @@ -2182,17 +2155,9 @@ def force_trigger_tasks( items, future=True, warn=False, ) - if flow == [FLOW_NEW]: - # Translate --flow=new to an actual flow number now to avoid - # incrementing it twice below. - flow = [ - str( - self.flow_mgr.get_flow_num(meta=flow_descr) - ) - ] + flow_nums = self._get_flow_nums(flow, flow_descr) # Trigger active tasks. - flow_nums = self._get_flow_nums(flow, flow_descr, active=True) for itask in existing_tasks: if itask.state(TASK_STATUS_PREPARING, *TASK_STATUSES_ACTIVE): LOG.warning(f"[{itask}] ignoring trigger - already active") @@ -2201,7 +2166,9 @@ def force_trigger_tasks( self._force_trigger(itask) # Spawn and trigger future tasks. - flow_nums = self._get_flow_nums(flow, flow_descr, active=False) + if not flow: + # default: assign to all active flows + flow_nums = self._get_active_flow_nums() for name, point in future_ids: if not self.can_be_spawned(name, point): continue diff --git a/tests/integration/test_trigger.py b/tests/integration/test_trigger.py index 3f6b5dee138..d9c5304b745 100644 --- a/tests/integration/test_trigger.py +++ b/tests/integration/test_trigger.py @@ -14,33 +14,9 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import logging - -from cylc.flow.flow_mgr import FLOW_ALL, FLOW_NEW, FLOW_NONE -from cylc.flow.command_validation import flow_opts -from cylc.flow.exceptions import InputError - -import pytest import time - -@pytest.mark.parametrize( - 'flow_strs', - ( - [FLOW_ALL, '1'], - ['1', FLOW_ALL], - [FLOW_NEW, '1'], - [FLOW_NONE, '1'], - ['a'], - ['1', 'a'], - ) -) -async def test_trigger_invalid(mod_one, start, log_filter, flow_strs): - """Ensure invalid flow values are rejected during command validation.""" - async with start(mod_one) as log: - log.clear() - with pytest.raises(InputError): - flow_opts(flow_strs, False) +from cylc.flow.flow_mgr import FLOW_ALL async def test_trigger_no_flows(one, start, log_filter): diff --git a/tests/unit/test_command_validation.py b/tests/unit/test_command_validation.py new file mode 100644 index 00000000000..42fdda5aedf --- /dev/null +++ b/tests/unit/test_command_validation.py @@ -0,0 +1,41 @@ +# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE. +# Copyright (C) NIWA & British Crown (Met Office) & Contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import pytest + +from cylc.flow.command_validation import ( + ERR_OPT_FLOW_COMBINE, + ERR_OPT_FLOW_VAL, + flow_opts, +) +from cylc.flow.exceptions import InputError +from cylc.flow.flow_mgr import FLOW_ALL, FLOW_NEW, FLOW_NONE + + +@pytest.mark.parametrize('flow_strs, expected_msg', [ + ([FLOW_ALL, '1'], ERR_OPT_FLOW_COMBINE.format(FLOW_ALL)), + (['1', FLOW_ALL], ERR_OPT_FLOW_COMBINE.format(FLOW_ALL)), + ([FLOW_NEW, '1'], ERR_OPT_FLOW_COMBINE.format(FLOW_NEW)), + ([FLOW_NONE, '1'], ERR_OPT_FLOW_COMBINE.format(FLOW_NONE)), + ([FLOW_NONE, FLOW_ALL], ERR_OPT_FLOW_COMBINE.format(FLOW_NONE)), + (['a'], ERR_OPT_FLOW_VAL), + (['1', 'a'], ERR_OPT_FLOW_VAL), +]) +async def test_trigger_invalid(flow_strs, expected_msg): + """Ensure invalid flow values are rejected during command validation.""" + with pytest.raises(InputError) as exc_info: + flow_opts(flow_strs, False) + assert str(exc_info.value) == expected_msg