Skip to content

Commit

Permalink
Merge pull request #1181 from tfranzel/type_hints_on_decorator
Browse files Browse the repository at this point in the history
Add support for direct usage of higher order hints #1174
  • Loading branch information
tfranzel authored Feb 21, 2024
2 parents ef82c0e + cada2e0 commit b65ae61
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 4 deletions.
15 changes: 11 additions & 4 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@
build_mocked_view, build_object_type, build_parameter_type, build_serializer_context,
filter_supported_arguments, follow_field_source, follow_model_field_lookup, force_instance,
get_doc, get_list_serializer, get_manager, get_type_hints, get_view_model, is_basic_serializer,
is_basic_type, is_field, is_list_serializer, is_list_serializer_customized,
is_patched_serializer, is_serializer, is_trivial_string_variation,
modify_media_types_for_versioning, resolve_django_path_parameter, resolve_regex_path_parameter,
resolve_type_hint, safe_ref, sanitize_specification_extensions, whitelisted,
is_basic_type, is_field, is_higher_order_type_hint, is_list_serializer,
is_list_serializer_customized, is_patched_serializer, is_serializer,
is_trivial_string_variation, modify_media_types_for_versioning, resolve_django_path_parameter,
resolve_regex_path_parameter, resolve_type_hint, safe_ref, sanitize_specification_extensions,
whitelisted,
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes
Expand Down Expand Up @@ -1338,6 +1339,9 @@ def _get_request_for_media_type(self, serializer, direction='request'):
elif is_basic_type(serializer):
schema = build_basic_type(serializer)
request_body_required = False
elif is_higher_order_type_hint(serializer):
schema = resolve_type_hint(serializer)
request_body_required = False
elif isinstance(serializer, dict):
# bypass processing and use given schema directly
schema = serializer
Expand All @@ -1358,6 +1362,7 @@ def _get_response_bodies(self, direction: Direction = 'response') -> _SchemaType
if (
is_serializer(response_serializers)
or is_basic_type(response_serializers)
or is_higher_order_type_hint(response_serializers)
or isinstance(response_serializers, OpenApiResponse)
):
if self.method == 'DELETE':
Expand Down Expand Up @@ -1426,6 +1431,8 @@ def _get_response_for_code(self, serializer, status_code, media_types=None, dire
schema = component.ref
elif is_basic_type(serializer):
schema = build_basic_type(serializer)
elif is_higher_order_type_hint(serializer):
schema = resolve_type_hint(serializer)
elif isinstance(serializer, dict):
# bypass processing and use given schema directly
schema = serializer
Expand Down
13 changes: 13 additions & 0 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ class Choices: # type: ignore
T = TypeVar('T')


class _Sentinel:
pass


class UnableToProceedError(Exception):
pass

Expand Down Expand Up @@ -1287,6 +1291,15 @@ def _resolve_typeddict(hint):
)


def is_higher_order_type_hint(hint) -> bool:
return isinstance(hint, (
getattr(types, 'GenericAlias', _Sentinel),
getattr(types, 'UnionType', _Sentinel),
getattr(typing, '_GenericAlias', _Sentinel),
getattr(typing, '_UnionGenericAlias', _Sentinel),
))


def resolve_type_hint(hint):
""" resolve return value type hints to schema """
origin, args = _get_type_hint_origin(hint)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import datetime
import re
import sys
import typing
import uuid
from decimal import Decimal
Expand Down Expand Up @@ -3300,3 +3301,30 @@ class XViewset(viewsets.ReadOnlyModelViewSet):
'readOnly': True
},
}


def test_extend_schema_higher_order_types(no_warnings):
cases = [
(typing.List[int], {'items': {'type': 'integer'}, 'type': 'array'}),
(typing.Dict[str, int], {'type': 'object', 'additionalProperties': {'type': 'integer'}}),
(typing.Union[int, float], {'oneOf': [{'type': 'integer'}, {'format': 'double', 'type': 'number'}]}),
(typing.Set[int], {'items': {'type': 'integer'}, 'type': 'array'}),
(typing.Optional[int], {'type': 'integer', 'nullable': True}),
]
if sys.version_info >= (3, 10):
cases.extend([
(list[int], {'items': {'type': 'integer'}, 'type': 'array'}),
(dict[str, int], {'type': 'object', 'additionalProperties': {'type': 'integer'}}),
(int | float, {'oneOf': [{'type': 'integer'}, {'format': 'double', 'type': 'number'}]}),
])

for t, ref_schema in cases:
@extend_schema(request=t, responses=t)
@api_view(['POST'])
def view_func(request, format=None):
pass # pragma: no cover

schema = generate_schema('x', view_function=view_func)

assert get_response_schema(schema['paths']['/x']['post']) == ref_schema
assert get_request_schema(schema['paths']['/x']['post']) == ref_schema

0 comments on commit b65ae61

Please sign in to comment.