Skip to content

Commit

Permalink
raise ValueError if class_alias lookup returns more than 1 step
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram committed Oct 24, 2024
1 parent 4cc46da commit fd68d87
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 21 deletions.
20 changes: 18 additions & 2 deletions src/stpipe/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,29 @@ def resolve_step_class_alias(name):
else:
scope, class_name = None, name

# track all found steps keyed by package name
found_class_names = {}
for info in entry_points.get_steps():
if scope and info.package_name != scope:
continue
if info.class_alias is not None and class_name == info.class_alias:
return info.class_name
found_class_names[info.package_name] = info

return name
if not found_class_names:
return name

if len(found_class_names) == 1:
return found_class_names.popitem()[1].class_name

# class alias resolved to several possible steps
scopes = list(found_class_names.keys())
msg = (
f"class alias {name} matched more than 1 step. Please provide "
"the package name along with the step name. One of:\n"
)
for scope in scopes:
msg += f" {scope}::{name}\n"
raise ValueError(msg)


def import_class(full_name, subclassof=object, config_file=None):
Expand Down
98 changes: 79 additions & 19 deletions tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,39 +56,99 @@ def test_import_func_no_module():
import_func("foo")


@pytest.mark.parametrize(
"name, resolve",
(
("foo_step", True),
("stpipe::foo_step", True),
("some_other_package::foo_step", False),
),
)
def test_class_alias_lookup(name, resolve, monkeypatch):
@pytest.fixture()
def mock_entry_points(monkeypatch, request):
# as the test class above isn't registered via an entry point
# we mock the entry points here
class FakeDist:
name = "stpipe"
version = "dev"
def __init__(self, name):
self.name = name
self.version = "dev"

class FakeEntryPoint:
dist = FakeDist()
def __init__(self, dist_name, steps):
self.dist = FakeDist(dist_name)
self.steps = steps

def load(self):
def loader():
return [("Foo", "foo_step", False)]
return self.steps

return loader

def fake_entrypoints(group=None):
return [FakeEntryPoint()]
return [FakeEntryPoint(k, v) for k, v in request.param.items()]

import importlib_metadata

monkeypatch.setattr(importlib_metadata, "entry_points", fake_entrypoints)
yield


resolved_name = resolve_step_class_alias(name)
if resolve:
assert resolved_name == Foo.__name__
else:
assert resolved_name == name
@pytest.mark.parametrize("name", ("foo_step", "stpipe::foo_step"))
@pytest.mark.parametrize(
"mock_entry_points", [{"stpipe": [("Foo", "foo_step", False)]}], indirect=True
)
def test_class_alias_lookup(name, mock_entry_points):
"""
Test that a step name can be resolved if either:
- only a single step is found that matches
- a step is found and a valid package name was provided
"""
assert resolve_step_class_alias(name) == "Foo"


@pytest.mark.parametrize("name", ("bar_step", "other_package::foo_step"))
@pytest.mark.parametrize(
"mock_entry_points", [{"stpipe": [("Foo", "foo_step", False)]}], indirect=True
)
def test_class_alias_lookup_fallthrough(name, mock_entry_points):
"""
Test that passing in an unknown class alias or an alias scoped
to a different package falls through to returning the unresolved
class_alias (to match previous behavior).
"""
assert resolve_step_class_alias(name) == name


@pytest.mark.parametrize("name", ("aaa::foo_step", "zzz::foo_step"))
@pytest.mark.parametrize(
"mock_entry_points",
[
{
"aaa": [("Foo", "foo_step", False)],
"zzz": [("Foo", "foo_step", False)],
}
],
indirect=True,
)
def test_class_alias_lookup_scoped(name, mock_entry_points):
"""
Test the lookup succeeds if more than 1 package
provides a matching step name but the "scope" (package name)
is provided on lookup.
"""
assert resolve_step_class_alias(name) == "Foo"


@pytest.mark.parametrize(
"mock_entry_points",
[
{
"aaa": [("Foo", "foo_step", False)],
"zzz": [("Foo", "foo_step", False)],
}
],
indirect=True,
)
def test_class_alias_lookup_conflict(mock_entry_points):
"""
Test that an ambiguous lookup (a class alias that resolves
to more than 1 step from different packages) results in
an error.
When the package name is provided, tes
"""
with pytest.raises(ValueError) as err:
resolve_step_class_alias("foo_step")
assert err.match("aaa::foo_step")
assert err.match("zzz::foo_step")

0 comments on commit fd68d87

Please sign in to comment.