-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8eb9cf0
commit e387fd9
Showing
6 changed files
with
205 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,93 @@ | ||
from functools import wraps | ||
from typing import Any | ||
from collections.abc import Callable | ||
from typing import Any, Dict, Union | ||
from datetime import datetime, timedelta | ||
|
||
from ziplineio.handler import Handler | ||
|
||
from ziplineio.request_context import get_request | ||
from ziplineio.utils import call_handler | ||
|
||
class MemoryCache: | ||
_instance = None | ||
_cache: dict[str, str] | ||
|
||
def __new__(cls): | ||
if cls._instance is None: | ||
cls._instance = super(cls, cls).__new__(cls) | ||
return cls._instance | ||
class BaseCache: | ||
pass | ||
|
||
def __init__(self) -> None: | ||
# Ensure the cache is only initialized once | ||
if not hasattr(self, "_cache"): | ||
self._cache = {} | ||
async def get(self, key: str) -> Any: | ||
pass | ||
|
||
def get(self, key: str) -> str | None: | ||
async def set(self, key: str, value: Any, duration: Union[int, float] = 0) -> None: | ||
pass | ||
|
||
async def clear(self) -> None: | ||
pass | ||
|
||
|
||
_cache: BaseCache = None | ||
|
||
|
||
class MemoryCache(BaseCache): | ||
def __init__(self): | ||
self._cache: Dict[str, Any] = {} | ||
self._expiry_times: Dict[str, datetime] = {} | ||
|
||
async def get(self, key: str) -> Any: | ||
"""Get a cache entry.""" | ||
if await self.is_expired(key): | ||
# remove the expired cache entry | ||
self._cache.pop(key, None) | ||
self._expiry_times.pop(key, None) | ||
return None | ||
return self._cache.get(key, None) | ||
|
||
def set(self, key: str, value: str) -> None: | ||
async def set(self, key: str, value: Any, duration: Union[int, float] = 0) -> None: | ||
"""Set a cache entry.""" | ||
self._cache[key] = value | ||
self._expiry_times[key] = datetime.now() + timedelta(seconds=duration) | ||
|
||
async def is_expired(self, key: str) -> bool: | ||
"""Check if a cache entry has expired.""" | ||
if key not in self._expiry_times: | ||
return True | ||
return datetime.now() >= self._expiry_times[key] | ||
|
||
def __call__(self, handler: Handler) -> Any: | ||
@wraps(handler) | ||
def wrapper(*args, **kwargs): | ||
# Create a key based on the arguments | ||
key = (args, frozenset(kwargs.items())) | ||
def clear(self): | ||
"""Clears the cache.""" | ||
self._cache.clear() | ||
self._expiry_times.clear() | ||
|
||
if key not in self._cache: | ||
# Call the function and store the result in the cache | ||
secache[key] = func(*args, **kwargs) | ||
return cache[key] | ||
|
||
def cache(duration: Union[int, float] = 0): | ||
"""Cache decorator that accepts duration in seconds.""" | ||
|
||
def decorator(func: Callable) -> Callable: | ||
async def wrapper(*args, **kwargs): | ||
req = get_request() | ||
|
||
url = req.path | ||
query_params_str = "&".join( | ||
[f"{k}={v}" for k, v in req.query_params.items()] | ||
) | ||
key = f"{func.__name__}:{kwargs}:{url}:{query_params_str}" | ||
|
||
# Check if the cache has expired or does not exist | ||
value = await _cache.get(key) | ||
if value is None: | ||
result = await call_handler(func, **kwargs) | ||
await _cache.set(key, result, duration) | ||
return result | ||
else: | ||
return value | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
def get_cache() -> BaseCache: | ||
"""Get the cache instance.""" | ||
return _cache | ||
|
||
memory_cache1 = MemoryCache() | ||
memory_cache2 = MemoryCache() | ||
|
||
print(memory_cache1 is memory_cache2) | ||
def set_cache(cache: BaseCache) -> None: | ||
"""Set the cache instance.""" | ||
global _cache | ||
_cache = cache | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from contextvars import ContextVar | ||
|
||
from ziplineio.request import Request | ||
|
||
_request_context_var = ContextVar("request_context") | ||
|
||
|
||
def set_request(request: Request): | ||
_request_context_var.set(request) | ||
|
||
|
||
def get_request() -> Request: | ||
return _request_context_var.get() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import unittest | ||
import random | ||
from ziplineio.app import App | ||
from ziplineio.cache import MemoryCache, set_cache, cache | ||
from ziplineio.dependency_injector import inject | ||
from ziplineio.request import Request | ||
|
||
|
||
class TestMemoryCache(unittest.IsolatedAsyncioTestCase): | ||
def setUp(self): | ||
self.app = App() | ||
set_cache(MemoryCache()) | ||
|
||
async def test_handler_cache(self): | ||
@self.app.get("/cached_number") | ||
@cache(5) | ||
async def handler(): | ||
return random.randint(0, 9999) | ||
|
||
req: Request = Request("GET", "/cached_number") | ||
|
||
first_call = await self.app._get_and_call_handler("GET", "/cached_number", req) | ||
second_call = await self.app._get_and_call_handler("GET", "/cached_number", req) | ||
|
||
print("HERE", first_call, second_call) | ||
|
||
# Ensure the result is cached | ||
self.assertEqual(first_call, second_call) | ||
|
||
async def test_handler_cache_with_dep_injector(self): | ||
class Service: | ||
def speak(): | ||
return "Hello" | ||
|
||
@self.app.get("/cached_number") | ||
@inject(Service) | ||
@cache(5) | ||
async def handler(s: Service): | ||
return s.speak() + str(random.randint(0, 9999)) | ||
|
||
req: Request = Request("GET", "/cached_number") | ||
|
||
first_call = await self.app._get_and_call_handler("GET", "/cached_number", req) | ||
second_call = await self.app._get_and_call_handler("GET", "/cached_number", req) | ||
|
||
print("HERE", first_call, second_call) | ||
|
||
# Ensure the result is cached | ||
self.assertEqual(first_call, second_call) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters