diff --git a/setup.cfg b/setup.cfg index 878508c5..b3ffc770 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,6 +36,7 @@ install_requires = django-filter>=2.0 django-solo djangorestframework>=3.11.0 + djangorestframework-gis>=1.0 djangorestframework_camel_case>=1.2.0 django-rest-framework-condition drf-extra-fields>=3.7.0 diff --git a/tests/test_cache_headers.py b/tests/test_cache_headers.py deleted file mode 100644 index bf49e46a..00000000 --- a/tests/test_cache_headers.py +++ /dev/null @@ -1,295 +0,0 @@ -from unittest.mock import patch - -from django.db import transaction - -import pytest -from drf_spectacular.generators import SchemaGenerator -from rest_framework import status, viewsets -from rest_framework.reverse import reverse -from rest_framework.test import APIRequestFactory -from rest_framework.views import APIView - -from testapp.factories import GroupFactory, HobbyFactory, PersonFactory -from testapp.models import Hobby, Person -from testapp.serializers import HobbySerializer -from testapp.viewsets import PersonViewSet -from vng_api_common.caching.decorators import conditional_retrieve -from vng_api_common.inspectors.cache import get_cache_headers - -pytestmark = pytest.mark.django_db(transaction=True) - - -@pytest.mark.django_db(transaction=False) -def test_etag_header_present(api_client, person): - path = reverse("person-detail", kwargs={"pk": person.pk}) - - response = api_client.get(path) - - person.refresh_from_db() - assert response.status_code == status.HTTP_200_OK - assert "ETag" in response - assert response["ETag"] == f'"{person._etag}"' - - -def skip_test_304_on_cached_resource(api_client, person): - person.calculate_etag_value() - path = reverse("person-detail", kwargs={"pk": person.pk}) - - response = api_client.get(path, HTTP_IF_NONE_MATCH=f'"{person._etag}"') - - assert response.status_code == status.HTTP_304_NOT_MODIFIED - assert "Etag" in response - - -def skip_test_200_on_stale_resource(api_client, person): - path = reverse("person-detail", kwargs={"pk": person.pk}) - - response = api_client.get(path, HTTP_IF_NONE_MATCH='"stale"') - - assert response.status_code == status.HTTP_200_OK - - -def skip_test_cache_headers_detected(): - request = APIRequestFactory().get("/api/persons/1") - request = APIView().initialize_request(request) - callback = PersonViewSet.as_view({"get": "retrieve"}, detail=True) - generator = SchemaGenerator() - - view = generator.create_view(callback, "GET", request=request) - - headers = get_cache_headers(view) - - assert any((True for header in headers if header.name == "ETag")) - - -@pytest.mark.django_db(transaction=False) -def skip_test_related_resource_changes_recalculate_etag1( - django_capture_on_commit_callbacks, -): - # Assert that resources references in the serializer trigger ETag recalculates, while - # resources not referenced don't. - hobbies = [ - HobbyFactory.create(name="playing guitar"), - HobbyFactory.create(name="playing synths"), - ] - person = PersonFactory.create( - name="Jon Carpenter", - address_street="Synthwave", - address_number="101", - # included in serializer, however this is through a serializer method field and thus - # explicitly declared. See the next test case for implicit relation following from - # the serializer fields. - group=GroupFactory.create(name="Brut"), - ) - person.hobbies.set(hobbies) # not included in serializer - person.calculate_etag_value() - - # discard any scheduled callback handlers from test set up - transaction.get_connection().run_on_commit = [] - - assert person._etag, "No ETag value calculated" - initial_etag_value = person._etag - - # start test 1 - changing the hobbies should not result in changed etags - with django_capture_on_commit_callbacks(execute=True): - person.hobbies.clear() - - person.refresh_from_db() - - assert person._etag == initial_etag_value - - # start test 2 - changing the group does affect the serializer output and thus the etag value - # discard any scheduled callback handlers from test set up - transaction.get_connection().run_on_commit = [] - with django_capture_on_commit_callbacks(execute=True): - person.group.name = "DWTD" - person.group.save() - - person.refresh_from_db() - - assert person._etag, "ETag should have been set" - assert person._etag != initial_etag_value, "ETag value should have been changed" - - -@pytest.mark.django_db(transaction=False) -def skip_test_related_resource_changes_recalculate_etag2( - django_capture_on_commit_callbacks, -): - # has a simple (reverse) m2m to Person - person = PersonFactory.create() - hobby = HobbyFactory.create() - hobby.calculate_etag_value() - - assert hobby._etag, "No ETag value calculated" - initial_etag_value = hobby._etag - - # now, change the related people resource to the hobby, which should trigger a - # re-calculate - # discard any scheduled callback handlers from test set up - transaction.get_connection().run_on_commit = [] - with django_capture_on_commit_callbacks(execute=True): - person.hobbies.add(hobby) - - hobby.refresh_from_db() - assert hobby._etag, "ETag should have been set" - assert hobby._etag != initial_etag_value, "ETag value should have been changed" - - -def test_etag_changes_m2m_changes_forward(api_client, hobby, person): - # ensure etags are calculted - person_path = reverse("person-detail", kwargs={"pk": person.pk}) - hobby_path = reverse("hobby-detail", kwargs={"pk": hobby.pk}) - person_response = api_client.get(person_path) - hobby_response = api_client.get(hobby_path) - person.refresh_from_db() - hobby.refresh_from_db() - - # change the m2m, in the forward direction - person.hobbies.add(hobby) - - # compare the new ETags - person_response2 = api_client.get(person_path) - hobby_response2 = api_client.get(hobby_path) - assert person_response["ETag"] - assert person_response["ETag"] != '""' - assert person_response["ETag"] == person_response2["ETag"] - - assert hobby_response["ETag"] - assert hobby_response["ETag"] != '""' - assert hobby_response["ETag"] != hobby_response2["ETag"] - - -def skip_test_etag_changes_m2m_changes_reverse(api_client, hobby, person): - path = reverse("hobby-detail", kwargs={"pk": hobby.pk}) - response = api_client.get(path) - hobby.refresh_from_db() - assert "ETag" in response - etag = response["ETag"] - - # change the m2m - reverse direction - hobby.people.add(person) - - response2 = api_client.get(path) - assert "ETag" in response2 - assert response2["ETag"] - assert response2["ETag"] != '""' - assert response2["ETag"] != etag - - -def skip_test_remove_m2m(api_client, person, hobby): - hobby_path = reverse("hobby-detail", kwargs={"pk": hobby.pk}) - person.hobbies.add(hobby) - - etag = api_client.get(hobby_path)["ETag"] - hobby.refresh_from_db() - assert etag - assert etag != '""' - - # this changes the output of the hobby resource - person.hobbies.remove(hobby) - - new_etag = api_client.get(hobby_path)["ETag"] - assert new_etag - assert new_etag != '""' - assert new_etag != etag - - -def skip_test_remove_m2m_reverse(api_client, person, hobby): - hobby_path = reverse("hobby-detail", kwargs={"pk": hobby.pk}) - person.hobbies.add(hobby) - - etag = api_client.get(hobby_path)["ETag"] - hobby.refresh_from_db() - assert etag - assert etag != '""' - - # this changes the output of the hobby resource - hobby.people.remove(person) - - new_etag = api_client.get(hobby_path)["ETag"] - assert new_etag - assert new_etag != '""' - assert new_etag != etag - - -def skip_test_related_object_changes_etag(api_client, person, group): - path = reverse("person-detail", kwargs={"pk": person.pk}) - - # set up group object for person - person.group = group - person.save() - - etag1 = api_client.get(path)["ETag"] - person.refresh_from_db() - assert etag1 - assert etag1 != '""' - - # change the group name, should change the ETag - group.name = "bar" - group.save() - - etag2 = api_client.get(path)["ETag"] - - assert etag2 - assert etag2 != '""' - assert etag2 != etag1 - - -def skip_test_etag_clearing_without_raw_key_in_kwargs(person): - person.delete() - - -def skip_test_delete_resource_after_get(api_client, person): - path = reverse("person-detail", kwargs={"pk": person.pk}) - - api_client.get(path) - - person.refresh_from_db() - person.delete() - - -def skip_test_fetching_cache_enabled_deleted_resource_404s(api_client, person): - path = reverse("person-detail", kwargs={"pk": person.pk}) - person.delete() - - response = api_client.get(path) - - assert response.status_code == 404 - - -@pytest.mark.django_db(transaction=False) -def skip_test_etag_updates_deduped(django_capture_on_commit_callbacks): - with patch( - "testapp.models.Person.calculate_etag_value" - ) as mock_calculate_etag_value: - with django_capture_on_commit_callbacks(execute=True): - # one post_save - person = PersonFactory.create() - # second post_save - person.save() - - assert mock_calculate_etag_value.call_count == 1 - - -class DynamicSerializerViewSet(viewsets.ReadOnlyModelViewSet): - queryset = Hobby.objects.all() - - def get_serializer(self, *args, **kwargs): - return HobbySerializer() - - -def skip_test_dynamic_serializer(): - REPLACEMENT_REGISTRY = {} - with patch( - "vng_api_common.caching.registry.DEPENDENCY_REGISTRY", new=REPLACEMENT_REGISTRY - ): - conditional_retrieve()(DynamicSerializerViewSet) - - assert Person in REPLACEMENT_REGISTRY - - -def skip_test_etag_object_cascading_delete(): - group = GroupFactory.create() - PersonFactory.create(group=group) - - group.delete() diff --git a/tests/test_field_extensions.py b/tests/test_field_extensions.py index e0b6e194..44bdc11c 100644 --- a/tests/test_field_extensions.py +++ b/tests/test_field_extensions.py @@ -29,7 +29,7 @@ class Meta: class LengthHyperLinkedSerializer(serializers.ModelSerializer): poly = LengthHyperlinkedRelatedField( - view_name="poly-detail", + view_name="field_extention_poly-detail", lookup_field="uuid", queryset=Poly.objects, min_length=20, @@ -68,6 +68,8 @@ class HyperlinkedIdentityViewSet(viewsets.ModelViewSet): serializer_class = HyperlinkedIdentityFieldSerializer +app_name = "field_extensions" + router = routers.DefaultRouter(trailing_slash=False) router.register("base64", Base64ViewSet) router.register("length", LengthHyperLinkedViewSet) @@ -112,7 +114,7 @@ def test_read_only(): assert "name" not in path["MediaFileModel"]["properties"] - assert "name" not in path["PatchedMediaFileModel"]["properties"]["name"] + assert "name" not in path["PatchedMediaFileModel"]["properties"] def test_hyper_link_related_field(): diff --git a/tests/test_filter_extension.py b/tests/test_filter_extension.py index 253dde54..d71b7c14 100644 --- a/tests/test_filter_extension.py +++ b/tests/test_filter_extension.py @@ -34,6 +34,8 @@ class FkModelViewSet(viewsets.ModelViewSet): ] +app_name = "filter_extensions" + router = routers.DefaultRouter(trailing_slash=False) router.register("camilize", FkModelViewSet) diff --git a/tests/test_schema.py b/tests/test_schema.py index 03f92e11..2dcf8500 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -5,6 +5,7 @@ from test_field_extensions import Base64ViewSet +app_name = "schema" router = routers.DefaultRouter(trailing_slash=False) router.register("base64", Base64ViewSet) diff --git a/tests/test_serializer_extensions.py b/tests/test_serializer_extensions.py index 9d9e2085..0dc32541 100644 --- a/tests/test_serializer_extensions.py +++ b/tests/test_serializer_extensions.py @@ -62,6 +62,8 @@ class PolyView(viewsets.ModelViewSet): serializer_class = PolySerializer +app_name = "serializer_extensions" + router = routers.DefaultRouter(trailing_slash=False) router.register("group", GroupView) router.register("poly", PolyView) diff --git a/vng_api_common/inspectors/__init__.py b/vng_api_common/inspectors/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/vng_api_common/inspectors/cache.py b/vng_api_common/inspectors/cache.py deleted file mode 100644 index d46ac129..00000000 --- a/vng_api_common/inspectors/cache.py +++ /dev/null @@ -1,55 +0,0 @@ -from django.utils.translation import gettext_lazy as _ - -from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import OpenApiExample, OpenApiParameter -from rest_framework.views import APIView - -from ..caching.introspection import has_cache_header - -CACHE_REQUEST_HEADERS = [ - OpenApiParameter( - name="If-None-Match", - type=OpenApiTypes.STR, - location=OpenApiParameter.HEADER, - required=False, - description=_( - "Perform conditional requests. This header should contain one or " - "multiple ETag values of resources the client has cached. If the " - "current resource ETag value is in this set, then an HTTP 304 " - "empty body will be returned. See " - "[MDN](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/If-None-Match) " - "for details." - ), - examples=[ - OpenApiExample( - "oneValue", - summary=_("One ETag value"), - value='"79054025255fb1a26e4bc422aef54eb4"', - ), - OpenApiExample( - "multipleValues", - summary=_("Multiple ETag values"), - value='"79054025255fb1a26e4bc422aef54eb4", "e4d909c290d0fb1ca068ffaddf22cbd0"', - ), - ], - ) -] - - -def get_cache_headers(view: APIView) -> list[OpenApiParameter]: - if not has_cache_header(view): - return [] - - return [ - OpenApiParameter( - "ETag", - type=str, - location=OpenApiParameter.HEADER, - description=_( - "De ETag berekend op de response body JSON. " - "Indien twee resources exact dezelfde ETag hebben, dan zijn " - "deze resources identiek aan elkaar. Je kan de ETag gebruiken " - "om caching te implementeren." - ), - ), - ] diff --git a/vng_api_common/inspectors/query.py b/vng_api_common/inspectors/query.py deleted file mode 100644 index 8b84e278..00000000 --- a/vng_api_common/inspectors/query.py +++ /dev/null @@ -1,91 +0,0 @@ -from django.db import models -from django.utils.encoding import force_str -from django.utils.translation import gettext as _ - -from django_filters.filters import BaseCSVFilter, ChoiceFilter -from drf_yasg import openapi -from drf_yasg.inspectors.query import CoreAPICompatInspector -from rest_framework.filters import OrderingFilter - -from ..filters import URLModelChoiceFilter -from ..utils import underscore_to_camel -from .utils import get_target_field - - -class FilterInspector(CoreAPICompatInspector): - """ - Filter inspector that specifies the format of URL-based fields and lists - enum options. - """ - - def get_filter_parameters(self, filter_backend): - fields = super().get_filter_parameters(filter_backend) - if isinstance(filter_backend, OrderingFilter): - return fields - - if fields: - queryset = self.view.get_queryset() - filter_class = filter_backend.get_filterset_class(self.view, queryset) - - for parameter in fields: - filter_field = filter_class.base_filters[parameter.name] - model_field = get_target_field(queryset.model, parameter.name) - parameter._filter_field = filter_field - - help_text = filter_field.extra.get( - "help_text", - getattr(model_field, "help_text", "") if model_field else "", - ) - - if isinstance(filter_field, BaseCSVFilter): - if "choices" in filter_field.extra: - schema = openapi.Schema( - type=openapi.TYPE_ARRAY, - items=openapi.Schema( - type=openapi.TYPE_STRING, - enum=[ - choice[0] - for choice in filter_field.extra["choices"] - ], - ), - ) - else: - schema = openapi.Schema( - type=openapi.TYPE_ARRAY, - items=openapi.Schema(type=openapi.TYPE_STRING), - ) - parameter["type"] = openapi.TYPE_ARRAY - parameter["schema"] = schema - parameter["style"] = "form" - parameter["explode"] = False - elif isinstance(filter_field, URLModelChoiceFilter): - description = _("URL to the related {resource}").format( - resource=parameter.name - ) - parameter.description = help_text or description - parameter.format = openapi.FORMAT_URI - elif isinstance(filter_field, ChoiceFilter): - parameter.enum = [ - choice[0] for choice in filter_field.extra["choices"] - ] - elif model_field and isinstance(model_field, models.URLField): - parameter.format = openapi.FORMAT_URI - - if not parameter.description and help_text: - parameter.description = force_str(help_text) - - if "max_length" in filter_field.extra: - parameter.max_length = filter_field.extra["max_length"] - if "min_length" in filter_field.extra: - parameter.min_length = filter_field.extra["min_length"] - - return fields - - def process_result(self, result, method_name, obj, **kwargs): - """ - Convert snake-case to camelCase. - """ - if result and type(result) is list: - for parameter in result: - parameter.name = underscore_to_camel(parameter.name) - return result diff --git a/vng_api_common/inspectors/utils.py b/vng_api_common/inspectors/utils.py deleted file mode 100644 index a6fd0982..00000000 --- a/vng_api_common/inspectors/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Optional, Type - -from django.db import models - -from rest_framework.utils.model_meta import get_field_info - - -def get_target_field(model: Type[models.Model], field: str) -> Optional[models.Field]: - """ - Retrieve the end-target that ``field`` points to. - - :param field: A string containing a lookup, potentially spanning relations. E.g.: - foo__bar__lte. - :return: A Django model field instance or `None` - """ - - start, *remaining = field.split("__") - field_info = get_field_info(model) - - # simple, non relational field? - if start in field_info.fields: - return field_info.fields[start] - - # simple relational field? - if start in field_info.forward_relations: - relation_info = field_info.forward_relations[start] - if not remaining: - return relation_info.model_field - else: - return get_target_field(relation_info.related_model, "__".join(remaining)) - - # check the reverse relations - note that the model name is used instead of model_name_set - # in the queries -> we can't just test for containment in field_info.reverse_relations - for relation_info in field_info.reverse_relations.values(): - # not sure about this - what if there are more relations with different related names? - if relation_info.related_model._meta.model_name != start: - continue - return get_target_field(relation_info.related_model, "__".join(remaining)) - - return None diff --git a/vng_api_common/inspectors/view.py b/vng_api_common/inspectors/view.py index 31950824..0cb7c3d5 100644 --- a/vng_api_common/inspectors/view.py +++ b/vng_api_common/inspectors/view.py @@ -1,33 +1,15 @@ import inspect import logging -from collections import OrderedDict -from itertools import chain -from typing import Optional, Tuple, Union from django.apps import apps -from django.conf import settings -from django.utils.translation import gettext, gettext_lazy as _ +from django.utils.translation import gettext_lazy as _ -from drf_yasg import openapi -from drf_yasg.inspectors import SwaggerAutoSchema -from drf_yasg.utils import get_consumes -from rest_framework import exceptions, serializers, status, viewsets +from rest_framework import exceptions, serializers, viewsets +from drf_spectacular.openapi import OpenApiTypes -from drf_spectacular.openapi import AutoSchema, OpenApiTypes, OpenApiParameter +from ..exceptions import Conflict, Gone -from ..constants import HEADER_AUDIT, HEADER_LOGRECORD_ID, VERSION_HEADER -from ..exceptions import Conflict, Gone, PreconditionFailed -from ..geo import GeoMixin -from ..permissions import BaseAuthRequired, get_required_scopes -from ..search import is_search_view -from ..serializers import ( - FoutSerializer, - ValidatieFoutSerializer, - add_choice_values_help_text, -) -from .cache import CACHE_REQUEST_HEADERS, get_cache_headers, has_cache_header - logger = logging.getLogger(__name__) TYPE_TO_FIELDMAPPING = { @@ -60,98 +42,8 @@ "destroy": COMMON_ERRORS + [exceptions.NotFound], } -HTTP_STATUS_CODE_TITLES = { - status.HTTP_100_CONTINUE: "Continue", - status.HTTP_101_SWITCHING_PROTOCOLS: "Switching protocols", - status.HTTP_200_OK: "OK", - status.HTTP_201_CREATED: "Created", - status.HTTP_202_ACCEPTED: "Accepted", - status.HTTP_203_NON_AUTHORITATIVE_INFORMATION: "Non authoritative information", - status.HTTP_204_NO_CONTENT: "No content", - status.HTTP_205_RESET_CONTENT: "Reset content", - status.HTTP_206_PARTIAL_CONTENT: "Partial content", - status.HTTP_207_MULTI_STATUS: "Multi status", - status.HTTP_300_MULTIPLE_CHOICES: "Multiple choices", - status.HTTP_301_MOVED_PERMANENTLY: "Moved permanently", - status.HTTP_302_FOUND: "Found", - status.HTTP_303_SEE_OTHER: "See other", - status.HTTP_304_NOT_MODIFIED: "Not modified", - status.HTTP_305_USE_PROXY: "Use proxy", - status.HTTP_306_RESERVED: "Reserved", - status.HTTP_307_TEMPORARY_REDIRECT: "Temporary redirect", - status.HTTP_400_BAD_REQUEST: "Bad request", - status.HTTP_401_UNAUTHORIZED: "Unauthorized", - status.HTTP_402_PAYMENT_REQUIRED: "Payment required", - status.HTTP_403_FORBIDDEN: "Forbidden", - status.HTTP_404_NOT_FOUND: "Not found", - status.HTTP_405_METHOD_NOT_ALLOWED: "Method not allowed", - status.HTTP_406_NOT_ACCEPTABLE: "Not acceptable", - status.HTTP_407_PROXY_AUTHENTICATION_REQUIRED: "Proxy authentication required", - status.HTTP_408_REQUEST_TIMEOUT: "Request timeout", - status.HTTP_409_CONFLICT: "Conflict", - status.HTTP_410_GONE: "Gone", - status.HTTP_411_LENGTH_REQUIRED: "Length required", - status.HTTP_412_PRECONDITION_FAILED: "Precondition failed", - status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: "Request entity too large", - status.HTTP_414_REQUEST_URI_TOO_LONG: "Request uri too long", - status.HTTP_415_UNSUPPORTED_MEDIA_TYPE: "Unsupported media type", - status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE: "Requested range not satisfiable", - status.HTTP_417_EXPECTATION_FAILED: "Expectation failed", - status.HTTP_422_UNPROCESSABLE_ENTITY: "Unprocessable entity", - status.HTTP_423_LOCKED: "Locked", - status.HTTP_424_FAILED_DEPENDENCY: "Failed dependency", - status.HTTP_428_PRECONDITION_REQUIRED: "Precondition required", - status.HTTP_429_TOO_MANY_REQUESTS: "Too many requests", - status.HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE: "Request header fields too large", - status.HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS: "Unavailable for legal reasons", - status.HTTP_500_INTERNAL_SERVER_ERROR: "Internal server error", - status.HTTP_501_NOT_IMPLEMENTED: "Not implemented", - status.HTTP_502_BAD_GATEWAY: "Bad gateway", - status.HTTP_503_SERVICE_UNAVAILABLE: "Service unavailable", - status.HTTP_504_GATEWAY_TIMEOUT: "Gateway timeout", - status.HTTP_505_HTTP_VERSION_NOT_SUPPORTED: "HTTP version not supported", - status.HTTP_507_INSUFFICIENT_STORAGE: "Insufficient storage", - status.HTTP_511_NETWORK_AUTHENTICATION_REQUIRED: "Network authentication required", -} - AUDIT_TRAIL_ENABLED = apps.is_installed("vng_api_common.audittrails") -AUDIT_REQUEST_HEADERS = [ - OpenApiParameter( - name=HEADER_LOGRECORD_ID, - type=OpenApiTypes.STR, - location="header", - required=False, - description=gettext( - "Identifier of the request, traceable throughout the network" - ), - ), - OpenApiParameter( - name=HEADER_AUDIT, - type=OpenApiTypes.STR, - location="header", - required=False, - description=gettext("Explanation why the request is done"), - ), -] - - -def response_header(description: str, type: str, format: str = None) -> OrderedDict: - header = OrderedDict( - (("schema", OrderedDict((("type", type),))), ("description", description)) - ) - if format is not None: - header["schema"]["format"] = format - return header - - -version_header = response_header( - "Geeft een specifieke API-versie aan in de context van een specifieke aanroep. Voorbeeld: 1.2.1.", - type=OpenApiTypes.STR, -) - -location_header = response_header("URL waar de resource leeft.", type=OpenApiTypes.URI) - def _view_supports_audittrail(view: viewsets.ViewSet) -> bool: if not AUDIT_TRAIL_ENABLED: @@ -183,365 +75,3 @@ def _view_supports_audittrail(view: viewsets.ViewSet) -> bool: ) return action_in_audit_bases - - -class ResponseRef(openapi._Ref): - def __init__(self, resolver, response_name, ignore_unresolved=False): - """ - Adds a reference to a named Response defined in the ``#/responses/`` object. - """ - assert "responses" in resolver.scopes - super().__init__( - resolver, response_name, "responses", openapi.Response, ignore_unresolved - ) - - -class AutoSchema(SwaggerAutoSchema): - @property - def model(self): - if hasattr(self.view, "queryset") and self.view.queryset is not None: - return self.view.queryset.model - - if hasattr(self.view, "get_queryset"): - qs = self.view.get_queryset() - return qs.model - return None - - @property - def _is_search_view(self): - return is_search_view(self.view) - - def get_operation_id(self, operation_keys=None) -> str: - """ - Simply return the model name as lowercase string, postfixed with the operation name. - """ - operation_keys = operation_keys or self.operation_keys - - operation_id = self.overrides.get("operation_id", "") - if operation_id: - return operation_id - - action = operation_keys[-1] - if self.model is not None: - model_name = self.model._meta.model_name - return f"{model_name}_{action}" - else: - operation_id = "_".join(operation_keys) - return operation_id - - def should_page(self): - if self._is_search_view: - return hasattr(self.view, "paginator") - return super().should_page() - - def get_request_serializer(self): - if not self._is_search_view: - return super().get_request_serializer() - - Base = self.view.get_search_input_serializer_class() - - filter_fields = [] - for filter_backend in self.view.filter_backends: - filter_fields += ( - self.probe_inspectors( - self.filter_inspectors, "get_filter_parameters", filter_backend() - ) - or [] - ) - - filters = {} - for parameter in filter_fields: - help_text = parameter.description - # we can't get the verbose_label back from the enum, so the inspector - # in vng_api_common.inspectors.fields leaves a filter field reference behind - _filter_field = getattr(parameter, "_filter_field", None) - choices = getattr(_filter_field, "extra", {}).get("choices", []) - if choices: - FieldClass = serializers.ChoiceField - extra = {"choices": choices} - value_display_mapping = add_choice_values_help_text(choices) - help_text += f"\n\n{value_display_mapping}" - else: - FieldClass = TYPE_TO_FIELDMAPPING[parameter.type] - extra = {} - - filters[parameter.name] = FieldClass( - help_text=help_text, required=parameter.required, **extra - ) - - SearchSerializer = type(Base.__name__, (Base,), filters) - return SearchSerializer() - - def _get_search_responses(self): - response_status = status.HTTP_200_OK - response_schema = self.serializer_to_schema(self.get_view_serializer()) - schema = openapi.Schema(type=openapi.TYPE_ARRAY, items=response_schema) - if self.should_page(): - schema = self.get_paginated_response(schema) or schema - return OrderedDict({str(response_status): schema}) - - def register_error_responses(self): - ref_responses = self.components.with_scope("responses") - - if not ref_responses.keys(): - # general errors - general_classes = list(chain(*DEFAULT_ACTION_ERRORS.values())) - # add geo and validation errors - exception_classes = general_classes + [ - PreconditionFailed, - exceptions.ValidationError, - ] - status_codes = sorted({e.status_code for e in exception_classes}) - - fout_schema = self.serializer_to_schema(FoutSerializer()) - validation_fout_schema = self.serializer_to_schema( - ValidatieFoutSerializer() - ) - for status_code in status_codes: - schema = ( - validation_fout_schema - if status_code == exceptions.ValidationError.status_code - else fout_schema - ) - response = openapi.Response( - description=HTTP_STATUS_CODE_TITLES.get(status_code, ""), - schema=schema, - ) - self.set_response_headers(str(status_code), response) - ref_responses.set(str(status_code), response) - - def _get_error_responses(self) -> OrderedDict: - """ - Add the appropriate possible error responses to the schema. - - E.g. - we know that HTTP 400 on a POST/PATCH/PUT leads to validation - errors, 403 to Permission Denied etc. - """ - # only supports viewsets - if not hasattr(self.view, "action"): - return OrderedDict() - - self.register_error_responses() - - action = self.view.action - if ( - action not in DEFAULT_ACTION_ERRORS and self._is_search_view - ): # similar to a CREATE - action = "create" - - # general errors - general_klasses = DEFAULT_ACTION_ERRORS.get(action) - if general_klasses is None: - logger.debug("Unknown action %s, no default error responses added") - return OrderedDict() - - exception_klasses = general_klasses[:] - # add geo and validation errors - has_validation_errors = self.get_filter_parameters() or any( - issubclass(klass, exceptions.ValidationError) for klass in exception_klasses - ) - if has_validation_errors: - exception_klasses.append(exceptions.ValidationError) - - if isinstance(self.view, GeoMixin): - exception_klasses.append(PreconditionFailed) - - status_codes = sorted({e.status_code for e in exception_klasses}) - - return OrderedDict( - [ - (status_code, ResponseRef(self.components, str(status_code))) - for status_code in status_codes - ] - ) - - def get_default_responses(self) -> OrderedDict: - if self._is_search_view: - responses = self._get_search_responses() - serializer = self.get_view_serializer() - else: - responses = super().get_default_responses() - serializer = self.get_request_serializer() or self.get_view_serializer() - - # inject any headers - _responses = OrderedDict() - custom_headers = OrderedDict() - for status_, schema in responses.items(): - if serializer is not None: - custom_headers = ( - self.probe_inspectors( - self.field_inspectors, - "get_response_headers", - serializer, - {"field_inspectors": self.field_inspectors}, - status=status_, - ) - or OrderedDict() - ) - - # add the cache headers, if applicable - for header, header_schema in get_cache_headers(self.view).items(): - custom_headers[header] = header_schema - - assert isinstance(schema, openapi.Schema.OR_REF) or schema == "" - response = openapi.Response( - description=HTTP_STATUS_CODE_TITLES.get(int(status_), ""), - schema=schema or None, - headers=custom_headers, - ) - _responses[status_] = response - - for status_code, response in self._get_error_responses().items(): - _responses[status_code] = response - - return _responses - - @staticmethod - def set_response_headers( - status_code: str, response: Union[openapi.Response, ResponseRef] - ): - if not isinstance(response, openapi.Response): - return - - response.setdefault("headers", OrderedDict()) - response["headers"][VERSION_HEADER] = version_header - - if status_code == "201": - response["headers"]["Location"] = location_header - - def get_response_schemas(self, response_serializers): - # parent class doesn't support responses as ref objects, - # so we temporary remove them - ref_responses = OrderedDict() - for status_code, serializer in response_serializers.copy().items(): - if isinstance(serializer, ResponseRef): - ref_responses[str(status_code)] = response_serializers.pop(status_code) - - responses = super().get_response_schemas(response_serializers) - - # and add them again - responses.update(ref_responses) - responses = OrderedDict(sorted(responses.items())) - - # add the Api-Version headers - for status_code, response in responses.items(): - self.set_response_headers(status_code, response) - - return responses - - def get_request_content_type_header(self) -> Optional[openapi.Parameter]: - if self.method not in ["POST", "PUT", "PATCH"]: - return None - - consumes = get_consumes(self.get_parser_classes()) - return openapi.Parameter( - name="Content-Type", - in_=openapi.IN_HEADER, - type=openapi.TYPE_STRING, - required=True, - enum=consumes, - description=_("Content type of the request body."), - ) - - def add_manual_parameters(self, parameters): - base = super().add_manual_parameters(parameters) - - content_type = self.get_request_content_type_header() - if content_type is not None: - base = [content_type] + base - - if self._is_search_view: - serializer = self.get_request_serializer() - else: - serializer = self.get_request_serializer() or self.get_view_serializer() - - extra = [] - if serializer is not None: - extra = ( - self.probe_inspectors( - self.field_inspectors, - "get_request_header_parameters", - serializer, - {"field_inspectors": self.field_inspectors}, - ) - or [] - ) - result = base + extra - - if has_cache_header(self.view): - result += CACHE_REQUEST_HEADERS - - if _view_supports_audittrail(self.view): - result += AUDIT_REQUEST_HEADERS - - return result - - def get_security(self): - """Return a list of security requirements for this operation. - - Returning an empty list marks the endpoint as unauthenticated (i.e. removes all accepted - authentication schemes). Returning ``None`` will inherit the top-level secuirty requirements. - - :return: security requirements - :rtype: list[dict[str,list[str]]]""" - permissions = self.view.get_permissions() - scope_permissions = [ - perm for perm in permissions if isinstance(perm, BaseAuthRequired) - ] - - if not scope_permissions: - return super().get_security() - - if len(permissions) != len(scope_permissions): - logger.warning( - "Can't represent all permissions in OAS for path %s and method %s", - self.path, - self.method, - ) - - required_scopes = [] - for perm in scope_permissions: - scopes = get_required_scopes(self.request, self.view) - if scopes is None: - continue - required_scopes.append(scopes) - - if not required_scopes: - return None # use global security - - scopes = [str(scope) for scope in sorted(required_scopes)] - - # operation level security - return [{settings.SECURITY_DEFINITION_NAME: scopes}] - - # all of these break if you accept method HEAD because the view.action is None - def is_list_view(self) -> bool: - if self.method == "HEAD": - return False - return super().is_list_view() - - def get_summary_and_description(self) -> Tuple[str, str]: - if self.method != "HEAD": - return super().get_summary_and_description() - - default_description = _( - "De headers voor een specifiek(e) {model_name} opvragen" - ).format(model_name=self.model._meta.model_name.upper()) - default_summary = _( - "Vraag de headers op die je bij een GET request zou krijgen." - ) - - description = self.overrides.get("operation_description", default_description) - summary = self.overrides.get("operation_summary", default_summary) - return description, summary - - # patch around drf-yasg not taking overrides into account - # TODO: contribute back in PR - def get_produces(self) -> list: - produces = super().get_produces() - return self.overrides.get("produces", produces) - - -# translations aren't picked up/defined in DRF, so we need to hook them up here -_("A page number within the paginated result set.") -_("Number of results to return per page.")