Skip to content

Commit

Permalink
Remove global instance (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
jegork authored Jun 8, 2024
1 parent 0c33a99 commit 9ec14f2
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 128 deletions.
5 changes: 3 additions & 2 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ultra_cache.build_cache_key import DefaultBuildCacheKey
from ultra_cache.decorator import cache
from ultra_cache.decorator import UltraCache
from ultra_cache.storage.inmemory import InMemoryStorage
import pytest
from fastapi import Request, Response
Expand Down Expand Up @@ -126,7 +126,8 @@ async def test_decorator_cached(
fn_with_args, "fn"
) # Weird syntax, but did not find any alternative

cached_fn = cache(storage=storage)(fn_with_args.fn)
cache = UltraCache(storage=storage)
cached_fn = cache()(fn_with_args.fn)
result = await cached_fn(
*fn_with_args.args, **{**fn_with_args.kwargs, **fn_with_args.injected_kwargs}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def reset_cache():
loop = asyncio.get_event_loop()
except: # noqa: E722
loop = asyncio.new_event_loop()
loop.run_until_complete(utils.storage.clear())
loop.run_until_complete(utils.cache.storage.clear())
yield
loop.close()

Expand Down
19 changes: 11 additions & 8 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from ultra_cache.main import get_storage, init_cache
from fakeredis import FakeAsyncRedis
from ultra_cache.decorator import UltraCache
from ultra_cache.storage.inmemory import InMemoryStorage
import pytest

from ultra_cache.storage.redis import RedisStorage

def test_init_cache():
with pytest.raises(ValueError):
get_storage()

created_storage = InMemoryStorage()
init_cache(created_storage)
s2 = get_storage()
@pytest.fixture(params=[InMemoryStorage(), RedisStorage(FakeAsyncRedis())])
def storage(request):
return request.param

assert created_storage == s2

def test_init_cache(storage):
cache = UltraCache(storage=storage)

assert isinstance(cache.storage, type(storage))
5 changes: 3 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from fastapi import FastAPI
from ultra_cache.decorator import cache
from ultra_cache.decorator import UltraCache
from ultra_cache.storage.inmemory import InMemoryStorage

app = FastAPI()
storage = InMemoryStorage()
cache = UltraCache(storage=storage)


@app.get("/items/{item_id}")
@cache(storage=storage)
@cache()
async def read_item(item_id: int):
return {"item_id": item_id}
190 changes: 99 additions & 91 deletions ultra_cache/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from dict_hash import dict_hash
from ultra_cache.build_cache_key import BuildCacheKey, DefaultBuildCacheKey
from ultra_cache.cache_control import CacheControl
from ultra_cache.main import get_storage
from ultra_cache.storage.base import BaseStorage
from fastapi import Request, Response
import sys
Expand Down Expand Up @@ -70,102 +69,111 @@ def _does_etag_match(etag: str, if_none_match: Union[str, None]) -> bool:
return False


def cache(
ttl: Union[int, float, None] = None,
build_cache_key: BuildCacheKey = DefaultBuildCacheKey(),
storage: Union[BaseStorage, None] = None,
hash_fn: Callable[[Any], str] = _default_hash_fn,
):
def _wrapper(
func: Callable[P, Union[R, Coroutine[R, Any, Any]]],
) -> Callable[P, Coroutine[R, Any, Any]]:
sig = inspect.signature(func)

original_request_param = _extract_param_of_type(sig, Request)
original_response_param = _extract_param_of_type(sig, Response)
request_param = original_request_param
response_param = original_response_param

new_parameters = list(sig.parameters.values())
if request_param is None:
request_param = inspect.Parameter(
"request", annotation=Request, kind=inspect.Parameter.KEYWORD_ONLY
)
new_parameters.append(request_param)
if response_param is None:
response_param = inspect.Parameter(
"response", annotation=Response, kind=inspect.Parameter.KEYWORD_ONLY
)
new_parameters.append(response_param)

func.__signature__ = sig.replace(parameters=new_parameters)

# allows for the decorator to be used with fastapi params interospection
@wraps(func)
async def _decorator(*args: P.args, **kwargs: P.kwargs):
nonlocal storage
request: Request = kwargs.get(request_param.name)
response: Response = kwargs.get(response_param.name)

cache_control = CacheControl.from_string(
request.headers.get("cache-control", None)
)
if_none_match = request.headers.get("if-none-match", None)

args_for_key, kwargs_for_key = _extract(
response_param, *(_extract(request_param, args, kwargs))
)
key = build_cache_key(func, args=args_for_key, kwargs=kwargs_for_key)

if storage is None:
storage = get_storage()

cached = None
if not cache_control.no_cache:
cached = await storage.get(key)

if ttl:
cache_control.setdefault("max-age", ttl)

response.headers["Cache-Control"] = cache_control.to_response_header()

if cached is not None:
response.headers["X-Cache"] = "HIT"
response.headers["ETag"] = hash_fn(cached)
class UltraCache:
storage: Union[BaseStorage, None] = None

def __init__(self, storage: BaseStorage) -> None:
self.storage = storage

def __call__(
self,
ttl: Union[int, float, None] = None,
build_cache_key: BuildCacheKey = DefaultBuildCacheKey(),
storage: Union[BaseStorage, None] = None,
hash_fn: Callable[[Any], str] = _default_hash_fn,
):
def _wrapper(
func: Callable[P, Union[R, Coroutine[R, Any, Any]]],
) -> Callable[P, Coroutine[R, Any, Any]]:
sig = inspect.signature(func)

original_request_param = _extract_param_of_type(sig, Request)
original_response_param = _extract_param_of_type(sig, Response)
request_param = original_request_param
response_param = original_response_param

new_parameters = list(sig.parameters.values())
if request_param is None:
request_param = inspect.Parameter(
"request", annotation=Request, kind=inspect.Parameter.KEYWORD_ONLY
)
new_parameters.append(request_param)
if response_param is None:
response_param = inspect.Parameter(
"response", annotation=Response, kind=inspect.Parameter.KEYWORD_ONLY
)
new_parameters.append(response_param)

func.__signature__ = sig.replace(parameters=new_parameters)

# allows for the decorator to be used with fastapi params interospection
@wraps(func)
async def _decorator(*args: P.args, **kwargs: P.kwargs):
nonlocal storage
request: Request = kwargs.get(request_param.name)
response: Response = kwargs.get(response_param.name)

cache_control = CacheControl.from_string(
request.headers.get("cache-control", None)
)
if_none_match = request.headers.get("if-none-match", None)

args_for_key, kwargs_for_key = _extract(
response_param, *(_extract(request_param, args, kwargs))
)
key = build_cache_key(func, args=args_for_key, kwargs=kwargs_for_key)

if storage is None:
storage = self.storage

cached = None
if not cache_control.no_cache:
cached = await storage.get(key)

if ttl:
cache_control.setdefault("max-age", ttl)

response.headers["Cache-Control"] = cache_control.to_response_header()

if cached is not None:
response.headers["X-Cache"] = "HIT"
response.headers["ETag"] = hash_fn(cached)

if request.method in ["HEAD", "GET"]:
if _does_etag_match(response.headers["ETag"], if_none_match):
response.status_code = 304
return

return cached
else:
response.headers["X-Cache"] = "MISS"

if original_request_param is None:
kwargs.pop("request")
if original_response_param is None:
kwargs.pop("response")

# Note: inspect.iscoroutinefunction returns False for AsyncMock
if asyncio.iscoroutinefunction(func):
output = await func(*args, **kwargs)
else:
output = await anyio.to_thread.run_sync(
partial(func, *args, **kwargs)
)

response.headers["ETag"] = hash_fn(output)
if request.method in ["HEAD", "GET"]:
if _does_etag_match(response.headers["ETag"], if_none_match):
response.status_code = 304
return

return cached
else:
response.headers["X-Cache"] = "MISS"

if original_request_param is None:
kwargs.pop("request")
if original_response_param is None:
kwargs.pop("response")

# Note: inspect.iscoroutinefunction returns False for AsyncMock
if asyncio.iscoroutinefunction(func):
output = await func(*args, **kwargs)
else:
output = await anyio.to_thread.run_sync(partial(func, *args, **kwargs))

response.headers["ETag"] = hash_fn(output)
if request.method in ["HEAD", "GET"]:
if _does_etag_match(response.headers["ETag"], if_none_match):
response.status_code = 304
return

if not cache_control.no_store:
await storage.save(
key=key, value=output, ttl=cache_control.max_age or ttl
)
if not cache_control.no_store:
await storage.save(
key=key, value=output, ttl=cache_control.max_age or ttl
)

return output
return output

return _decorator
return _decorator

return _wrapper
return _wrapper
24 changes: 0 additions & 24 deletions ultra_cache/main.py
Original file line number Diff line number Diff line change
@@ -1,24 +0,0 @@
from contextlib import AbstractAsyncContextManager
from typing import Union
from ultra_cache.storage.base import BaseStorage

_storage_instance: Union[BaseStorage, None] = None


def init_cache(storage: BaseStorage) -> None:
global _storage_instance
_storage_instance = storage


def get_storage() -> BaseStorage:
if _storage_instance is None:
raise ValueError("Cache not initialized")

return _storage_instance


class FastCache(AbstractAsyncContextManager):
storage: Union[BaseStorage, None] = None

def __init__(self) -> None:
pass

0 comments on commit 9ec14f2

Please sign in to comment.