diff --git a/cylc/flow/xtrigger_mgr.py b/cylc/flow/xtrigger_mgr.py index 890918c49dc..d42e9a87710 100644 --- a/cylc/flow/xtrigger_mgr.py +++ b/cylc/flow/xtrigger_mgr.py @@ -16,6 +16,7 @@ from contextlib import suppress from enum import Enum +from inspect import signature import json import re from copy import deepcopy @@ -279,12 +280,22 @@ def check_xtrigger( fname, f"'{fname}' not found in xtrigger module '{fname}'", ) + if not callable(func): raise XtriggerConfigError( label, fname, f"'{fname}' not callable in xtrigger module '{fname}'", ) + if func is not wall_clock: + # Validate args and kwargs against the function signature + # (but not for wall_clock because it's a special case). + try: + signature(func).bind(*fctx.func_args, **fctx.func_kwargs) + except TypeError as exc: + raise XtriggerConfigError( + label, fname, f"{fctx.get_signature()}: {exc}" + ) # Check any string templates in the function arg values (note this # won't catch bad task-specific values - which are added dynamically). diff --git a/cylc/flow/xtriggers/xrandom.py b/cylc/flow/xtriggers/xrandom.py index f48789ce152..c61b2114729 100644 --- a/cylc/flow/xtriggers/xrandom.py +++ b/cylc/flow/xtriggers/xrandom.py @@ -112,21 +112,6 @@ def validate(f_args, f_kwargs, f_signature): If f_args used, convert to f_kwargs for clarity. """ - n_args = len(f_args) - n_kwargs = len(f_kwargs) - - if n_args + n_kwargs > 3: - raise WorkflowConfigError(f"Too many args: {f_signature}") - - if n_args + n_kwargs < 1: - raise WorkflowConfigError(f"Wrong number of args: {f_signature}") - - if n_kwargs: - # kwargs must be "secs" and "_" - kw = next(iter(f_kwargs)) - if kw not in ("secs", "_"): - raise WorkflowConfigError(f"Illegal arg '{kw}': {f_signature}") - # convert to kwarg f_kwargs["percent"] = f_args[0] del f_args[0] diff --git a/tests/integration/test_config.py b/tests/integration/test_config.py index df93d50f555..11d7bd483e1 100644 --- a/tests/integration/test_config.py +++ b/tests/integration/test_config.py @@ -16,17 +16,18 @@ from pathlib import Path import sqlite3 -from typing import TYPE_CHECKING +from typing import Any import pytest -from cylc.flow.exceptions import ServiceFileError, WorkflowConfigError +from cylc.flow.exceptions import ( + ServiceFileError, + WorkflowConfigError, + XtriggerConfigError, +) from cylc.flow.parsec.exceptions import ListValueError from cylc.flow.pathutil import get_workflow_run_pub_db_path -if TYPE_CHECKING: - from types import Any - - Fixture = Any +Fixture = Any @pytest.mark.parametrize( @@ -353,7 +354,7 @@ def test_xtrig_validation_wall_clock( flow: 'Fixture', validate: 'Fixture', ): - """If an xtrigger module has a `validate_config` it is called. + """If an xtrigger module has a `validate()` function is called. https://github.com/cylc/cylc-flow/issues/5448 """ @@ -376,14 +377,13 @@ def test_xtrig_validation_echo( flow: 'Fixture', validate: 'Fixture', ): - """If an xtrigger module has a `validate_config` it is called. + """If an xtrigger module has a `validate()` function is called. https://github.com/cylc/cylc-flow/issues/5448 """ id_ = flow({ 'scheduler': {'allow implicit tasks': True}, 'scheduling': { - 'initial cycle point': '1012', 'xtriggers': {'myxt': 'echo()'}, 'graph': {'R1': '@myxt => foo'}, } @@ -399,21 +399,20 @@ def test_xtrig_validation_xrandom( flow: 'Fixture', validate: 'Fixture', ): - """If an xtrigger module has a `validate_config` it is called. + """If an xtrigger module has a `validate()` function it is called. https://github.com/cylc/cylc-flow/issues/5448 """ id_ = flow({ 'scheduler': {'allow implicit tasks': True}, 'scheduling': { - 'initial cycle point': '1012', - 'xtriggers': {'myxt': 'xrandom()'}, + 'xtriggers': {'myxt': 'xrandom(200)'}, 'graph': {'R1': '@myxt => foo'}, } }) with pytest.raises( WorkflowConfigError, - match=r'Wrong number of args: xrandom\(\)' + match=r"'percent' should be a float between 0 and 100:" ): validate(id_) @@ -423,29 +422,30 @@ def test_xtrig_validation_custom( validate: 'Fixture', monkeypatch: 'Fixture', ): - """If an xtrigger module has a `validate_config` + """If an xtrigger module has a `validate()` function an exception is raised if that validate function fails. https://github.com/cylc/cylc-flow/issues/5448 - - Rather than create our own xtrigger module on disk - and attempt to trigger a validation failure we - mock our own exception, xtrigger and xtrigger - validation functions and inject these into the - appropriate locations: """ + # Rather than create our own xtrigger module on disk + # and attempt to trigger a validation failure we + # mock our own exception, xtrigger and xtrigger + # validation functions and inject these into the + # appropriate locations: GreenExc = type('Green', (Exception,), {}) - def kustom_mock(suite): + def kustom_xt(feature): return True, {} def kustom_validate(args, kwargs, sig): raise GreenExc('This is only a test.') + # Patch xtrigger func monkeypatch.setattr( 'cylc.flow.xtrigger_mgr.get_xtrig_func', - lambda *args: kustom_mock, + lambda *args: kustom_xt, ) + # Patch xtrigger's validate func monkeypatch.setattr( 'cylc.flow.config.get_xtrig_func', lambda *args: kustom_validate if "validate" in args else '' @@ -463,3 +463,22 @@ def kustom_validate(args, kwargs, sig): Path(id_) with pytest.raises(GreenExc, match=r'This is only a test.'): validate(id_) + + +def test_xtrig_signature_validation( + flow: 'Fixture', + validate: 'Fixture', +): + """Test automatic xtrigger function signature validation.""" + id_ = flow({ + 'scheduler': {'allow implicit tasks': True}, + 'scheduling': { + 'xtriggers': {'myxt': 'xrandom()'}, + 'graph': {'R1': '@myxt => foo'}, + } + }) + with pytest.raises( + XtriggerConfigError, + match=r"xrandom\(\): missing a required argument: 'percent'" + ): + validate(id_) diff --git a/tests/unit/xtriggers/test_xrandom.py b/tests/unit/xtriggers/test_xrandom.py index 50dacfebcc2..23fcd1b97fd 100644 --- a/tests/unit/xtriggers/test_xrandom.py +++ b/tests/unit/xtriggers/test_xrandom.py @@ -27,13 +27,10 @@ def test_validate_good_path(): @pytest.mark.parametrize( 'args, kwargs, err', ( - param([100], {'f': 1.1, 'b': 1, 'x': 2}, 'Too', id='too-many-args'), - param([], {}, 'Wrong number', id='too-few-args'), - param(['foo'], {}, '\'percent', id='percent-not-numeric'), - param([101], {}, '\'percent', id='percent>100'), - param([-1], {}, '\'percent', id='percent<0'), - param([100], {'egg': 1}, 'Illegal', id='invalid-kwarg'), - param([100], {'secs': 1.1}, "'secs'", id='secs-not-int'), + param(['foo'], {}, r"'percent", id='percent-not-numeric'), + param([101], {}, r"'percent", id='percent>100'), + param([-1], {}, r"'percent", id='percent<0'), + param([100], {'secs': 1.1}, r"'secs'", id='secs-not-int'), ) ) def test_validate_exceptions(args, kwargs, err):