Skip to content

Commit

Permalink
serializer as composition
Browse files Browse the repository at this point in the history
  • Loading branch information
Krukov committed Nov 1, 2024
1 parent 6d9c487 commit c68c4df
Show file tree
Hide file tree
Showing 33 changed files with 404 additions and 320 deletions.
35 changes: 22 additions & 13 deletions cashews/backends/diskcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,24 @@
from diskcache import Cache, FanoutCache

from cashews._typing import Key, Value
from cashews.serialize import SerializerMixin
from cashews.serialize import DEFAULT_SERIALIZER, Serializer
from cashews.utils import Bitarray

from .interface import NOT_EXIST, UNLIMITED, Backend


class _DiskCache(Backend):
class DiskCache(Backend):
def __init__(self, *args, directory=None, shards=8, **kwargs: Any) -> None:
serializer = kwargs.pop("serializer", DEFAULT_SERIALIZER)
self.__is_init = False
self._set_locks: dict[str, asyncio.Lock] = {}
self._sharded = shards > 1
if not self._sharded:
self._cache = Cache(directory=directory, **kwargs)
else:
self._cache = FanoutCache(directory=directory, shards=shards, **kwargs)
super().__init__(**kwargs)
super().__init__(serializer=serializer, **kwargs)
self._serializer: Serializer

async def init(self):
self.__is_init = True
Expand All @@ -46,6 +48,7 @@ async def set(
expire: float | None = None,
exist: bool | None = None,
) -> bool:
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
future = self._run_in_executor(self._set, key, value, expire, exist)
if exist is not None:
# we should have async lock until value real set
Expand All @@ -69,25 +72,34 @@ async def set_raw(self, key: Key, value: Any, **kwargs: Any):
return self._cache.set(key, value, **kwargs)

async def get(self, key: Key, default: Value | None = None) -> Value:
return await self._run_in_executor(self._cache.get, key, default)
value = await self._run_in_executor(self._cache.get, key, default)
return await self._serializer.decode(self, key=key, value=value, default=default)

async def get_raw(self, key: Key) -> Value:
return self._cache.get(key)

async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value]:
return await self._run_in_executor(self._get_many, keys, default)
async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value | None, ...]:
if not keys:
return ()
values = await self._run_in_executor(self._get_many, keys, default)
values = await asyncio.gather(
*[self._serializer.decode(self, key=key, value=value, default=default) for key, value in zip(keys, values)]
)
return tuple(None if isinstance(value, Bitarray) else value for value in values)

def _get_many(self, keys: list[Key], default: Value | None = None):
values = []
for key in keys:
val = self._cache.get(key, default=default)
if isinstance(val, Bitarray):
val = None
values.append(val)
return values

async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None):
return await self._run_in_executor(self._set_many, pairs, expire)
_pairs = {}
for key, value in pairs.items():
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
_pairs[key] = value
return await self._run_in_executor(self._set_many, _pairs, expire)

def _set_many(self, pairs: Mapping[Key, Value], expire: float | None = None):
for key, value in pairs.items():
Expand Down Expand Up @@ -215,6 +227,7 @@ async def is_locked(
return await self.exists(key)

async def unlock(self, key: Key, value: Value) -> bool:
value = await self._serializer.encode(self, key=key, value=value, expire=None)
return await self._run_in_executor(self._unlock, key, value)

def _unlock(self, key: Key, value: Value) -> bool:
Expand Down Expand Up @@ -269,7 +282,3 @@ async def set_pop(self, key: Key, count: int = 100) -> Iterable[str]:

async def get_keys_count(self) -> int:
return await self._run_in_executor(lambda: len(self._cache))


class DiskCache(SerializerMixin, _DiskCache):
pass
5 changes: 4 additions & 1 deletion cashews/backends/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from cashews.commands import ALL, Command
from cashews.exceptions import CacheBackendInteractionError, LockedError
from cashews.serialize import Serializer

if TYPE_CHECKING: # pragma: no cover
from cashews._typing import Default, Key, OnRemoveCallback, Value
Expand Down Expand Up @@ -226,8 +227,10 @@ def enable(self, *cmds: Command) -> None:


class Backend(ControlMixin, _BackendInterface, metaclass=ABCMeta):
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args, serializer: Serializer | None = None, **kwargs) -> None:
super().__init__()
self._id = uuid.uuid4().hex
self._serializer = serializer
self._on_remove_callbacks: list[OnRemoveCallback] = []

def on_remove_callback(self, callback: OnRemoveCallback) -> None:
Expand Down
22 changes: 13 additions & 9 deletions cashews/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from copy import copy
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Mapping, overload

from cashews.serialize import SerializerMixin
from cashews.utils import Bitarray, get_obj_size

from .interface import NOT_EXIST, UNLIMITED, Backend
Expand All @@ -22,7 +21,7 @@
_missed = object()


class _Memory(Backend):
class Memory(Backend):
"""
Inmemory backend lru with ttl
"""
Expand Down Expand Up @@ -74,17 +73,22 @@ async def set(
) -> bool:
if exist is not None and (key in self.store) is not exist:
return False
if self._serializer:
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
self._set(key, value, expire)
return True

async def set_raw(self, key: Key, value: Value, **kwargs: Any) -> None:
self.store[key] = value
self.store[key] = (None, value)

async def get(self, key: Key, default: Value | None = None) -> Value:
return await self._get(key, default=default)

async def get_raw(self, key: Key) -> Value:
return self.store.get(key)
val = self.store.get(key)
if val:
return val[1]
return None

async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value | None, ...]:
values = []
Expand All @@ -97,6 +101,8 @@ async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Valu

async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None):
for key, value in pairs.items():
if self._serializer:
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
self._set(key, value, expire)

async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]: # type: ignore
Expand Down Expand Up @@ -200,7 +206,9 @@ async def _get(self, key: Key, default: Default | None = None) -> Value | None:
if expire_at and expire_at < time.time():
await self._delete(key)
return default
return value
if not self._serializer:
return value
return await self._serializer.decode(self, key=key, value=value, default=default)

async def _key_exist(self, key: Key) -> bool:
return (await self._get(key, default=_missed)) is not _missed
Expand Down Expand Up @@ -279,7 +287,3 @@ async def close(self):
del self.__remove_expired_stop
self.__remove_expired_stop = None
self.__is_init = False


class Memory(SerializerMixin, _Memory):
pass
7 changes: 2 additions & 5 deletions cashews/backends/redis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from cashews.picklers import DEFAULT_PICKLE
from cashews.serialize import SerializerMixin

from .backend import _Redis

__all__ = ["Redis"]


class Redis(SerializerMixin, _Redis):
pickle_type = DEFAULT_PICKLE
class Redis(_Redis):
pass
17 changes: 11 additions & 6 deletions cashews/backends/redis/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from cashews._typing import Key, Value
from cashews.backends.interface import Backend
from cashews.serialize import DEFAULT_SERIALIZER, Serializer

from .client import Redis, SafePipeline, SafeRedis

Expand Down Expand Up @@ -76,7 +77,8 @@ def __init__(
self._kwargs = kwargs
self._address = address
self.__is_init = False
super().__init__()
super().__init__(serializer=kwargs.pop("serializer", None))
self._serializer: Serializer = self._serializer or DEFAULT_SERIALIZER

@property
def is_init(self) -> bool:
Expand Down Expand Up @@ -105,6 +107,7 @@ async def set(
expire: float | None = None,
exist=None,
) -> bool:
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
nx = xx = False
if exist is True:
xx = True
Expand All @@ -118,6 +121,7 @@ async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None
px = int(expire * 1000) if expire else None
async with self._pipeline as pipe:
for key, value in pairs.items():
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
await pipe.set(key, value, px=px)
await pipe.execute()

Expand Down Expand Up @@ -211,23 +215,24 @@ async def get_size(self, key: Key) -> int:

async def get(self, key: Key, default: Value | None = None) -> Value:
value = await self._client.get(key)
return self._transform_value(value, default)
return await self._transform_value(key, value, default)

async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value | None, ...]:
if not keys:
return ()
values = await self._client.mget(*keys)
if values is None:
return tuple([default] * len(keys))
return tuple(self._transform_value(value, default) for value in values)
return tuple(
await asyncio.gather(*[self._transform_value(key, value, default) for key, value in zip(keys, values)])
)

@staticmethod
def _transform_value(value: bytes | None, default: Value | None):
async def _transform_value(self, key: Key, value: bytes | None, default: Value | None):
if value is None:
return default
if value.isdigit():
return int(value)
return value
return await self._serializer.decode(self, key=key, value=value, default=default)

async def incr(self, key: Key, value: int = 1, expire: float | None = None) -> int:
if not expire:
Expand Down
7 changes: 4 additions & 3 deletions cashews/backends/redis/client_side.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,12 @@ def __init__(
self._expire_for_recently_update = 5
self._listen_started = asyncio.Event()
self.__listen_stop = asyncio.Event()
super().__init__(*args, suppress=suppress, **kwargs)
kwargs["suppress"] = suppress
super().__init__(*args, **kwargs)

async def init(self):
self._listen_started = asyncio.Event()
self.__listen_stop = asyncio.Event()
self._listen_started.clear()
self.__listen_stop.clear()
await self._local_cache.init()
await self._recently_update.init()
await super().init()
Expand Down
2 changes: 2 additions & 0 deletions cashews/backends/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ class TransactionBackend(Backend):
"_local_cache",
"_to_delete",
"__disable",
"_id",
]

def __init__(self, backend: Backend):
self._backend = backend
self._local_cache = Memory()
self._to_delete: set[Key] = set()
super().__init__()
self._id = backend._id

def _key_is_delete(self, key: Key) -> bool:
if key in self._to_delete:
Expand Down
12 changes: 1 addition & 11 deletions cashews/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,16 @@ class Command(Enum):
DELETE_MANY = "delete_many"
DELETE_MATCH = "delete_match"

EXIST = "exists"
EXISTS = "exists"
SCAN = "scan"
INCR = "incr"
EXPIRE = "expire"
GET_EXPIRE = "get_expire"
CLEAR = "clear"

SET_LOCK = "set_lock"
UNLOCK = "unlock"
IS_LOCKED = "is_locked"

GET_BITS = "get_bits"
INCR_BITS = "incr_bits"

SLICE_INCR = "slice_incr"

SET_ADD = "set_add"
SET_REMOVE = "set_remove"
SET_POP = "set_pop"

PING = "ping"
GET_SIZE = "get_size"
GET_KEYS_COUNT = "get_keys_count"
Expand Down
2 changes: 1 addition & 1 deletion cashews/contrib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def set_callback(key: str, result: Any):
_data = None
else:
_key = calls[0][0]
_data = calls[0][1][0]["value"]
_data = calls[0][1]["value"]
_etag = await self._set_etag(_key, _data)
return self._response_etag(response, _etag, request_etag)

Expand Down
2 changes: 1 addition & 1 deletion cashews/decorators/cache/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, previous_level=0, unset_token=None):
self._previous_level = previous_level

def _set(self, key: Key, **kwargs: Any) -> None:
self._value.append((key, [kwargs]))
self._value.append((key, kwargs))

@property
def calls(self):
Expand Down
19 changes: 14 additions & 5 deletions cashews/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,20 @@ async def _middleware(call: AsyncCallable_T, cmd: Command, backend: Backend, *ar

def memory_limit(min_bytes: int = 0, max_bytes: int | None = None) -> Middleware:
async def _middleware(call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs) -> Result_T | None:
if cmd != Command.SET:
return await call(*args, **kwargs)
value_size = get_obj_size(kwargs["value"])
if max_bytes and value_size > max_bytes or value_size < min_bytes:
return None
if cmd == Command.SET_MANY:
pairs = {}
for key, value in kwargs["pairs"].items():
value_size = get_obj_size(value)
if max_bytes and value_size > max_bytes or value_size < min_bytes:
continue
pairs[key] = value
if not pairs:
return None
kwargs["pairs"] = pairs
elif cmd == Command.SET:
value_size = get_obj_size(kwargs["value"])
if max_bytes and value_size > max_bytes or value_size < min_bytes:
return None
return await call(*args, **kwargs)

return _middleware
9 changes: 6 additions & 3 deletions cashews/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,13 @@ def generate_key_template(func: Callable, exclude_parameters: Container = ()) ->

class _Star:
def __getattr__(self, item):
return _Star()
return self

def __getitem__(self, item):
return _Star()
return self

def __call__(self, *args, **kwargs):
return "*"


def _check_key_params(key: KeyOrTemplate, func_params: Iterable[str]):
Expand Down Expand Up @@ -142,7 +145,7 @@ def _get_func_signature(func: Callable):


def _get_call_values(func: Callable, args: Args, kwargs: Kwargs):
if len(args) == 0:
if not args:
_kwargs = {**kwargs}
for name, parameter in _get_func_signature(func).parameters.items():
if parameter.kind != inspect.Parameter.VAR_KEYWORD and name in _kwargs:
Expand Down
Loading

0 comments on commit c68c4df

Please sign in to comment.