Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support http cache headers, add fastapi tests #9

Merged
merged 2 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ readme = "README.md"
python = "^3.12"
fastapi = "^0.111.0"
pydantic = "^2.7.1"
dict-hash = "^1.1.37"


[tool.poetry.group.dev.dependencies]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def key_builder():


def sample_request():
return Request({"type": "http", "headers": {}})
return Request({"type": "http", "headers": {}, "method": "GET"})


def sample_response():
Expand Down
104 changes: 102 additions & 2 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,123 @@
import asyncio
from fastapi.testclient import TestClient
from .utils import app
import pytest

client = TestClient(app)
from ultra_cache.decorator import _default_hash_fn
from . import utils


client = TestClient(utils.app)


@pytest.fixture(autouse=True, scope="function")
def reset_cache():
try:
loop = asyncio.get_event_loop()
except: # noqa: E722
loop = asyncio.new_event_loop()
loop.run_until_complete(utils.storage.clear())
yield
loop.close()


# TODO: Reset cache between tests
def test_cache_decorator():
response = client.get("/items/1")
assert response.status_code == 200
assert response.json() == {"item_id": 1}
assert response.headers.get("X-Cache") == "MISS"
assert response.headers.get("Cache-Control", "") == ""
etag_1 = response.headers.get("ETag")
assert etag_1 is not None

# Test cache hit
response = client.get("/items/1")
assert response.status_code == 200
assert response.json() == {"item_id": 1}
assert response.headers.get("X-Cache") == "HIT"
assert response.headers.get("Cache-Control", "") == ""
etag_2 = response.headers.get("ETag")
assert etag_2 is not None
assert etag_1 == etag_2

# Test cache miss
response = client.get("/items/2")
assert response.status_code == 200
assert response.json() == {"item_id": 2}
assert response.headers.get("X-Cache") == "MISS"
assert response.headers.get("ETag") != etag_1
assert response.headers.get("Cache-Control", "") == ""


def test_cache_with_maxage():
response = client.get("/items/1", headers={"Cache-Control": "max-age=10"})
assert response.status_code == 200
assert response.json() == {"item_id": 1}
assert response.headers.get("X-Cache") == "MISS"
assert response.headers.get("Cache-Control", "") == "max-age=10"


def test_cache_with_if_none_match_hit():
response = client.get(
"/items/1",
)
assert response.status_code == 200
assert response.json() == {"item_id": 1}
assert response.headers.get("X-Cache") == "MISS"
assert response.headers.get("Cache-Control", "") == ""

etag = response.headers.get("ETag")
assert etag is not None
assert etag == _default_hash_fn(response.json())

# run again with If-None-Match
response = client.get("/items/1", headers={"If-None-Match": etag})

assert response.status_code == 304
assert response.headers.get("X-Cache") == "HIT"
assert response.headers.get("Cache-Control", "") == ""
assert response.headers.get("ETag") == etag


def test_cache_with_if_none_match_hit_star():
response = client.get(
"/items/1",
)
assert response.status_code == 200
assert response.json() == {"item_id": 1}
assert response.headers.get("X-Cache") == "MISS"
assert response.headers.get("Cache-Control", "") == ""

etag = response.headers.get("ETag")
assert etag is not None
assert etag == _default_hash_fn(response.json())

# run again with If-None-Match
response = client.get("/items/1", headers={"If-None-Match": "*"})

assert response.status_code == 304
assert response.headers.get("X-Cache") == "HIT"
assert response.headers.get("Cache-Control", "") == ""
assert response.headers.get("ETag") == etag


def test_cache_with_if_none_match_miss():
response = client.get(
"/items/1",
)
assert response.status_code == 200
assert response.json() == {"item_id": 1}
assert response.headers.get("X-Cache") == "MISS"
assert response.headers.get("Cache-Control", "") == ""

etag = response.headers.get("ETag")
assert etag is not None
assert etag == _default_hash_fn(response.json())

# run again with If-None-Match
response = client.get("/items/1", headers={"If-None-Match": "W/123"})

assert response.status_code == 200
assert response.headers.get("X-Cache") == "HIT"
assert response.headers.get("Cache-Control", "") == ""
assert response.headers.get("ETag") == etag
52 changes: 52 additions & 0 deletions ultra_cache/cache_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Self


class CacheControl:
REQUEST_ONLY_KEYS = ["max-stale", "min-fresh", "only-if-cached"]

def __init__(self, parts: dict[str, str]) -> None:
self.parts: dict[str, str | None] = parts

def set(self, key: str, value: str) -> None:
self.parts[key] = value

def get(self, key: str) -> str | None:
return self.parts.get(key, None)

def setdefault(self, key: str, value: str) -> None:
self.parts.setdefault(key, value)

@classmethod
def from_string(cls, cache_control: str | None) -> Self:
if cache_control is None:
return cls({})
return cls(
{
x.split("=")[0].strip(): x.split("=")[1].strip() if "=" in x else None
for x in cache_control.lower().split(",")
}
)

@property
def max_age(self) -> int | None:
value = self.parts.get("max-age", None)
if value is None:
return None
return int(value)

@property
def no_cache(self) -> bool:
return "no-cache" in self.parts

@property
def no_store(self) -> bool:
return "no-store" in self.parts

def to_response_header(self) -> str:
return ", ".join(
[
f"{k}={v}"
for k, v in self.parts.items()
if k not in self.REQUEST_ONLY_KEYS
]
)
61 changes: 47 additions & 14 deletions ultra_cache/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import anyio
import anyio.to_thread

from dict_hash import dict_hash
from ultra_cache.build_cache_key import BuildCacheKey, DefaultBuildCacheKey
from ultra_cache.cache_control import CacheControl
from ultra_cache.main import get_storage
from ultra_cache.storage.base import BaseStorage
from fastapi import Request, Response
Expand All @@ -28,6 +29,13 @@ def _extract_param_of_type(
return None


def _default_hash_fn(x: Any) -> str:
if isinstance(x, dict) or isinstance(x, list):
return "W/" + str(dict_hash(x, maximal_recursion=10))

return "W/" + str(hash(x))


def _extract(
param: inspect.Parameter | None, args: tuple[S1, ...], kwargs: dict[str, S2]
) -> tuple[tuple[S1, ...], dict[str, S2]]:
Expand All @@ -48,10 +56,19 @@ def _extract(
return (args_copy, kwargs)


def _does_etag_match(etag: str, if_none_match: str | None) -> bool:
if if_none_match is not None and (
if_none_match == "*" or any(etag == x.strip() for x in if_none_match.split(","))
):
return True
return False


def cache(
ttl: int | float | None = None,
build_cache_key: BuildCacheKey = DefaultBuildCacheKey(),
storage: BaseStorage | None = None,
hash_fn: Callable[[Any], str] = _default_hash_fn,
):
def _wrapper(
func: Callable[P, Union[R, Coroutine[R, Any, Any]]],
Expand Down Expand Up @@ -84,14 +101,10 @@ async def _decorator(*args: P.args, **kwargs: P.kwargs):
request: Request = kwargs.get(request_param.name)
response: Response = kwargs.get(response_param.name)

cache_control = request.headers.get("Cache-Control", None)

no_cache = False
no_store = False

if cache_control:
no_cache = "no-cache" in cache_control.lower()
no_store = "no-store" in cache_control.lower()
cache_control = CacheControl.from_string(
request.headers.get("cache-control", None)
)
if_none_match = request.headers.get("if-none-match", None)

args_for_key, kwargs_for_key = _extract(
response_param, *(_extract(request_param, args, kwargs))
Expand All @@ -102,14 +115,26 @@ async def _decorator(*args: P.args, **kwargs: P.kwargs):
storage = get_storage()

cached = None
if not no_cache:
if not cache_control.no_cache:
cached = await storage.get(key)

if ttl:
cache_control.setdefault("max-age", ttl)

response.headers["Cache-Control"] = cache_control.to_response_header()

if cached is not None:
response.headers["X-Cache"] = "HIT"
return cached
response.headers["ETag"] = hash_fn(cached)

response.headers["X-Cache"] = "MISS"
if request.method in ["HEAD", "GET"]:
if _does_etag_match(response.headers["ETag"], if_none_match):
response.status_code = 304
return

return cached
else:
response.headers["X-Cache"] = "MISS"

if original_request_param is None:
kwargs.pop("request")
Expand All @@ -122,8 +147,16 @@ async def _decorator(*args: P.args, **kwargs: P.kwargs):
else:
output = await anyio.to_thread.run_sync(partial(func, *args, **kwargs))

if not no_store:
await storage.save(key=key, value=output, ttl=ttl)
response.headers["ETag"] = hash_fn(output)
if request.method in ["HEAD", "GET"]:
if _does_etag_match(response.headers["ETag"], if_none_match):
response.status_code = 304
return

if not cache_control.no_store:
await storage.save(
key=key, value=output, ttl=cache_control.max_age or ttl
)

return output

Expand Down
3 changes: 3 additions & 0 deletions ultra_cache/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ async def save(self, key: K, value: V, ttl: int | float | None = None) -> None:

@abstractmethod
async def get(self, key: K) -> V | None: ...

@abstractmethod
async def clear(self) -> None: ...
3 changes: 3 additions & 0 deletions ultra_cache/storage/inmemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,6 @@ async def get(self, key: K) -> V | None:
return None

return item.value

async def clear(self) -> None:
self.storage = {}