Skip to content

Commit

Permalink
Parameter and header get value refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
p1c2u committed Sep 22, 2023
1 parent 0da2a38 commit ee082bc
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 75 deletions.
32 changes: 0 additions & 32 deletions openapi_core/schema/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,38 +45,6 @@ def get_explode(param_or_header: Spec) -> bool:
return style == "form"


def get_value(
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
"""Returns parameter/header value from specific location"""
name = name or param_or_header["name"]
style = get_style(param_or_header)

if name not in location:
# Only check if the name is not in the location if the style of
# the param is deepObject,this is because deepObjects will never be found
# as their key also includes the properties of the object already.
if style != "deepObject":
raise KeyError
keys_str = " ".join(location.keys())
if not re.search(rf"{name}\[\w+\]", keys_str):
raise KeyError

aslist = get_aslist(param_or_header)
explode = get_explode(param_or_header)
if aslist and explode:
if style == "deepObject":
return get_deep_object_value(location, name)
if isinstance(location, SuportsGetAll):
return location.getall(name)
if isinstance(location, SuportsGetList):
return location.getlist(name)

return location[name]


def get_deep_object_value(
location: Mapping[str, Any],
name: Optional[str] = None,
Expand Down
4 changes: 4 additions & 0 deletions openapi_core/templating/media_types/finders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ class MediaTypeFinder:
def __init__(self, content: Spec):
self.content = content

def get_first(self) -> MediaType:
mimetype, media_type = next(self.content.items())
return MediaType(media_type, mimetype)

def find(self, mimetype: str) -> MediaType:
if mimetype in self.content:
return MediaType(self.content / mimetype, mimetype)
Expand Down
9 changes: 4 additions & 5 deletions openapi_core/unmarshalling/unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,21 @@ def _unmarshal_schema(self, schema: Spec, value: Any) -> Any:

def _get_param_or_header_value(
self,
raw: Any,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
casted, schema = self._get_param_or_header_value_and_schema(
param_or_header, location, name
raw, param_or_header
)
if schema is None:
return casted
return self._unmarshal_schema(schema, casted)

def _get_content_value(
self, raw: Any, mimetype: str, content: Spec
self, raw: Any, content: Spec, mimetype: Optional[str] = None
) -> Any:
casted, schema = self._get_content_value_and_schema(
raw, mimetype, content
raw, content, mimetype
)
if schema is None:
return casted
Expand Down
5 changes: 3 additions & 2 deletions openapi_core/validation/request/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,9 @@ def _get_parameter(

param_location = param["in"]
location = parameters[param_location]

try:
return self._get_param_or_header_value(param, location)
return self._get_param_or_header(param, location, name=name)
except KeyError:
required = param.getkey("required", False)
if required:
Expand Down Expand Up @@ -248,7 +249,7 @@ def _get_body(
content = request_body / "content"

raw_body = self._get_body_value(body, request_body)
return self._get_content_value(raw_body, mimetype, content)
return self._get_content_value(raw_body, content, mimetype)

def _get_body_value(self, body: Optional[str], request_body: Spec) -> Any:
if not body:
Expand Down
4 changes: 2 additions & 2 deletions openapi_core/validation/response/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _get_data(
content = operation_response / "content"

raw_data = self._get_data_value(data)
return self._get_content_value(raw_data, mimetype, content)
return self._get_content_value(raw_data, content, mimetype)

def _get_data_value(self, data: str) -> Any:
if not data:
Expand Down Expand Up @@ -163,7 +163,7 @@ def _get_header(
)

try:
return self._get_param_or_header_value(header, headers, name=name)
return self._get_param_or_header(header, headers, name=name)
except KeyError:
required = header.getkey("required", False)
if required:
Expand Down
146 changes: 112 additions & 34 deletions openapi_core/validation/validators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""OpenAPI core validation validators module"""
import re
from functools import cached_property
from typing import Any
from typing import Mapping
Expand All @@ -23,7 +24,12 @@
)
from openapi_core.protocols import Request
from openapi_core.protocols import WebhookRequest
from openapi_core.schema.parameters import get_value
from openapi_core.schema.parameters import get_aslist
from openapi_core.schema.parameters import get_deep_object_value
from openapi_core.schema.parameters import get_explode
from openapi_core.schema.parameters import get_style
from openapi_core.schema.protocols import SuportsGetAll
from openapi_core.schema.protocols import SuportsGetList
from openapi_core.spec import Spec
from openapi_core.templating.media_types.datatypes import MediaType
from openapi_core.templating.paths.datatypes import PathOperationServer
Expand Down Expand Up @@ -70,10 +76,14 @@ def __init__(
self.extra_format_validators = extra_format_validators
self.extra_media_type_deserializers = extra_media_type_deserializers

def _get_media_type(self, content: Spec, mimetype: str) -> MediaType:
def _get_media_type(
self, content: Spec, mimetype: Optional[str] = None
) -> MediaType:
from openapi_core.templating.media_types.finders import MediaTypeFinder

finder = MediaTypeFinder(content)
if mimetype is None:
return finder.get_first()
return finder.find(mimetype)

def _deserialise_media_type(self, mimetype: str, value: Any) -> Any:
Expand All @@ -99,25 +109,74 @@ def _validate_schema(self, schema: Spec, value: Any) -> None:
)
validator.validate(value)

def _get_param_or_header_value(
def _get_param_or_header(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
# Simple scenario
if "content" not in param_or_header:
return self._get_simple_value(param_or_header, location, name=name)

# Complex scenario
return self._get_complex(param_or_header, location, name=name)

def _get_simple_value(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
try:
raw = self._get_style_value(param_or_header, location, name=name)
except KeyError:
if "schema" not in param_or_header:
raise
schema = param_or_header / "schema"
if "default" not in schema:
raise
raw = schema["default"]
return self._get_param_or_header_value(raw, param_or_header)

def _get_complex(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
content = param_or_header / "content"
try:
raw = self._get_media_type_value(
param_or_header, location, name=name
)
except KeyError:
if "schema" not in param_or_header:
raise
schema = param_or_header / "schema"
if "default" not in schema:
raise
raw = schema["default"]
return self._get_content_value(raw, content)

def _get_param_or_header_value(
self,
raw: Any,
param_or_header: Spec,
) -> Any:
casted, schema = self._get_param_or_header_value_and_schema(
param_or_header, location, name
raw, param_or_header
)
if schema is None:
return casted
self._validate_schema(schema, casted)
return casted

def _get_content_value(
self, raw: Any, mimetype: str, content: Spec
self, raw: Any, content: Spec, mimetype: Optional[str] = None
) -> Any:
casted, schema = self._get_content_value_and_schema(
raw, mimetype, content
raw, content, mimetype
)
if schema is None:
return casted
Expand All @@ -126,39 +185,19 @@ def _get_content_value(

def _get_param_or_header_value_and_schema(
self,
raw: Any,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Tuple[Any, Spec]:
try:
raw_value = get_value(param_or_header, location, name=name)
except KeyError:
if "schema" not in param_or_header:
raise
schema = param_or_header / "schema"
if "default" not in schema:
raise
casted = schema["default"]
else:
# Simple scenario
if "content" not in param_or_header:
deserialised = self._deserialise_style(
param_or_header, raw_value
)
schema = param_or_header / "schema"
# Complex scenario
else:
content = param_or_header / "content"
mimetype, media_type = next(content.items())
deserialised = self._deserialise_media_type(
mimetype, raw_value
)
schema = media_type / "schema"
casted = self._cast(schema, deserialised)
deserialised = self._deserialise_style(param_or_header, raw)
schema = param_or_header / "schema"
casted = self._cast(schema, deserialised)
return casted, schema

def _get_content_value_and_schema(
self, raw: Any, mimetype: str, content: Spec
self,
raw: Any,
content: Spec,
mimetype: Optional[str] = None,
) -> Tuple[Any, Optional[Spec]]:
media_type, mimetype = self._get_media_type(content, mimetype)
deserialised = self._deserialise_media_type(mimetype, raw)
Expand All @@ -170,6 +209,45 @@ def _get_content_value_and_schema(
schema = media_type / "schema"
return casted, schema

def _get_style_value(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
name = name or param_or_header["name"]
style = get_style(param_or_header)
if name not in location:
# Only check if the name is not in the location if the style of
# the param is deepObject,this is because deepObjects will never be found
# as their key also includes the properties of the object already.
if style != "deepObject":
raise KeyError
keys_str = " ".join(location.keys())
if not re.search(rf"{name}\[\w+\]", keys_str):
raise KeyError

aslist = get_aslist(param_or_header)
explode = get_explode(param_or_header)
if aslist and explode:
if style == "deepObject":
return get_deep_object_value(location, name)
if isinstance(location, SuportsGetAll):
return location.getall(name)
if isinstance(location, SuportsGetList):
return location.getlist(name)

return location[name]

def _get_media_type_value(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
name = name or param_or_header["name"]
return location[name]


class BaseAPICallValidator(BaseValidator):
@cached_property
Expand Down

0 comments on commit ee082bc

Please sign in to comment.