Skip to content

Commit

Permalink
fix cache/ tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CameronSima committed Sep 11, 2024
1 parent 8eb9cf0 commit e387fd9
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 67 deletions.
6 changes: 4 additions & 2 deletions src/ziplineio/app.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import inspect
import re
from typing import Any, Callable, List, Tuple, Type


from ziplineio.exception import NotFoundHttpException
from ziplineio.middleware import run_middleware_stack
from ziplineio.dependency_injector import inject, injector, DependencyInjector
from ziplineio import settings
from ziplineio.handler import Handler
from ziplineio.request import Request
from ziplineio.request_context import set_request
from ziplineio.response import Response, NotFoundResponse, format_response
from ziplineio.router import Router
from ziplineio.static import staticfiles
Expand Down Expand Up @@ -87,6 +86,9 @@ async def _get_and_call_handler(
handler, path_params = self.get_handler(method, path)
req.path_params = path_params

# set request context
set_request(req)

if handler is not None:
# If a handler is found, call it with the request
return await call_handler(handler, req=req)
Expand Down
105 changes: 77 additions & 28 deletions src/ziplineio/cache.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,93 @@
from functools import wraps
from typing import Any
from collections.abc import Callable
from typing import Any, Dict, Union
from datetime import datetime, timedelta

from ziplineio.handler import Handler

from ziplineio.request_context import get_request
from ziplineio.utils import call_handler

class MemoryCache:
_instance = None
_cache: dict[str, str]

def __new__(cls):
if cls._instance is None:
cls._instance = super(cls, cls).__new__(cls)
return cls._instance
class BaseCache:
pass

def __init__(self) -> None:
# Ensure the cache is only initialized once
if not hasattr(self, "_cache"):
self._cache = {}
async def get(self, key: str) -> Any:
pass

def get(self, key: str) -> str | None:
async def set(self, key: str, value: Any, duration: Union[int, float] = 0) -> None:
pass

async def clear(self) -> None:
pass


_cache: BaseCache = None


class MemoryCache(BaseCache):
def __init__(self):
self._cache: Dict[str, Any] = {}
self._expiry_times: Dict[str, datetime] = {}

async def get(self, key: str) -> Any:
"""Get a cache entry."""
if await self.is_expired(key):
# remove the expired cache entry
self._cache.pop(key, None)
self._expiry_times.pop(key, None)
return None
return self._cache.get(key, None)

def set(self, key: str, value: str) -> None:
async def set(self, key: str, value: Any, duration: Union[int, float] = 0) -> None:
"""Set a cache entry."""
self._cache[key] = value
self._expiry_times[key] = datetime.now() + timedelta(seconds=duration)

async def is_expired(self, key: str) -> bool:
"""Check if a cache entry has expired."""
if key not in self._expiry_times:
return True
return datetime.now() >= self._expiry_times[key]

def __call__(self, handler: Handler) -> Any:
@wraps(handler)
def wrapper(*args, **kwargs):
# Create a key based on the arguments
key = (args, frozenset(kwargs.items()))
def clear(self):
"""Clears the cache."""
self._cache.clear()
self._expiry_times.clear()

if key not in self._cache:
# Call the function and store the result in the cache
secache[key] = func(*args, **kwargs)
return cache[key]

def cache(duration: Union[int, float] = 0):
"""Cache decorator that accepts duration in seconds."""

def decorator(func: Callable) -> Callable:
async def wrapper(*args, **kwargs):
req = get_request()

url = req.path
query_params_str = "&".join(
[f"{k}={v}" for k, v in req.query_params.items()]
)
key = f"{func.__name__}:{kwargs}:{url}:{query_params_str}"

# Check if the cache has expired or does not exist
value = await _cache.get(key)
if value is None:
result = await call_handler(func, **kwargs)
await _cache.set(key, result, duration)
return result
else:
return value

return wrapper

return decorator


def get_cache() -> BaseCache:
"""Get the cache instance."""
return _cache

memory_cache1 = MemoryCache()
memory_cache2 = MemoryCache()

print(memory_cache1 is memory_cache2)
def set_cache(cache: BaseCache) -> None:
"""Set the cache instance."""
global _cache
_cache = cache
return None
13 changes: 13 additions & 0 deletions src/ziplineio/request_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from contextvars import ContextVar

from ziplineio.request import Request

_request_context_var = ContextVar("request_context")


def set_request(request: Request):
_request_context_var.set(request)


def get_request() -> Request:
return _request_context_var.get()
20 changes: 14 additions & 6 deletions src/ziplineio/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import re
import inspect
import asyncio

from typing import Dict

from httpx import request

from ziplineio.request import Request
from ziplineio.response import Response
from ziplineio.handler import Handler
Expand All @@ -18,16 +19,23 @@ def get_class_params(cls):
return inspect.signature(cls.__init__).parameters


"""
Only pass the kwargs that are required by the handler function.
"""


def clean_kwargs(kwargs: dict, handler: Handler) -> dict:
params = inspect.signature(handler).parameters
kwargs = {k: v for k, v in kwargs.items() if k in params}
return kwargs


async def call_handler(
handler: Handler,
**kwargs,
) -> bytes | str | dict | Response | Exception:
try:
if "req" not in kwargs:
raise ValueError("Request object not found in kwargs")

params = inspect.signature(handler).parameters
kwargs = {k: v for k, v in kwargs.items() if k in params}
kwargs = clean_kwargs(kwargs, handler)
if not inspect.iscoroutinefunction(handler):
response = await asyncio.to_thread(handler, **kwargs)
else:
Expand Down
49 changes: 49 additions & 0 deletions test/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest
import random
from ziplineio.app import App
from ziplineio.cache import MemoryCache, set_cache, cache
from ziplineio.dependency_injector import inject
from ziplineio.request import Request


class TestMemoryCache(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.app = App()
set_cache(MemoryCache())

async def test_handler_cache(self):
@self.app.get("/cached_number")
@cache(5)
async def handler():
return random.randint(0, 9999)

req: Request = Request("GET", "/cached_number")

first_call = await self.app._get_and_call_handler("GET", "/cached_number", req)
second_call = await self.app._get_and_call_handler("GET", "/cached_number", req)

print("HERE", first_call, second_call)

# Ensure the result is cached
self.assertEqual(first_call, second_call)

async def test_handler_cache_with_dep_injector(self):
class Service:
def speak():
return "Hello"

@self.app.get("/cached_number")
@inject(Service)
@cache(5)
async def handler(s: Service):
return s.speak() + str(random.randint(0, 9999))

req: Request = Request("GET", "/cached_number")

first_call = await self.app._get_and_call_handler("GET", "/cached_number", req)
second_call = await self.app._get_and_call_handler("GET", "/cached_number", req)

print("HERE", first_call, second_call)

# Ensure the result is cached
self.assertEqual(first_call, second_call)
79 changes: 48 additions & 31 deletions test/test_e2e.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from multiprocessing import Process
import asyncio
from jinja2 import Environment, PackageLoader, select_autoescape
import requests
import httpx
import uvicorn
import unittest

Expand Down Expand Up @@ -78,51 +78,68 @@ def run_server():
class TestE2E(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
"""Bring server up."""
self.proc = Process(target=run_server, args=(), daemon=False)
self.proc = Process(target=run_server, args=(), daemon=True)
self.proc.start()

# Wait for the server to be up
await self.wait_for_server()

async def wait_for_server(self):
async def wait_for_server(self, timeout=30):
"""Wait for the server to be ready."""
while True:
try:
response = requests.get("http://localhost:5050/bytes")
if response.status_code == 200:
break
except requests.ConnectionError:
await asyncio.sleep(0.1) # Short sleep between retries

async def check_server():
async with httpx.AsyncClient() as client:
while True:
try:
response = await client.get("http://localhost:5050/bytes")
if response.status_code == 200:
return True
except httpx.RequestError:
pass
await asyncio.sleep(0.2) # Short sleep between retries

try:
await asyncio.wait_for(check_server(), timeout=timeout)
except asyncio.TimeoutError:
self.fail("Server did not start within the specified timeout")

async def asyncTearDown(self):
self.proc.terminate()
self.proc.join()

async def test_handler_returns_bytes(self):
response = requests.get("http://localhost:5050/bytes")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"Hello, world!")
async with httpx.AsyncClient() as client:
response = await client.get("http://localhost:5050/bytes")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"Hello, world!")

async def test_handler_returns_str(self):
response = requests.get("http://localhost:5050/str")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.text, "Hello, world!")
async with httpx.AsyncClient() as client:
response = await client.get("http://localhost:5050/str")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.text, "Hello, world!")

async def test_handler_returns_dict(self):
response = requests.get("http://localhost:5050/dict")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"message": "Hello, world!"})
async with httpx.AsyncClient() as client:
response = await client.get("http://localhost:5050/dict")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"message": "Hello, world!"})

async def test_sync_route(self):
response = requests.get("http://localhost:5050/sync-thread")
self.assertEqual(response.status_code, 200)
async with httpx.AsyncClient() as client:
response = await client.get("http://localhost:5050/sync-thread")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"message": "Hello, sync world!"})

async def test_jinja(self):
response = requests.get("http://localhost:5050/jinja")
self.assertEqual(response.status_code, 200)
self.assertTrue("service1 content" in response.text)
async with httpx.AsyncClient() as client:
response = await client.get("http://localhost:5050/jinja")
self.assertEqual(response.status_code, 200)
self.assertTrue("service1 content" in response.text)

async def test_404_jinja(self):
response = requests.get("http://localhost:5050/some-random-route")
self.assertEqual(response.status_code, 404)
self.assertTrue("404" in response.text)
self.assertEqual(response.headers["Content-Type"], "text/html")

async def asyncTearDown(self):
self.proc.terminate()
async with httpx.AsyncClient() as client:
response = await client.get("http://localhost:5050/some-random-route")
self.assertEqual(response.status_code, 404)
self.assertTrue("404" in response.text)
self.assertEqual(response.headers["Content-Type"], "text/html")

0 comments on commit e387fd9

Please sign in to comment.