Skip to content

Commit

Permalink
support selecting steps by package name for strun
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram committed Oct 24, 2024
1 parent 799be65 commit c6bf192
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/stpipe/cli/strun.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys

from stpipe import Step
from stpipe import cmdline
from stpipe.cli.main import _print_versions
from stpipe.exceptions import StpipeExitException

Expand All @@ -21,7 +21,7 @@ def main():
sys.exit(0)

try:
Step.from_cmdline(sys.argv[1:])
cmdline.step_from_cmdline(sys.argv[1:])
except StpipeExitException as e:
sys.exit(e.exit_status)
except Exception:
Expand Down
4 changes: 2 additions & 2 deletions src/stpipe/entry_points.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from collections import namedtuple

from importlib_metadata import entry_points
import importlib_metadata

STEPS_GROUP = "stpipe.steps"

Expand All @@ -26,7 +26,7 @@ class alias, and the third is a bool indicating whether the class is to be
"""
steps = []

for entry_point in entry_points(group=STEPS_GROUP):
for entry_point in importlib_metadata.entry_points(group=STEPS_GROUP):
package_name = entry_point.dist.name
package_version = entry_point.dist.version
package_steps = []
Expand Down
13 changes: 12 additions & 1 deletion src/stpipe/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,24 @@ def resolve_step_class_alias(name):
Parameters
----------
name : str
If name contains "::" only the package with
a name matching the characters before "::"
will be searched for the matching step.
Returns
-------
str
"""
# check if the name contains a package name
if "::" in name:
scope, class_name = name.split("::", maxsplit=1)
else:
scope, class_name = None, name

for info in entry_points.get_steps():
if info.class_alias is not None and name == info.class_alias:
if scope and info.package_name != scope:
continue
if info.class_alias is not None and class_name == info.class_alias:
return info.class_name

return name
Expand Down
42 changes: 41 additions & 1 deletion tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from stpipe import Step
from stpipe.utilities import import_class, import_func
from stpipe.utilities import import_class, import_func, resolve_step_class_alias


def what_is_your_quest():
Expand All @@ -13,6 +13,8 @@ class HovercraftFullOfEels:


class Foo(Step):
class_alias = "foo_step"

def process(self, input_data):
pass

Expand Down Expand Up @@ -52,3 +54,41 @@ def test_import_class_no_module():
def test_import_func_no_module():
with pytest.raises(ImportError):
import_func("foo")


@pytest.mark.parametrize(
"name, resolve",
(
("foo_step", True),
("stpipe::foo_step", True),
("some_other_package::foo_step", False),
),
)
def test_class_alias_lookup(name, resolve, monkeypatch):
# as the test class above isn't registered via an entry point
# we mock the entry points here
class FakeDist:
name = "stpipe"
version = "dev"

class FakeEntryPoint:
dist = FakeDist()

def load(self):
def loader():
return [("Foo", "foo_step", False)]

return loader

def fake_entrypoints(group=None):
return [FakeEntryPoint()]

import importlib_metadata

monkeypatch.setattr(importlib_metadata, "entry_points", fake_entrypoints)

resolved_name = resolve_step_class_alias(name)
if resolve:
assert resolved_name == Foo.__name__
else:
assert resolved_name == name

0 comments on commit c6bf192

Please sign in to comment.