Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameter and header get value refactor #677

Merged
merged 1 commit into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 0 additions & 32 deletions openapi_core/schema/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,38 +45,6 @@ def get_explode(param_or_header: Spec) -> bool:
return style == "form"


def get_value(
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
"""Returns parameter/header value from specific location"""
name = name or param_or_header["name"]
style = get_style(param_or_header)

if name not in location:
# Only check if the name is not in the location if the style of
# the param is deepObject,this is because deepObjects will never be found
# as their key also includes the properties of the object already.
if style != "deepObject":
raise KeyError
keys_str = " ".join(location.keys())
if not re.search(rf"{name}\[\w+\]", keys_str):
raise KeyError

aslist = get_aslist(param_or_header)
explode = get_explode(param_or_header)
if aslist and explode:
if style == "deepObject":
return get_deep_object_value(location, name)
if isinstance(location, SuportsGetAll):
return location.getall(name)
if isinstance(location, SuportsGetList):
return location.getlist(name)

return location[name]


def get_deep_object_value(
location: Mapping[str, Any],
name: Optional[str] = None,
Expand Down
4 changes: 4 additions & 0 deletions openapi_core/templating/media_types/finders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ class MediaTypeFinder:
def __init__(self, content: Spec):
self.content = content

def get_first(self) -> MediaType:
mimetype, media_type = next(self.content.items())
return MediaType(media_type, mimetype)

def find(self, mimetype: str) -> MediaType:
if mimetype in self.content:
return MediaType(self.content / mimetype, mimetype)
Expand Down
6 changes: 3 additions & 3 deletions openapi_core/unmarshalling/response/unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _unmarshal(
operation: Spec,
) -> ResponseUnmarshalResult:
try:
operation_response = self._get_operation_response(
operation_response = self._find_operation_response(
response.status_code, operation
)
# don't process if operation errors
Expand Down Expand Up @@ -96,7 +96,7 @@ def _unmarshal_data(
operation: Spec,
) -> ResponseUnmarshalResult:
try:
operation_response = self._get_operation_response(
operation_response = self._find_operation_response(
response.status_code, operation
)
# don't process if operation errors
Expand Down Expand Up @@ -124,7 +124,7 @@ def _unmarshal_headers(
operation: Spec,
) -> ResponseUnmarshalResult:
try:
operation_response = self._get_operation_response(
operation_response = self._find_operation_response(
response.status_code, operation
)
# don't process if operation errors
Expand Down
17 changes: 8 additions & 9 deletions openapi_core/unmarshalling/unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,23 @@ def _unmarshal_schema(self, schema: Spec, value: Any) -> Any:
)
return unmarshaller.unmarshal(value)

def _get_param_or_header_value(
def _convert_schema_style_value(
self,
raw: Any,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
casted, schema = self._get_param_or_header_value_and_schema(
param_or_header, location, name
casted, schema = self._convert_schema_style_value_and_schema(
raw, param_or_header
)
if schema is None:
return casted
return self._unmarshal_schema(schema, casted)

def _get_content_value(
self, raw: Any, mimetype: str, content: Spec
def _convert_content_schema_value(
self, raw: Any, content: Spec, mimetype: Optional[str] = None
) -> Any:
casted, schema = self._get_content_value_and_schema(
raw, mimetype, content
casted, schema = self._convert_content_schema_value_and_schema(
raw, content, mimetype
)
if schema is None:
return casted
Expand Down
5 changes: 3 additions & 2 deletions openapi_core/validation/request/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,9 @@ def _get_parameter(

param_location = param["in"]
location = parameters[param_location]

try:
return self._get_param_or_header_value(param, location)
return self._get_param_or_header(param, location, name=name)
except KeyError:
required = param.getkey("required", False)
if required:
Expand Down Expand Up @@ -248,7 +249,7 @@ def _get_body(
content = request_body / "content"

raw_body = self._get_body_value(body, request_body)
return self._get_content_value(raw_body, mimetype, content)
return self._convert_content_schema_value(raw_body, content, mimetype)

def _get_body_value(self, body: Optional[str], request_body: Spec) -> Any:
if not body:
Expand Down
12 changes: 6 additions & 6 deletions openapi_core/validation/response/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _iter_errors(
operation: Spec,
) -> Iterator[Exception]:
try:
operation_response = self._get_operation_response(
operation_response = self._find_operation_response(
status_code, operation
)
# don't process if operation errors
Expand All @@ -64,7 +64,7 @@ def _iter_data_errors(
self, status_code: int, data: str, mimetype: str, operation: Spec
) -> Iterator[Exception]:
try:
operation_response = self._get_operation_response(
operation_response = self._find_operation_response(
status_code, operation
)
# don't process if operation errors
Expand All @@ -81,7 +81,7 @@ def _iter_headers_errors(
self, status_code: int, headers: Mapping[str, Any], operation: Spec
) -> Iterator[Exception]:
try:
operation_response = self._get_operation_response(
operation_response = self._find_operation_response(
status_code, operation
)
# don't process if operation errors
Expand All @@ -94,7 +94,7 @@ def _iter_headers_errors(
except HeadersError as exc:
yield from exc.context

def _get_operation_response(
def _find_operation_response(
self,
status_code: int,
operation: Spec,
Expand All @@ -114,7 +114,7 @@ def _get_data(
content = operation_response / "content"

raw_data = self._get_data_value(data)
return self._get_content_value(raw_data, mimetype, content)
return self._convert_content_schema_value(raw_data, content, mimetype)

def _get_data_value(self, data: str) -> Any:
if not data:
Expand Down Expand Up @@ -163,7 +163,7 @@ def _get_header(
)

try:
return self._get_param_or_header_value(header, headers, name=name)
return self._get_param_or_header(header, headers, name=name)
except KeyError:
required = header.getkey("required", False)
if required:
Expand Down
155 changes: 114 additions & 41 deletions openapi_core/validation/validators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""OpenAPI core validation validators module"""
import re
from functools import cached_property
from typing import Any
from typing import Mapping
Expand All @@ -23,7 +24,12 @@
)
from openapi_core.protocols import Request
from openapi_core.protocols import WebhookRequest
from openapi_core.schema.parameters import get_value
from openapi_core.schema.parameters import get_aslist
from openapi_core.schema.parameters import get_deep_object_value
from openapi_core.schema.parameters import get_explode
from openapi_core.schema.parameters import get_style
from openapi_core.schema.protocols import SuportsGetAll
from openapi_core.schema.protocols import SuportsGetList
from openapi_core.spec import Spec
from openapi_core.templating.media_types.datatypes import MediaType
from openapi_core.templating.paths.datatypes import PathOperationServer
Expand Down Expand Up @@ -70,10 +76,14 @@ def __init__(
self.extra_format_validators = extra_format_validators
self.extra_media_type_deserializers = extra_media_type_deserializers

def _get_media_type(self, content: Spec, mimetype: str) -> MediaType:
def _find_media_type(
self, content: Spec, mimetype: Optional[str] = None
) -> MediaType:
from openapi_core.templating.media_types.finders import MediaTypeFinder

finder = MediaTypeFinder(content)
if mimetype is None:
return finder.get_first()
return finder.find(mimetype)

def _deserialise_media_type(self, mimetype: str, value: Any) -> Any:
Expand All @@ -99,69 +109,93 @@ def _validate_schema(self, schema: Spec, value: Any) -> None:
)
validator.validate(value)

def _get_param_or_header_value(
def _get_param_or_header(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
casted, schema = self._get_param_or_header_value_and_schema(
param_or_header, location, name
# Simple scenario
if "content" not in param_or_header:
return self._get_simple_param_or_header(
param_or_header, location, name=name
)

# Complex scenario
return self._get_complex_param_or_header(
param_or_header, location, name=name
)

def _get_simple_param_or_header(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
try:
raw = self._get_style_value(param_or_header, location, name=name)
except KeyError:
# in simple scenrios schema always exist
schema = param_or_header / "schema"
if "default" not in schema:
raise
raw = schema["default"]
return self._convert_schema_style_value(raw, param_or_header)

def _get_complex_param_or_header(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
content = param_or_header / "content"
# no point to catch KetError
# in complex scenrios schema doesn't exist
raw = self._get_media_type_value(param_or_header, location, name=name)
return self._convert_content_schema_value(raw, content)

def _convert_schema_style_value(
self,
raw: Any,
param_or_header: Spec,
) -> Any:
casted, schema = self._convert_schema_style_value_and_schema(
raw, param_or_header
)
if schema is None:
return casted
self._validate_schema(schema, casted)
return casted

def _get_content_value(
self, raw: Any, mimetype: str, content: Spec
def _convert_content_schema_value(
self, raw: Any, content: Spec, mimetype: Optional[str] = None
) -> Any:
casted, schema = self._get_content_value_and_schema(
raw, mimetype, content
casted, schema = self._convert_content_schema_value_and_schema(
raw, content, mimetype
)
if schema is None:
return casted
self._validate_schema(schema, casted)
return casted

def _get_param_or_header_value_and_schema(
def _convert_schema_style_value_and_schema(
self,
raw: Any,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Tuple[Any, Spec]:
try:
raw_value = get_value(param_or_header, location, name=name)
except KeyError:
if "schema" not in param_or_header:
raise
schema = param_or_header / "schema"
if "default" not in schema:
raise
casted = schema["default"]
else:
# Simple scenario
if "content" not in param_or_header:
deserialised = self._deserialise_style(
param_or_header, raw_value
)
schema = param_or_header / "schema"
# Complex scenario
else:
content = param_or_header / "content"
mimetype, media_type = next(content.items())
deserialised = self._deserialise_media_type(
mimetype, raw_value
)
schema = media_type / "schema"
casted = self._cast(schema, deserialised)
deserialised = self._deserialise_style(param_or_header, raw)
schema = param_or_header / "schema"
casted = self._cast(schema, deserialised)
return casted, schema

def _get_content_value_and_schema(
self, raw: Any, mimetype: str, content: Spec
def _convert_content_schema_value_and_schema(
self,
raw: Any,
content: Spec,
mimetype: Optional[str] = None,
) -> Tuple[Any, Optional[Spec]]:
media_type, mimetype = self._get_media_type(content, mimetype)
deserialised = self._deserialise_media_type(mimetype, raw)
media_type, mime_type = self._find_media_type(content, mimetype)
deserialised = self._deserialise_media_type(mime_type, raw)
casted = self._cast(media_type, deserialised)

if "schema" not in media_type:
Expand All @@ -170,6 +204,45 @@ def _get_content_value_and_schema(
schema = media_type / "schema"
return casted, schema

def _get_style_value(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
name = name or param_or_header["name"]
style = get_style(param_or_header)
if name not in location:
# Only check if the name is not in the location if the style of
# the param is deepObject,this is because deepObjects will never be found
# as their key also includes the properties of the object already.
if style != "deepObject":
raise KeyError
keys_str = " ".join(location.keys())
if not re.search(rf"{name}\[\w+\]", keys_str):
raise KeyError

aslist = get_aslist(param_or_header)
explode = get_explode(param_or_header)
if aslist and explode:
if style == "deepObject":
return get_deep_object_value(location, name)
if isinstance(location, SuportsGetAll):
return location.getall(name)
if isinstance(location, SuportsGetList):
return location.getlist(name)

return location[name]

def _get_media_type_value(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
name = name or param_or_header["name"]
return location[name]


class BaseAPICallValidator(BaseValidator):
@cached_property
Expand Down
Loading