From 2c12d121f203726b64dd098a49d96294b0ab07b3 Mon Sep 17 00:00:00 2001 From: p1c2u Date: Fri, 22 Sep 2023 16:55:45 +0000 Subject: [PATCH] Parameter and header get value refactor --- openapi_core/schema/parameters.py | 32 ---- .../templating/media_types/finders.py | 4 + openapi_core/unmarshalling/unmarshallers.py | 9 +- openapi_core/validation/request/validators.py | 5 +- .../validation/response/validators.py | 4 +- openapi_core/validation/validators.py | 141 +++++++++++++----- tests/integration/data/v3.0/petstore.yaml | 15 ++ 7 files changed, 133 insertions(+), 77 deletions(-) diff --git a/openapi_core/schema/parameters.py b/openapi_core/schema/parameters.py index c8f2fa33..e8ab1fdf 100644 --- a/openapi_core/schema/parameters.py +++ b/openapi_core/schema/parameters.py @@ -45,38 +45,6 @@ def get_explode(param_or_header: Spec) -> bool: return style == "form" -def get_value( - param_or_header: Spec, - location: Mapping[str, Any], - name: Optional[str] = None, -) -> Any: - """Returns parameter/header value from specific location""" - name = name or param_or_header["name"] - style = get_style(param_or_header) - - if name not in location: - # Only check if the name is not in the location if the style of - # the param is deepObject,this is because deepObjects will never be found - # as their key also includes the properties of the object already. - if style != "deepObject": - raise KeyError - keys_str = " ".join(location.keys()) - if not re.search(rf"{name}\[\w+\]", keys_str): - raise KeyError - - aslist = get_aslist(param_or_header) - explode = get_explode(param_or_header) - if aslist and explode: - if style == "deepObject": - return get_deep_object_value(location, name) - if isinstance(location, SuportsGetAll): - return location.getall(name) - if isinstance(location, SuportsGetList): - return location.getlist(name) - - return location[name] - - def get_deep_object_value( location: Mapping[str, Any], name: Optional[str] = None, diff --git a/openapi_core/templating/media_types/finders.py b/openapi_core/templating/media_types/finders.py index b7be6a4d..6477c9d7 100644 --- a/openapi_core/templating/media_types/finders.py +++ b/openapi_core/templating/media_types/finders.py @@ -10,6 +10,10 @@ class MediaTypeFinder: def __init__(self, content: Spec): self.content = content + def get_first(self) -> MediaType: + mimetype, media_type = next(self.content.items()) + return MediaType(media_type, mimetype) + def find(self, mimetype: str) -> MediaType: if mimetype in self.content: return MediaType(self.content / mimetype, mimetype) diff --git a/openapi_core/unmarshalling/unmarshallers.py b/openapi_core/unmarshalling/unmarshallers.py index 5efaf5bf..7cc051ef 100644 --- a/openapi_core/unmarshalling/unmarshallers.py +++ b/openapi_core/unmarshalling/unmarshallers.py @@ -89,22 +89,21 @@ def _unmarshal_schema(self, schema: Spec, value: Any) -> Any: def _get_param_or_header_value( self, + raw: Any, param_or_header: Spec, - location: Mapping[str, Any], - name: Optional[str] = None, ) -> Any: casted, schema = self._get_param_or_header_value_and_schema( - param_or_header, location, name + raw, param_or_header ) if schema is None: return casted return self._unmarshal_schema(schema, casted) def _get_content_value( - self, raw: Any, mimetype: str, content: Spec + self, raw: Any, content: Spec, mimetype: Optional[str] = None ) -> Any: casted, schema = self._get_content_value_and_schema( - raw, mimetype, content + raw, content, mimetype ) if schema is None: return casted diff --git a/openapi_core/validation/request/validators.py b/openapi_core/validation/request/validators.py index fc21a933..44497987 100644 --- a/openapi_core/validation/request/validators.py +++ b/openapi_core/validation/request/validators.py @@ -189,8 +189,9 @@ def _get_parameter( param_location = param["in"] location = parameters[param_location] + try: - return self._get_param_or_header_value(param, location) + return self._get_param_or_header(param, location, name=name) except KeyError: required = param.getkey("required", False) if required: @@ -248,7 +249,7 @@ def _get_body( content = request_body / "content" raw_body = self._get_body_value(body, request_body) - return self._get_content_value(raw_body, mimetype, content) + return self._get_content_value(raw_body, content, mimetype) def _get_body_value(self, body: Optional[str], request_body: Spec) -> Any: if not body: diff --git a/openapi_core/validation/response/validators.py b/openapi_core/validation/response/validators.py index 49c6f193..fbb97b9c 100644 --- a/openapi_core/validation/response/validators.py +++ b/openapi_core/validation/response/validators.py @@ -114,7 +114,7 @@ def _get_data( content = operation_response / "content" raw_data = self._get_data_value(data) - return self._get_content_value(raw_data, mimetype, content) + return self._get_content_value(raw_data, content, mimetype) def _get_data_value(self, data: str) -> Any: if not data: @@ -163,7 +163,7 @@ def _get_header( ) try: - return self._get_param_or_header_value(header, headers, name=name) + return self._get_param_or_header(header, headers, name=name) except KeyError: required = header.getkey("required", False) if required: diff --git a/openapi_core/validation/validators.py b/openapi_core/validation/validators.py index 4fbd7e36..ce9c0fcc 100644 --- a/openapi_core/validation/validators.py +++ b/openapi_core/validation/validators.py @@ -1,4 +1,5 @@ """OpenAPI core validation validators module""" +import re from functools import cached_property from typing import Any from typing import Mapping @@ -23,7 +24,12 @@ ) from openapi_core.protocols import Request from openapi_core.protocols import WebhookRequest -from openapi_core.schema.parameters import get_value +from openapi_core.schema.parameters import get_aslist +from openapi_core.schema.parameters import get_deep_object_value +from openapi_core.schema.parameters import get_explode +from openapi_core.schema.parameters import get_style +from openapi_core.schema.protocols import SuportsGetAll +from openapi_core.schema.protocols import SuportsGetList from openapi_core.spec import Spec from openapi_core.templating.media_types.datatypes import MediaType from openapi_core.templating.paths.datatypes import PathOperationServer @@ -70,10 +76,14 @@ def __init__( self.extra_format_validators = extra_format_validators self.extra_media_type_deserializers = extra_media_type_deserializers - def _get_media_type(self, content: Spec, mimetype: str) -> MediaType: + def _get_media_type( + self, content: Spec, mimetype: Optional[str] = None + ) -> MediaType: from openapi_core.templating.media_types.finders import MediaTypeFinder finder = MediaTypeFinder(content) + if mimetype is None: + return finder.get_first() return finder.find(mimetype) def _deserialise_media_type(self, mimetype: str, value: Any) -> Any: @@ -99,14 +109,54 @@ def _validate_schema(self, schema: Spec, value: Any) -> None: ) validator.validate(value) - def _get_param_or_header_value( + def _get_param_or_header( + self, + param_or_header: Spec, + location: Mapping[str, Any], + name: Optional[str] = None, + ) -> Any: + # Simple scenario + if "content" not in param_or_header: + return self._get_simple_value(param_or_header, location, name=name) + + # Complex scenario + return self._get_complex(param_or_header, location, name=name) + + def _get_simple_value( + self, + param_or_header: Spec, + location: Mapping[str, Any], + name: Optional[str] = None, + ) -> Any: + try: + raw = self._get_style_value(param_or_header, location, name=name) + except KeyError: + # in simple scenrios schema always exist + schema = param_or_header / "schema" + if "default" not in schema: + raise + raw = schema["default"] + return self._get_param_or_header_value(raw, param_or_header) + + def _get_complex( self, param_or_header: Spec, location: Mapping[str, Any], name: Optional[str] = None, + ) -> Any: + content = param_or_header / "content" + # no point to catch KetError + # in complex scenrios schema doesn't exist + raw = self._get_media_type_value(param_or_header, location, name=name) + return self._get_content_value(raw, content) + + def _get_param_or_header_value( + self, + raw: Any, + param_or_header: Spec, ) -> Any: casted, schema = self._get_param_or_header_value_and_schema( - param_or_header, location, name + raw, param_or_header ) if schema is None: return casted @@ -114,10 +164,10 @@ def _get_param_or_header_value( return casted def _get_content_value( - self, raw: Any, mimetype: str, content: Spec + self, raw: Any, content: Spec, mimetype: Optional[str] = None ) -> Any: casted, schema = self._get_content_value_and_schema( - raw, mimetype, content + raw, content, mimetype ) if schema is None: return casted @@ -126,42 +176,22 @@ def _get_content_value( def _get_param_or_header_value_and_schema( self, + raw: Any, param_or_header: Spec, - location: Mapping[str, Any], - name: Optional[str] = None, ) -> Tuple[Any, Spec]: - try: - raw_value = get_value(param_or_header, location, name=name) - except KeyError: - if "schema" not in param_or_header: - raise - schema = param_or_header / "schema" - if "default" not in schema: - raise - casted = schema["default"] - else: - # Simple scenario - if "content" not in param_or_header: - deserialised = self._deserialise_style( - param_or_header, raw_value - ) - schema = param_or_header / "schema" - # Complex scenario - else: - content = param_or_header / "content" - mimetype, media_type = next(content.items()) - deserialised = self._deserialise_media_type( - mimetype, raw_value - ) - schema = media_type / "schema" - casted = self._cast(schema, deserialised) + deserialised = self._deserialise_style(param_or_header, raw) + schema = param_or_header / "schema" + casted = self._cast(schema, deserialised) return casted, schema def _get_content_value_and_schema( - self, raw: Any, mimetype: str, content: Spec + self, + raw: Any, + content: Spec, + mimetype: Optional[str] = None, ) -> Tuple[Any, Optional[Spec]]: - media_type, mimetype = self._get_media_type(content, mimetype) - deserialised = self._deserialise_media_type(mimetype, raw) + media_type, mime_type = self._get_media_type(content, mimetype) + deserialised = self._deserialise_media_type(mime_type, raw) casted = self._cast(media_type, deserialised) if "schema" not in media_type: @@ -170,6 +200,45 @@ def _get_content_value_and_schema( schema = media_type / "schema" return casted, schema + def _get_style_value( + self, + param_or_header: Spec, + location: Mapping[str, Any], + name: Optional[str] = None, + ) -> Any: + name = name or param_or_header["name"] + style = get_style(param_or_header) + if name not in location: + # Only check if the name is not in the location if the style of + # the param is deepObject,this is because deepObjects will never be found + # as their key also includes the properties of the object already. + if style != "deepObject": + raise KeyError + keys_str = " ".join(location.keys()) + if not re.search(rf"{name}\[\w+\]", keys_str): + raise KeyError + + aslist = get_aslist(param_or_header) + explode = get_explode(param_or_header) + if aslist and explode: + if style == "deepObject": + return get_deep_object_value(location, name) + if isinstance(location, SuportsGetAll): + return location.getall(name) + if isinstance(location, SuportsGetList): + return location.getlist(name) + + return location[name] + + def _get_media_type_value( + self, + param_or_header: Spec, + location: Mapping[str, Any], + name: Optional[str] = None, + ) -> Any: + name = name or param_or_header["name"] + return location[name] + class BaseAPICallValidator(BaseValidator): @cached_property diff --git a/tests/integration/data/v3.0/petstore.yaml b/tests/integration/data/v3.0/petstore.yaml index 282b880d..43b27398 100644 --- a/tests/integration/data/v3.0/petstore.yaml +++ b/tests/integration/data/v3.0/petstore.yaml @@ -82,6 +82,21 @@ paths: application/json: schema: $ref: "#/components/schemas/Coordinates" + - name: color + in: query + description: RGB color + style: deepObject + required: false + explode: true + schema: + type: object + properties: + R: + type: integer + G: + type: integer + B: + type: integer responses: '200': $ref: "#/components/responses/PetsResponse"