diff --git a/autoregistry/_registry.py b/autoregistry/_registry.py index c2515cd..0b1cc69 100644 --- a/autoregistry/_registry.py +++ b/autoregistry/_registry.py @@ -10,7 +10,6 @@ from .exceptions import ( CannotDeriveNameError, CannotRegisterPythonBuiltInError, - InternalError, InvalidNameError, KeyCollisionError, ModuleAliasError, @@ -26,7 +25,44 @@ def __init__(self, config: RegistryConfig, name: str = ""): self.name = name # These will be populated later - self.cls: Any = None + self._cls: Any = None + + @property + def cls(self): + return self._cls + + @cls.setter + def cls(self, new_cls: type): + if self._cls is not None: + self._rereference(self._cls, new_cls) + self._cls = new_cls + + def _rereference(self, old_cls: type, new_cls: type) -> None: + """Recursively updates all registry references from ``old_cls`` to ``new_cls``.""" + # TODO: this could be optimized by only recursively apply to recursive parents + # And not searching self at first _rereference iteration. + for k, v in self.items(): + if v is old_cls: + self[k] = new_cls + + for parent_registry in self.walk_parent_registries(): + parent_registry._rereference(old_cls, new_cls) + + def walk_parent_registries(self) -> Generator["_Registry", None, None]: + """Iterates over immediate parenting classes and returns their ``_Registry``.""" + for parent_cls in self.cls.__bases__: + if parent_cls is Registry: + # Never register to the base Registry class. + # Unwanted cross-library interactions may occur, otherwise. + continue + + try: + parent_registry = parent_cls.__registry__ + except AttributeError: + # Not a Registry object + continue + + yield parent_registry def register( self, @@ -89,18 +125,7 @@ def register( # 1. This is the root ``__recursive__`` call. # 2. Both this.recursive is True, and parent.recursive is True. if (root or self.config.recursive) and self.cls is not None: - for parent_cls in self.cls.__bases__: - try: - parent_registry = parent_cls.__registry__ - except AttributeError: - # Not a Registry object - continue - - if parent_cls is Registry: - # Never register to the base Registry class. - # Unwanted cross-library interactions may occur, otherwise. - continue - + for parent_registry in self.walk_parent_registries(): if root or parent_registry.config.recursive: parent_registry.register(obj, name=name, aliases=aliases) @@ -203,12 +228,15 @@ def __new__( # that hooks like __init_subclass__ have appropriately set registry attributes. # Each subclass gets its own registry. - # Copy the nearest parent config, then update it with new params. - # Some class construction libraries, like ``attrs``, will recreate a class. - # In these situations, the old-class will have it's attributes (like the __registry__ - # object) passed in via the ``namespace``. + # Copy the nearest parent config if "__registry__" in namespace: - registry_config = namespace["__registry__"].config + # Some class construction libraries, like ``attrs``, will recreate a class. + # In these situations, the old-class will have it's attributes (like the __registry__ + # object) passed in via the ``namespace``. + new_cls = super().__new__(cls, cls_name, bases, namespace) + new_cls.__registry__.cls = new_cls + + return new_cls else: for parent_cls in bases: try: diff --git a/tests/test_attrs.py b/tests/test_attrs.py index 96f7191..73633bb 100644 --- a/tests/test_attrs.py +++ b/tests/test_attrs.py @@ -3,7 +3,7 @@ from autoregistry import Registry -def test_attrs_compatability(): +def test_attrs_root(): @frozen class Media(Registry, snake_case=True): name: str @@ -18,3 +18,34 @@ class MusicVideo(Media): assert list(Media) == ["movie", "music_video"] assert Media["movie"] == Movie assert Media["music_video"] == MusicVideo + + +def test_attrs_children(): + @frozen + class Media(Registry, snake_case=True): + name: str + year: int + + @frozen + class Movie(Media): + director: str + + @frozen + class HorrorMovie(Movie): + antagonist: str + + assert list(Media) == ["movie", "horror_movie"] + assert Media["movie"] is Movie + assert Media["horror_movie"] is HorrorMovie + assert Movie["horror_movie"] is HorrorMovie + + horror_movie = Media["horror_movie"]( + name="Nosferatu", + year=1922, + director="Murnau", + antagonist="Count Orlok", + ) + assert ( + str(horror_movie) + == "HorrorMovie(name='Nosferatu', year=1922, director='Murnau', antagonist='Count Orlok')" + )