Skip to content

Commit

Permalink
factor out middleware to better control imports
Browse files Browse the repository at this point in the history
  • Loading branch information
bisgaard-itis committed Sep 14, 2023
1 parent ed10118 commit 96553f6
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 62 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import json

from fastapi import FastAPI
from pyinstrument import Profiler
from starlette.requests import Request


def _generate_response_headers(content: bytes) -> list[tuple[bytes, bytes]]:
headers: dict = dict()
headers[b"content-length"] = str(len(content)).encode("utf8")
headers[b"content-type"] = b"application/json"
return list(headers.items())


class ApiServerProfilerMiddleware:
"""Following
https://www.starlette.io/middleware/#cleanup-and-error-handling
https://www.starlette.io/middleware/#reusing-starlette-components
https://fastapi.tiangolo.com/advanced/middleware/#advanced-middleware
"""

def __init__(self, app: FastAPI):
self._app: FastAPI = app
self._profile_header_trigger: str = "x-profile-api-server"

async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self._app(scope, receive, send)
return

profiler = Profiler(async_mode="enabled")
request: Request = Request(scope)
headers = dict(request.headers)
if self._profile_header_trigger in headers:
headers.pop(self._profile_header_trigger)
scope["headers"] = [
(k.encode("utf8"), v.encode("utf8")) for k, v in headers.items()
]
profiler.start()

async def send_wrapper(message):
if profiler.is_running:
profiler.stop()
if profiler.last_session:
body: bytes = json.dumps(
{"profile": profiler.output_text(unicode=True, color=True)}
).encode("utf8")
if message["type"] == "http.response.start":
message["headers"] = _generate_response_headers(body)
elif message["type"] == "http.response.body":
message["body"] = body
await send(message)

await self._app(scope, receive, send_wrapper)
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging

from fastapi import FastAPI
Expand All @@ -24,70 +23,15 @@
from .openapi import override_openapi_method, use_route_names_as_operation_ids
from .settings import ApplicationSettings

_settings: ApplicationSettings = ApplicationSettings.create_from_envs()

if _settings.API_SERVER_DEV_FEATURES_ENABLED:
from pyinstrument import Profiler
from starlette.requests import Request


_logger = logging.getLogger(__name__)


def _generate_response_headers(content: bytes) -> list[tuple[bytes, bytes]]:
headers: dict = dict()
headers[b"content-length"] = str(len(content)).encode("utf8")
headers[b"content-type"] = b"application/json"
return list(headers.items())


class ApiServerProfilerMiddleware:
"""Following
https://www.starlette.io/middleware/#cleanup-and-error-handling
https://www.starlette.io/middleware/#reusing-starlette-components
https://fastapi.tiangolo.com/advanced/middleware/#advanced-middleware
"""

def __init__(self, app: FastAPI):
self._app: FastAPI = app

async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self._app(scope, receive, send)
return

profiler = Profiler(async_mode="enabled")
request: Request = Request(scope)
headers = dict(request.headers)
if "x-profile-api-server" in headers:
headers.pop("x-profile-api-server")
scope["headers"] = [
(k.encode("utf8"), v.encode("utf8")) for k, v in headers.items()
]
profiler.start()

async def send_wrapper(message):
if profiler.is_running:
profiler.stop()
if profiler.last_session:
body: bytes = json.dumps(
{"profile": profiler.output_text(unicode=True, color=True)}
).encode("utf8")
if message["type"] == "http.response.start":
message["headers"] = _generate_response_headers(body)
elif message["type"] == "http.response.body":
message["body"] = body
await send(message)

await self._app(scope, receive, send_wrapper)


def _label_info_with_state(title: str, version: str):
def _label_info_with_state(settings: ApplicationSettings, title: str, version: str):
labels = []
if _settings.API_SERVER_DEV_FEATURES_ENABLED:
if settings.API_SERVER_DEV_FEATURES_ENABLED:
labels.append("dev")

if _settings.debug:
if settings.debug:
labels.append("debug")

if suffix_label := "+".join(labels):
Expand All @@ -97,8 +41,9 @@ def _label_info_with_state(title: str, version: str):
return title, version


def init_app() -> FastAPI:
settings = ApplicationSettings.create_from_envs()
def init_app(settings: ApplicationSettings | None = None) -> FastAPI:
if settings is None:
settings = ApplicationSettings.create_from_envs()
assert settings # nosec

logging.basicConfig(level=settings.log_level)
Expand All @@ -110,7 +55,7 @@ def init_app() -> FastAPI:
title = "osparc.io web API"
version = API_VERSION
description = "osparc-simcore public API specifications"
title, version = _label_info_with_state(title, version)
title, version = _label_info_with_state(settings, title, version)

# creates app instance
app = FastAPI(
Expand Down Expand Up @@ -169,6 +114,8 @@ def init_app() -> FastAPI:
),
)
if settings.API_SERVER_DEV_FEATURES_ENABLED:
from ._profiler_middleware import ApiServerProfilerMiddleware

app.add_middleware(ApiServerProfilerMiddleware)

# routing
Expand Down

0 comments on commit 96553f6

Please sign in to comment.