Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Attrs Subclassing #35

Merged
merged 2 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 47 additions & 19 deletions autoregistry/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .exceptions import (
CannotDeriveNameError,
CannotRegisterPythonBuiltInError,
InternalError,
InvalidNameError,
KeyCollisionError,
ModuleAliasError,
Expand All @@ -26,7 +25,44 @@ def __init__(self, config: RegistryConfig, name: str = ""):
self.name = name

# These will be populated later
self.cls: Any = None
self._cls: Any = None

@property
def cls(self):
return self._cls

@cls.setter
def cls(self, new_cls: type):
if self._cls is not None:
self._rereference(self._cls, new_cls)
self._cls = new_cls

def _rereference(self, old_cls: type, new_cls: type) -> None:
"""Recursively updates all registry references from ``old_cls`` to ``new_cls``."""
# TODO: this could be optimized by only recursively apply to recursive parents
# And not searching self at first _rereference iteration.
for k, v in self.items():
if v is old_cls:
self[k] = new_cls

for parent_registry in self.walk_parent_registries():
parent_registry._rereference(old_cls, new_cls)

def walk_parent_registries(self) -> Generator["_Registry", None, None]:
"""Iterates over immediate parenting classes and returns their ``_Registry``."""
for parent_cls in self.cls.__bases__:
if parent_cls is Registry:
# Never register to the base Registry class.
# Unwanted cross-library interactions may occur, otherwise.
continue

try:
parent_registry = parent_cls.__registry__
except AttributeError:
# Not a Registry object
continue

yield parent_registry

def register(
self,
Expand Down Expand Up @@ -89,18 +125,7 @@ def register(
# 1. This is the root ``__recursive__`` call.
# 2. Both this.recursive is True, and parent.recursive is True.
if (root or self.config.recursive) and self.cls is not None:
for parent_cls in self.cls.__bases__:
try:
parent_registry = parent_cls.__registry__
except AttributeError:
# Not a Registry object
continue

if parent_cls is Registry:
# Never register to the base Registry class.
# Unwanted cross-library interactions may occur, otherwise.
continue

for parent_registry in self.walk_parent_registries():
if root or parent_registry.config.recursive:
parent_registry.register(obj, name=name, aliases=aliases)

Expand Down Expand Up @@ -203,12 +228,15 @@ def __new__(
# that hooks like __init_subclass__ have appropriately set registry attributes.
# Each subclass gets its own registry.

# Copy the nearest parent config, then update it with new params.
# Some class construction libraries, like ``attrs``, will recreate a class.
# In these situations, the old-class will have it's attributes (like the __registry__
# object) passed in via the ``namespace``.
# Copy the nearest parent config
if "__registry__" in namespace:
registry_config = namespace["__registry__"].config
# Some class construction libraries, like ``attrs``, will recreate a class.
# In these situations, the old-class will have it's attributes (like the __registry__
# object) passed in via the ``namespace``.
new_cls = super().__new__(cls, cls_name, bases, namespace)
new_cls.__registry__.cls = new_cls

return new_cls
else:
for parent_cls in bases:
try:
Expand Down
33 changes: 32 additions & 1 deletion tests/test_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from autoregistry import Registry


def test_attrs_compatability():
def test_attrs_root():
@frozen
class Media(Registry, snake_case=True):
name: str
Expand All @@ -18,3 +18,34 @@ class MusicVideo(Media):
assert list(Media) == ["movie", "music_video"]
assert Media["movie"] == Movie
assert Media["music_video"] == MusicVideo


def test_attrs_children():
@frozen
class Media(Registry, snake_case=True):
name: str
year: int

@frozen
class Movie(Media):
director: str

@frozen
class HorrorMovie(Movie):
antagonist: str

assert list(Media) == ["movie", "horror_movie"]
assert Media["movie"] is Movie
assert Media["horror_movie"] is HorrorMovie
assert Movie["horror_movie"] is HorrorMovie

horror_movie = Media["horror_movie"](
name="Nosferatu",
year=1922,
director="Murnau",
antagonist="Count Orlok",
)
assert (
str(horror_movie)
== "HorrorMovie(name='Nosferatu', year=1922, director='Murnau', antagonist='Count Orlok')"
)
Loading