-
-
Notifications
You must be signed in to change notification settings - Fork 132
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
257 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""OpenAPI core contrib starlette handlers module""" | ||
from typing import Any | ||
from typing import Callable | ||
from typing import Dict | ||
from typing import Iterable | ||
from typing import Optional | ||
from typing import Type | ||
|
||
from starlette.requests import Request | ||
from starlette.responses import JSONResponse | ||
from starlette.responses import Response | ||
|
||
from openapi_core.templating.media_types.exceptions import MediaTypeNotFound | ||
from openapi_core.templating.paths.exceptions import OperationNotFound | ||
from openapi_core.templating.paths.exceptions import PathNotFound | ||
from openapi_core.templating.paths.exceptions import ServerNotFound | ||
from openapi_core.templating.security.exceptions import SecurityNotFound | ||
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult | ||
|
||
|
||
class StarletteOpenAPIErrorsHandler: | ||
OPENAPI_ERROR_STATUS: Dict[Type[BaseException], int] = { | ||
ServerNotFound: 400, | ||
SecurityNotFound: 403, | ||
OperationNotFound: 405, | ||
PathNotFound: 404, | ||
MediaTypeNotFound: 415, | ||
} | ||
|
||
def __call__( | ||
self, | ||
errors: Iterable[Exception], | ||
) -> JSONResponse: | ||
data_errors = [self.format_openapi_error(err) for err in errors] | ||
data = { | ||
"errors": data_errors, | ||
} | ||
data_error_max = max(data_errors, key=self.get_error_status) | ||
return JSONResponse(data, status=data_error_max["status"]) | ||
|
||
@classmethod | ||
def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]: | ||
if error.__cause__ is not None: | ||
error = error.__cause__ | ||
return { | ||
"title": str(error), | ||
"status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400), | ||
"type": str(type(error)), | ||
} | ||
|
||
@classmethod | ||
def get_error_status(cls, error: Dict[str, Any]) -> str: | ||
return str(error["status"]) | ||
|
||
|
||
class StarletteOpenAPIValidRequestHandler: | ||
def __init__(self, request: Request, call_next: Callable[[Any], Response]): | ||
self.request = request | ||
self.call_next = call_next | ||
|
||
async def __call__( | ||
self, request_unmarshal_result: RequestUnmarshalResult | ||
) -> Response: | ||
self.request.openapi = request_unmarshal_result | ||
return await self.call_next(self.request) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
"""OpenAPI core contrib starlette middlewares module""" | ||
from typing import Callable | ||
|
||
from starlette.middleware.base import BaseHTTPMiddleware | ||
from starlette.requests import Request | ||
from starlette.responses import Response | ||
|
||
from openapi_core import Spec | ||
from openapi_core.contrib.starlette.handlers import StarletteOpenAPIErrorsHandler | ||
from openapi_core.contrib.starlette.handlers import ( | ||
StarletteOpenAPIValidRequestHandler, | ||
) | ||
from openapi_core.contrib.starlette.requests import StarletteOpenAPIRequest | ||
from openapi_core.contrib.starlette.responses import StarletteOpenAPIResponse | ||
from openapi_core.unmarshalling.processors import AsyncUnmarshallingProcessor | ||
|
||
|
||
class StarletteOpenAPIMiddleware( | ||
BaseHTTPMiddleware, | ||
AsyncUnmarshallingProcessor[Request, Response] | ||
): | ||
request_cls = StarletteOpenAPIRequest | ||
response_cls = StarletteOpenAPIResponse | ||
valid_request_handler_cls = StarletteOpenAPIValidRequestHandler | ||
errors_handler = StarletteOpenAPIErrorsHandler() | ||
|
||
def __init__(self, app, spec: Spec): | ||
BaseHTTPMiddleware.__init__(self, app) | ||
AsyncUnmarshallingProcessor.__init__(self, spec) | ||
|
||
async def dispatch(self, request: Request, call_next) -> Response: | ||
valid_request_handler = self.valid_request_handler_cls( | ||
request, call_next | ||
) | ||
response = await self.handle_request( | ||
request, valid_request_handler, self.errors_handler | ||
) | ||
return await self.handle_response(request, response, self.errors_handler) | ||
|
||
async def _get_openapi_request( | ||
self, request: Request | ||
) -> StarletteOpenAPIRequest: | ||
body = await request.body() | ||
return self.request_cls(request, body) | ||
|
||
async def _get_openapi_response( | ||
self, response: Response | ||
) -> StarletteOpenAPIResponse: | ||
assert self.response_cls is not None | ||
return self.response_cls(response) | ||
|
||
def _validate_response(self) -> bool: | ||
return self.response_cls is not None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters