diff --git a/benchmarks/bench_json_schema.py b/benchmarks/bench_json_schema.py index c85982c7..6e5b9aa8 100644 --- a/benchmarks/bench_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -1,4 +1,3 @@ -from outlines_core.caching import cache_disabled from outlines_core.fsm.guide import RegexGuide from outlines_core.fsm.json_schema import build_regex_from_schema @@ -71,11 +70,9 @@ def setup(self, schema_name): self.schema = schemas[schema_name] ensure_numba_compiled(self.tokenizer) - @cache_disabled() def time_json_schema_to_regex(self, schema_name): build_regex_from_schema(self.schema) - @cache_disabled() def time_json_schema_to_fsm(self, schema_name): regex = build_regex_from_schema(self.schema) RegexGuide(regex, self.tokenizer) diff --git a/benchmarks/bench_numba_compile.py b/benchmarks/bench_numba_compile.py index 35edc953..ae62702a 100644 --- a/benchmarks/bench_numba_compile.py +++ b/benchmarks/bench_numba_compile.py @@ -3,7 +3,6 @@ import interegular import numba -from outlines_core.caching import cache_disabled from outlines_core.fsm import regex from .common import setup_tokenizer @@ -29,6 +28,5 @@ def mock_njit(*args, **kwargs): def teardown(self): numba.njit = self.original_njit - @cache_disabled() def time_compile_numba(self): self.regex.create_fsm_index_tokenizer(self.regex_pattern, self.tokenizer) diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index eeb1f983..5d505a48 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -1,4 +1,3 @@ -from outlines_core.caching import cache_disabled from outlines_core.fsm.guide import RegexGuide from .common import ensure_numba_compiled, setup_tokenizer @@ -24,7 +23,6 @@ def setup(self, pattern_name): ensure_numba_compiled(self.tokenizer) self.pattern = regex_samples[pattern_name] - @cache_disabled() def time_regex_to_guide(self, pattern_name): RegexGuide(self.pattern, self.tokenizer) @@ -37,6 +35,5 @@ def setup(self, pattern_name): ensure_numba_compiled(self.tokenizer) self.pattern = regex_samples[pattern_name] - @cache_disabled() def peakmem_regex_to_guide(self, pattern_name): RegexGuide(self.pattern, self.tokenizer) diff --git a/src/outlines_core/__init__.py b/src/outlines_core/__init__.py index 1a322f2b..ed2d5a9c 100644 --- a/src/outlines_core/__init__.py +++ b/src/outlines_core/__init__.py @@ -1,9 +1,4 @@ """Outlines is a Generative Model Programming Framework.""" import outlines_core.models -from outlines_core.caching import clear_cache, disable_cache, get_cache -__all__ = [ - "clear_cache", - "disable_cache", - "get_cache", -] +__all__ = ["models"] diff --git a/src/outlines_core/caching.py b/src/outlines_core/caching.py deleted file mode 100644 index 92a08415..00000000 --- a/src/outlines_core/caching.py +++ /dev/null @@ -1,181 +0,0 @@ -import asyncio -import contextlib -import functools -import os -from typing import Callable, Optional - -import cloudpickle -from diskcache import Cache, Disk -from diskcache.core import ENOVAL, UNKNOWN, args_to_key, full_name - -_caching_enabled = True - - -class CloudpickleDisk(Disk): - def __init__(self, directory, compress_level=1, **kwargs): - self.compress_level = compress_level - super().__init__(directory, **kwargs) - - def put(self, key): - data = cloudpickle.dumps(key) - return super().put(data) - - def get(self, key, raw): - data = super().get(key, raw) - return cloudpickle.loads(data) - - def store(self, value, read, key=UNKNOWN): - if not read: - value = cloudpickle.dumps(value) - return super().store(value, read, key=key) - - def fetch(self, mode, filename, value, read): - data = super().fetch(mode, filename, value, read) - if not read: - data = cloudpickle.loads(data) - return data - - -@functools.lru_cache(1) -def get_cache(): - """Get the context object that contains previously-computed return values. - - The cache is used to avoid unnecessary computations and API calls, which can - be long and expensive for large models. - - The cache directory defaults to `HOMEDIR/.cache/outlines`, but this choice - can be overridden by the user by setting the value of the `OUTLINES_CACHE_DIR` - environment variable. - - """ - from outlines_core._version import ( - __version__ as outlines_core_version, # type: ignore - ) - - home_dir = os.path.expanduser("~") - cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines-core") - memory = Cache( - cache_dir, - eviction_policy="none", - cull_limit=0, - disk=CloudpickleDisk, - ) - - # ensure if version upgrade occurs, old cache is pruned - if outlines_core_version != memory.get("__version__"): - memory.clear() - memory["__version__"] = outlines_core_version - - return memory - - -def cache(expire: Optional[float] = None, typed=False, ignore=()): - """Caching decorator for memoizing function calls. - - The cache key is created based on the values returned by the key_function callable - if provided or based on the arguments of the decorated function directly otherwise - - This is based on `diskcache`'s `memoize`. - - Parameters - ---------- - expire - Seconds until arguments expire. - typed - Cache different types separately. - ignore - Positional or keyword arguments to ignore. - - Returns - ------- - A decorator function that can be applied to other functions. - """ - - def decorator(cached_function: Callable): - memory = get_cache() - - base = (full_name(cached_function),) - - if asyncio.iscoroutinefunction(cached_function): - - async def wrapper(*args, **kwargs): - if not _caching_enabled: - return await cached_function(*args, **kwargs) - - cache_key = wrapper.__cache_key__(*args, **kwargs) - result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True) - - if result is ENOVAL: - result = await cached_function(*args, **kwargs) - wrapper.__memory__.set(cache_key, result, expire, retry=True) - - return result - - else: - - def wrapper(*args, **kwargs): - if not _caching_enabled: - return cached_function(*args, **kwargs) - - cache_key = wrapper.__cache_key__(*args, **kwargs) - result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True) - - if result is ENOVAL: - result = cached_function(*args, **kwargs) - wrapper.__memory__.set(cache_key, result, expire, retry=True) - - return result - - def __cache_key__(*args, **kwargs): - """Make key for cache given function arguments.""" - return args_to_key(base, args, kwargs, typed, ignore) - - wrapper.__cache_key__ = __cache_key__ # type: ignore - wrapper.__memory__ = memory # type: ignore - wrapper.__wrapped__ = cached_function # type: ignore - - return wrapper - - return decorator - - -def disable_cache(): - """Disable the cache for this session. - - Generative models output different results each time they are called when - sampling. This can be a desirable property for some workflows, in which case - one can call `outlines.call.disable` to disable the cache for the session. - - This function does not delete the cache, call `outlines.cache.clear` - instead. It also does not overwrite the cache with the values returned - during the session. - - Example - ------- - - `outlines.cache.disable` should be called right after importing outlines: - - >>> import outlines.caching as cache - >>> cache.disable_cache() - - """ - global _caching_enabled - _caching_enabled = False - - -def clear_cache(): - """Erase the cache completely.""" - memory = get_cache() - memory.clear() - - -@contextlib.contextmanager -def cache_disabled(): - # outlines.caching._caching_enabled - global _caching_enabled - original_state = _caching_enabled - _caching_enabled = False - try: - yield - finally: - _caching_enabled = original_state diff --git a/src/outlines_core/fsm/guide.py b/src/outlines_core/fsm/guide.py index 3773505d..5bfdf81b 100644 --- a/src/outlines_core/fsm/guide.py +++ b/src/outlines_core/fsm/guide.py @@ -14,7 +14,6 @@ import interegular import torch -from outlines_core.caching import cache from outlines_core.fsm.regex import ( create_fsm_index_tokenizer, make_byte_level_fsm, @@ -114,7 +113,6 @@ def copy(self): return self -@cache() def create_states_mapping( regex_string: str, tokenizer: "Tokenizer", diff --git a/tests/test_cache.py b/tests/test_cache.py deleted file mode 100644 index 766d97ad..00000000 --- a/tests/test_cache.py +++ /dev/null @@ -1,190 +0,0 @@ -import os -import tempfile -import unittest - -import diskcache -import pytest - - -@pytest.fixture -def refresh_environment(): - """Refresh the test environment. - - This deletes any reference to `outlines` in the modules dictionary and unsets the - `OUTLINES_CACHE_DIR` environment variable if set. This is necessary because we - are using a module variable to hold the cache. - - """ - import sys - - for key in list(sys.modules.keys()): - if "outlines" in key: - del sys.modules[key] - - try: - del os.environ["OUTLINES_CACHE_DIR"] - except KeyError: - pass - - -@pytest.fixture -def test_cache(refresh_environment): - """Initialize a temporary cache and delete it after the test has run.""" - with tempfile.TemporaryDirectory() as tempdir: - os.environ["OUTLINES_CACHE_DIR"] = tempdir - import outlines_core - - memory = outlines_core.get_cache() - assert memory.directory == tempdir - - yield outlines_core.caching.cache() - - memory.clear() - - -def test_get_cache(test_cache): - import outlines_core - - memory = outlines_core.get_cache() - assert isinstance(memory, diskcache.Cache) - - # If the cache is enabled then the size - # of `store` should not increase the - # second time `f` is called. - store = list() - - @test_cache - def f(x): - store.append(1) - return x - - f(1) - store_size = len(store) - - f(1) - assert len(store) == store_size - - f(2) - assert len(store) == store_size + 1 - - -def test_disable_cache(test_cache): - """Make sure that we can disable the cache.""" - import outlines_core - - outlines_core.disable_cache() - - # If the cache is disabled then the size - # of `store` should increase every time - # `f` is called. - store = list() - - @test_cache - def f(x): - store.append(1) - return x - - f(1) - store_size = len(store) - f(1) - assert len(store) == store_size + 1 - - -def test_clear_cache(test_cache): - """Make sure that we can clear the cache.""" - import outlines_core - - store = list() - - @test_cache - def f(x): - store.append(1) - return x - - # The size of `store` does not increase since - # `f` is cached after the first run. - f(1) - store_size = len(store) - f(1) - assert len(store) == store_size - - # The size of `store` should increase if we call `f` - # after clearing the cache. - outlines_core.clear_cache() - f(1) - assert len(store) == store_size + 1 - - -def test_version_upgrade_cache_invalidate(test_cache, mocker): - """Ensure we can change the signature of a cached function if we upgrade the version""" - - import outlines_core.caching - - def simulate_restart_outlines(): - # clearing in-memory lru_cache which returns the diskcache in - # order to simulate a reload, we're not clearing the diskcache itself - outlines_core.caching.get_cache.cache_clear() - - mocker.patch("outlines_core._version.__version__", new="0.0.0") - simulate_restart_outlines() - - # initialize cache with signature of Tuple-of-3 - @test_cache - def foo(): - return (1, 2, 3) - - a, b, c = foo() - - # "restart" outlines without upgrading version - simulate_restart_outlines() - - # change signature to Tuple-of-2 - @test_cache - def foo(): - return (1, 2) - - # assert without version upgrade, old, bad cache is used - with pytest.raises(ValueError): - a, b = foo() - - # "restart" outlines WITH version upgrade - mocker.patch("outlines_core._version.__version__", new="0.0.1") - simulate_restart_outlines() - - # change signature to Tuple-of-2 - @test_cache - def foo(): - return (1, 2) - - # assert with version upgrade, old cache is invalidated and new cache is used - a, b = foo() - - -def test_cache_disabled_decorator(test_cache): - """Ensure cache can be disabled in a local scope""" - - from outlines_core.caching import cache_disabled - - mock = unittest.mock.MagicMock() - - @test_cache - def fn(): - mock() - return 1 - - # first call isn't cached - fn() - assert mock.call_count == 1 - - # second call doesn't run fn, uses cache - fn() - assert mock.call_count == 1 - - # cache_disabled decorator disables cache within scope - with cache_disabled(): - fn() - assert mock.call_count == 2 # called once in cache_disabled scope - - # scope has exited, cache is enabled again - fn() - assert mock.call_count == 2