Skip to content

Commit

Permalink
refactor(MkDoc): use classhelpers for iterating classes
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Nov 10, 2023
1 parent a8e91a6 commit c0c5e54
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 29 deletions.
39 changes: 10 additions & 29 deletions mknodes/navs/mkdoc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from collections.abc import Callable, Iterator, Sequence
import inspect
import types

from typing import Any
Expand Down Expand Up @@ -94,11 +93,10 @@ def _collect_classes(self):

def iter_classes(
self,
submodule: types.ModuleType | str | tuple | list | None = None,
submodule: types.ModuleType | str | tuple[str, ...] | list[str] | None = None,
*,
recursive: bool = False,
predicate: Callable[[type], bool] | None = None,
_seen: set | None = None,
) -> Iterator[type]:
"""Iterate over all classes of the module.
Expand All @@ -108,33 +106,16 @@ def iter_classes(
or whether it should also include classes from submodules.
predicate: filter classes based on a predicate.
"""
if isinstance(submodule, list):
submodule = tuple(submodule)
mod = classhelpers.to_module(submodule) if submodule else self.module
if mod is None:
module = submodule or self.module
if module is None:
return
if recursive:
seen = _seen or set()
# TODO: pkgutil.iter_modules would also list "unknown" modules
for submod in classhelpers.get_submodules(mod):
if submod not in seen:
seen.add(submod)
yield from self.iter_classes(
submod,
recursive=True,
predicate=predicate,
_seen=seen,
)
for klass_name, klass in inspect.getmembers(mod, inspect.isclass):
if self.filter_by___all__ and (
not hasattr(mod, "__all__") or klass_name not in mod.__all__
):
continue
if predicate and not predicate(klass):
continue
# if klass.__module__.startswith(self.module_name):
if self.module_name in klass.__module__.split("."):
yield klass
yield from classhelpers.list_classes(
module=tuple(module) if isinstance(module, list) else module,
recursive=recursive,
filter_by___all__=self.filter_by___all__,
predicate=predicate,
module_filter=self.module_name,
)

def add_class_page(
self,
Expand Down
12 changes: 12 additions & 0 deletions mknodes/utils/classhelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def list_classes(
type_filter: type | None | types.UnionType = None,
module_filter: str | None = None,
filter_by___all__: bool = False,
predicate: Callable[[type], bool] | None = None,
recursive: bool = False,
) -> list[type]:
"""Return list of classes from given module.
Expand All @@ -277,6 +278,7 @@ def list_classes(
type_filter: only return classes which are subclasses of given type.
module_filter: filter by a module prefix.
filter_by___all__: Whether to filter based on whats defined in __all__.
predicate: filter classes based on a predicate.
recursive: import all submodules recursively and also return their classes.
"""
return list(
Expand All @@ -285,6 +287,7 @@ def list_classes(
type_filter=type_filter,
module_filter=module_filter,
filter_by___all__=filter_by___all__,
predicate=predicate,
recursive=recursive,
)
)
Expand All @@ -296,7 +299,9 @@ def iter_classes(
type_filter: type | None | types.UnionType = None,
module_filter: str | None = None,
filter_by___all__: bool = False,
predicate: Callable[[type], bool] | None = None,
recursive: bool = False,
# _seen: set | None = None,
) -> Iterator[type]:
"""Yield classes from given module.
Expand All @@ -306,25 +311,32 @@ def iter_classes(
type_filter: only yield classes which are subclasses of given type.
module_filter: filter by a module prefix.
filter_by___all__: Whether to filter based on whats defined in __all__.
predicate: filter classes based on a predicate.
recursive: import all submodules recursively and also yield their classes.
"""
mod = to_module(module)
if not mod:
return []
if recursive:
# seen = _seen or set()
for submod in get_submodules(mod):
# if submod not in seen:
# seen.add(submod)
if submod.__name__.startswith(module_filter or ""):
yield from iter_classes(
submod,
type_filter=type_filter,
module_filter=submod.__name__,
filter_by___all__=filter_by___all__,
predicate=predicate,
recursive=True,
)
for klass_name, kls in get_members(mod, inspect.isclass):
has_all = hasattr(mod, "__all__")
if filter_by___all__ and (not has_all or klass_name not in mod.__all__):
continue
if predicate and not predicate(kls):
continue
if type_filter is not None and not issubclass(kls, type_filter):
continue
if module_filter is not None and not kls.__module__.startswith(module_filter):
Expand Down

0 comments on commit c0c5e54

Please sign in to comment.