diff --git a/cylc/flow/command_validation.py b/cylc/flow/command_validation.py
index e5f87e85ae6..d87c0711a8d 100644
--- a/cylc/flow/command_validation.py
+++ b/cylc/flow/command_validation.py
@@ -30,7 +30,7 @@
ERR_OPT_FLOW_VAL = "Flow values must be an integer, or 'all', 'new', or 'none'"
-ERR_OPT_FLOW_INT = "Multiple flow options must all be integer valued"
+ERR_OPT_FLOW_COMBINE = "Cannot combine --flow={0} with other flow values"
ERR_OPT_FLOW_WAIT = (
f"--wait is not compatible with --flow={FLOW_NEW} or --flow={FLOW_NONE}"
)
@@ -51,7 +51,8 @@ def flow_opts(flows: List[str], flow_wait: bool) -> None:
Bad:
>>> flow_opts(["none", "1"], False)
Traceback (most recent call last):
- cylc.flow.exceptions.InputError: ... must all be integer valued
+ cylc.flow.exceptions.InputError: Cannot combine --flow=none with other
+ flow values
>>> flow_opts(["cheese", "2"], True)
Traceback (most recent call last):
@@ -59,24 +60,26 @@ def flow_opts(flows: List[str], flow_wait: bool) -> None:
>>> flow_opts(["new"], True)
Traceback (most recent call last):
- cylc.flow.exceptions.InputError: ...
+ cylc.flow.exceptions.InputError: --wait is not compatible with
+ --flow=new or --flow=none
"""
if not flows:
return
+ flows = [val.strip() for val in flows]
+
for val in flows:
- val = val.strip()
- if val in [FLOW_NONE, FLOW_NEW, FLOW_ALL]:
+ if val in {FLOW_NONE, FLOW_NEW, FLOW_ALL}:
if len(flows) != 1:
- raise InputError(ERR_OPT_FLOW_INT)
+ raise InputError(ERR_OPT_FLOW_COMBINE.format(val))
else:
try:
int(val)
except ValueError:
- raise InputError(ERR_OPT_FLOW_VAL.format(val))
+ raise InputError(ERR_OPT_FLOW_VAL)
- if flow_wait and flows[0] in [FLOW_NEW, FLOW_NONE]:
+ if flow_wait and flows[0] in {FLOW_NEW, FLOW_NONE}:
raise InputError(ERR_OPT_FLOW_WAIT)
diff --git a/cylc/flow/task_pool.py b/cylc/flow/task_pool.py
index b279b3bf25f..38eaeb9dede 100644
--- a/cylc/flow/task_pool.py
+++ b/cylc/flow/task_pool.py
@@ -1915,17 +1915,9 @@ def set_prereqs_and_outputs(
warn=False,
)
- if flow == [FLOW_NEW]:
- # Translate --flow=new to an actual flow number now to avoid
- # incrementing it twice below.
- flow = [
- str(
- self.flow_mgr.get_flow_num(meta=flow_descr)
- )
- ]
+ flow_nums = self._get_flow_nums(flow, flow_descr)
# Set existing task proxies.
- flow_nums = self._get_flow_nums(flow, flow_descr, active=True)
for itask in itasks:
self.merge_flows(itask, flow_nums)
if prereqs:
@@ -1934,7 +1926,9 @@ def set_prereqs_and_outputs(
self._set_outputs_itask(itask, outputs)
# Spawn and set future tasks.
- flow_nums = self._get_flow_nums(flow, flow_descr, active=False)
+ if not flow:
+ # default: assign to all active flows
+ flow_nums = self._get_active_flow_nums()
for name, point in future_tasks:
tdef = self.config.get_taskdef(name)
if prereqs:
@@ -2070,51 +2064,30 @@ def remove_tasks(self, items):
return len(bad_items)
def _get_flow_nums(
- self,
- flow: List[str],
- meta: Optional[str] = None,
- active: bool = False
+ self,
+ flow: List[str],
+ meta: Optional[str] = None,
) -> Set[int]:
"""Return flow numbers corresponding to user command options.
Arg should have been validated already during command validation.
- Call this method separately for active (n=0) and future tasks.
- - future tasks: assign the result to the new task
- - active tasks: merge the result with existing flow numbers
-
- Note if a single command results in two calls to this method (for
- active and future tasks), translate --flow=new to an actual flow
- number first, to avoid incrementing the flow counter twice.
-
- The result is different in the default case (no --flow option):
- - future tasks: return all active flows
- - active tasks: stick with the existing flows (so return empty set).
+ In the default case (--flow option not provided), stick with the
+ existing flows (so return empty set) - NOTE this only applies for
+ active tasks.
"""
- if not flow:
- # default (i.e. no --flow option was used)
- if active:
- # active tasks: stick with the existing flow
- flow_nums = set()
- else:
- # future tasks: assign to all active flows
- flow_nums = self._get_active_flow_nums()
- elif flow == [FLOW_NONE]:
- flow_nums = set()
- elif flow == [FLOW_ALL]:
- flow_nums = self._get_active_flow_nums()
- elif flow == [FLOW_NEW]:
- flow_nums = {self.flow_mgr.get_flow_num(meta=meta)}
- else:
- # specific flow numbers
- flow_nums = {
- self.flow_mgr.get_flow_num(
- flow_num=int(n), meta=meta
- )
- for n in flow
- }
- return flow_nums
+ if flow == [FLOW_NONE]:
+ return set()
+ if flow == [FLOW_ALL]:
+ return self._get_active_flow_nums()
+ if flow == [FLOW_NEW]:
+ return {self.flow_mgr.get_flow_num(meta=meta)}
+ # else specific flow numbers:
+ return {
+ self.flow_mgr.get_flow_num(flow_num=int(n), meta=meta)
+ for n in flow
+ }
def _force_trigger(self, itask):
"""Assumes task is in the pool"""
@@ -2182,17 +2155,9 @@ def force_trigger_tasks(
items, future=True, warn=False,
)
- if flow == [FLOW_NEW]:
- # Translate --flow=new to an actual flow number now to avoid
- # incrementing it twice below.
- flow = [
- str(
- self.flow_mgr.get_flow_num(meta=flow_descr)
- )
- ]
+ flow_nums = self._get_flow_nums(flow, flow_descr)
# Trigger active tasks.
- flow_nums = self._get_flow_nums(flow, flow_descr, active=True)
for itask in existing_tasks:
if itask.state(TASK_STATUS_PREPARING, *TASK_STATUSES_ACTIVE):
LOG.warning(f"[{itask}] ignoring trigger - already active")
@@ -2201,7 +2166,9 @@ def force_trigger_tasks(
self._force_trigger(itask)
# Spawn and trigger future tasks.
- flow_nums = self._get_flow_nums(flow, flow_descr, active=False)
+ if not flow:
+ # default: assign to all active flows
+ flow_nums = self._get_active_flow_nums()
for name, point in future_ids:
if not self.can_be_spawned(name, point):
continue
diff --git a/tests/integration/test_trigger.py b/tests/integration/test_trigger.py
index 3f6b5dee138..d9c5304b745 100644
--- a/tests/integration/test_trigger.py
+++ b/tests/integration/test_trigger.py
@@ -14,33 +14,9 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-import logging
-
-from cylc.flow.flow_mgr import FLOW_ALL, FLOW_NEW, FLOW_NONE
-from cylc.flow.command_validation import flow_opts
-from cylc.flow.exceptions import InputError
-
-import pytest
import time
-
-@pytest.mark.parametrize(
- 'flow_strs',
- (
- [FLOW_ALL, '1'],
- ['1', FLOW_ALL],
- [FLOW_NEW, '1'],
- [FLOW_NONE, '1'],
- ['a'],
- ['1', 'a'],
- )
-)
-async def test_trigger_invalid(mod_one, start, log_filter, flow_strs):
- """Ensure invalid flow values are rejected during command validation."""
- async with start(mod_one) as log:
- log.clear()
- with pytest.raises(InputError):
- flow_opts(flow_strs, False)
+from cylc.flow.flow_mgr import FLOW_ALL
async def test_trigger_no_flows(one, start, log_filter):
diff --git a/tests/unit/test_command_validation.py b/tests/unit/test_command_validation.py
new file mode 100644
index 00000000000..42fdda5aedf
--- /dev/null
+++ b/tests/unit/test_command_validation.py
@@ -0,0 +1,41 @@
+# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE.
+# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import pytest
+
+from cylc.flow.command_validation import (
+ ERR_OPT_FLOW_COMBINE,
+ ERR_OPT_FLOW_VAL,
+ flow_opts,
+)
+from cylc.flow.exceptions import InputError
+from cylc.flow.flow_mgr import FLOW_ALL, FLOW_NEW, FLOW_NONE
+
+
+@pytest.mark.parametrize('flow_strs, expected_msg', [
+ ([FLOW_ALL, '1'], ERR_OPT_FLOW_COMBINE.format(FLOW_ALL)),
+ (['1', FLOW_ALL], ERR_OPT_FLOW_COMBINE.format(FLOW_ALL)),
+ ([FLOW_NEW, '1'], ERR_OPT_FLOW_COMBINE.format(FLOW_NEW)),
+ ([FLOW_NONE, '1'], ERR_OPT_FLOW_COMBINE.format(FLOW_NONE)),
+ ([FLOW_NONE, FLOW_ALL], ERR_OPT_FLOW_COMBINE.format(FLOW_NONE)),
+ (['a'], ERR_OPT_FLOW_VAL),
+ (['1', 'a'], ERR_OPT_FLOW_VAL),
+])
+async def test_trigger_invalid(flow_strs, expected_msg):
+ """Ensure invalid flow values are rejected during command validation."""
+ with pytest.raises(InputError) as exc_info:
+ flow_opts(flow_strs, False)
+ assert str(exc_info.value) == expected_msg