From a9a7ba9a5c1703036a83498c86f9e0f3f6e6c0f5 Mon Sep 17 00:00:00 2001 From: Brian Pugh Date: Thu, 21 Sep 2023 10:57:43 -0700 Subject: [PATCH] Fix subclassing attrs of attrs --- autoregistry/_registry.py | 66 ++++++++++++++++++++++++++++----------- tests/test_attrs.py | 14 +++++++++ 2 files changed, 61 insertions(+), 19 deletions(-) diff --git a/autoregistry/_registry.py b/autoregistry/_registry.py index c2515cd..0b1cc69 100644 --- a/autoregistry/_registry.py +++ b/autoregistry/_registry.py @@ -10,7 +10,6 @@ from .exceptions import ( CannotDeriveNameError, CannotRegisterPythonBuiltInError, - InternalError, InvalidNameError, KeyCollisionError, ModuleAliasError, @@ -26,7 +25,44 @@ def __init__(self, config: RegistryConfig, name: str = ""): self.name = name # These will be populated later - self.cls: Any = None + self._cls: Any = None + + @property + def cls(self): + return self._cls + + @cls.setter + def cls(self, new_cls: type): + if self._cls is not None: + self._rereference(self._cls, new_cls) + self._cls = new_cls + + def _rereference(self, old_cls: type, new_cls: type) -> None: + """Recursively updates all registry references from ``old_cls`` to ``new_cls``.""" + # TODO: this could be optimized by only recursively apply to recursive parents + # And not searching self at first _rereference iteration. + for k, v in self.items(): + if v is old_cls: + self[k] = new_cls + + for parent_registry in self.walk_parent_registries(): + parent_registry._rereference(old_cls, new_cls) + + def walk_parent_registries(self) -> Generator["_Registry", None, None]: + """Iterates over immediate parenting classes and returns their ``_Registry``.""" + for parent_cls in self.cls.__bases__: + if parent_cls is Registry: + # Never register to the base Registry class. + # Unwanted cross-library interactions may occur, otherwise. + continue + + try: + parent_registry = parent_cls.__registry__ + except AttributeError: + # Not a Registry object + continue + + yield parent_registry def register( self, @@ -89,18 +125,7 @@ def register( # 1. This is the root ``__recursive__`` call. # 2. Both this.recursive is True, and parent.recursive is True. if (root or self.config.recursive) and self.cls is not None: - for parent_cls in self.cls.__bases__: - try: - parent_registry = parent_cls.__registry__ - except AttributeError: - # Not a Registry object - continue - - if parent_cls is Registry: - # Never register to the base Registry class. - # Unwanted cross-library interactions may occur, otherwise. - continue - + for parent_registry in self.walk_parent_registries(): if root or parent_registry.config.recursive: parent_registry.register(obj, name=name, aliases=aliases) @@ -203,12 +228,15 @@ def __new__( # that hooks like __init_subclass__ have appropriately set registry attributes. # Each subclass gets its own registry. - # Copy the nearest parent config, then update it with new params. - # Some class construction libraries, like ``attrs``, will recreate a class. - # In these situations, the old-class will have it's attributes (like the __registry__ - # object) passed in via the ``namespace``. + # Copy the nearest parent config if "__registry__" in namespace: - registry_config = namespace["__registry__"].config + # Some class construction libraries, like ``attrs``, will recreate a class. + # In these situations, the old-class will have it's attributes (like the __registry__ + # object) passed in via the ``namespace``. + new_cls = super().__new__(cls, cls_name, bases, namespace) + new_cls.__registry__.cls = new_cls + + return new_cls else: for parent_cls in bases: try: diff --git a/tests/test_attrs.py b/tests/test_attrs.py index 592b885..73633bb 100644 --- a/tests/test_attrs.py +++ b/tests/test_attrs.py @@ -35,3 +35,17 @@ class HorrorMovie(Movie): antagonist: str assert list(Media) == ["movie", "horror_movie"] + assert Media["movie"] is Movie + assert Media["horror_movie"] is HorrorMovie + assert Movie["horror_movie"] is HorrorMovie + + horror_movie = Media["horror_movie"]( + name="Nosferatu", + year=1922, + director="Murnau", + antagonist="Count Orlok", + ) + assert ( + str(horror_movie) + == "HorrorMovie(name='Nosferatu', year=1922, director='Murnau', antagonist='Count Orlok')" + )