diff --git a/openapi_core/contrib/starlette/handlers.py b/openapi_core/contrib/starlette/handlers.py new file mode 100644 index 00000000..10487d67 --- /dev/null +++ b/openapi_core/contrib/starlette/handlers.py @@ -0,0 +1,65 @@ +"""OpenAPI core contrib starlette handlers module""" +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import Type + +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.responses import Response + +from openapi_core.templating.media_types.exceptions import MediaTypeNotFound +from openapi_core.templating.paths.exceptions import OperationNotFound +from openapi_core.templating.paths.exceptions import PathNotFound +from openapi_core.templating.paths.exceptions import ServerNotFound +from openapi_core.templating.security.exceptions import SecurityNotFound +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult + + +class StarletteOpenAPIErrorsHandler: + OPENAPI_ERROR_STATUS: Dict[Type[BaseException], int] = { + ServerNotFound: 400, + SecurityNotFound: 403, + OperationNotFound: 405, + PathNotFound: 404, + MediaTypeNotFound: 415, + } + + def __call__( + self, + errors: Iterable[Exception], + ) -> JSONResponse: + data_errors = [self.format_openapi_error(err) for err in errors] + data = { + "errors": data_errors, + } + data_error_max = max(data_errors, key=self.get_error_status) + return JSONResponse(data, status=data_error_max["status"]) + + @classmethod + def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]: + if error.__cause__ is not None: + error = error.__cause__ + return { + "title": str(error), + "status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400), + "type": str(type(error)), + } + + @classmethod + def get_error_status(cls, error: Dict[str, Any]) -> str: + return str(error["status"]) + + +class StarletteOpenAPIValidRequestHandler: + def __init__(self, request: Request, call_next: Callable[[Any], Response]): + self.request = request + self.call_next = call_next + + async def __call__( + self, request_unmarshal_result: RequestUnmarshalResult + ) -> Response: + self.request.openapi = request_unmarshal_result + return await self.call_next(self.request) diff --git a/openapi_core/contrib/starlette/middlewares.py b/openapi_core/contrib/starlette/middlewares.py new file mode 100644 index 00000000..d5f3ac25 --- /dev/null +++ b/openapi_core/contrib/starlette/middlewares.py @@ -0,0 +1,53 @@ +"""OpenAPI core contrib starlette middlewares module""" +from typing import Callable + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +from openapi_core import Spec +from openapi_core.contrib.starlette.handlers import StarletteOpenAPIErrorsHandler +from openapi_core.contrib.starlette.handlers import ( + StarletteOpenAPIValidRequestHandler, +) +from openapi_core.contrib.starlette.requests import StarletteOpenAPIRequest +from openapi_core.contrib.starlette.responses import StarletteOpenAPIResponse +from openapi_core.unmarshalling.processors import AsyncUnmarshallingProcessor + + +class StarletteOpenAPIMiddleware( + BaseHTTPMiddleware, + AsyncUnmarshallingProcessor[Request, Response] +): + request_cls = StarletteOpenAPIRequest + response_cls = StarletteOpenAPIResponse + valid_request_handler_cls = StarletteOpenAPIValidRequestHandler + errors_handler = StarletteOpenAPIErrorsHandler() + + def __init__(self, app, spec: Spec): + BaseHTTPMiddleware.__init__(self, app) + AsyncUnmarshallingProcessor.__init__(self, spec) + + async def dispatch(self, request: Request, call_next) -> Response: + valid_request_handler = self.valid_request_handler_cls( + request, call_next + ) + response = await self.handle_request( + request, valid_request_handler, self.errors_handler + ) + return await self.handle_response(request, response, self.errors_handler) + + async def _get_openapi_request( + self, request: Request + ) -> StarletteOpenAPIRequest: + body = await request.body() + return self.request_cls(request, body) + + async def _get_openapi_response( + self, response: Response + ) -> StarletteOpenAPIResponse: + assert self.response_cls is not None + return self.response_cls(response) + + def _validate_response(self) -> bool: + return self.response_cls is not None diff --git a/openapi_core/contrib/starlette/requests.py b/openapi_core/contrib/starlette/requests.py index d31886bc..f6e3e0dc 100644 --- a/openapi_core/contrib/starlette/requests.py +++ b/openapi_core/contrib/starlette/requests.py @@ -8,7 +8,7 @@ class StarletteOpenAPIRequest: - def __init__(self, request: Request): + def __init__(self, request: Request, body: str | None): if not isinstance(request, Request): raise TypeError(f"'request' argument is not type of {Request}") self.request = request @@ -19,7 +19,7 @@ def __init__(self, request: Request): cookie=self.request.cookies, ) - self._get_body = AsyncToSync(self.request.body, force_new_loop=True) + self._body = body @property def host_url(self) -> str: @@ -35,13 +35,7 @@ def method(self) -> str: @property def body(self) -> Optional[str]: - body = self._get_body() - if body is None: - return None - if isinstance(body, bytes): - return body.decode("utf-8") - assert isinstance(body, str) - return body + return self._body @property def mimetype(self) -> str: diff --git a/openapi_core/unmarshalling/processors.py b/openapi_core/unmarshalling/processors.py index fcec7c26..4316a93b 100644 --- a/openapi_core/unmarshalling/processors.py +++ b/openapi_core/unmarshalling/processors.py @@ -19,6 +19,7 @@ ResponseUnmarshallingProcessor, ) from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType +from openapi_core.util import force_async class UnmarshallingProcessor(Generic[RequestType, ResponseType]): @@ -89,3 +90,79 @@ def handle_response( if response_unmarshal_result.errors: return errors_handler(response_unmarshal_result.errors) return response + + +class AsyncUnmarshallingProcessor(Generic[RequestType, ResponseType]): + def __init__( + self, + spec: Spec, + request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, + response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, + **unmarshaller_kwargs: Any, + ): + if ( + request_unmarshaller_cls is None + or response_unmarshaller_cls is None + ): + classes = get_classes(spec) + if request_unmarshaller_cls is None: + request_unmarshaller_cls = classes.request_unmarshaller_cls + if response_unmarshaller_cls is None: + response_unmarshaller_cls = classes.response_unmarshaller_cls + + self.request_processor = RequestUnmarshallingProcessor( + spec, + request_unmarshaller_cls, + **unmarshaller_kwargs, + ) + self.response_processor = ResponseUnmarshallingProcessor( + spec, + response_unmarshaller_cls, + **unmarshaller_kwargs, + ) + + def _get_openapi_request(self, request: RequestType) -> Request: + raise NotImplementedError + + def _get_openapi_response(self, response: ResponseType) -> Response: + raise NotImplementedError + + def _validate_response(self) -> bool: + raise NotImplementedError + + async def handle_request( + self, + request: RequestType, + valid_handler: ValidRequestHandlerCallable[ResponseType], + errors_handler: ErrorsHandlerCallable[ResponseType], + ) -> ResponseType: + awaitable_get_openapi_request = force_async(self._get_openapi_request) + openapi_request = await awaitable_get_openapi_request(request) + request_unmarshal_result = self.request_processor.process( + openapi_request + ) + if request_unmarshal_result.errors: + return errors_handler(request_unmarshal_result.errors) + awaitable_valid_handler = force_async(valid_handler) + return await awaitable_valid_handler(request_unmarshal_result) + + async def handle_response( + self, + request: RequestType, + response: ResponseType, + errors_handler: ErrorsHandlerCallable[ResponseType], + ) -> ResponseType: + if not self._validate_response(): + return response + awaitable_get_openapi_request = force_async(self._get_openapi_request) + openapi_request = await awaitable_get_openapi_request(request) + awaitable_get_openapi_response = force_async( + self._get_openapi_response + ) + openapi_response = await awaitable_get_openapi_response(response) + response_unmarshal_result = self.response_processor.process( + openapi_request, openapi_response + ) + if response_unmarshal_result.errors: + return errors_handler(response_unmarshal_result.errors) + return response diff --git a/openapi_core/util.py b/openapi_core/util.py index cf551e24..ffc06cd1 100644 --- a/openapi_core/util.py +++ b/openapi_core/util.py @@ -1,4 +1,7 @@ """OpenAPI core util module""" +from asyncio import iscoroutinefunction +from functools import partial +from functools import wraps from itertools import chain from typing import Any from typing import Iterable @@ -20,3 +23,22 @@ def forcebool(val: Any) -> bool: def chainiters(*lists: Iterable[Any]) -> Iterable[Any]: iters = map(lambda l: l and iter(l) or [], lists) return chain(*iters) + + +def is_async_callable(obj: Any) -> Any: + while isinstance(obj, partial): + obj = obj.func + + return iscoroutinefunction(obj) or ( + callable(obj) and iscoroutinefunction(obj.__call__) + ) + + +def force_async(func): + async def decorated(*args, **kwargs): + if is_async_callable(func): + return await func(*args, **kwargs) + else: + return await run_in_threadpool(func, *args, **kwargs) + + return decorated diff --git a/poetry.lock b/poetry.lock index 4dfb9ae0..1d21d1bd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1777,6 +1777,20 @@ files = [ flake8 = ">=4.0" pytest = ">=7.0" +[[package]] +name = "python-multipart" +version = "0.0.6" +description = "A streaming multipart parser for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "python_multipart-0.0.6-py3-none-any.whl", hash = "sha256:ee698bab5ef148b0a760751c261902cd096e57e10558e11aca17646b74ee1c18"}, + {file = "python_multipart-0.0.6.tar.gz", hash = "sha256:e9925a80bb668529f1b67c7fdb0a5dacdd7cbfc6fb0bff3ea443fe22bdd62132"}, +] + +[package.extras] +dev = ["atomicwrites (==1.2.1)", "attrs (==19.2.0)", "coverage (==6.5.0)", "hatch", "invoke (==1.7.3)", "more-itertools (==4.3.0)", "pbr (==4.3.0)", "pluggy (==1.0.0)", "py (==1.11.0)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-timeout (==2.1.0)", "pyyaml (==5.1)"] + [[package]] name = "pytz" version = "2023.3" @@ -2233,7 +2247,7 @@ test = ["pytest", "pytest-cov"] name = "starlette" version = "0.31.1" description = "The little ASGI library that shines." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "starlette-0.31.1-py3-none-any.whl", hash = "sha256:009fb98ecd551a55017d204f033c58b13abcd4719cb5c41503abbf6d260fde11"}, @@ -2483,4 +2497,4 @@ starlette = ["starlette"] [metadata] lock-version = "2.0" python-versions = "^3.8.0" -content-hash = "80ad9a19a5925d231dfd01e7d7f5637190b54daa5c925e8431ae5cd69333ec25" +content-hash = "a7531f5fc9b00339130780ef4d6592133bb78e913e392b387f1acc7d01e1af20" diff --git a/pyproject.toml b/pyproject.toml index 7fb5530b..50fbd1d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,9 @@ pre-commit = "*" pytest = "^7" pytest-flake8 = "*" pytest-cov = "*" +python-multipart = "*" responses = "*" +starlette = ">=0.26.1,<0.32.0" strict-rfc3339 = "^0.7" webob = "*" mypy = "^1.2" diff --git a/tests/integration/contrib/starlette/data/v3.0/starletteproject/__main__.py b/tests/integration/contrib/starlette/data/v3.0/starletteproject/__main__.py index bf1b0e7a..be5623d8 100644 --- a/tests/integration/contrib/starlette/data/v3.0/starletteproject/__main__.py +++ b/tests/integration/contrib/starlette/data/v3.0/starletteproject/__main__.py @@ -1,7 +1,20 @@ from starlette.applications import Starlette +from starlette.middleware import Middleware from starlette.routing import Route +from starletteproject.openapi import spec from starletteproject.pets.endpoints import pet_photo_endpoint +from openapi_core.contrib.starlette.middlewares import ( + StarletteOpenAPIMiddleware, +) + +middleware = [ + Middleware( + StarletteOpenAPIMiddleware, + spec=spec, + ), +] + routes = [ Route( "/v1/pets/{petId}/photo", pet_photo_endpoint, methods=["GET", "POST"] @@ -10,5 +23,6 @@ app = Starlette( debug=True, + middleware=middleware, routes=routes, ) diff --git a/tests/integration/contrib/starlette/data/v3.0/starletteproject/pets/endpoints.py b/tests/integration/contrib/starlette/data/v3.0/starletteproject/pets/endpoints.py index 535da4e5..76710805 100644 --- a/tests/integration/contrib/starlette/data/v3.0/starletteproject/pets/endpoints.py +++ b/tests/integration/contrib/starlette/data/v3.0/starletteproject/pets/endpoints.py @@ -20,15 +20,16 @@ ) -def pet_photo_endpoint(request): - openapi_request = StarletteOpenAPIRequest(request) +async def pet_photo_endpoint(request): + body = await request.body() + openapi_request = StarletteOpenAPIRequest(request, body) request_unmarshalled = unmarshal_request(openapi_request, spec=spec) if request.method == "GET": response = StreamingResponse([OPENID_LOGO], media_type="image/gif") elif request.method == "POST": - with request.form() as form: + async with request.form() as form: filename = form["file"].filename - contents = form["file"].read() + contents = await form["file"].read() response = Response(status_code=201) openapi_response = StarletteOpenAPIResponse(response) response_unmarshalled = unmarshal_response(