Skip to content

Commit

Permalink
chore: use epregistry for entry point management
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Nov 14, 2024
1 parent bcbd1e7 commit 57cccd7
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 76 deletions.
4 changes: 2 additions & 2 deletions mknodes/basenodes/mkclidoc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def info(self) -> commandinfo.CommandInfo | None:
mod = importlib.import_module(module)
instance = getattr(mod, command)
case None:
if cli_eps := self.ctx.metadata.entry_points.get("console_scripts"):
module, command = cli_eps[0].dotted_path.split(":")
if cli_eps := self.ctx.metadata.entry_points.get_group("console_scripts"):
module, command = cli_eps[0].value.split(":")
prog_name = cli_eps[0].name
mod = importlib.import_module(module)
instance = getattr(mod, command)
Expand Down
5 changes: 4 additions & 1 deletion mknodes/info/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pathlib
from typing import TYPE_CHECKING, Any

import epregistry
import jinjarope

import mknodes as mk
Expand Down Expand Up @@ -204,7 +205,9 @@ class PackageContext(Context):
"""A icon-name -> URL dictionary containing ."""
inventory_url: str | None = ""
"""A best guess for an inventory URL for the package."""
entry_points: dict = dataclasses.field(default_factory=dict)
entry_points: epregistry.ModuleEntryPointRegistry = dataclasses.field(
default_factory=epregistry.ModuleEntryPointRegistry
)
"""A dictionary containing the entry points of the distribution."""
cli: str | None = None
"""The cli package name used by the distribution."""
Expand Down
8 changes: 5 additions & 3 deletions mknodes/info/packageinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import collections
import contextlib
import functools
from typing import Any

import epregistry
from requests import structures

from mknodes.info.cli import clihelpers, commandinfo
Expand Down Expand Up @@ -201,17 +203,17 @@ def cli(self) -> str | None:
@functools.cached_property
def cli_info(self) -> commandinfo.CommandInfo | None:
"""Return a CLI info object containing infos about all CLI commands / options."""
if eps := self.entry_points.get("console_scripts"):
if eps := self.entry_points.get_group("console_scripts"):
ep = eps[0].load()
qual_name = ep.__class__.__module__.lower()
if qual_name.startswith(("typer", "click")):
return clihelpers.get_cli_info(ep)
return None

@functools.cached_property
def entry_points(self) -> dict[str, list[packagehelpers.EntryPoint]]:
def entry_points(self) -> epregistry.ModuleEntryPointRegistry[Any]:
"""Get entry points for this package."""
return packagehelpers.get_entry_points(self.distribution)
return epregistry.ModuleEntryPointRegistry(self.package_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion mknodes/templatenodes/mkpluginflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def plugins(self):
match self._plugin:
case None:
ep_group = self.event_plugin.entry_point_group
eps = self.ctx.metadata.entry_points.get(ep_group, [])
eps = self.ctx.metadata.entry_points.get_group(ep_group, [])
return [i.load() for i in eps]
case type():
return [self._plugin]
Expand Down
70 changes: 1 addition & 69 deletions mknodes/utils/packagehelpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

import collections
import dataclasses
import functools
import importlib
from importlib import metadata
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from packaging.markers import Marker
from packaging.requirements import Requirement
Expand Down Expand Up @@ -90,11 +88,6 @@ def get_marker(marker) -> Marker:
return Marker(marker)


@functools.cache
def _get_entry_points(dist: metadata.Distribution | None = None, **kwargs: Any):
return dist.entry_points if dist else metadata.entry_points(**kwargs)


def get_extras(markers: list) -> list[str]:
extras = []
for marker in markers:
Expand All @@ -116,67 +109,6 @@ def import_dotted_path(path: str) -> type | types.ModuleType:
return getattr(mod, kls_name) if kls_name else mod


@dataclasses.dataclass
class EntryPoint:
"""EntryPoint including imported module."""

name: str
dotted_path: str
group: str

def load(self) -> Any:
"""Import and return the EntryPoint object."""
return import_dotted_path(self.dotted_path)

@property
def module(self) -> str:
"""The module of the entry point."""
return self.dotted_path.split(":")[0]


@functools.cache
def get_entry_points(
dist: metadata.Distribution | str | None = None,
group: str | None = None,
**kwargs: Any,
) -> EntryPointMap: # [str, list[EntryPoint]]
"""Returns a dictionary with entry point group as key, entry points as value.
Args:
dist: Optional distribution filter.
group: Optional group filter.
kwargs: Entry point filters
"""
if dist:
if isinstance(dist, str):
dist = get_distribution(dist)
eps = [i for i in _get_entry_points(dist) if i.group == group or not group]
else:
kw_args = dict(group=group, **kwargs) if group else kwargs
eps = [i for ls in _get_entry_points(**kw_args).values() for i in ls]

return EntryPointMap(eps)


class EntryPointMap(collections.defaultdict[str, list[EntryPoint]]):
def __init__(self, eps: list | None = None):
super().__init__(list)
for ep in eps or []:
if not isinstance(ep, EntryPoint):
ep = EntryPoint(name=ep.name, dotted_path=ep.value, group=ep.group)
self[ep.group].append(ep)

@property
def all_eps(self) -> list[EntryPoint]:
return [i for ls in self.values() for i in ls]

def by_name(self, name: str) -> EntryPoint | None:
return next((i for i in self.all_eps if i.name == name), None)

def by_dotted_path(self, dotted_path: str) -> EntryPoint | None:
return next((i for i in self.all_eps if i.dotted_path == dotted_path), None)


class Dependency:
def __init__(self, name: str):
self.req = Requirement(name)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ dependencies = [
"pipdeptree",
"git-changelog",
"mkdocstrings[python]",
"epregistry",
]
license = { file = "LICENSE" }

Expand Down

0 comments on commit 57cccd7

Please sign in to comment.