diff --git a/mknodes/navs/mkdoc.py b/mknodes/navs/mkdoc.py index de8b7693..3beb61b0 100644 --- a/mknodes/navs/mkdoc.py +++ b/mknodes/navs/mkdoc.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections.abc import Callable, Iterator, Sequence -import inspect import types from typing import Any @@ -94,11 +93,10 @@ def _collect_classes(self): def iter_classes( self, - submodule: types.ModuleType | str | tuple | list | None = None, + submodule: types.ModuleType | str | tuple[str, ...] | list[str] | None = None, *, recursive: bool = False, predicate: Callable[[type], bool] | None = None, - _seen: set | None = None, ) -> Iterator[type]: """Iterate over all classes of the module. @@ -108,33 +106,16 @@ def iter_classes( or whether it should also include classes from submodules. predicate: filter classes based on a predicate. """ - if isinstance(submodule, list): - submodule = tuple(submodule) - mod = classhelpers.to_module(submodule) if submodule else self.module - if mod is None: + module = submodule or self.module + if module is None: return - if recursive: - seen = _seen or set() - # TODO: pkgutil.iter_modules would also list "unknown" modules - for submod in classhelpers.get_submodules(mod): - if submod not in seen: - seen.add(submod) - yield from self.iter_classes( - submod, - recursive=True, - predicate=predicate, - _seen=seen, - ) - for klass_name, klass in inspect.getmembers(mod, inspect.isclass): - if self.filter_by___all__ and ( - not hasattr(mod, "__all__") or klass_name not in mod.__all__ - ): - continue - if predicate and not predicate(klass): - continue - # if klass.__module__.startswith(self.module_name): - if self.module_name in klass.__module__.split("."): - yield klass + yield from classhelpers.list_classes( + module=tuple(module) if isinstance(module, list) else module, + recursive=recursive, + filter_by___all__=self.filter_by___all__, + predicate=predicate, + module_filter=self.module_name, + ) def add_class_page( self, diff --git a/mknodes/utils/classhelpers.py b/mknodes/utils/classhelpers.py index df257e10..f99bb611 100644 --- a/mknodes/utils/classhelpers.py +++ b/mknodes/utils/classhelpers.py @@ -267,6 +267,7 @@ def list_classes( type_filter: type | None | types.UnionType = None, module_filter: str | None = None, filter_by___all__: bool = False, + predicate: Callable[[type], bool] | None = None, recursive: bool = False, ) -> list[type]: """Return list of classes from given module. @@ -277,6 +278,7 @@ def list_classes( type_filter: only return classes which are subclasses of given type. module_filter: filter by a module prefix. filter_by___all__: Whether to filter based on whats defined in __all__. + predicate: filter classes based on a predicate. recursive: import all submodules recursively and also return their classes. """ return list( @@ -285,6 +287,7 @@ def list_classes( type_filter=type_filter, module_filter=module_filter, filter_by___all__=filter_by___all__, + predicate=predicate, recursive=recursive, ) ) @@ -296,7 +299,9 @@ def iter_classes( type_filter: type | None | types.UnionType = None, module_filter: str | None = None, filter_by___all__: bool = False, + predicate: Callable[[type], bool] | None = None, recursive: bool = False, + # _seen: set | None = None, ) -> Iterator[type]: """Yield classes from given module. @@ -306,25 +311,32 @@ def iter_classes( type_filter: only yield classes which are subclasses of given type. module_filter: filter by a module prefix. filter_by___all__: Whether to filter based on whats defined in __all__. + predicate: filter classes based on a predicate. recursive: import all submodules recursively and also yield their classes. """ mod = to_module(module) if not mod: return [] if recursive: + # seen = _seen or set() for submod in get_submodules(mod): + # if submod not in seen: + # seen.add(submod) if submod.__name__.startswith(module_filter or ""): yield from iter_classes( submod, type_filter=type_filter, module_filter=submod.__name__, filter_by___all__=filter_by___all__, + predicate=predicate, recursive=True, ) for klass_name, kls in get_members(mod, inspect.isclass): has_all = hasattr(mod, "__all__") if filter_by___all__ and (not has_all or klass_name not in mod.__all__): continue + if predicate and not predicate(kls): + continue if type_filter is not None and not issubclass(kls, type_filter): continue if module_filter is not None and not kls.__module__.startswith(module_filter):