Skip to content

Commit

Permalink
Merge pull request #35 from BrianPugh/attrs-subclass
Browse files Browse the repository at this point in the history
Fix Attrs Subclassing
  • Loading branch information
BrianPugh authored Sep 21, 2023
2 parents 2e2303f + a9a7ba9 commit 48d41cd
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 20 deletions.
66 changes: 47 additions & 19 deletions autoregistry/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .exceptions import (
CannotDeriveNameError,
CannotRegisterPythonBuiltInError,
InternalError,
InvalidNameError,
KeyCollisionError,
ModuleAliasError,
Expand All @@ -26,7 +25,44 @@ def __init__(self, config: RegistryConfig, name: str = ""):
self.name = name

# These will be populated later
self.cls: Any = None
self._cls: Any = None

@property
def cls(self):
return self._cls

@cls.setter
def cls(self, new_cls: type):
if self._cls is not None:
self._rereference(self._cls, new_cls)
self._cls = new_cls

def _rereference(self, old_cls: type, new_cls: type) -> None:
"""Recursively updates all registry references from ``old_cls`` to ``new_cls``."""
# TODO: this could be optimized by only recursively apply to recursive parents
# And not searching self at first _rereference iteration.
for k, v in self.items():
if v is old_cls:
self[k] = new_cls

for parent_registry in self.walk_parent_registries():
parent_registry._rereference(old_cls, new_cls)

def walk_parent_registries(self) -> Generator["_Registry", None, None]:
"""Iterates over immediate parenting classes and returns their ``_Registry``."""
for parent_cls in self.cls.__bases__:
if parent_cls is Registry:
# Never register to the base Registry class.
# Unwanted cross-library interactions may occur, otherwise.
continue

try:
parent_registry = parent_cls.__registry__
except AttributeError:
# Not a Registry object
continue

yield parent_registry

def register(
self,
Expand Down Expand Up @@ -89,18 +125,7 @@ def register(
# 1. This is the root ``__recursive__`` call.
# 2. Both this.recursive is True, and parent.recursive is True.
if (root or self.config.recursive) and self.cls is not None:
for parent_cls in self.cls.__bases__:
try:
parent_registry = parent_cls.__registry__
except AttributeError:
# Not a Registry object
continue

if parent_cls is Registry:
# Never register to the base Registry class.
# Unwanted cross-library interactions may occur, otherwise.
continue

for parent_registry in self.walk_parent_registries():
if root or parent_registry.config.recursive:
parent_registry.register(obj, name=name, aliases=aliases)

Expand Down Expand Up @@ -203,12 +228,15 @@ def __new__(
# that hooks like __init_subclass__ have appropriately set registry attributes.
# Each subclass gets its own registry.

# Copy the nearest parent config, then update it with new params.
# Some class construction libraries, like ``attrs``, will recreate a class.
# In these situations, the old-class will have it's attributes (like the __registry__
# object) passed in via the ``namespace``.
# Copy the nearest parent config
if "__registry__" in namespace:
registry_config = namespace["__registry__"].config
# Some class construction libraries, like ``attrs``, will recreate a class.
# In these situations, the old-class will have it's attributes (like the __registry__
# object) passed in via the ``namespace``.
new_cls = super().__new__(cls, cls_name, bases, namespace)
new_cls.__registry__.cls = new_cls

return new_cls
else:
for parent_cls in bases:
try:
Expand Down
33 changes: 32 additions & 1 deletion tests/test_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from autoregistry import Registry


def test_attrs_compatability():
def test_attrs_root():
@frozen
class Media(Registry, snake_case=True):
name: str
Expand All @@ -18,3 +18,34 @@ class MusicVideo(Media):
assert list(Media) == ["movie", "music_video"]
assert Media["movie"] == Movie
assert Media["music_video"] == MusicVideo


def test_attrs_children():
@frozen
class Media(Registry, snake_case=True):
name: str
year: int

@frozen
class Movie(Media):
director: str

@frozen
class HorrorMovie(Movie):
antagonist: str

assert list(Media) == ["movie", "horror_movie"]
assert Media["movie"] is Movie
assert Media["horror_movie"] is HorrorMovie
assert Movie["horror_movie"] is HorrorMovie

horror_movie = Media["horror_movie"](
name="Nosferatu",
year=1922,
director="Murnau",
antagonist="Count Orlok",
)
assert (
str(horror_movie)
== "HorrorMovie(name='Nosferatu', year=1922, director='Murnau', antagonist='Count Orlok')"
)

0 comments on commit 48d41cd

Please sign in to comment.