From 292b7b948a1f4ce38e735bf5653d27fa62df6749 Mon Sep 17 00:00:00 2001 From: Sergei Maertens Date: Fri, 29 Nov 2024 17:07:18 +0100 Subject: [PATCH] :label: First pass at addressing type checker errors * Decided to ignore the lookups.py because... well, the Django ORM. * and decided to ignore zgw_consumers.legacy which is scheduled for removal anyway, there are better places to sink our energy into * zgw_consumers/api_models/base.py should be rewritten using pydantic OR be based on TypedDict - the code is not type checker friendly --- pyright.pyproject.toml | 16 ++++++++++ zgw_consumers/api_models/_camel_case.py | 1 - zgw_consumers/api_models/base.py | 2 +- zgw_consumers/api_models/besluiten.py | 2 +- zgw_consumers/api_models/catalogi.py | 8 +++-- zgw_consumers/api_models/compat.py | 4 --- zgw_consumers/api_models/documenten.py | 2 +- zgw_consumers/api_models/types.py | 10 ++++--- zgw_consumers/api_models/zaken.py | 6 ++-- zgw_consumers/cache.py | 2 +- zgw_consumers/concurrent.py | 9 +++--- zgw_consumers/drf/serializers.py | 35 +++++++++++++++++----- zgw_consumers/drf/utils.py | 13 +++++--- zgw_consumers/models/abstract.py | 2 +- zgw_consumers/models/certificates.py | 2 +- zgw_consumers/models/fields.py | 25 +++++++++------- zgw_consumers/models/services.py | 8 +++-- zgw_consumers/nlx.py | 16 +++++++--- zgw_consumers/test/component_generation.py | 3 +- zgw_consumers/test/factories.py | 4 +-- zgw_consumers/utils.py | 21 +++++++++---- 21 files changed, 126 insertions(+), 65 deletions(-) create mode 100644 pyright.pyproject.toml delete mode 100644 zgw_consumers/api_models/compat.py diff --git a/pyright.pyproject.toml b/pyright.pyproject.toml new file mode 100644 index 0000000..e546dfa --- /dev/null +++ b/pyright.pyproject.toml @@ -0,0 +1,16 @@ +[tool.pyright] +include = [ + "zgw_consumers/" +] +exclude = [ + # should really be replaced with pydantic or typed dicts... + "zgw_consumers/api_models/base.py", + # this module is quite funky... doesn't hold up to the base types + "zgw_consumers/models/lookups.py", + # this should be removed instead of fixed + "zgw_consumers/legacy/", +] +ignore = [] + +pythonVersion = "3.11" +pythonPlatform = "Linux" diff --git a/zgw_consumers/api_models/_camel_case.py b/zgw_consumers/api_models/_camel_case.py index 8a363ed..59057bc 100644 --- a/zgw_consumers/api_models/_camel_case.py +++ b/zgw_consumers/api_models/_camel_case.py @@ -27,7 +27,6 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. - import re from django.core.files import File diff --git a/zgw_consumers/api_models/base.py b/zgw_consumers/api_models/base.py index cc39e33..65f47a4 100644 --- a/zgw_consumers/api_models/base.py +++ b/zgw_consumers/api_models/base.py @@ -12,9 +12,9 @@ from dateutil.parser import parse from dateutil.relativedelta import relativedelta +from relativedeltafield.utils import parse_relativedelta from ._camel_case import underscoreize -from .compat import parse_relativedelta from .types import JSONObject __all__ = ["CONVERTERS", "factory", "Model", "ZGWModel"] diff --git a/zgw_consumers/api_models/besluiten.py b/zgw_consumers/api_models/besluiten.py index 585e069..57869aa 100644 --- a/zgw_consumers/api_models/besluiten.py +++ b/zgw_consumers/api_models/besluiten.py @@ -26,7 +26,7 @@ class Besluit(ZGWModel): uiterlijke_reactiedatum: Optional[date] = None def get_vervalreden_display(self) -> str: - return VervalRedenen.labels[self.vervalreden] + return VervalRedenen(self.vervalreden).label @dataclass diff --git a/zgw_consumers/api_models/catalogi.py b/zgw_consumers/api_models/catalogi.py index b14b61d..33d7e19 100644 --- a/zgw_consumers/api_models/catalogi.py +++ b/zgw_consumers/api_models/catalogi.py @@ -123,17 +123,21 @@ class Eigenschap(ZGWModel): zaaktype: str naam: str definitie: str - specificatie: dict + specificatie: EigenschapSpecificatie | dict toelichting: str = "" def __post_init__(self): super().__post_init__() - self.specificatie = factory(EigenschapSpecificatie, self.specificatie) + assert isinstance(self.specificatie, dict) + _specificatie = factory(EigenschapSpecificatie, self.specificatie) + assert isinstance(_specificatie, EigenschapSpecificatie) + self.specificatie = _specificatie def to_python(self, value: str) -> Union[str, Decimal, date, datetime]: """ Cast the string value into the appropriate python type based on the spec. """ + assert isinstance(self.specificatie, EigenschapSpecificatie) formaat = self.specificatie.formaat assert formaat in EIGENSCHAP_FORMATEN, f"Unknown format {formaat}" diff --git a/zgw_consumers/api_models/compat.py b/zgw_consumers/api_models/compat.py deleted file mode 100644 index 204c093..0000000 --- a/zgw_consumers/api_models/compat.py +++ /dev/null @@ -1,4 +0,0 @@ -try: - from relativedeltafield.utils import parse_relativedelta -except ImportError: # before 1.1.2 - from relativedeltafield import parse_relativedelta # noqa diff --git a/zgw_consumers/api_models/documenten.py b/zgw_consumers/api_models/documenten.py index e3d55c0..1f1e07a 100644 --- a/zgw_consumers/api_models/documenten.py +++ b/zgw_consumers/api_models/documenten.py @@ -34,4 +34,4 @@ class Document(ZGWModel): locked: bool = False def get_vertrouwelijkheidaanduiding_display(self): - return VertrouwelijkheidsAanduidingen.values[self.vertrouwelijkheidaanduiding] + return VertrouwelijkheidsAanduidingen(self.vertrouwelijkheidaanduiding).label diff --git a/zgw_consumers/api_models/types.py b/zgw_consumers/api_models/types.py index c80fe95..9dc87de 100644 --- a/zgw_consumers/api_models/types.py +++ b/zgw_consumers/api_models/types.py @@ -1,5 +1,7 @@ -from typing import Dict, List, Union +from __future__ import annotations -JSONPrimitive = Union[str, int, None, float] -JSONValue = Union[JSONPrimitive, "JSONObject", List["JSONValue"]] -JSONObject = Dict[str, JSONValue] +from typing import TypeAlias + +JSONPrimitive: TypeAlias = str | int | None | float +JSONValue: TypeAlias = "JSONPrimitive | JSONObject | list[JSONValue]" +JSONObject: TypeAlias = dict[str, JSONValue] diff --git a/zgw_consumers/api_models/zaken.py b/zgw_consumers/api_models/zaken.py index 126d570..4c7a4cf 100644 --- a/zgw_consumers/api_models/zaken.py +++ b/zgw_consumers/api_models/zaken.py @@ -30,7 +30,7 @@ class Zaak(ZGWModel): zaakgeometrie: dict = field(default_factory=dict) def get_vertrouwelijkheidaanduiding_display(self): - return VertrouwelijkheidsAanduidingen.values[self.vertrouwelijkheidaanduiding] + return VertrouwelijkheidsAanduidingen(self.vertrouwelijkheidaanduiding).label @dataclass @@ -92,10 +92,10 @@ class Rol(ZGWModel): betrokkene_identificatie: dict = field(default_factory=dict) def get_betrokkene_type_display(self): - return RolTypes.values[self.betrokkene_type] + return RolTypes(self.betrokkene_type).label def get_omschrijving_generiek_display(self): - return RolOmschrijving.values[self.omschrijving_generiek] + return RolOmschrijving(self.omschrijving_generiek).label @dataclass diff --git a/zgw_consumers/cache.py b/zgw_consumers/cache.py index f282269..0b744e9 100644 --- a/zgw_consumers/cache.py +++ b/zgw_consumers/cache.py @@ -56,4 +56,4 @@ def install_schema_fetcher_cache(): except ImportError: return - schema_fetcher.cache = OASCache() + schema_fetcher.cache = OASCache() # type: ignore - untyped library... diff --git a/zgw_consumers/concurrent.py b/zgw_consumers/concurrent.py index 261ea75..f28696f 100644 --- a/zgw_consumers/concurrent.py +++ b/zgw_consumers/concurrent.py @@ -47,12 +47,13 @@ class parallel: def __init__(self, **kwargs): self.executor = futures.ThreadPoolExecutor(**kwargs) - def submit(*args, **kwargs): - if len(args) >= 2: - self, _fn, *args = args + def submit(self, *args, **kwargs): + if len(args) >= 1: + _fn, *args = args elif "fn" in kwargs: _fn = kwargs.pop("fn") - self, *args = args + else: + raise TypeError("Invalid signature") fn = wrap_fn(_fn) diff --git a/zgw_consumers/drf/serializers.py b/zgw_consumers/drf/serializers.py index 52a0f5f..d8594e8 100644 --- a/zgw_consumers/drf/serializers.py +++ b/zgw_consumers/drf/serializers.py @@ -45,14 +45,18 @@ def get_fields(self): serializer_class=self.__class__.__name__ ) assert hasattr( - self.Meta, "model" + self.Meta, "model" # pyright: ignore[reportAttributeAccessIssue] ), 'Class {serializer_class} missing "Meta.model" attribute'.format( serializer_class=self.__class__.__name__ ) declared_fields = copy.deepcopy(self._declared_fields) - model = self.Meta.model - depth = getattr(self.Meta, "depth", 0) + model = self.Meta.model # pyright: ignore[reportAttributeAccessIssue] + depth = getattr( + self.Meta, # pyright: ignore[reportAttributeAccessIssue] + "depth", + 0, + ) if depth is not None: assert depth >= 0, "'depth' may not be negative." @@ -88,7 +92,7 @@ def get_fields(self): return fields def get_field_names(self, declared_fields): - fields = self.Meta.fields + fields = self.Meta.fields # pyright: ignore[reportAttributeAccessIssue] # Ensure that all declared fields have also been included in the # `Meta.fields` option. @@ -114,9 +118,19 @@ def get_extra_kwargs(self): Return a dictionary mapping field names to a dictionary of additional keyword arguments. """ - extra_kwargs = copy.deepcopy(getattr(self.Meta, "extra_kwargs", {})) + extra_kwargs = copy.deepcopy( + getattr( + self.Meta, # pyright: ignore[reportAttributeAccessIssue] + "extra_kwargs", + {}, + ) + ) - read_only_fields = getattr(self.Meta, "read_only_fields", None) + read_only_fields = getattr( + self.Meta, # pyright: ignore[reportAttributeAccessIssue] + "read_only_fields", + None, + ) if read_only_fields is not None: if not isinstance(read_only_fields, (list, tuple)): raise TypeError( @@ -131,7 +145,10 @@ def get_extra_kwargs(self): else: # Guard against the possible misspelling `readonly_fields` (used # by the Django admin and others). - assert not hasattr(self.Meta, "readonly_fields"), ( + assert not hasattr( + self.Meta, # pyright: ignore[reportAttributeAccessIssue] + "readonly_fields", + ), ( "Serializer `%s.%s` has field `readonly_fields`; " "the correct spelling for the option is `read_only_fields`." % (self.__class__.__module__, self.__class__.__name__) @@ -186,7 +203,9 @@ def build_standard_field(self, field_name, model_field_type): if "choices" in field_kwargs: # Fields with choices get coerced into `ChoiceField` # instead of using their regular typed field. - field_class = self.serializer_choice_field + # fmt: off + field_class = self.serializer_choice_field # pyright: ignore[reportAttributeAccessIssue] + # fmt: on # Some model fields may introduce kwargs that would not be valid # for the choice field. We need to strip these out. # Eg. models.DecimalField(max_digits=3, decimal_places=1, choices=DECIMAL_CHOICES) diff --git a/zgw_consumers/drf/utils.py b/zgw_consumers/drf/utils.py index e24bb8d..1534e2a 100644 --- a/zgw_consumers/drf/utils.py +++ b/zgw_consumers/drf/utils.py @@ -21,12 +21,17 @@ def extract_model_field_type(model_class, field_name): # support for Optional / List if hasattr(typehint, "__origin__"): - if typehint.__origin__ is list and typehint.__args__: - subtypehint = typehint.__args__[0] + if ( + typehint.__origin__ is list # pyright: ignore[reportAttributeAccessIssue] + and typehint.__args__ # pyright: ignore[reportAttributeAccessIssue] + ): + # fmt: off + subtypehint = typehint.__args__[0] # pyright: ignore[reportAttributeAccessIssue] + # fmt: on raise NotImplementedError("TODO: support collections") - if typehint.__origin__ is Union: - typehint = typehint.__args__ + if typehint.__origin__ is Union: # pyright: ignore[reportAttributeAccessIssue] + typehint = typehint.__args__ # pyright: ignore[reportAttributeAccessIssue] # Optional is ONE type combined with None typehint = next(t for t in typehint if t is not None) return typehint diff --git a/zgw_consumers/models/abstract.py b/zgw_consumers/models/abstract.py index 454cc3f..42c34a4 100644 --- a/zgw_consumers/models/abstract.py +++ b/zgw_consumers/models/abstract.py @@ -25,7 +25,7 @@ class RestAPIService(Service): validators=[FileExtensionValidator(["yml", "yaml"])], ) - class Meta: + class Meta: # pyright: ignore[reportIncompatibleVariableOverride] abstract = True def clean(self): diff --git a/zgw_consumers/models/certificates.py b/zgw_consumers/models/certificates.py index b0e1990..0932c71 100644 --- a/zgw_consumers/models/certificates.py +++ b/zgw_consumers/models/certificates.py @@ -2,5 +2,5 @@ class Certificate(NewCertificate): - class Meta: + class Meta: # pyright: ignore[reportIncompatibleVariableOverride] proxy = True diff --git a/zgw_consumers/models/fields.py b/zgw_consumers/models/fields.py index d5bb7da..4086bda 100644 --- a/zgw_consumers/models/fields.py +++ b/zgw_consumers/models/fields.py @@ -1,4 +1,3 @@ -from typing import Optional from urllib.parse import urljoin from django.core import checks @@ -12,14 +11,14 @@ def __init__(self, field): self.field = field def get_base_url(self, base_val) -> str: - return getattr(base_val, "api_root", None) + return getattr(base_val, "api_root", "") def get_base_val(self, detail_url: str): from zgw_consumers.models import Service return Service.get_service(detail_url) - def __get__(self, instance: Model, cls=None) -> Optional[str]: + def __get__(self, instance: Model | None, cls=None) -> str | None: if instance is None: return None @@ -30,7 +29,7 @@ def __get__(self, instance: Model, cls=None) -> Optional[str]: # todo cache value return urljoin(base_url, relative_val) - def __set__(self, instance: Model, value: Optional[str]): + def __set__(self, instance: Model, value: str | None): if value is None and not self.field.null: raise ValueError( "A 'None'-value is not allowed. Make the field " @@ -64,9 +63,9 @@ class ServiceUrlField(Field): """ # field flags - name = None + name: str concrete = False - column = None + column: str | None = None # pyright: ignore[reportIncompatibleVariableOverride] db_column = None descriptor_class = ServiceUrlDescriptor @@ -132,10 +131,10 @@ def _add_check_constraint( return @property - def attname(self) -> str: + def attname(self) -> str: # pyright: ignore[reportIncompatibleVariableOverride] return self.name - def get_attname_column(self): + def get_attname_column(self): # pyright: ignore[reportIncompatibleMethodOverride] return self.attname, None def deconstruct(self): @@ -150,13 +149,17 @@ def deconstruct(self): @property def _base_field(self) -> ForeignKey: - return self.model._meta.get_field(self.base_field) + field = self.model._meta.get_field(self.base_field) + assert isinstance(field, ForeignKey) + return field @property def _relative_field(self) -> CharField: - return self.model._meta.get_field(self.relative_field) + field = self.model._meta.get_field(self.relative_field) + assert isinstance(field, CharField) + return field - def check(self, **kwargs): + def check(self, **kwargs) -> list[checks.CheckMessage]: return [ *self._check_field_name(), *self._check_base_field(), diff --git a/zgw_consumers/models/services.py b/zgw_consumers/models/services.py index 1f2f130..581ef80 100644 --- a/zgw_consumers/models/services.py +++ b/zgw_consumers/models/services.py @@ -3,7 +3,7 @@ import logging import socket import uuid -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from urllib.parse import urlparse, urlsplit, urlunsplit from django.core.exceptions import ValidationError @@ -116,7 +116,9 @@ class Service(RestAPIService): objects = ServiceManager() - class Meta: + get_api_type_display: Callable[[], str] + + class Meta: # pyright: ignore[reportIncompatibleVariableOverride] verbose_name = _("service") verbose_name_plural = _("services") @@ -290,7 +292,7 @@ class NLXConfig(SingletonModel): blank=True, ) - class Meta: + class Meta: # pyright: ignore[reportIncompatibleVariableOverride] verbose_name = _("NLX configuration") @property diff --git a/zgw_consumers/nlx.py b/zgw_consumers/nlx.py index cb224a0..10143cb 100644 --- a/zgw_consumers/nlx.py +++ b/zgw_consumers/nlx.py @@ -2,6 +2,7 @@ import logging from collections.abc import Iterable from itertools import groupby +from typing import TypedDict import requests from ape_pie import APIClient @@ -26,9 +27,12 @@ def _rewrite_url(value: str, rewrites: Iterable[tuple[str, str]]) -> str | None: class Rewriter: def __init__(self): - self.rewrites: list[tuple[str, str]] = Service.objects.exclude( - nlx="" - ).values_list("api_root", "nlx") + qs = Service.objects.exclude(nlx="").values_list("api_root", "nlx") + self._rewrites = qs + + @property + def rewrites(self) -> list[tuple[str, str]]: + return list(self._rewrites) @property def reverse_rewrites(self) -> list[tuple[str, str]]: @@ -159,7 +163,11 @@ class NLXClient(NLXMixin, APIClient): Organization = dict[str, str] -ServiceType = dict[str, str] + + +class ServiceType(TypedDict): + name: str + organization: Organization def get_nlx_services() -> list[tuple[Organization, list[ServiceType]]]: diff --git a/zgw_consumers/test/component_generation.py b/zgw_consumers/test/component_generation.py index 244a58d..084ee1c 100644 --- a/zgw_consumers/test/component_generation.py +++ b/zgw_consumers/test/component_generation.py @@ -1,10 +1,9 @@ import logging import random +from datetime import timezone from functools import lru_cache from typing import Any, Dict -from django.utils import timezone - import yaml from faker import Faker from typing_extensions import deprecated diff --git a/zgw_consumers/test/factories.py b/zgw_consumers/test/factories.py index 20d976d..bc94d9e 100644 --- a/zgw_consumers/test/factories.py +++ b/zgw_consumers/test/factories.py @@ -1,5 +1,3 @@ -import uuid - from django.utils.text import slugify import factory @@ -27,7 +25,7 @@ class ServiceFactory(factory.django.DjangoModelFactory): api_root = factory.Faker("api_root") slug = factory.LazyAttribute(lambda o: slugify(o.api_root)) - class Meta: + class Meta: # pyright: ignore[reportIncompatibleVariableOverride] model = Service django_get_or_create = ("api_root",) diff --git a/zgw_consumers/utils.py b/zgw_consumers/utils.py index a4f69eb..d693e94 100644 --- a/zgw_consumers/utils.py +++ b/zgw_consumers/utils.py @@ -1,16 +1,25 @@ import logging -from typing import Optional, TypedDict +from typing import Callable, Generic, Optional, TypedDict, TypeVar from django.http import HttpRequest from ape_pie.client import APIClient logger = logging.getLogger(__name__) -NOTSET = object() -class cache_on_request: - def __init__(self, request: HttpRequest, key: str, callback: callable): +class NotSet: + pass + + +NOTSET = NotSet() + + +T = TypeVar("T") + + +class cache_on_request(Generic[T]): + def __init__(self, request: HttpRequest, key: str, callback: Callable[[], T]): self.request = request self.key = key self.callback = callback @@ -22,10 +31,10 @@ def __exit__(self, *args, **kwargs): pass @property - def value(self): + def value(self) -> T: # check if it's cached on the request cached_value = getattr(self.request, self.key, NOTSET) - if cached_value is NOTSET: + if isinstance(cached_value, NotSet): value = self.callback() setattr(self.request, self.key, value) cached_value = value