diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index ca1967bb..13998ff2 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -35,10 +35,11 @@ build_mocked_view, build_object_type, build_parameter_type, build_serializer_context, filter_supported_arguments, follow_field_source, follow_model_field_lookup, force_instance, get_doc, get_list_serializer, get_manager, get_type_hints, get_view_model, is_basic_serializer, - is_basic_type, is_field, is_list_serializer, is_list_serializer_customized, - is_patched_serializer, is_serializer, is_trivial_string_variation, - modify_media_types_for_versioning, resolve_django_path_parameter, resolve_regex_path_parameter, - resolve_type_hint, safe_ref, sanitize_specification_extensions, whitelisted, + is_basic_type, is_field, is_higher_order_type_hint, is_list_serializer, + is_list_serializer_customized, is_patched_serializer, is_serializer, + is_trivial_string_variation, modify_media_types_for_versioning, resolve_django_path_parameter, + resolve_regex_path_parameter, resolve_type_hint, safe_ref, sanitize_specification_extensions, + whitelisted, ) from drf_spectacular.settings import spectacular_settings from drf_spectacular.types import OpenApiTypes @@ -1338,6 +1339,9 @@ def _get_request_for_media_type(self, serializer, direction='request'): elif is_basic_type(serializer): schema = build_basic_type(serializer) request_body_required = False + elif is_higher_order_type_hint(serializer): + schema = resolve_type_hint(serializer) + request_body_required = False elif isinstance(serializer, dict): # bypass processing and use given schema directly schema = serializer @@ -1358,6 +1362,7 @@ def _get_response_bodies(self, direction: Direction = 'response') -> _SchemaType if ( is_serializer(response_serializers) or is_basic_type(response_serializers) + or is_higher_order_type_hint(response_serializers) or isinstance(response_serializers, OpenApiResponse) ): if self.method == 'DELETE': @@ -1426,6 +1431,8 @@ def _get_response_for_code(self, serializer, status_code, media_types=None, dire schema = component.ref elif is_basic_type(serializer): schema = build_basic_type(serializer) + elif is_higher_order_type_hint(serializer): + schema = resolve_type_hint(serializer) elif isinstance(serializer, dict): # bypass processing and use given schema directly schema = serializer diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index d35da3aa..ffd54755 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -98,6 +98,10 @@ class Choices: # type: ignore T = TypeVar('T') +class _Sentinel: + pass + + class UnableToProceedError(Exception): pass @@ -1287,6 +1291,15 @@ def _resolve_typeddict(hint): ) +def is_higher_order_type_hint(hint) -> bool: + return isinstance(hint, ( + getattr(types, 'GenericAlias', _Sentinel), + getattr(types, 'UnionType', _Sentinel), + getattr(typing, '_GenericAlias', _Sentinel), + getattr(typing, '_UnionGenericAlias', _Sentinel), + )) + + def resolve_type_hint(hint): """ resolve return value type hints to schema """ origin, args = _get_type_hint_origin(hint) diff --git a/tests/test_regressions.py b/tests/test_regressions.py index 9bdac659..c2b2e396 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -1,6 +1,7 @@ import collections import datetime import re +import sys import typing import uuid from decimal import Decimal @@ -3300,3 +3301,30 @@ class XViewset(viewsets.ReadOnlyModelViewSet): 'readOnly': True }, } + + +def test_extend_schema_higher_order_types(no_warnings): + cases = [ + (typing.List[int], {'items': {'type': 'integer'}, 'type': 'array'}), + (typing.Dict[str, int], {'type': 'object', 'additionalProperties': {'type': 'integer'}}), + (typing.Union[int, float], {'oneOf': [{'type': 'integer'}, {'format': 'double', 'type': 'number'}]}), + (typing.Set[int], {'items': {'type': 'integer'}, 'type': 'array'}), + (typing.Optional[int], {'type': 'integer', 'nullable': True}), + ] + if sys.version_info >= (3, 10): + cases.extend([ + (list[int], {'items': {'type': 'integer'}, 'type': 'array'}), + (dict[str, int], {'type': 'object', 'additionalProperties': {'type': 'integer'}}), + (int | float, {'oneOf': [{'type': 'integer'}, {'format': 'double', 'type': 'number'}]}), + ]) + + for t, ref_schema in cases: + @extend_schema(request=t, responses=t) + @api_view(['POST']) + def view_func(request, format=None): + pass # pragma: no cover + + schema = generate_schema('x', view_function=view_func) + + assert get_response_schema(schema['paths']['/x']['post']) == ref_schema + assert get_request_schema(schema['paths']['/x']['post']) == ref_schema