From a333329ac6a2ee824632b14a6cde09c9159a088f Mon Sep 17 00:00:00 2001 From: tarsil Date: Mon, 21 Aug 2023 09:58:32 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=AA=B2=20Fix=20regression=20in=20OpenAPI?= =?UTF-8?q?=20with=20middleware?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- esmerald/openapi/openapi.py | 15 ++++++++++++++- esmerald/routing/router.py | 2 +- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/esmerald/openapi/openapi.py b/esmerald/openapi/openapi.py index a304a8a7..aef74681 100644 --- a/esmerald/openapi/openapi.py +++ b/esmerald/openapi/openapi.py @@ -6,6 +6,7 @@ from pydantic import AnyUrl from pydantic.fields import FieldInfo from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue +from starlette.middleware import Middleware from starlette.routing import BaseRoute from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY from typing_extensions import Literal @@ -335,6 +336,9 @@ def should_include_in_schema(route: router.Include) -> bool: if not route.include_in_schema: return False + if not is_middleware_app(route): + return True + if ( isinstance(route.app, (Esmerald, ChildEsmerald)) or ( @@ -355,6 +359,15 @@ def should_include_in_schema(route: router.Include) -> bool: return True +def is_middleware_app(route: router.Include) -> bool: + """ + Checks if the app is a middleware or a router + """ + from esmerald import MiddlewareProtocol + + return bool(isinstance(route.app, (Middleware, MiddlewareProtocol))) + + def get_openapi( *, app: Any, @@ -427,7 +440,7 @@ def iterate_routes( continue # For external middlewares - if getattr(route.app, "routes", None) is None: + if getattr(route.app, "routes", None) is None and not is_middleware_app(route): continue if hasattr(route, "app") and isinstance(route.app, (Esmerald, ChildEsmerald)): diff --git a/esmerald/routing/router.py b/esmerald/routing/router.py index c09d42d2..037b9fb3 100644 --- a/esmerald/routing/router.py +++ b/esmerald/routing/router.py @@ -1075,7 +1075,7 @@ def __init__( super().__init__( self.path, - app=self.app, + app=app, routes=routes, name=name, middleware=cast("Sequence[StarletteMiddleware]", include_middleware),