diff --git a/poetry.lock b/poetry.lock index 6f0b633..014814b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -67,6 +67,38 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "deflate-dict" +version = "1.0.11" +description = "Python package to deflate and re-inflate dictionaries." +optional = false +python-versions = "*" +files = [ + {file = "deflate_dict-1.0.11.tar.gz", hash = "sha256:07f0c57960070e8460d84271f03291f7dcf56a12dd68c9831a270cf76aa6f42a"}, +] + +[package.dependencies] +support_developer = ">=1.0.2" + +[package.extras] +test = ["codacy-coverage", "coveralls", "pytest", "pytest-cov", "random_dict", "validate_version_code"] + +[[package]] +name = "dict-hash" +version = "1.1.37" +description = "Python package to hash dictionaries using both default hash and sha256." +optional = false +python-versions = "*" +files = [ + {file = "dict_hash-1.1.37.tar.gz", hash = "sha256:40a5dfef0c0291abeead3990ee33ecbc662550f7335074d91316cc0a6e67b1d5"}, +] + +[package.dependencies] +deflate_dict = ">=1.0.8" + +[package.extras] +test = ["hvplot (>=0.9.1)", "netaddr", "numba", "polars", "pytest", "pytest-cov", "random_dict", "tqdm", "validate_version_code"] + [[package]] name = "dnspython" version = "2.6.1" @@ -892,6 +924,19 @@ anyio = ">=3.4.0,<5" [package.extras] full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] +[[package]] +name = "support-developer" +version = "1.0.5" +description = "Package designed to centralize messages to support developers development work." +optional = false +python-versions = "*" +files = [ + {file = "support_developer-1.0.5.tar.gz", hash = "sha256:7739f63766b90480ee71d56c70ed8953500c48ab19e2782f867786642fce03d9"}, +] + +[package.extras] +test = ["codacy-coverage", "coveralls", "pytest", "pytest-cov", "validate_version_code"] + [[package]] name = "typer" version = "0.12.3" @@ -1234,4 +1279,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "f7cffba1160cdc14ec038309c4d035302257cbb1058edd83503424100440c67e" +content-hash = "ceb62ce7559db5b51619896167492e6c10049032ad52d785084d1dfa6a37a7e7" diff --git a/pyproject.toml b/pyproject.toml index e69ecd0..45245bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ readme = "README.md" python = "^3.12" fastapi = "^0.111.0" pydantic = "^2.7.1" +dict-hash = "^1.1.37" [tool.poetry.group.dev.dependencies] diff --git a/tests/test_decorator.py b/tests/test_decorator.py index da34752..82b5fd5 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -57,7 +57,7 @@ def key_builder(): def sample_request(): - return Request({"type": "http", "headers": {}}) + return Request({"type": "http", "headers": {}, "method": "GET"}) def sample_response(): diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 95e79e7..66f34e8 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -1,23 +1,123 @@ +import asyncio from fastapi.testclient import TestClient -from .utils import app +import pytest -client = TestClient(app) +from ultra_cache.decorator import _default_hash_fn +from . import utils +client = TestClient(utils.app) + + +@pytest.fixture(autouse=True, scope="function") +def reset_cache(): + try: + loop = asyncio.get_event_loop() + except: # noqa: E722 + loop = asyncio.new_event_loop() + loop.run_until_complete(utils.storage.clear()) + yield + loop.close() + + +# TODO: Reset cache between tests def test_cache_decorator(): response = client.get("/items/1") assert response.status_code == 200 assert response.json() == {"item_id": 1} assert response.headers.get("X-Cache") == "MISS" + assert response.headers.get("Cache-Control", "") == "" + etag_1 = response.headers.get("ETag") + assert etag_1 is not None # Test cache hit response = client.get("/items/1") assert response.status_code == 200 assert response.json() == {"item_id": 1} assert response.headers.get("X-Cache") == "HIT" + assert response.headers.get("Cache-Control", "") == "" + etag_2 = response.headers.get("ETag") + assert etag_2 is not None + assert etag_1 == etag_2 # Test cache miss response = client.get("/items/2") assert response.status_code == 200 assert response.json() == {"item_id": 2} assert response.headers.get("X-Cache") == "MISS" + assert response.headers.get("ETag") != etag_1 + assert response.headers.get("Cache-Control", "") == "" + + +def test_cache_with_maxage(): + response = client.get("/items/1", headers={"Cache-Control": "max-age=10"}) + assert response.status_code == 200 + assert response.json() == {"item_id": 1} + assert response.headers.get("X-Cache") == "MISS" + assert response.headers.get("Cache-Control", "") == "max-age=10" + + +def test_cache_with_if_none_match_hit(): + response = client.get( + "/items/1", + ) + assert response.status_code == 200 + assert response.json() == {"item_id": 1} + assert response.headers.get("X-Cache") == "MISS" + assert response.headers.get("Cache-Control", "") == "" + + etag = response.headers.get("ETag") + assert etag is not None + assert etag == _default_hash_fn(response.json()) + + # run again with If-None-Match + response = client.get("/items/1", headers={"If-None-Match": etag}) + + assert response.status_code == 304 + assert response.headers.get("X-Cache") == "HIT" + assert response.headers.get("Cache-Control", "") == "" + assert response.headers.get("ETag") == etag + + +def test_cache_with_if_none_match_hit_star(): + response = client.get( + "/items/1", + ) + assert response.status_code == 200 + assert response.json() == {"item_id": 1} + assert response.headers.get("X-Cache") == "MISS" + assert response.headers.get("Cache-Control", "") == "" + + etag = response.headers.get("ETag") + assert etag is not None + assert etag == _default_hash_fn(response.json()) + + # run again with If-None-Match + response = client.get("/items/1", headers={"If-None-Match": "*"}) + + assert response.status_code == 304 + assert response.headers.get("X-Cache") == "HIT" + assert response.headers.get("Cache-Control", "") == "" + assert response.headers.get("ETag") == etag + + +def test_cache_with_if_none_match_miss(): + response = client.get( + "/items/1", + ) + assert response.status_code == 200 + assert response.json() == {"item_id": 1} + assert response.headers.get("X-Cache") == "MISS" + assert response.headers.get("Cache-Control", "") == "" + + etag = response.headers.get("ETag") + assert etag is not None + assert etag == _default_hash_fn(response.json()) + + # run again with If-None-Match + response = client.get("/items/1", headers={"If-None-Match": "W/123"}) + + assert response.status_code == 200 + assert response.headers.get("X-Cache") == "HIT" + assert response.headers.get("Cache-Control", "") == "" + assert response.headers.get("ETag") == etag diff --git a/ultra_cache/cache_control.py b/ultra_cache/cache_control.py new file mode 100644 index 0000000..fddc334 --- /dev/null +++ b/ultra_cache/cache_control.py @@ -0,0 +1,52 @@ +from typing import Self + + +class CacheControl: + REQUEST_ONLY_KEYS = ["max-stale", "min-fresh", "only-if-cached"] + + def __init__(self, parts: dict[str, str]) -> None: + self.parts: dict[str, str | None] = parts + + def set(self, key: str, value: str) -> None: + self.parts[key] = value + + def get(self, key: str) -> str | None: + return self.parts.get(key, None) + + def setdefault(self, key: str, value: str) -> None: + self.parts.setdefault(key, value) + + @classmethod + def from_string(cls, cache_control: str | None) -> Self: + if cache_control is None: + return cls({}) + return cls( + { + x.split("=")[0].strip(): x.split("=")[1].strip() if "=" in x else None + for x in cache_control.lower().split(",") + } + ) + + @property + def max_age(self) -> int | None: + value = self.parts.get("max-age", None) + if value is None: + return None + return int(value) + + @property + def no_cache(self) -> bool: + return "no-cache" in self.parts + + @property + def no_store(self) -> bool: + return "no-store" in self.parts + + def to_response_header(self) -> str: + return ", ".join( + [ + f"{k}={v}" + for k, v in self.parts.items() + if k not in self.REQUEST_ONLY_KEYS + ] + ) diff --git a/ultra_cache/decorator.py b/ultra_cache/decorator.py index a6fc6ca..38ae2cf 100644 --- a/ultra_cache/decorator.py +++ b/ultra_cache/decorator.py @@ -6,8 +6,9 @@ import anyio import anyio.to_thread - +from dict_hash import dict_hash from ultra_cache.build_cache_key import BuildCacheKey, DefaultBuildCacheKey +from ultra_cache.cache_control import CacheControl from ultra_cache.main import get_storage from ultra_cache.storage.base import BaseStorage from fastapi import Request, Response @@ -28,6 +29,13 @@ def _extract_param_of_type( return None +def _default_hash_fn(x: Any) -> str: + if isinstance(x, dict) or isinstance(x, list): + return "W/" + str(dict_hash(x, maximal_recursion=10)) + + return "W/" + str(hash(x)) + + def _extract( param: inspect.Parameter | None, args: tuple[S1, ...], kwargs: dict[str, S2] ) -> tuple[tuple[S1, ...], dict[str, S2]]: @@ -48,10 +56,19 @@ def _extract( return (args_copy, kwargs) +def _does_etag_match(etag: str, if_none_match: str | None) -> bool: + if if_none_match is not None and ( + if_none_match == "*" or any(etag == x.strip() for x in if_none_match.split(",")) + ): + return True + return False + + def cache( ttl: int | float | None = None, build_cache_key: BuildCacheKey = DefaultBuildCacheKey(), storage: BaseStorage | None = None, + hash_fn: Callable[[Any], str] = _default_hash_fn, ): def _wrapper( func: Callable[P, Union[R, Coroutine[R, Any, Any]]], @@ -84,14 +101,10 @@ async def _decorator(*args: P.args, **kwargs: P.kwargs): request: Request = kwargs.get(request_param.name) response: Response = kwargs.get(response_param.name) - cache_control = request.headers.get("Cache-Control", None) - - no_cache = False - no_store = False - - if cache_control: - no_cache = "no-cache" in cache_control.lower() - no_store = "no-store" in cache_control.lower() + cache_control = CacheControl.from_string( + request.headers.get("cache-control", None) + ) + if_none_match = request.headers.get("if-none-match", None) args_for_key, kwargs_for_key = _extract( response_param, *(_extract(request_param, args, kwargs)) @@ -102,14 +115,26 @@ async def _decorator(*args: P.args, **kwargs: P.kwargs): storage = get_storage() cached = None - if not no_cache: + if not cache_control.no_cache: cached = await storage.get(key) + if ttl: + cache_control.setdefault("max-age", ttl) + + response.headers["Cache-Control"] = cache_control.to_response_header() + if cached is not None: response.headers["X-Cache"] = "HIT" - return cached + response.headers["ETag"] = hash_fn(cached) - response.headers["X-Cache"] = "MISS" + if request.method in ["HEAD", "GET"]: + if _does_etag_match(response.headers["ETag"], if_none_match): + response.status_code = 304 + return + + return cached + else: + response.headers["X-Cache"] = "MISS" if original_request_param is None: kwargs.pop("request") @@ -122,8 +147,16 @@ async def _decorator(*args: P.args, **kwargs: P.kwargs): else: output = await anyio.to_thread.run_sync(partial(func, *args, **kwargs)) - if not no_store: - await storage.save(key=key, value=output, ttl=ttl) + response.headers["ETag"] = hash_fn(output) + if request.method in ["HEAD", "GET"]: + if _does_etag_match(response.headers["ETag"], if_none_match): + response.status_code = 304 + return + + if not cache_control.no_store: + await storage.save( + key=key, value=output, ttl=cache_control.max_age or ttl + ) return output diff --git a/ultra_cache/storage/base.py b/ultra_cache/storage/base.py index 5fd3b2b..d285889 100644 --- a/ultra_cache/storage/base.py +++ b/ultra_cache/storage/base.py @@ -11,3 +11,6 @@ async def save(self, key: K, value: V, ttl: int | float | None = None) -> None: @abstractmethod async def get(self, key: K) -> V | None: ... + + @abstractmethod + async def clear(self) -> None: ... diff --git a/ultra_cache/storage/inmemory.py b/ultra_cache/storage/inmemory.py index 16cafc1..a9a1171 100644 --- a/ultra_cache/storage/inmemory.py +++ b/ultra_cache/storage/inmemory.py @@ -45,3 +45,6 @@ async def get(self, key: K) -> V | None: return None return item.value + + async def clear(self) -> None: + self.storage = {}