Skip to content

Commit

Permalink
Add coro equivalent for taskcache
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeshardmind committed Nov 20, 2024
1 parent fd36783 commit 924531f
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 5 deletions.
135 changes: 135 additions & 0 deletions async_utils/corofunc_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2020-present Michael Hall
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import asyncio
from collections.abc import Callable, Coroutine, Hashable
from functools import partial, wraps
from typing import Any, ParamSpec, TypeVar

from ._cpython_stuff import make_key

__all__ = ("corocache", "lrucorocache")


P = ParamSpec("P")
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")


type CoroFunc[**P, T] = Callable[P, Coroutine[Any, Any, T]]


class LRU[K, V]:
def __init__(self, maxsize: int, /):
self.cache: dict[K, V] = {}
self.maxsize = maxsize

def get(self, key: K, default: T, /) -> V | T:
if key not in self.cache:
return default
self.cache[key] = self.cache.pop(key)
return self.cache[key]

def __getitem__(self, key: K, /) -> V:
self.cache[key] = self.cache.pop(key)
return self.cache[key]

def __setitem__(self, key: K, value: V, /):
self.cache[key] = value
if len(self.cache) > self.maxsize:
self.cache.pop(next(iter(self.cache)))

def remove(self, key: K) -> None:
self.cache.pop(key, None)


def corocache(
ttl: float | None = None,
) -> Callable[[CoroFunc[P, T]], CoroFunc[P, T]]:
"""Decorator to cache coroutine functions.
This is less powerful than the version in task_cache.py but may work better for
some cases where typing of libraries this interacts with is too restrictive.
Note: This uses the args and kwargs of the original coroutine function as a cache key.
This includes instances (self) when wrapping methods.
Consider not wrapping instance methods, but what those methods call when feasible in cases where this may matter.
The ordering of args and kwargs matters."""

def wrapper(coro: CoroFunc[P, T]) -> CoroFunc[P, T]:
internal_cache: dict[Hashable, asyncio.Task[T]] = {}

async def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
key = make_key(args, kwargs)
try:
return await internal_cache[key]
except KeyError:
internal_cache[key] = task = asyncio.create_task(coro(*args, **kwargs))
if ttl is not None:
# This results in internal_cache.pop(key, task) later
# while avoiding a late binding issue with a lambda instead
call_after_ttl = partial(
asyncio.get_running_loop().call_later,
ttl,
internal_cache.pop,
key,
)
task.add_done_callback(call_after_ttl)
return await task

return wrapped

return wrapper


def _lru_evict(ttl: float, cache: LRU[Hashable, Any], key: Hashable, _ignored_task: object) -> None:
asyncio.get_running_loop().call_later(ttl, cache.remove, key)


def lrucorocache(ttl: float | None = None, maxsize: int = 1024) -> Callable[[CoroFunc[P, T]], CoroFunc[P, T]]:
"""Decorator to cache coroutine functions.
This is less powerful than the version in task_cache.py but may work better for
some cases where typing of libraries this interacts with is too restrictive.
Note: This uses the args and kwargs of the original coroutine function as a cache key.
This includes instances (self) when wrapping methods.
Consider not wrapping instance methods, but what those methods call when feasible in cases where this may matter.
The ordering of args and kwargs matters.
tasks are evicted by LRU and ttl.
"""

def wrapper(coro: CoroFunc[P, T]) -> CoroFunc[P, T]:
internal_cache: LRU[Hashable, asyncio.Task[T]] = LRU(maxsize)

@wraps(coro)
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
key = make_key(args, kwargs)
try:
return await internal_cache[key]
except KeyError:
internal_cache[key] = task = asyncio.create_task(coro(*args, **kwargs))
if ttl is not None:
task.add_done_callback(partial(_lru_evict, ttl, internal_cache, key))
return await task

return wrapped

return wrapper
10 changes: 5 additions & 5 deletions async_utils/task_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import asyncio
from collections.abc import Callable, Coroutine, Hashable
from functools import partial
from typing import Any, Generic, ParamSpec, TypeVar
from functools import partial, wraps
from typing import Any, ParamSpec, TypeVar

from ._cpython_stuff import make_key

Expand All @@ -26,11 +26,9 @@

P = ParamSpec("P")
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")


class LRU(Generic[K, V]):
class LRU[K, V]:
def __init__(self, maxsize: int, /):
self.cache: dict[K, V] = {}
self.maxsize = maxsize
Expand Down Expand Up @@ -71,6 +69,7 @@ def taskcache(
def wrapper(coro: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, asyncio.Task[T]]:
internal_cache: dict[Hashable, asyncio.Task[T]] = {}

@wraps(coro, assigned=("__module__", "__name__", "__qualname__", "__doc__"))
def wrapped(*args: P.args, **kwargs: P.kwargs) -> asyncio.Task[T]:
key = make_key(args, kwargs)
try:
Expand Down Expand Up @@ -118,6 +117,7 @@ def lrutaskcache(
def wrapper(coro: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, asyncio.Task[T]]:
internal_cache: LRU[Hashable, asyncio.Task[T]] = LRU(maxsize)

@wraps(coro, assigned=("__module__", "__name__", "__qualname__", "__doc__"))
def wrapped(*args: P.args, **kwargs: P.kwargs) -> asyncio.Task[T]:
key = make_key(args, kwargs)
try:
Expand Down

0 comments on commit 924531f

Please sign in to comment.