Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
CameronSima committed Aug 30, 2024
1 parent f8bc3f6 commit 8c20541
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 43 deletions.
29 changes: 14 additions & 15 deletions src/ziplineio/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_handler(self, method: str, path: str) -> Tuple[Handler, dict]:
]

for name, service in app_level_deps.items():
if name in filtered_kwargs_names:
if "kwargs" in sig.parameters.keys() or name in filtered_kwargs_names:
handler = inject(service, name)(handler)

return handler, params
Expand All @@ -79,23 +79,22 @@ def static(self, path: str, path_prefix: str = "/static") -> None:
async def _get_and_call_handler(
self, method: str, path: str, req: Request
) -> Callable:
# Retrieve the handler and path parameters for the given method and path
handler, path_params = self.get_handler(method, path)
req.path_params = path_params

if handler is None:
# try running through middlewares
req, ctx, res = await run_middleware_stack(
self._router._router_level_middelwares, req, **{}
)

if res is None:
response = NotFoundHttpException()
else:
response = res

else:
response = await call_handler(handler, req)
return response
if handler is not None:
# If a handler is found, call it with the request
return await call_handler(handler, req)

# If no handler was found, attempt to run middlewares.
# (If a handler was found, middlewares will be run by `call_handler`)
req, ctx, res = await run_middleware_stack(
self._router._router_level_middelwares, req
)

# If middleware does not provide a response, return a 404 Not Found
return res if res is not None else NotFoundHttpException()

def __call__(self, *args: Any, **kwds: Any) -> Any:
async def uvicorn_handler(scope: dict, receive: Any, send: Any) -> None:
Expand Down
17 changes: 13 additions & 4 deletions src/ziplineio/dependency_injector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from typing import Any, Callable
from ziplineio.request import Request
from ziplineio.service import Service, is_service_class
Expand Down Expand Up @@ -44,22 +45,28 @@ async def wrapped_handler(req: Request, **kwargs):
def add_injected_service(
self, service_class: Any, name: str = None, scope: str = "func"
) -> None:
service_name = name if name else service_class.__name__.lower()

if scope not in self._injected_services:
self._injected_services[scope] = {}

if name:
service_name = name
elif hasattr(service_class, "name"):
service_name = service_class.name
else:
service_name = service_class.__name__.lower()

services = self._injected_services[scope]

# Check if the service class is a subclass of `Service`
if is_service_class(service_class):
sig = inspect.signature(service_class.__init__)
services_in_scope = self.get_injected_services(scope)

# Prepare dependencies for the current service instance
service_kwargs = {
name: service
for name, service in services_in_scope.items()
if isinstance(service, Service)
if isinstance(service, Service) and name in sig.parameters
}

# Create a new instance of the service
Expand All @@ -69,7 +76,9 @@ def add_injected_service(
# Inject this service instance into all other services within the scope
for _name, service in services_in_scope.items():
if isinstance(service, Service):
setattr(service, service_name, instance)
sig = inspect.signature(service.__init__)
if service_name in sig.parameters:
setattr(service, service_name, instance)

return instance, service_name
else:
Expand Down
8 changes: 4 additions & 4 deletions src/ziplineio/html/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ def jinja(env: Any, template_name: str):
def decorator(handler):
async def wrapped_handler(req, **kwargs):
# Pass all arguments directly to the handler
sig = inspect.signature(handler)
print(f"sig: {sig}")
# sig = inspect.signature(handler)

# Filter kwargs to only pass those that the handler expects
filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
context = await call_handler(handler, req, **filtered_kwargs)
# filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
context = await call_handler(handler, req, **kwargs)
rendered = template.render(context)
return JinjaResponse(rendered)

Expand Down
6 changes: 1 addition & 5 deletions src/ziplineio/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ async def run_middleware_stack(
if "ctx" not in kwargs:
kwargs["ctx"] = {}

sig = inspect.signature(middleware)
if len(sig.parameters) > 1:
_res = await call_handler(middleware, request, **kwargs)
else:
_res = await call_handler(middleware, request)
_res = await call_handler(middleware, request, **kwargs)

# regular handlers return a response, but middleware can return a tuple
if not isinstance(_res, tuple):
Expand Down
5 changes: 0 additions & 5 deletions src/ziplineio/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ def format_headers(headers: Dict[str, str] | None) -> List[Tuple[bytes, bytes]]:


def format_body(body: bytes | str | dict) -> bytes:
print("BODY:")
print(body)
print(type(body))
if isinstance(body, bytes):
return body
elif isinstance(body, str):
Expand All @@ -62,8 +59,6 @@ def format_body(body: bytes | str | dict) -> bytes:
def format_response(
response: bytes | dict | str | Response | Exception, default_headers: dict[str, str]
) -> RawResponse:
print("TYUUPE:")
print(type(response))
if isinstance(response, bytes):
raw_response = {
"headers": [
Expand Down
1 change: 0 additions & 1 deletion src/ziplineio/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def _get_headers(filename: str) -> dict[str, str]:

def staticfiles(filepath: str, path_prefix: str):
async def handler(req, ctx):
print(f"Request path: {req.path}")
if req.path.startswith(path_prefix):
# remove path prefix
_filepath = path.join(filepath, req.path[len(path_prefix) :])
Expand Down
15 changes: 13 additions & 2 deletions src/ziplineio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,27 @@
from ziplineio.models import ASGIScope


def get_func_params(func: Handler):
return inspect.signature(func).parameters


def get_class_params(cls):
return inspect.signature(cls.__init__).parameters


async def call_handler(
handler: Handler,
req: Request,
**kwargs,
) -> bytes | str | dict | Response | Exception:
try:
kwargs = {"req": req, **kwargs}
params = inspect.signature(handler).parameters
kwargs = {k: v for k, v in kwargs.items() if k in params}
if not inspect.iscoroutinefunction(handler):
response = await asyncio.to_thread(handler, req, **kwargs)
response = await asyncio.to_thread(handler, **kwargs)
else:
response = await handler(req, **kwargs)
response = await handler(**kwargs)

except Exception as e:
response = e
Expand Down
1 change: 1 addition & 0 deletions test/mocks/templates/home.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
<body>
<h1>Home</h1>
<p>Welcome to the home page!</p>
{{ content }}
</body>
</html>
4 changes: 2 additions & 2 deletions test/test_dependency_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def test_handler(req, service1: Service1):

# Call the handler
handler, params = self.ziplineio.get_handler("GET", "/")
print(self.ziplineio._injector._injected_services)

response = await call_handler(handler, {})
print(response)

self.assertEqual(response["message"], "Service 1")
40 changes: 35 additions & 5 deletions test/test_e2e.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,44 @@
from multiprocessing import Process
import asyncio
from jinja2 import Environment, PackageLoader, select_autoescape
import requests
import uvicorn
import unittest
import unittest.async_case


from ziplineio.app import App
from ziplineio.html.jinja import jinja
from ziplineio.router import Router
from ziplineio.service import Service

env = Environment(loader=PackageLoader("test.mocks"), autoescape=select_autoescape())

app = App()


class Service1(Service):
name = "service1"


class Service2(Service):
name = "service2"

def __init__(self, service1: Service1):
self.service1 = service1


app.inject([Service1, Service2])


@app.get("/jinja")
@jinja(env, "home.html")
def jinja_handler(service2: Service2):
return {"content": service2.service1.name + " content"}


@app.get("/bytes")
async def bytes_handler(req):
async def bytes_handler():
return b"Hello, world!"


Expand All @@ -24,13 +48,13 @@ async def dict_handler(req):


@app.get("/str")
async def str_handler(req):
async def str_handler():
return "Hello, world!"


# Will be made multithreaded
@app.get("/sync-thread")
def sync_handler(req):
def sync_handler():
return {"message": "Hello, sync world!"}


Expand Down Expand Up @@ -58,25 +82,31 @@ async def test_handler_returns_bytes(self):
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"Hello, world!")

print("BYTES")
print(response.content)

async def test_handler_returns_str(self):
response = requests.get("http://localhost:5050/str")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.text, "Hello, world!")

async def test_handler_returns_dict(self):
response = requests.get("http://localhost:5050/dict")
print(response.json())
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"message": "Hello, world!"})

async def test_sync_route(self):
response = requests.get("http://localhost:5050/sync-thread")
print(response.content)
self.assertEqual(response.status_code, 200)

async def test_404(self):
response = requests.get("http://localhost:5050/some-random-route")
self.assertEqual(response.status_code, 404)

async def test_jinja(self):
response = requests.get("http://localhost:5050/jinja")
self.assertEqual(response.status_code, 200)
self.assertTrue("service1 content" in response.text)

async def asyncTearDown(self):
self.proc.terminate()

0 comments on commit 8c20541

Please sign in to comment.