From 14f122b0dbe4f72957c4c6fb33dc932b6017fae2 Mon Sep 17 00:00:00 2001 From: "T. Franzel" Date: Fri, 6 Sep 2024 18:46:30 +0200 Subject: [PATCH] parametrize component registry identity #1288 --- drf_spectacular/contrib/pydantic.py | 1 + .../contrib/rest_framework_dataclasses.py | 7 ++- drf_spectacular/extensions.py | 4 ++ drf_spectacular/openapi.py | 12 ++++- drf_spectacular/plumbing.py | 31 ++++++++++--- .../test_rest_framework_dataclasses.py | 46 +++++++++++++++++++ tests/test_warnings.py | 2 +- 7 files changed, 94 insertions(+), 9 deletions(-) diff --git a/drf_spectacular/contrib/pydantic.py b/drf_spectacular/contrib/pydantic.py index f03dda6a..395a8a9a 100644 --- a/drf_spectacular/contrib/pydantic.py +++ b/drf_spectacular/contrib/pydantic.py @@ -23,6 +23,7 @@ def get_name(self, auto_schema, direction): # of the entry model, we simply use the class name as string for object. This hack may # create false positive warnings, so turn it off. However, this may suppress correct # warnings involving the entry class. + # TODO suppression may be migrated to new ComponentIdentity system set_override(self.target, 'suppress_collision_warning', True) return self.target.__name__ diff --git a/drf_spectacular/contrib/rest_framework_dataclasses.py b/drf_spectacular/contrib/rest_framework_dataclasses.py index 760cdca5..95a8adee 100644 --- a/drf_spectacular/contrib/rest_framework_dataclasses.py +++ b/drf_spectacular/contrib/rest_framework_dataclasses.py @@ -1,6 +1,8 @@ +from typing import Any + from drf_spectacular.drainage import get_override, has_override from drf_spectacular.extensions import OpenApiSerializerExtension -from drf_spectacular.plumbing import get_doc +from drf_spectacular.plumbing import ComponentIdentity, get_doc from drf_spectacular.utils import Direction @@ -18,6 +20,9 @@ def get_name(self): return get_override(self.target.dataclass_definition.dataclass_type, 'component_name') return self.target.dataclass_definition.dataclass_type.__name__ + def get_identity(self, auto_schema, direction: Direction) -> Any: + return ComponentIdentity(self.target.dataclass_definition.dataclass_type) + def strip_library_doc(self, schema): """Strip the DataclassSerializer library documentation from the schema.""" from rest_framework_dataclasses.serializers import DataclassSerializer diff --git a/drf_spectacular/extensions.py b/drf_spectacular/extensions.py index 052be3a0..1eae1c6d 100644 --- a/drf_spectacular/extensions.py +++ b/drf_spectacular/extensions.py @@ -68,6 +68,10 @@ def get_name(self, auto_schema: 'AutoSchema', direction: Direction) -> Optional[ """ return str for overriding default name extraction """ return None + def get_identity(self, auto_schema: 'AutoSchema', direction: Direction) -> Any: + """ return anything to compare instances of target. Target will be used by default. """ + return None + def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType: """ override for customized serializer mapping """ return auto_schema._map_serializer(self.target_class, direction, bypass_extensions=True) diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index f4f14f3b..5054f88f 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -1556,7 +1556,17 @@ def _get_response_headers_for_code(self, status_code, direction='response') -> _ return result + def get_serializer_identity(self, serializer: serializers.Serializer, direction: Direction) -> Any: + serializer_extension = OpenApiSerializerExtension.get_match(serializer) + if serializer_extension: + identity = serializer_extension.get_identity(self, direction) + if identity is not None: + return identity + + return serializer + def get_serializer_name(self, serializer: serializers.Serializer, direction: Direction) -> str: + """ override this for custom behaviour """ return serializer.__class__.__name__ def _get_serializer_name(self, serializer, direction, bypass_extensions=False) -> str: @@ -1612,7 +1622,7 @@ def resolve_serializer( component = ResolvedComponent( name=self._get_serializer_name(serializer, direction, bypass_extensions), type=ResolvedComponent.SCHEMA, - object=serializer, + object=self.get_serializer_identity(serializer, direction), ) if component in self.registry: return self.registry[component] # return component with schema diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index 74753efa..df1aa03c 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -723,6 +723,17 @@ def ref(self) -> _SchemaType: return {'$ref': f'#/components/{self.type}/{self.name}'} +class ComponentIdentity: + """ A container class to make object/component comparison explicit """ + def __init__(self, obj): + self.obj = obj + + def __eq__(self, other): + if isinstance(other, ComponentIdentity): + return self.obj == other.obj + return self.obj == other + + class ComponentRegistry: def __init__(self) -> None: self._components: Dict[Tuple[str, str], ResolvedComponent] = {} @@ -746,17 +757,25 @@ def __contains__(self, component): query_obj = component.object registry_obj = self._components[component.key].object - query_class = query_obj if inspect.isclass(query_obj) else query_obj.__class__ - registry_class = query_obj if inspect.isclass(registry_obj) else registry_obj.__class__ + + if isinstance(query_obj, ComponentIdentity) or inspect.isclass(query_obj): + query_id = query_obj + else: + query_id = query_obj.__class__ + + if isinstance(registry_obj, ComponentIdentity) or inspect.isclass(registry_obj): + registry_id = registry_obj + else: + registry_id = registry_obj.__class__ suppress_collision_warning = ( - get_override(registry_class, 'suppress_collision_warning', False) - or get_override(query_class, 'suppress_collision_warning', False) + get_override(registry_id, 'suppress_collision_warning', False) + or get_override(query_id, 'suppress_collision_warning', False) ) - if query_class != registry_class and not suppress_collision_warning: + if query_id != registry_id and not suppress_collision_warning: warn( f'Encountered 2 components with identical names "{component.name}" and ' - f'different classes {query_class} and {registry_class}. This will very ' + f'different identities {query_id} and {registry_id}. This will very ' f'likely result in an incorrect schema. Try renaming one.' ) return True diff --git a/tests/contrib/test_rest_framework_dataclasses.py b/tests/contrib/test_rest_framework_dataclasses.py index 0062e122..d0dd8c91 100644 --- a/tests/contrib/test_rest_framework_dataclasses.py +++ b/tests/contrib/test_rest_framework_dataclasses.py @@ -90,3 +90,49 @@ def custom_name_via_serializer_decoration(request): generate_schema(None, patterns=urlpatterns), 'tests/contrib/test_rest_framework_dataclasses.yml' ) + + +@pytest.mark.contrib('rest_framework_dataclasses') +@pytest.mark.skipif(sys.version_info < (3, 7), reason='dataclass required by package') +def test_rest_framework_dataclasses_class_reuse(no_warnings): + from dataclasses import dataclass + + from rest_framework_dataclasses.serializers import DataclassSerializer + + @dataclass + class Person: + name: str + age: int + + @dataclass + class Party: + person: Person + num_persons: int + + class PartySerializer(DataclassSerializer[Party]): + class Meta: + dataclass = Party + + class PersonSerializer(DataclassSerializer[Person]): + class Meta: + dataclass = Person + + @extend_schema(responses=PartySerializer) + @api_view() + def party(request): + pass # pragma: no cover + + @extend_schema(responses=PersonSerializer) + @api_view() + def person(request): + pass # pragma: no cover + + urlpatterns = [ + path('party', person), + path('person', party), + ] + + schema = generate_schema(None, patterns=urlpatterns) + # just existence is enough to check since its about no_warnings + assert 'Person' in schema['components']['schemas'] + assert 'Party' in schema['components']['schemas'] diff --git a/tests/test_warnings.py b/tests/test_warnings.py index 71c9539a..89c8ef44 100644 --- a/tests/test_warnings.py +++ b/tests/test_warnings.py @@ -49,7 +49,7 @@ class X2Viewset(mixins.ListModelMixin, viewsets.GenericViewSet): generate_schema(None, patterns=router.urls) stderr = capsys.readouterr().err - assert 'Encountered 2 components with identical names "X" and different classes' in stderr + assert 'Encountered 2 components with identical names "X" and different identities' in stderr def test_owned_serializer_naming_override_with_ref_name_collision(warnings):