Skip to content

Commit

Permalink
bugfix: middelware
Browse files Browse the repository at this point in the history
  • Loading branch information
CameronSima committed Aug 27, 2024
1 parent b1259e4 commit 53cc0b6
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 44 deletions.
18 changes: 10 additions & 8 deletions src/ziplineio/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,32 +68,34 @@ async def uvicorn_handler(scope: dict, receive: Any, send: Any) -> None:
# try running through middlewares

req, ctx, res = await run_middleware_stack(
self._router._router_level_middelwares, req, {}
self._router._router_level_middelwares, req, **{}
)

if res is not None:
response = format_response(res, settings.DEFAULT_HEADERS)
else:
if res is None:
response = {"status": 404, "headers": [], "body": b"Not found"}
else:
response = res

else:
response = await call_handler(handler, req)

raw_response = format_response(response, settings.DEFAULT_HEADERS)

print("SENDING RESPONSE")
print(response)
print(raw_response)

await send(
{
"type": "http.response.start",
"status": response["status"],
"headers": response["headers"],
"status": raw_response["status"],
"headers": raw_response["headers"],
}
)

await send(
{
"type": "http.response.body",
"body": response["body"],
"body": raw_response["body"],
}
)

Expand Down
3 changes: 3 additions & 0 deletions src/ziplineio/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@ class BaseHttpException(Exception):
def __init__(self, message, status_code):
self.message = message
self.status_code = status_code

def __len__(self):
return 1
2 changes: 1 addition & 1 deletion src/ziplineio/html/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def jinja(env: Any, template_name: str):

def decorator(handler):
async def wrapped_handler(*args, **kwargs):
context = await call_handler(handler, *args, **kwargs, format=False)
context = await call_handler(handler, *args, **kwargs)
rendered = template.render(context)
return JinjaResponse(rendered)

Expand Down
30 changes: 12 additions & 18 deletions src/ziplineio/middleware.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import re
import inspect
from typing import List, Callable, Tuple
from ziplineio import response
from ziplineio.handler import Handler
from ziplineio.request import Request
from ziplineio.response import Response
from ziplineio.utils import call_handler


def middleware(middlewares: List[Callable]) -> Callable[[Callable], Callable]:
Expand All @@ -13,14 +14,8 @@ async def wrapped_handler(req: Request, **kwargs):

kwargs.setdefault("ctx", {})
# Run the middleware stack
try:
req, kwargs, res = await run_middleware_stack(middlewares, req, kwargs)
except Exception as e:
return {
"status": 500,
"headers": [],
"body": b"Internal server error: " + str(e).encode(),
}

req, kwargs, res = await run_middleware_stack(middlewares, req, **kwargs)

if res is not None:
return res
Expand All @@ -33,24 +28,23 @@ async def wrapped_handler(req: Request, **kwargs):


async def run_middleware_stack(
middlewares: list[Handler], request: Request, kwargs
middlewares: list[Handler], request: Request, **kwargs
) -> Tuple[Request, dict, bytes | str | dict | Response | None]:
for middleware in middlewares:
# if the middleware func takes params, pass them in. Otherwise, just pass req

if "ctx" not in kwargs:
kwargs["ctx"] = {}

print("KWARGS")
print(kwargs)
print(middleware.__code__.co_varnames)
if len(middleware.__code__.co_varnames) > 1:
_res = await middleware(request, **kwargs)
sig = inspect.signature(middleware)
if len(sig.parameters) > 1:
_res = await call_handler(middleware, request, **kwargs)
else:
_res = await middleware(request)
_res = await call_handler(middleware, request)

print("RES")
print(_res)
# regular handlers return a response, but middleware can return a tuple
if not isinstance(_res, tuple):
_res = (_res, kwargs)

if len(_res) != 2:
req = _res
Expand Down
15 changes: 12 additions & 3 deletions src/ziplineio/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def format_headers(headers: Dict[str, str] | None) -> List[Tuple[bytes, bytes]]:


def format_body(body: bytes | str | dict) -> bytes:
print("BODY:")
print(body)
print(type(body))
if isinstance(body, bytes):
return body
elif isinstance(body, str):
Expand All @@ -57,7 +60,7 @@ def format_body(body: bytes | str | dict) -> bytes:


def format_response(
response: bytes | dict | str | Response, default_headers: dict[str, str]
response: bytes | dict | str | Response | Exception, default_headers: dict[str, str]
) -> RawResponse:
print("TYUUPE:")
print(type(response))
Expand Down Expand Up @@ -86,7 +89,6 @@ def format_response(
"body": bytes(json.dumps(response), "utf-8"),
}
elif isinstance(response, Response):
print("HRE")
raw_response = {
"headers": format_headers(response.headers),
"status": response.status,
Expand All @@ -100,8 +102,15 @@ def format_response(
"status": response.status_code,
"body": format_body(response.message),
}
elif isinstance(response, Exception):
raw_response = {
"headers": [
(b"content-type", b"text/plain"),
],
"status": 500,
"body": b"Internal server error: " + bytes(str(response), "utf-8"),
}
else:
print(response)
raise ValueError("Invalid response type")

raw_response["headers"].extend(format_headers(default_headers))
Expand Down
27 changes: 15 additions & 12 deletions src/ziplineio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,35 @@
import asyncio
from typing import Dict
from ziplineio.request import Request
from ziplineio.response import RawResponse, format_response
from ziplineio.response import RawResponse, Response, format_response
from ziplineio.handler import Handler
from ziplineio.models import ASGIScope
from ziplineio.exception import BaseHttpException
from ziplineio import settings


async def call_handler(
handler: Handler, req: Request, format: bool = True
) -> RawResponse:
handler: Handler,
req: Request,
**kwargs,
) -> bytes | str | dict | Response | Exception:
try:
if not inspect.iscoroutinefunction(handler):
response = await asyncio.to_thread(handler, req)
response = await asyncio.to_thread(handler, req, **kwargs)
else:
response = await handler(req)
response = await handler(req, **kwargs)

except BaseHttpException as e:
response = e
except Exception as e:
response = e
print("EXCEPTION")
print(e)
response = BaseHttpException(e, 500)
# except Exception as e:
# print(e)
# response = BaseHttpException(e, 500)

if format:
return format_response(response, settings.DEFAULT_HEADERS)
else:
return response
print(f"response in call handler: {response}")

return response


def parse_scope(scope: ASGIScope) -> Request:
Expand Down
4 changes: 4 additions & 0 deletions test/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest

from ziplineio import settings
from ziplineio.app import App
from ziplineio.exception import BaseHttpException
from ziplineio.response import format_response
from ziplineio.utils import call_handler


Expand All @@ -18,6 +20,7 @@ async def test_handler(req):

# Call the handler
response = await call_handler(test_handler, {})
response = format_response(response, settings.DEFAULT_HEADERS)
self.assertEqual(response["body"], b"Hey! We messed up")
self.assertEqual(response["status"], 409)

Expand All @@ -31,5 +34,6 @@ async def test_handler(req):

# Call the handler
response = await call_handler(test_handler, {})
response = format_response(response, settings.DEFAULT_HEADERS)
self.assertEqual(response["body"], b"Hey! We messed up bad")
self.assertEqual(response["status"], 402)
3 changes: 3 additions & 0 deletions test/test_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ async def asyncSetUp(self):
async def test_render_jinja(self):
response = requests.get("http://localhost:5050/")

print("Resooinse")
print(response.text)

self.assertEqual(response.status_code, 200)
self.assertEqual(response.headers["Content-Type"], "text/html")
self.assertTrue("<p>Welcome to the home page!</p>" in response.text)
Expand Down
13 changes: 11 additions & 2 deletions test/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from operator import call
import unittest
from ziplineio import settings
from ziplineio.request import Request
from ziplineio.app import App
from ziplineio.middleware import middleware
from ziplineio.response import format_response
from ziplineio.utils import call_handler


class TestMiddleware(unittest.IsolatedAsyncioTestCase):
Expand Down Expand Up @@ -79,7 +83,10 @@ async def test_handler_with_middleware(req: Request, ctx: dict):

# Call the route
handler, params = self.app._router.get_handler("GET", "/with-middleware")
response = await handler(req)
response = await call_handler(handler, req, format=True)
response = format_response(response, settings.DEFAULT_HEADERS)

print(f"response: {response}")

# Assertions
self.assertEqual(response["status"], 500)
Expand Down Expand Up @@ -110,7 +117,9 @@ async def test_handler_with_middleware(req: Request, ctx: dict):

# Call the route
handler, params = self.app._router.get_handler("GET", "/with-middleware")
response = await handler(req)
response = await call_handler(handler, req, format=False)

print(response)

# Assertions
self.assertEqual(response["message"], "Hi from middleware 2")
Expand Down

0 comments on commit 53cc0b6

Please sign in to comment.