diff --git a/src/ziplineio/app.py b/src/ziplineio/app.py index 2eefeac..35020d8 100644 --- a/src/ziplineio/app.py +++ b/src/ziplineio/app.py @@ -9,7 +9,7 @@ from ziplineio import settings from ziplineio.handler import Handler from ziplineio.request import Request -from ziplineio.response import NotFoundResponse, format_response +from ziplineio.response import Response, NotFoundResponse, format_response from ziplineio.router import Router from ziplineio.static import staticfiles from ziplineio.utils import call_handler, parse_scope @@ -103,8 +103,9 @@ async def _get_and_call_handler( return res if self._router._not_found_handler: - body = await call_handler(self._router._not_found_handler, req) - return NotFoundResponse(body) + response = await call_handler(self._router._not_found_handler, req) + headers = isinstance(response, Response) and response._headers or {} + return NotFoundResponse(response, headers) return NotFoundHttpException() @@ -131,42 +132,3 @@ async def uvicorn_handler(scope: dict, receive: Any, send: Any) -> None: ) return uvicorn_handler - - -app = App() - - -async def logging_middleware(req: Request) -> Tuple[Request, dict]: - print(f"Received request: {req.method} {req.path}") - # Pass along request and an empty context (no modifications) - return req, {} - - -async def auth_middleware(req: Request) -> Tuple[Request, dict]: - # Modify the context with some auth information - return req, {"auth": "Unauthorized"} - - -class LoggingService: - def log(self, message: str) -> None: - print(f"Logging: {message}") - - -@app.get("/") -@inject(LoggingService) -@middleware([logging_middleware, auth_middleware]) -async def handler(req: Request, ctx: dict, LoggingService: LoggingService) -> dict: - res = {"message": "Hello, world!"} - - # Use the injected LoggingService - LoggingService.log(f"Params: {req.query_params.get('bar')}") - - return {**res, **ctx} - - -@app.get("/foo/:bar") -async def handler2(req: Request) -> dict: - bar = req.path_params.get("bar") - res = {"message": f"Hello, {bar}!"} - - return res diff --git a/src/ziplineio/response.py b/src/ziplineio/response.py index 73bc786..073080e 100644 --- a/src/ziplineio/response.py +++ b/src/ziplineio/response.py @@ -55,8 +55,8 @@ def __init__(self, body: str): class NotFoundResponse(Response): - def __init__(self, body: str): - super().__init__(404, {}, body) + def __init__(self, body: str, headers: Dict[str, str] = {}): + super().__init__(404, headers, body) def format_headers(headers: Dict[str, str] | None) -> List[Tuple[bytes, bytes]]: diff --git a/test/test_e2e.py b/test/test_e2e.py index 9a2ea9a..b6ef4fb 100644 --- a/test/test_e2e.py +++ b/test/test_e2e.py @@ -111,6 +111,7 @@ async def test_404_jinja(self): response = requests.get("http://localhost:5050/some-random-route") self.assertEqual(response.status_code, 404) self.assertTrue("404" in response.text) + self.assertEqual(response.headers["Content-Type"], "text/html") async def asyncTearDown(self): self.proc.terminate()