Skip to content

Commit

Permalink
support selecting steps by package name for strun (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram authored Oct 25, 2024
2 parents 799be65 + fd68d87 commit 3ed69b7
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 9 deletions.
1 change: 1 addition & 0 deletions changes/202.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow class aliases (used during strun) to contain the package name (for example "jwst::resample").
4 changes: 2 additions & 2 deletions src/stpipe/cli/strun.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys

from stpipe import Step
from stpipe import cmdline
from stpipe.cli.main import _print_versions
from stpipe.exceptions import StpipeExitException

Expand All @@ -21,7 +21,7 @@ def main():
sys.exit(0)

try:
Step.from_cmdline(sys.argv[1:])
cmdline.step_from_cmdline(sys.argv[1:])
except StpipeExitException as e:
sys.exit(e.exit_status)
except Exception:
Expand Down
4 changes: 2 additions & 2 deletions src/stpipe/entry_points.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from collections import namedtuple

from importlib_metadata import entry_points
import importlib_metadata

STEPS_GROUP = "stpipe.steps"

Expand All @@ -26,7 +26,7 @@ class alias, and the third is a bool indicating whether the class is to be
"""
steps = []

for entry_point in entry_points(group=STEPS_GROUP):
for entry_point in importlib_metadata.entry_points(group=STEPS_GROUP):
package_name = entry_point.dist.name
package_version = entry_point.dist.version
package_steps = []
Expand Down
35 changes: 31 additions & 4 deletions src/stpipe/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,43 @@ def resolve_step_class_alias(name):
Parameters
----------
name : str
If name contains "::" only the package with
a name matching the characters before "::"
will be searched for the matching step.
Returns
-------
str
"""
# check if the name contains a package name
if "::" in name:
scope, class_name = name.split("::", maxsplit=1)
else:
scope, class_name = None, name

# track all found steps keyed by package name
found_class_names = {}
for info in entry_points.get_steps():
if info.class_alias is not None and name == info.class_alias:
return info.class_name

return name
if scope and info.package_name != scope:
continue
if info.class_alias is not None and class_name == info.class_alias:
found_class_names[info.package_name] = info

if not found_class_names:
return name

if len(found_class_names) == 1:
return found_class_names.popitem()[1].class_name

# class alias resolved to several possible steps
scopes = list(found_class_names.keys())
msg = (
f"class alias {name} matched more than 1 step. Please provide "
"the package name along with the step name. One of:\n"
)
for scope in scopes:
msg += f" {scope}::{name}\n"
raise ValueError(msg)


def import_class(full_name, subclassof=object, config_file=None):
Expand Down
102 changes: 101 additions & 1 deletion tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from stpipe import Step
from stpipe.utilities import import_class, import_func
from stpipe.utilities import import_class, import_func, resolve_step_class_alias


def what_is_your_quest():
Expand All @@ -13,6 +13,8 @@ class HovercraftFullOfEels:


class Foo(Step):
class_alias = "foo_step"

def process(self, input_data):
pass

Expand Down Expand Up @@ -52,3 +54,101 @@ def test_import_class_no_module():
def test_import_func_no_module():
with pytest.raises(ImportError):
import_func("foo")


@pytest.fixture()
def mock_entry_points(monkeypatch, request):
# as the test class above isn't registered via an entry point
# we mock the entry points here
class FakeDist:
def __init__(self, name):
self.name = name
self.version = "dev"

class FakeEntryPoint:
def __init__(self, dist_name, steps):
self.dist = FakeDist(dist_name)
self.steps = steps

def load(self):
def loader():
return self.steps

return loader

def fake_entrypoints(group=None):
return [FakeEntryPoint(k, v) for k, v in request.param.items()]

import importlib_metadata

monkeypatch.setattr(importlib_metadata, "entry_points", fake_entrypoints)
yield


@pytest.mark.parametrize("name", ("foo_step", "stpipe::foo_step"))
@pytest.mark.parametrize(
"mock_entry_points", [{"stpipe": [("Foo", "foo_step", False)]}], indirect=True
)
def test_class_alias_lookup(name, mock_entry_points):
"""
Test that a step name can be resolved if either:
- only a single step is found that matches
- a step is found and a valid package name was provided
"""
assert resolve_step_class_alias(name) == "Foo"


@pytest.mark.parametrize("name", ("bar_step", "other_package::foo_step"))
@pytest.mark.parametrize(
"mock_entry_points", [{"stpipe": [("Foo", "foo_step", False)]}], indirect=True
)
def test_class_alias_lookup_fallthrough(name, mock_entry_points):
"""
Test that passing in an unknown class alias or an alias scoped
to a different package falls through to returning the unresolved
class_alias (to match previous behavior).
"""
assert resolve_step_class_alias(name) == name


@pytest.mark.parametrize("name", ("aaa::foo_step", "zzz::foo_step"))
@pytest.mark.parametrize(
"mock_entry_points",
[
{
"aaa": [("Foo", "foo_step", False)],
"zzz": [("Foo", "foo_step", False)],
}
],
indirect=True,
)
def test_class_alias_lookup_scoped(name, mock_entry_points):
"""
Test the lookup succeeds if more than 1 package
provides a matching step name but the "scope" (package name)
is provided on lookup.
"""
assert resolve_step_class_alias(name) == "Foo"


@pytest.mark.parametrize(
"mock_entry_points",
[
{
"aaa": [("Foo", "foo_step", False)],
"zzz": [("Foo", "foo_step", False)],
}
],
indirect=True,
)
def test_class_alias_lookup_conflict(mock_entry_points):
"""
Test that an ambiguous lookup (a class alias that resolves
to more than 1 step from different packages) results in
an error.
When the package name is provided, tes
"""
with pytest.raises(ValueError) as err:
resolve_step_class_alias("foo_step")
assert err.match("aaa::foo_step")
assert err.match("zzz::foo_step")

0 comments on commit 3ed69b7

Please sign in to comment.