diff --git a/src/ziplineio/app.py b/src/ziplineio/app.py index b4de303..b8c1e11 100644 --- a/src/ziplineio/app.py +++ b/src/ziplineio/app.py @@ -1,6 +1,8 @@ +from ast import Not from typing import Any, Callable, List, Tuple, Type +from ziplineio.exception import NotFoundHttpException from ziplineio.middleware import middleware, run_middleware_stack from ziplineio.dependency_injector import inject, injector, DependencyInjector from ziplineio import settings @@ -72,7 +74,7 @@ async def uvicorn_handler(scope: dict, receive: Any, send: Any) -> None: ) if res is None: - response = {"status": 404, "headers": [], "body": b"Not found"} + response = NotFoundHttpException() else: response = res diff --git a/src/ziplineio/exception.py b/src/ziplineio/exception.py index f4f5785..b8e8988 100644 --- a/src/ziplineio/exception.py +++ b/src/ziplineio/exception.py @@ -5,3 +5,8 @@ def __init__(self, message, status_code): def __len__(self): return 1 + + +class NotFoundHttpException(BaseHttpException): + def __init__(self, message="Not found"): + super().__init__(message, 404) diff --git a/test/test_e2e.py b/test/test_e2e.py index b222e9e..b0c8f57 100644 --- a/test/test_e2e.py +++ b/test/test_e2e.py @@ -13,11 +13,21 @@ app = App() -@app.get("/") -async def handler(req): +@app.get("/bytes") +async def bytes_handler(req): + return b"Hello, world!" + + +@app.get("/dict") +async def dict_handler(req): return {"message": "Hello, world!"} +@app.get("/str") +async def str_handler(req): + return "Hello, world!" + + # Will be made multithreaded @app.get("/sync-thread") def sync_handler(req): @@ -43,15 +53,30 @@ async def asyncSetUp(self): self.proc.start() await asyncio.sleep(0.2) # time for the server to start - async def test_basic_route(self): - response = requests.get("http://localhost:5050/") + async def test_handler_returns_bytes(self): + response = requests.get("http://localhost:5050/bytes") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b"Hello, world!") + + async def test_handler_returns_str(self): + response = requests.get("http://localhost:5050/str") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.text, "Hello, world!") + + async def test_handler_returns_dict(self): + response = requests.get("http://localhost:5050/dict") print(response.json()) self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"message": "Hello, world!"}) async def test_sync_route(self): response = requests.get("http://localhost:5050/sync-thread") print(response.content) self.assertEqual(response.status_code, 200) + async def test_404(self): + response = requests.get("http://localhost:5050/some-random-route") + self.assertEqual(response.status_code, 404) + async def asyncTearDown(self): self.proc.terminate()