From 53cc0b644528ddbe2e9db3d09ac271ac329546f1 Mon Sep 17 00:00:00 2001 From: Cam Sima Date: Tue, 27 Aug 2024 19:17:31 -0400 Subject: [PATCH] bugfix: middelware --- src/ziplineio/app.py | 18 ++++++++++-------- src/ziplineio/exception.py | 3 +++ src/ziplineio/html/jinja.py | 2 +- src/ziplineio/middleware.py | 30 ++++++++++++------------------ src/ziplineio/response.py | 15 ++++++++++++--- src/ziplineio/utils.py | 27 +++++++++++++++------------ test/test_exceptions.py | 4 ++++ test/test_html.py | 3 +++ test/test_middleware.py | 13 +++++++++++-- 9 files changed, 71 insertions(+), 44 deletions(-) diff --git a/src/ziplineio/app.py b/src/ziplineio/app.py index b02276f..b4de303 100644 --- a/src/ziplineio/app.py +++ b/src/ziplineio/app.py @@ -68,32 +68,34 @@ async def uvicorn_handler(scope: dict, receive: Any, send: Any) -> None: # try running through middlewares req, ctx, res = await run_middleware_stack( - self._router._router_level_middelwares, req, {} + self._router._router_level_middelwares, req, **{} ) - if res is not None: - response = format_response(res, settings.DEFAULT_HEADERS) - else: + if res is None: response = {"status": 404, "headers": [], "body": b"Not found"} + else: + response = res else: response = await call_handler(handler, req) + raw_response = format_response(response, settings.DEFAULT_HEADERS) + print("SENDING RESPONSE") - print(response) + print(raw_response) await send( { "type": "http.response.start", - "status": response["status"], - "headers": response["headers"], + "status": raw_response["status"], + "headers": raw_response["headers"], } ) await send( { "type": "http.response.body", - "body": response["body"], + "body": raw_response["body"], } ) diff --git a/src/ziplineio/exception.py b/src/ziplineio/exception.py index 17312dc..f4f5785 100644 --- a/src/ziplineio/exception.py +++ b/src/ziplineio/exception.py @@ -2,3 +2,6 @@ class BaseHttpException(Exception): def __init__(self, message, status_code): self.message = message self.status_code = status_code + + def __len__(self): + return 1 diff --git a/src/ziplineio/html/jinja.py b/src/ziplineio/html/jinja.py index de46e81..9daa08a 100644 --- a/src/ziplineio/html/jinja.py +++ b/src/ziplineio/html/jinja.py @@ -9,7 +9,7 @@ def jinja(env: Any, template_name: str): def decorator(handler): async def wrapped_handler(*args, **kwargs): - context = await call_handler(handler, *args, **kwargs, format=False) + context = await call_handler(handler, *args, **kwargs) rendered = template.render(context) return JinjaResponse(rendered) diff --git a/src/ziplineio/middleware.py b/src/ziplineio/middleware.py index 8003ce3..595bbc6 100644 --- a/src/ziplineio/middleware.py +++ b/src/ziplineio/middleware.py @@ -1,9 +1,10 @@ -import re +import inspect from typing import List, Callable, Tuple from ziplineio import response from ziplineio.handler import Handler from ziplineio.request import Request from ziplineio.response import Response +from ziplineio.utils import call_handler def middleware(middlewares: List[Callable]) -> Callable[[Callable], Callable]: @@ -13,14 +14,8 @@ async def wrapped_handler(req: Request, **kwargs): kwargs.setdefault("ctx", {}) # Run the middleware stack - try: - req, kwargs, res = await run_middleware_stack(middlewares, req, kwargs) - except Exception as e: - return { - "status": 500, - "headers": [], - "body": b"Internal server error: " + str(e).encode(), - } + + req, kwargs, res = await run_middleware_stack(middlewares, req, **kwargs) if res is not None: return res @@ -33,7 +28,7 @@ async def wrapped_handler(req: Request, **kwargs): async def run_middleware_stack( - middlewares: list[Handler], request: Request, kwargs + middlewares: list[Handler], request: Request, **kwargs ) -> Tuple[Request, dict, bytes | str | dict | Response | None]: for middleware in middlewares: # if the middleware func takes params, pass them in. Otherwise, just pass req @@ -41,16 +36,15 @@ async def run_middleware_stack( if "ctx" not in kwargs: kwargs["ctx"] = {} - print("KWARGS") - print(kwargs) - print(middleware.__code__.co_varnames) - if len(middleware.__code__.co_varnames) > 1: - _res = await middleware(request, **kwargs) + sig = inspect.signature(middleware) + if len(sig.parameters) > 1: + _res = await call_handler(middleware, request, **kwargs) else: - _res = await middleware(request) + _res = await call_handler(middleware, request) - print("RES") - print(_res) + # regular handlers return a response, but middleware can return a tuple + if not isinstance(_res, tuple): + _res = (_res, kwargs) if len(_res) != 2: req = _res diff --git a/src/ziplineio/response.py b/src/ziplineio/response.py index dcc2ac7..c46d702 100644 --- a/src/ziplineio/response.py +++ b/src/ziplineio/response.py @@ -47,6 +47,9 @@ def format_headers(headers: Dict[str, str] | None) -> List[Tuple[bytes, bytes]]: def format_body(body: bytes | str | dict) -> bytes: + print("BODY:") + print(body) + print(type(body)) if isinstance(body, bytes): return body elif isinstance(body, str): @@ -57,7 +60,7 @@ def format_body(body: bytes | str | dict) -> bytes: def format_response( - response: bytes | dict | str | Response, default_headers: dict[str, str] + response: bytes | dict | str | Response | Exception, default_headers: dict[str, str] ) -> RawResponse: print("TYUUPE:") print(type(response)) @@ -86,7 +89,6 @@ def format_response( "body": bytes(json.dumps(response), "utf-8"), } elif isinstance(response, Response): - print("HRE") raw_response = { "headers": format_headers(response.headers), "status": response.status, @@ -100,8 +102,15 @@ def format_response( "status": response.status_code, "body": format_body(response.message), } + elif isinstance(response, Exception): + raw_response = { + "headers": [ + (b"content-type", b"text/plain"), + ], + "status": 500, + "body": b"Internal server error: " + bytes(str(response), "utf-8"), + } else: - print(response) raise ValueError("Invalid response type") raw_response["headers"].extend(format_headers(default_headers)) diff --git a/src/ziplineio/utils.py b/src/ziplineio/utils.py index bfd6bdd..609616c 100644 --- a/src/ziplineio/utils.py +++ b/src/ziplineio/utils.py @@ -3,7 +3,7 @@ import asyncio from typing import Dict from ziplineio.request import Request -from ziplineio.response import RawResponse, format_response +from ziplineio.response import RawResponse, Response, format_response from ziplineio.handler import Handler from ziplineio.models import ASGIScope from ziplineio.exception import BaseHttpException @@ -11,24 +11,27 @@ async def call_handler( - handler: Handler, req: Request, format: bool = True -) -> RawResponse: + handler: Handler, + req: Request, + **kwargs, +) -> bytes | str | dict | Response | Exception: try: if not inspect.iscoroutinefunction(handler): - response = await asyncio.to_thread(handler, req) + response = await asyncio.to_thread(handler, req, **kwargs) else: - response = await handler(req) + response = await handler(req, **kwargs) - except BaseHttpException as e: - response = e except Exception as e: + response = e + print("EXCEPTION") print(e) - response = BaseHttpException(e, 500) + # except Exception as e: + # print(e) + # response = BaseHttpException(e, 500) - if format: - return format_response(response, settings.DEFAULT_HEADERS) - else: - return response + print(f"response in call handler: {response}") + + return response def parse_scope(scope: ASGIScope) -> Request: diff --git a/test/test_exceptions.py b/test/test_exceptions.py index 7a65be1..5719914 100644 --- a/test/test_exceptions.py +++ b/test/test_exceptions.py @@ -1,7 +1,9 @@ import unittest +from ziplineio import settings from ziplineio.app import App from ziplineio.exception import BaseHttpException +from ziplineio.response import format_response from ziplineio.utils import call_handler @@ -18,6 +20,7 @@ async def test_handler(req): # Call the handler response = await call_handler(test_handler, {}) + response = format_response(response, settings.DEFAULT_HEADERS) self.assertEqual(response["body"], b"Hey! We messed up") self.assertEqual(response["status"], 409) @@ -31,5 +34,6 @@ async def test_handler(req): # Call the handler response = await call_handler(test_handler, {}) + response = format_response(response, settings.DEFAULT_HEADERS) self.assertEqual(response["body"], b"Hey! We messed up bad") self.assertEqual(response["status"], 402) diff --git a/test/test_html.py b/test/test_html.py index 104dd42..c578648 100644 --- a/test/test_html.py +++ b/test/test_html.py @@ -34,6 +34,9 @@ async def asyncSetUp(self): async def test_render_jinja(self): response = requests.get("http://localhost:5050/") + print("Resooinse") + print(response.text) + self.assertEqual(response.status_code, 200) self.assertEqual(response.headers["Content-Type"], "text/html") self.assertTrue("

Welcome to the home page!

" in response.text) diff --git a/test/test_middleware.py b/test/test_middleware.py index af7197c..dbaee19 100644 --- a/test/test_middleware.py +++ b/test/test_middleware.py @@ -1,7 +1,11 @@ +from operator import call import unittest +from ziplineio import settings from ziplineio.request import Request from ziplineio.app import App from ziplineio.middleware import middleware +from ziplineio.response import format_response +from ziplineio.utils import call_handler class TestMiddleware(unittest.IsolatedAsyncioTestCase): @@ -79,7 +83,10 @@ async def test_handler_with_middleware(req: Request, ctx: dict): # Call the route handler, params = self.app._router.get_handler("GET", "/with-middleware") - response = await handler(req) + response = await call_handler(handler, req, format=True) + response = format_response(response, settings.DEFAULT_HEADERS) + + print(f"response: {response}") # Assertions self.assertEqual(response["status"], 500) @@ -110,7 +117,9 @@ async def test_handler_with_middleware(req: Request, ctx: dict): # Call the route handler, params = self.app._router.get_handler("GET", "/with-middleware") - response = await handler(req) + response = await call_handler(handler, req, format=False) + + print(response) # Assertions self.assertEqual(response["message"], "Hi from middleware 2")