From 5610b663e480da294daf7b66fc571da0fcebcc4c Mon Sep 17 00:00:00 2001 From: p1c2u Date: Sun, 24 Sep 2023 09:56:40 +0000 Subject: [PATCH] Mimetype parameters handling --- .../media_types/deserializers.py | 4 ++- .../deserializing/media_types/factories.py | 8 ++++- .../deserializing/media_types/util.py | 17 +++++++--- .../templating/media_types/datatypes.py | 5 ++- .../templating/media_types/finders.py | 33 +++++++++++++++---- openapi_core/validation/validators.py | 11 +++++-- tests/integration/test_petstore.py | 8 +++-- .../test_media_types_deserializers.py | 30 +++++++++++++---- .../templating/test_media_types_finders.py | 13 ++++++-- 9 files changed, 101 insertions(+), 28 deletions(-) diff --git a/openapi_core/deserializing/media_types/deserializers.py b/openapi_core/deserializing/media_types/deserializers.py index 43f99c81..2bdef976 100644 --- a/openapi_core/deserializing/media_types/deserializers.py +++ b/openapi_core/deserializing/media_types/deserializers.py @@ -16,9 +16,11 @@ def __init__( self, mimetype: str, deserializer_callable: Optional[DeserializerCallable] = None, + **parameters: str, ): self.mimetype = mimetype self.deserializer_callable = deserializer_callable + self.parameters = parameters def deserialize(self, value: Any) -> Any: if self.deserializer_callable is None: @@ -26,6 +28,6 @@ def deserialize(self, value: Any) -> Any: return value try: - return self.deserializer_callable(value) + return self.deserializer_callable(value, **self.parameters) except (ParseError, ValueError, TypeError, AttributeError): raise MediaTypeDeserializeError(self.mimetype, value) diff --git a/openapi_core/deserializing/media_types/factories.py b/openapi_core/deserializing/media_types/factories.py index f35257b2..9087c6b1 100644 --- a/openapi_core/deserializing/media_types/factories.py +++ b/openapi_core/deserializing/media_types/factories.py @@ -1,3 +1,4 @@ +from typing import Mapping from typing import Optional from openapi_core.deserializing.media_types.datatypes import ( @@ -23,10 +24,13 @@ def __init__( def create( self, mimetype: str, + parameters: Optional[Mapping[str, str]] = None, extra_media_type_deserializers: Optional[ MediaTypeDeserializersDict ] = None, ) -> CallableMediaTypeDeserializer: + if parameters is None: + parameters = {} if extra_media_type_deserializers is None: extra_media_type_deserializers = {} deserialize_callable = self.get_deserializer_callable( @@ -34,7 +38,9 @@ def create( extra_media_type_deserializers=extra_media_type_deserializers, ) - return CallableMediaTypeDeserializer(mimetype, deserialize_callable) + return CallableMediaTypeDeserializer( + mimetype, deserialize_callable, **parameters + ) def get_deserializer_callable( self, diff --git a/openapi_core/deserializing/media_types/util.py b/openapi_core/deserializing/media_types/util.py index df03eba2..c73315d7 100644 --- a/openapi_core/deserializing/media_types/util.py +++ b/openapi_core/deserializing/media_types/util.py @@ -5,17 +5,26 @@ from urllib.parse import parse_qsl -def plain_loads(value: Union[str, bytes]) -> str: +def plain_loads(value: Union[str, bytes], **parameters: str) -> str: + charset = "utf-8" + if "charset" in parameters: + charset = parameters["charset"] if isinstance(value, bytes): - value = value.decode("ASCII", errors="surrogateescape") + try: + return value.decode(charset) + # fallback safe decode + except UnicodeDecodeError: + return value.decode("ASCII", errors="surrogateescape") return value -def urlencoded_form_loads(value: Any) -> Dict[str, Any]: +def urlencoded_form_loads(value: Any, **parameters: str) -> Dict[str, Any]: return dict(parse_qsl(value)) -def data_form_loads(value: Union[str, bytes]) -> Dict[str, Any]: +def data_form_loads( + value: Union[str, bytes], **parameters: str +) -> Dict[str, Any]: if isinstance(value, bytes): value = value.decode("ASCII", errors="surrogateescape") parser = Parser() diff --git a/openapi_core/templating/media_types/datatypes.py b/openapi_core/templating/media_types/datatypes.py index d76fe9d2..37c4c064 100644 --- a/openapi_core/templating/media_types/datatypes.py +++ b/openapi_core/templating/media_types/datatypes.py @@ -1,3 +1,6 @@ from collections import namedtuple +from dataclasses import dataclass +from typing import Mapping +from typing import Optional -MediaType = namedtuple("MediaType", ["value", "key"]) +MediaType = namedtuple("MediaType", ["mime_type", "parameters", "media_type"]) diff --git a/openapi_core/templating/media_types/finders.py b/openapi_core/templating/media_types/finders.py index 6477c9d7..15ffe89e 100644 --- a/openapi_core/templating/media_types/finders.py +++ b/openapi_core/templating/media_types/finders.py @@ -1,5 +1,7 @@ """OpenAPI core templating media types finders module""" import fnmatch +from typing import Mapping +from typing import Tuple from openapi_core.spec import Spec from openapi_core.templating.media_types.datatypes import MediaType @@ -12,15 +14,34 @@ def __init__(self, content: Spec): def get_first(self) -> MediaType: mimetype, media_type = next(self.content.items()) - return MediaType(media_type, mimetype) + return MediaType(mimetype, {}, media_type) def find(self, mimetype: str) -> MediaType: - if mimetype in self.content: - return MediaType(self.content / mimetype, mimetype) + if mimetype is None: + raise MediaTypeNotFound(mimetype, list(self.content.keys())) - if mimetype: + mime_type, parameters = self._parse_mimetype(mimetype) + + # simple mime type + for m in [mimetype, mime_type]: + if m in self.content: + return MediaType(mime_type, parameters, self.content / m) + + # range mime type + if mime_type: for key, value in self.content.items(): - if fnmatch.fnmatch(mimetype, key): - return MediaType(value, key) + if fnmatch.fnmatch(mime_type, key): + return MediaType(key, parameters, value) raise MediaTypeNotFound(mimetype, list(self.content.keys())) + + def _parse_mimetype(self, mimetype: str) -> Tuple[str, Mapping[str, str]]: + mimetype_parts = mimetype.split("; ") + mime_type = mimetype_parts[0] + parameters = {} + if len(mimetype_parts) > 1: + parameters_list = ( + param_str.split("=") for param_str in mimetype_parts[1:] + ) + parameters = dict(parameters_list) + return mime_type, parameters diff --git a/openapi_core/validation/validators.py b/openapi_core/validation/validators.py index 20166ae9..b9e7f397 100644 --- a/openapi_core/validation/validators.py +++ b/openapi_core/validation/validators.py @@ -86,10 +86,13 @@ def _find_media_type( return finder.get_first() return finder.find(mimetype) - def _deserialise_media_type(self, mimetype: str, value: Any) -> Any: + def _deserialise_media_type( + self, mimetype: str, parameters: Mapping[str, str], value: Any + ) -> Any: deserializer = self.media_type_deserializers_factory.create( mimetype, extra_media_type_deserializers=self.extra_media_type_deserializers, + parameters=parameters, ) return deserializer.deserialize(value) @@ -194,8 +197,10 @@ def _convert_content_schema_value_and_schema( content: Spec, mimetype: Optional[str] = None, ) -> Tuple[Any, Optional[Spec]]: - media_type, mime_type = self._find_media_type(content, mimetype) - deserialised = self._deserialise_media_type(mime_type, raw) + mime_type, parameters, media_type = self._find_media_type( + content, mimetype + ) + deserialised = self._deserialise_media_type(mime_type, parameters, raw) casted = self._cast(media_type, deserialised) if "schema" not in media_type: diff --git a/tests/integration/test_petstore.py b/tests/integration/test_petstore.py index 2d8794d5..1c28dc36 100644 --- a/tests/integration/test_petstore.py +++ b/tests/integration/test_petstore.py @@ -230,13 +230,15 @@ def test_get_pets_response_no_schema(self, spec): assert result.body is None - data = "" - response = MockResponse(data, status_code=404, mimetype="text/html") + data = b"" + response = MockResponse( + data, status_code=404, mimetype="text/html; charset=utf-8" + ) response_result = unmarshal_response(request, response, spec=spec) assert response_result.errors == [] - assert response_result.data == data + assert response_result.data == data.decode("utf-8") def test_get_pets_invalid_response(self, spec, response_unmarshaller): host_url = "http://petstore.swagger.io/v1" diff --git a/tests/unit/deserializing/test_media_types_deserializers.py b/tests/unit/deserializing/test_media_types_deserializers.py index e6f3bed8..28279f93 100644 --- a/tests/unit/deserializing/test_media_types_deserializers.py +++ b/tests/unit/deserializing/test_media_types_deserializers.py @@ -14,6 +14,7 @@ class TestMediaTypeDeserializer: def deserializer_factory(self): def create_deserializer( media_type, + parameters=None, media_type_deserializers=media_type_deserializers, extra_media_type_deserializers=None, ): @@ -21,6 +22,7 @@ def create_deserializer( media_type_deserializers, ).create( media_type, + parameters=parameters, extra_media_type_deserializers=extra_media_type_deserializers, ) @@ -49,19 +51,33 @@ def test_no_deserializer(self, deserializer_factory): assert result == value @pytest.mark.parametrize( - "mimetype", + "mimetype,parameters,value,expected", [ - "text/plain", - "text/html", + ( + "text/plain", + {"charset": "iso-8859-2"}, + b"\xb1\xb6\xbc\xe6", + "ąśźć", + ), + ( + "text/plain", + {"charset": "utf-8"}, + b"\xc4\x85\xc5\x9b\xc5\xba\xc4\x87", + "ąśźć", + ), + ("text/plain", {}, b"\xc4\x85\xc5\x9b\xc5\xba\xc4\x87", "ąśźć"), + ("text/plain", {}, "somestr", "somestr"), + ("text/html", {}, "somestr", "somestr"), ], ) - def test_plain_valid(self, deserializer_factory, mimetype): - deserializer = deserializer_factory(mimetype) - value = "somestr" + def test_plain_valid( + self, deserializer_factory, mimetype, parameters, value, expected + ): + deserializer = deserializer_factory(mimetype, parameters=parameters) result = deserializer.deserialize(value) - assert result == value + assert result == expected @pytest.mark.parametrize( "mimetype", diff --git a/tests/unit/templating/test_media_types_finders.py b/tests/unit/templating/test_media_types_finders.py index 3a93fb94..62adfdae 100644 --- a/tests/unit/templating/test_media_types_finders.py +++ b/tests/unit/templating/test_media_types_finders.py @@ -22,17 +22,26 @@ def content(self, spec): def finder(self, content): return MediaTypeFinder(content) + def test_charset(self, finder, content): + mimetype = "text/html; charset=utf-8" + + mimetype, parameters, _ = finder.find(mimetype) + assert mimetype == "text/*" + assert parameters == {"charset": "utf-8"} + def test_exact(self, finder, content): mimetype = "application/json" - _, mimetype = finder.find(mimetype) + mimetype, parameters, _ = finder.find(mimetype) assert mimetype == "application/json" + assert parameters == {} def test_match(self, finder, content): mimetype = "text/html" - _, mimetype = finder.find(mimetype) + mimetype, parameters, _ = finder.find(mimetype) assert mimetype == "text/*" + assert parameters == {} def test_not_found(self, finder, content): mimetype = "unknown"