diff --git a/src/ziplineio/response.py b/src/ziplineio/response.py index 073080e..8f04829 100644 --- a/src/ziplineio/response.py +++ b/src/ziplineio/response.py @@ -80,54 +80,42 @@ def format_body(body: bytes | str | dict) -> bytes: def format_response( response: bytes | dict | str | Response | Exception, default_headers: dict[str, str] ) -> RawResponse: + # Helper to merge and deduplicate headers + + # Format different response types if isinstance(response, bytes): - raw_response = { - "headers": [ - (b"content-type", b"text/plain"), - ], - "status": 200, - "body": response, - } + headers = [(b"content-type", b"text/plain")] + body = response + status = 200 + elif isinstance(response, str): - raw_response = { - "headers": [ - (b"content-type", b"text/plain"), - ], - "status": 200, - "body": bytes(response, "utf-8"), - } + headers = [(b"content-type", b"text/plain")] + body = response.encode("utf-8") + status = 200 + elif isinstance(response, dict): - raw_response = { - "headers": [ - (b"content-type", b"application/json"), - ], - "status": 200, - "body": bytes(json.dumps(response), "utf-8"), - } + headers = [(b"content-type", b"application/json")] + body = json.dumps(response).encode("utf-8") + status = 200 + elif isinstance(response, Response): - raw_response = { - "headers": format_headers(response.get_headers()), - "status": response.status, - "body": format_body(response.body), - } + headers = format_headers(response.get_headers()) + body = format_body(response.body) + status = response.status + elif isinstance(response, BaseHttpException): - raw_response = { - "headers": [ - (b"content-type", b"application/json"), - ], - "status": response.status_code, - "body": format_body(response.message), - } + headers = [(b"content-type", b"application/json")] + body = format_body(response.message) + status = response.status_code + elif isinstance(response, Exception): - raw_response = { - "headers": [ - (b"content-type", b"text/plain"), - ], - "status": 500, - "body": b"Internal server error: " + bytes(str(response), "utf-8"), - } + headers = [(b"content-type", b"text/plain")] + body = f"Internal server error: {str(response)}".encode("utf-8") + status = 500 + else: raise ValueError("Invalid response type") - raw_response["headers"].extend(format_headers(default_headers)) - return raw_response + default_headers = format_headers(default_headers) + + return {"headers": default_headers + headers, "status": status, "body": body} diff --git a/src/ziplineio/utils.py b/src/ziplineio/utils.py index edc9a2b..27ebf85 100644 --- a/src/ziplineio/utils.py +++ b/src/ziplineio/utils.py @@ -11,14 +11,6 @@ from ziplineio.models import ASGIScope -def get_func_params(func: Handler): - return inspect.signature(func).parameters - - -def get_class_params(cls): - return inspect.signature(cls.__init__).parameters - - """ Only pass the kwargs that are required by the handler function. """ diff --git a/test/test_e2e.py b/test/test_e2e.py index 24115f5..4f2381f 100644 --- a/test/test_e2e.py +++ b/test/test_e2e.py @@ -96,7 +96,7 @@ async def check_server(): return True except httpx.RequestError: pass - await asyncio.sleep(0.2) # Short sleep between retries + await asyncio.sleep(0.1) # Short sleep between retries try: await asyncio.wait_for(check_server(), timeout=timeout) diff --git a/test/test_response.py b/test/test_response.py index 9214a11..5bfd3c1 100644 --- a/test/test_response.py +++ b/test/test_response.py @@ -1,8 +1,14 @@ import unittest from ziplineio.app import App +from ziplineio.exception import BaseHttpException from ziplineio.request import Request -from ziplineio.response import StaticFileResponse +from ziplineio.response import ( + Response, + StaticFileResponse, + format_body, + format_response, +) class TestResponse(unittest.IsolatedAsyncioTestCase): @@ -23,3 +29,81 @@ async def test_static_file(self): self.assertEqual(r.status, 200) self.assertEqual(r.get_headers()["Content-Type"], "text/css") self.assertTrue(b"background-color: #f0f0f0;" in r.body) + + +class TestFormatBody(unittest.TestCase): + def test_format_body_dict(self): + body = format_body({"message": "Hello, world!"}) + self.assertEqual(body, b'{"message": "Hello, world!"}') + + def test_format_body_bytes(self): + body = format_body(b"Hello, world!") + self.assertEqual(body, b"Hello, world!") + + def test_format_body_str(self): + body = format_body("Hello, world!") + self.assertEqual(body, b"Hello, world!") + + +class TestFormatResponse(unittest.TestCase): + def test_format_bytes(self): + response = b"Hello, world!" + headers = {} + formatted = format_response(response, headers) + + self.assertEqual(formatted["status"], 200) + self.assertEqual(formatted["headers"], [(b"content-type", b"text/plain")]) + self.assertEqual(formatted["body"], response) + + def test_format_str(self): + response = "Hello, world!" + headers = {} + formatted = format_response(response, headers) + + self.assertEqual(formatted["status"], 200) + self.assertEqual(formatted["headers"], [(b"content-type", b"text/plain")]) + self.assertEqual(formatted["body"], b"Hello, world!") + + def test_format_dict(self): + response = {"message": "Hello, world!"} + headers = {} + formatted = format_response(response, headers) + + self.assertEqual(formatted["status"], 200) + self.assertEqual(formatted["headers"], [(b"content-type", b"application/json")]) + self.assertEqual(formatted["body"], b'{"message": "Hello, world!"}') + + def test_format_response_object(self): + response = Response(200, {"Content-Type": "CUSTOM-TYPE"}, "Hello, world!") + + formatted = format_response(response, {}) + self.assertEqual(formatted["status"], 200) + self.assertEqual(formatted["headers"], [(b"Content-Type", b"CUSTOM-TYPE")]) + self.assertEqual(formatted["body"], b"Hello, world!") + + def test_format_exception(self): + response = Exception("Hello, world!") + headers = {} + formatted = format_response(response, headers) + + self.assertEqual(formatted["status"], 500) + self.assertEqual(formatted["headers"], [(b"content-type", b"text/plain")]) + self.assertEqual(formatted["body"], b"Internal server error: Hello, world!") + + def test_format_http_exception(self): + response = BaseHttpException("Hello, world!", 444) + headers = {} + formatted = format_response(response, headers) + + self.assertEqual(formatted["status"], 444) + self.assertEqual(formatted["headers"], [(b"content-type", b"application/json")]) + self.assertEqual(formatted["body"], b"Hello, world!") + + def test_format_bytes_with_headers(self): + response = b"Hello, world!" + headers = {} + formatted = format_response(response, headers) + + self.assertEqual(formatted["status"], 200) + self.assertEqual(formatted["headers"], [(b"content-type", b"text/plain")]) + self.assertEqual(formatted["body"], response) diff --git a/test/test_utils.py b/test/test_utils.py index 8561adc..d784d51 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -17,3 +17,19 @@ def test_match_url_pattern(self): # Assertions self.assertEqual(parsed, {"id": "123"}) + + def test_parse_scope(self): + scope = { + "query_string": b"", + "path": "/user/12", + "headers": [(b"host", b"localhost")], + "method": "GET", + } + parsed = utils.parse_scope(scope) + + # Assertions + self.assertEqual(parsed.method, "GET") + self.assertEqual(parsed.path, "/user/12") + self.assertEqual(parsed.query_params, {}) + self.assertEqual(parsed.path_params, {}) + self.assertEqual(parsed.headers, {"host": "localhost"})