Skip to content

Commit

Permalink
Starlette middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
p1c2u committed Sep 26, 2023
1 parent 1eff5a2 commit 41d8f03
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 15 deletions.
65 changes: 65 additions & 0 deletions openapi_core/contrib/starlette/handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""OpenAPI core contrib starlette handlers module"""
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Type

from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.responses import Response

from openapi_core.templating.media_types.exceptions import MediaTypeNotFound
from openapi_core.templating.paths.exceptions import OperationNotFound
from openapi_core.templating.paths.exceptions import PathNotFound
from openapi_core.templating.paths.exceptions import ServerNotFound
from openapi_core.templating.security.exceptions import SecurityNotFound
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult


class StarletteOpenAPIErrorsHandler:
OPENAPI_ERROR_STATUS: Dict[Type[BaseException], int] = {
ServerNotFound: 400,
SecurityNotFound: 403,
OperationNotFound: 405,
PathNotFound: 404,
MediaTypeNotFound: 415,
}

def __call__(
self,
errors: Iterable[Exception],
) -> JSONResponse:
data_errors = [self.format_openapi_error(err) for err in errors]
data = {
"errors": data_errors,
}
data_error_max = max(data_errors, key=self.get_error_status)
return JSONResponse(data, status=data_error_max["status"])

@classmethod
def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
if error.__cause__ is not None:
error = error.__cause__
return {
"title": str(error),
"status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400),
"type": str(type(error)),
}

@classmethod
def get_error_status(cls, error: Dict[str, Any]) -> str:
return str(error["status"])


class StarletteOpenAPIValidRequestHandler:
def __init__(self, request: Request, call_next: Callable[[Any], Response]):
self.request = request
self.call_next = call_next

async def __call__(
self, request_unmarshal_result: RequestUnmarshalResult
) -> Response:
self.request.openapi = request_unmarshal_result
return await self.call_next(self.request)
53 changes: 53 additions & 0 deletions openapi_core/contrib/starlette/middlewares.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""OpenAPI core contrib starlette middlewares module"""
from typing import Callable

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response

from openapi_core import Spec
from openapi_core.contrib.starlette.handlers import StarletteOpenAPIErrorsHandler
from openapi_core.contrib.starlette.handlers import (
StarletteOpenAPIValidRequestHandler,
)
from openapi_core.contrib.starlette.requests import StarletteOpenAPIRequest
from openapi_core.contrib.starlette.responses import StarletteOpenAPIResponse
from openapi_core.unmarshalling.processors import AsyncUnmarshallingProcessor


class StarletteOpenAPIMiddleware(
BaseHTTPMiddleware,
AsyncUnmarshallingProcessor[Request, Response]
):
request_cls = StarletteOpenAPIRequest
response_cls = StarletteOpenAPIResponse
valid_request_handler_cls = StarletteOpenAPIValidRequestHandler
errors_handler = StarletteOpenAPIErrorsHandler()

def __init__(self, app, spec: Spec):
BaseHTTPMiddleware.__init__(self, app)
AsyncUnmarshallingProcessor.__init__(self, spec)

async def dispatch(self, request: Request, call_next) -> Response:
valid_request_handler = self.valid_request_handler_cls(
request, call_next
)
response = await self.handle_request(
request, valid_request_handler, self.errors_handler
)
return await self.handle_response(request, response, self.errors_handler)

async def _get_openapi_request(
self, request: Request
) -> StarletteOpenAPIRequest:
body = await request.body()
return self.request_cls(request, body)

async def _get_openapi_response(
self, response: Response
) -> StarletteOpenAPIResponse:
assert self.response_cls is not None
return self.response_cls(response)

def _validate_response(self) -> bool:
return self.response_cls is not None
12 changes: 3 additions & 9 deletions openapi_core/contrib/starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class StarletteOpenAPIRequest:
def __init__(self, request: Request):
def __init__(self, request: Request, body: str | None):
if not isinstance(request, Request):
raise TypeError(f"'request' argument is not type of {Request}")
self.request = request
Expand All @@ -19,7 +19,7 @@ def __init__(self, request: Request):
cookie=self.request.cookies,
)

self._get_body = AsyncToSync(self.request.body, force_new_loop=True)
self._body = body

@property
def host_url(self) -> str:
Expand All @@ -35,13 +35,7 @@ def method(self) -> str:

@property
def body(self) -> Optional[str]:
body = self._get_body()
if body is None:
return None
if isinstance(body, bytes):
return body.decode("utf-8")
assert isinstance(body, str)
return body
return self._body

@property
def mimetype(self) -> str:
Expand Down
77 changes: 77 additions & 0 deletions openapi_core/unmarshalling/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ResponseUnmarshallingProcessor,
)
from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType
from openapi_core.util import force_async


class UnmarshallingProcessor(Generic[RequestType, ResponseType]):
Expand Down Expand Up @@ -89,3 +90,79 @@ def handle_response(
if response_unmarshal_result.errors:
return errors_handler(response_unmarshal_result.errors)
return response


class AsyncUnmarshallingProcessor(Generic[RequestType, ResponseType]):
def __init__(
self,
spec: Spec,
request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None,
response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None,
**unmarshaller_kwargs: Any,
):
if (
request_unmarshaller_cls is None
or response_unmarshaller_cls is None
):
classes = get_classes(spec)
if request_unmarshaller_cls is None:
request_unmarshaller_cls = classes.request_unmarshaller_cls
if response_unmarshaller_cls is None:
response_unmarshaller_cls = classes.response_unmarshaller_cls

self.request_processor = RequestUnmarshallingProcessor(
spec,
request_unmarshaller_cls,
**unmarshaller_kwargs,
)
self.response_processor = ResponseUnmarshallingProcessor(
spec,
response_unmarshaller_cls,
**unmarshaller_kwargs,
)

def _get_openapi_request(self, request: RequestType) -> Request:
raise NotImplementedError

def _get_openapi_response(self, response: ResponseType) -> Response:
raise NotImplementedError

def _validate_response(self) -> bool:
raise NotImplementedError

async def handle_request(
self,
request: RequestType,
valid_handler: ValidRequestHandlerCallable[ResponseType],
errors_handler: ErrorsHandlerCallable[ResponseType],
) -> ResponseType:
awaitable_get_openapi_request = force_async(self._get_openapi_request)
openapi_request = await awaitable_get_openapi_request(request)
request_unmarshal_result = self.request_processor.process(
openapi_request
)
if request_unmarshal_result.errors:
return errors_handler(request_unmarshal_result.errors)
awaitable_valid_handler = force_async(valid_handler)
return await awaitable_valid_handler(request_unmarshal_result)

async def handle_response(
self,
request: RequestType,
response: ResponseType,
errors_handler: ErrorsHandlerCallable[ResponseType],
) -> ResponseType:
if not self._validate_response():
return response
awaitable_get_openapi_request = force_async(self._get_openapi_request)
openapi_request = await awaitable_get_openapi_request(request)
awaitable_get_openapi_response = force_async(
self._get_openapi_response
)
openapi_response = await awaitable_get_openapi_response(response)
response_unmarshal_result = self.response_processor.process(
openapi_request, openapi_response
)
if response_unmarshal_result.errors:
return errors_handler(response_unmarshal_result.errors)
return response
22 changes: 22 additions & 0 deletions openapi_core/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""OpenAPI core util module"""
from asyncio import iscoroutinefunction
from functools import partial
from functools import wraps
from itertools import chain
from typing import Any
from typing import Iterable
Expand All @@ -20,3 +23,22 @@ def forcebool(val: Any) -> bool:
def chainiters(*lists: Iterable[Any]) -> Iterable[Any]:
iters = map(lambda l: l and iter(l) or [], lists)
return chain(*iters)


def is_async_callable(obj: Any) -> Any:
while isinstance(obj, partial):
obj = obj.func

return iscoroutinefunction(obj) or (
callable(obj) and iscoroutinefunction(obj.__call__)
)


def force_async(func):
async def decorated(*args, **kwargs):
if is_async_callable(func):
return await func(*args, **kwargs)
else:
return await run_in_threadpool(func, *args, **kwargs)

return decorated
18 changes: 16 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ pre-commit = "*"
pytest = "^7"
pytest-flake8 = "*"
pytest-cov = "*"
python-multipart = "*"
responses = "*"
starlette = ">=0.26.1,<0.32.0"
strict-rfc3339 = "^0.7"
webob = "*"
mypy = "^1.2"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.routing import Route
from starletteproject.openapi import spec
from starletteproject.pets.endpoints import pet_photo_endpoint

from openapi_core.contrib.starlette.middlewares import (
StarletteOpenAPIMiddleware,
)

middleware = [
Middleware(
StarletteOpenAPIMiddleware,
spec=spec,
),
]

routes = [
Route(
"/v1/pets/{petId}/photo", pet_photo_endpoint, methods=["GET", "POST"]
Expand All @@ -10,5 +23,6 @@

app = Starlette(
debug=True,
middleware=middleware,
routes=routes,
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
)


def pet_photo_endpoint(request):
openapi_request = StarletteOpenAPIRequest(request)
async def pet_photo_endpoint(request):
body = await request.body()
openapi_request = StarletteOpenAPIRequest(request, body)
request_unmarshalled = unmarshal_request(openapi_request, spec=spec)
if request.method == "GET":
response = StreamingResponse([OPENID_LOGO], media_type="image/gif")
elif request.method == "POST":
with request.form() as form:
async with request.form() as form:
filename = form["file"].filename
contents = form["file"].read()
contents = await form["file"].read()
response = Response(status_code=201)
openapi_response = StarletteOpenAPIResponse(response)
response_unmarshalled = unmarshal_response(
Expand Down

0 comments on commit 41d8f03

Please sign in to comment.