diff --git a/changes/202.feature.rst b/changes/202.feature.rst new file mode 100644 index 00000000..1967d26c --- /dev/null +++ b/changes/202.feature.rst @@ -0,0 +1 @@ +Allow class aliases (used during strun) to contain the package name (for example "jwst::resample"). diff --git a/src/stpipe/cli/strun.py b/src/stpipe/cli/strun.py index 6ba31fbb..7c95bbde 100755 --- a/src/stpipe/cli/strun.py +++ b/src/stpipe/cli/strun.py @@ -1,6 +1,6 @@ import sys -from stpipe import Step +from stpipe import cmdline from stpipe.cli.main import _print_versions from stpipe.exceptions import StpipeExitException @@ -21,7 +21,7 @@ def main(): sys.exit(0) try: - Step.from_cmdline(sys.argv[1:]) + cmdline.step_from_cmdline(sys.argv[1:]) except StpipeExitException as e: sys.exit(e.exit_status) except Exception: diff --git a/src/stpipe/entry_points.py b/src/stpipe/entry_points.py index 133fa6ac..323abfce 100644 --- a/src/stpipe/entry_points.py +++ b/src/stpipe/entry_points.py @@ -1,7 +1,7 @@ import warnings from collections import namedtuple -from importlib_metadata import entry_points +import importlib_metadata STEPS_GROUP = "stpipe.steps" @@ -26,7 +26,7 @@ class alias, and the third is a bool indicating whether the class is to be """ steps = [] - for entry_point in entry_points(group=STEPS_GROUP): + for entry_point in importlib_metadata.entry_points(group=STEPS_GROUP): package_name = entry_point.dist.name package_version = entry_point.dist.version package_steps = [] diff --git a/src/stpipe/utilities.py b/src/stpipe/utilities.py index 1a069c0b..5ba436c0 100644 --- a/src/stpipe/utilities.py +++ b/src/stpipe/utilities.py @@ -19,16 +19,43 @@ def resolve_step_class_alias(name): Parameters ---------- name : str + If name contains "::" only the package with + a name matching the characters before "::" + will be searched for the matching step. Returns ------- str """ + # check if the name contains a package name + if "::" in name: + scope, class_name = name.split("::", maxsplit=1) + else: + scope, class_name = None, name + + # track all found steps keyed by package name + found_class_names = {} for info in entry_points.get_steps(): - if info.class_alias is not None and name == info.class_alias: - return info.class_name - - return name + if scope and info.package_name != scope: + continue + if info.class_alias is not None and class_name == info.class_alias: + found_class_names[info.package_name] = info + + if not found_class_names: + return name + + if len(found_class_names) == 1: + return found_class_names.popitem()[1].class_name + + # class alias resolved to several possible steps + scopes = list(found_class_names.keys()) + msg = ( + f"class alias {name} matched more than 1 step. Please provide " + "the package name along with the step name. One of:\n" + ) + for scope in scopes: + msg += f" {scope}::{name}\n" + raise ValueError(msg) def import_class(full_name, subclassof=object, config_file=None): diff --git a/tests/test_utilities.py b/tests/test_utilities.py index f466618b..9ca34e64 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -1,7 +1,7 @@ import pytest from stpipe import Step -from stpipe.utilities import import_class, import_func +from stpipe.utilities import import_class, import_func, resolve_step_class_alias def what_is_your_quest(): @@ -13,6 +13,8 @@ class HovercraftFullOfEels: class Foo(Step): + class_alias = "foo_step" + def process(self, input_data): pass @@ -52,3 +54,101 @@ def test_import_class_no_module(): def test_import_func_no_module(): with pytest.raises(ImportError): import_func("foo") + + +@pytest.fixture() +def mock_entry_points(monkeypatch, request): + # as the test class above isn't registered via an entry point + # we mock the entry points here + class FakeDist: + def __init__(self, name): + self.name = name + self.version = "dev" + + class FakeEntryPoint: + def __init__(self, dist_name, steps): + self.dist = FakeDist(dist_name) + self.steps = steps + + def load(self): + def loader(): + return self.steps + + return loader + + def fake_entrypoints(group=None): + return [FakeEntryPoint(k, v) for k, v in request.param.items()] + + import importlib_metadata + + monkeypatch.setattr(importlib_metadata, "entry_points", fake_entrypoints) + yield + + +@pytest.mark.parametrize("name", ("foo_step", "stpipe::foo_step")) +@pytest.mark.parametrize( + "mock_entry_points", [{"stpipe": [("Foo", "foo_step", False)]}], indirect=True +) +def test_class_alias_lookup(name, mock_entry_points): + """ + Test that a step name can be resolved if either: + - only a single step is found that matches + - a step is found and a valid package name was provided + """ + assert resolve_step_class_alias(name) == "Foo" + + +@pytest.mark.parametrize("name", ("bar_step", "other_package::foo_step")) +@pytest.mark.parametrize( + "mock_entry_points", [{"stpipe": [("Foo", "foo_step", False)]}], indirect=True +) +def test_class_alias_lookup_fallthrough(name, mock_entry_points): + """ + Test that passing in an unknown class alias or an alias scoped + to a different package falls through to returning the unresolved + class_alias (to match previous behavior). + """ + assert resolve_step_class_alias(name) == name + + +@pytest.mark.parametrize("name", ("aaa::foo_step", "zzz::foo_step")) +@pytest.mark.parametrize( + "mock_entry_points", + [ + { + "aaa": [("Foo", "foo_step", False)], + "zzz": [("Foo", "foo_step", False)], + } + ], + indirect=True, +) +def test_class_alias_lookup_scoped(name, mock_entry_points): + """ + Test the lookup succeeds if more than 1 package + provides a matching step name but the "scope" (package name) + is provided on lookup. + """ + assert resolve_step_class_alias(name) == "Foo" + + +@pytest.mark.parametrize( + "mock_entry_points", + [ + { + "aaa": [("Foo", "foo_step", False)], + "zzz": [("Foo", "foo_step", False)], + } + ], + indirect=True, +) +def test_class_alias_lookup_conflict(mock_entry_points): + """ + Test that an ambiguous lookup (a class alias that resolves + to more than 1 step from different packages) results in + an error. + When the package name is provided, tes + """ + with pytest.raises(ValueError) as err: + resolve_step_class_alias("foo_step") + assert err.match("aaa::foo_step") + assert err.match("zzz::foo_step")