diff --git a/fallback_property/__init__.py b/fallback_property/__init__.py index dce9a61..6e4b218 100644 --- a/fallback_property/__init__.py +++ b/fallback_property/__init__.py @@ -1,15 +1,31 @@ +import functools import logging -from typing import Type, TypeVar, Generic, Callable +from typing import Type, TypeVar, Generic, Callable, Optional logger = logging.getLogger(__name__) Class = TypeVar("Class") Value = TypeVar("Value") -Method = Callable[[Class], Value] +FuncType = Callable[[Class], Value] +Method = TypeVar('Method', bound=FuncType) + +CUSTOM_WRAPPER_ASSIGNMENTS = ( + 'admin_order_value', + 'allow_tags', + 'boolean', + 'empty_value_display', + 'short_description', +) +# TODO mypy sees `WRAPPER_ASSIGNMENT` as `Sequence[str]`, even if its actually defined as +# `Tuple[str, ...]`. mypy raises an error, since combining a `Sequence` and a `Typle` +# using `+` is invalid, +WRAPPER_ASSIGNMENTS = CUSTOM_WRAPPER_ASSIGNMENTS + functools.WRAPPER_ASSIGNMENTS # type: ignore # NOQA class FallbackDescriptor(Generic[Class, Value]): - def __init__(self, func: Method, cached: bool = True, logging: bool = False) -> None: + def __init__( + self, func: Optional[Method] = None, cached: bool = True, logging: bool = False, + ) -> None: """ Initialize the descriptor. @@ -21,13 +37,58 @@ def __init__(self, func: Method, cached: bool = True, logging: bool = False) -> Cache the value calculated by `func`. logging Log a warning if fallback function is used. + + + `func` is not `None`, when the descriptor is used as a "function", eg. + + def _bar(...) -> ...: + ... + bar = fallback_property(_bar) """ - self.__doc__ = getattr(func, "__doc__") # keep the docs - self.func = func self.cached = cached self.logging = logging + + if func is not None: + self.__call__(func) + + def __call__(self, func: Method) -> 'fallback_property': + """ + Apply decorator to specific method. + + Arguments + --------- + func + Fallback function if no value exists. + + + This method is either called from the constructor, when descriptor is used like + + def _bar(...) -> ...: + ... + bar = fallback_property(_bar) + + or directly after the descriptor has been created and the function will be wrapped + + # case 1 + @fallback_property + def foo(self) -> ...: + ... + + # case 2 + @fallback_property(...) + def foo(self) -> ...: + ... + """ + # copy attribute from method to descriptor + # TODO mypy expects a `Callable` as first argument, even though it is not required + functools.update_wrapper(self, func, assigned=WRAPPER_ASSIGNMENTS) # type: ignore + + # bind descriptor to method + self.func = func self.prop_name = f"__{self.func.__name__}" + return self + def __get__(self, obj: Class, cls: Type[Class]) -> Value: """ Get the value. @@ -35,6 +96,9 @@ def __get__(self, obj: Class, cls: Type[Class]) -> Value: Return either the cached value or call the underlying function and optionally cache its result. """ + # https://stackoverflow.com/a/21629855/7774036 + if obj is None: + return self if not hasattr(obj, self.prop_name): if self.logging: logger.warning("Using `%s` without prefetched value.", self.func) @@ -61,21 +125,4 @@ def __delete__(self, obj: Class) -> None: delattr(obj, self.prop_name) -def fallback_property( - cached: bool = True, logging: bool = False -) -> Callable[[Method], FallbackDescriptor]: - """ - Decorate a class method to return a precalculated value instead. - - This might be useful if you have a function that aggregates values from - related objects, which could already be fetched using an annotated queryset. - The decorated methods will favor the precalculated value over calling the - actual method. - - NOTE: The annotated value must have the same name as the decorated function! - """ - - def inner(func: Method) -> FallbackDescriptor: - return FallbackDescriptor(func, cached=cached, logging=logging) - - return inner +fallback_property = FallbackDescriptor diff --git a/setup.cfg b/setup.cfg index 0949546..0a0b37d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,3 +44,7 @@ norecursedirs = build dist testpaths = fallback_property tests + + +[mypy] +ignore_missing_imports = True diff --git a/tests/test_decorator.py b/tests/test_decorator.py index b6aec2d..37558ed 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -103,3 +103,33 @@ def test_fallback_property__logging(caplog): product.total_with_logging assert 'without prefetched value.' in caplog.text assert 'Product.total_with_logging' in caplog.text + + +def test_use_like_property(): + """ + Use as a function should be possible. + """ + class Foo: + @fallback_property + def bar(self): + """ + Test. + """ + return 1 + + assert Foo().bar == 1 + + +def test_use_as_function(): + """ + Use as a function should be possible. + """ + class Foo: + def _bar(self): + """ + Test. + """ + return 1 + bar = fallback_property(_bar, logging=False) + + assert Foo().bar == 1 diff --git a/tests/test_django.py b/tests/test_django.py index e2997c5..afd3f96 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -1,4 +1,7 @@ import pytest +from django.contrib.admin.utils import label_for_field + +from fallback_property import FallbackDescriptor, fallback_property from . import models @@ -24,3 +27,41 @@ def test_fallback_property(pipeline, django_assert_num_queries): pipeline = models.Pipeline.objects.with_total_length().get(pk=pipeline.pk) with django_assert_num_queries(0): assert pipeline.total_length == TOTAL_LENGTH + + +def test_admin_special_properties(): + """ + Copy special attribute to the decorator. + + The django `ModelAdmin` uses special attributes to alter the behaviour of a + property/method displayed in the admin. + """ + from django.db import models as django_models + + BOOLEAN = True + EMPTY = 'empty' + LABEL = "LABEL" + ORDER_VALUE = 'foo_bar' + + class Foo(django_models.Model): + def _bar(self): + """ + Test. + """ + return 1 + _bar.admin_order_value = ORDER_VALUE + _bar.boolean = BOOLEAN + _bar.empty_value_display = EMPTY + _bar.short_description = LABEL + bar = fallback_property(_bar, logging=False) + + descriptor = getattr(Foo, 'bar') + assert isinstance(descriptor, FallbackDescriptor) + + assert descriptor.admin_order_value == ORDER_VALUE + assert descriptor.boolean == BOOLEAN + assert descriptor.empty_value_display == EMPTY + assert descriptor.short_description == LABEL + + # Django should be able to extract the label + assert label_for_field('bar', Foo) == LABEL diff --git a/tests/urls.py b/tests/urls.py index 637600f..7a46d94 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1 +1,3 @@ -urlpatterns = [] +from typing import Any, List + +urlpatterns: List[Any] = []