From f45b1bab1e79338943ce2d1957435f3d7c95e63a Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 8 Aug 2023 17:31:52 -0400 Subject: [PATCH 1/2] x --- libs/langchain/langchain/storage/redis.py | 112 ++++++++++++++++++ .../integration_tests/storage/__init__.py | 0 .../integration_tests/storage/test_redis.py | 104 ++++++++++++++++ .../tests/unit_tests/storage/test_redis.py | 11 ++ 4 files changed, 227 insertions(+) create mode 100644 libs/langchain/langchain/storage/redis.py create mode 100644 libs/langchain/tests/integration_tests/storage/__init__.py create mode 100644 libs/langchain/tests/integration_tests/storage/test_redis.py create mode 100644 libs/langchain/tests/unit_tests/storage/test_redis.py diff --git a/libs/langchain/langchain/storage/redis.py b/libs/langchain/langchain/storage/redis.py new file mode 100644 index 0000000000000..be8a49c7fc74b --- /dev/null +++ b/libs/langchain/langchain/storage/redis.py @@ -0,0 +1,112 @@ +from typing import Any, Iterator, List, Optional, Sequence, Tuple + +from langchain.schema import BaseStore + + +class RedisStore(BaseStore[str, bytes]): + """BaseStore implementation using Redis as the underlying store. + + Examples: + Create a RedisStore instance and perform operations on it: + + .. code-block:: python + + # Instantiate the RedisStore with a Redis connection + from langchain.storage import RedisStore + from langchain.vectorstores.redis import get_client + + client = get_client('redis://localhost:6379') + redis_store = RedisStore(client) + + # Set values for keys + redis_store.mset([("key1", b"value1"), ("key2", b"value2")]) + + # Get values for keys + values = redis_store.mget(["key1", "key2"]) + # [b"value1", b"value2"] + + # Delete keys + redis_store.mdelete(["key1"]) + + # Iterate over keys + for key in redis_store.yield_keys(): + print(key) + """ + + def __init__( + self, client: Any, *, ttl: Optional[int] = None, namespace: Optional[str] = None + ) -> None: + """Initialize the RedisStore with a Redis connection. + + Args: + client: A Redis connection instance + ttl: time to expire keys in seconds if provided, + if None keys will never expire + namespace: if provided, all keys will be prefixed with this namespace + """ + try: + from redis import Redis + except ImportError as e: + raise ImportError( + "The RedisStore requires the redis library to be installed. " + "pip install redis" + ) from e + + if not isinstance(client, Redis): + raise TypeError( + f"Expected Redis client, got {type(client).__name__} instead." + ) + + self.client = client + + if not isinstance(ttl, int) and ttl is not None: + raise TypeError(f"Expected int or None, got {type(ttl)} instead.") + + self.ttl = ttl + self.namespace = namespace + self.namespace_delimiter = "/" + + def _get_prefixed_key(self, key: str) -> str: + """Get the key with the namespace prefix. + + Args: + key (str): The original key. + + Returns: + str: The key with the namespace prefix. + """ + if self.namespace: + return f"{self.namespace}{self.namespace_delimiter}{key}" + return key + + def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: + """Get the values associated with the given keys.""" + return self.client.mget([self._get_prefixed_key(key) for key in keys]) + + def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + """Set the given key-value pairs.""" + pipe = self.client.pipeline() + + for key, value in key_value_pairs: + pipe.set(self._get_prefixed_key(key), value, ex=self.ttl) + pipe.execute() + + def mdelete(self, keys: Sequence[str]) -> None: + """Delete the given keys.""" + _keys = [self._get_prefixed_key(key) for key in keys] + self.client.delete(*_keys) + + def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: + """Yield keys in the store.""" + if prefix: + pattern = self._get_prefixed_key(prefix) + else: + pattern = self._get_prefixed_key("*") + scan_iter = self.client.scan_iter(match=pattern) + for key in scan_iter: + decoded_key = key.decode("utf-8") + if self.namespace: + relative_key = decoded_key[len(self.namespace) + 1 :] + yield relative_key + else: + yield decoded_key diff --git a/libs/langchain/tests/integration_tests/storage/__init__.py b/libs/langchain/tests/integration_tests/storage/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/langchain/tests/integration_tests/storage/test_redis.py b/libs/langchain/tests/integration_tests/storage/test_redis.py new file mode 100644 index 0000000000000..cec22ae22624d --- /dev/null +++ b/libs/langchain/tests/integration_tests/storage/test_redis.py @@ -0,0 +1,104 @@ +"""Implement integration tests for Redis storage.""" +import os +import typing +import uuid +from typing import Any + +import pytest +import redis + +from langchain.storage.redis import RedisStore + +if typing.TYPE_CHECKING: + try: + from redis import Redis + except ImportError: + Redis = Any +else: + Redis = Any + + +pytest.importorskip("redis") + + +@pytest.fixture +def redis_client() -> Redis: + """Yield redis client.""" + # Using standard port, but protecting against accidental data loss + # by requiring a password. + # This fixture flushes the database! + # The only role of the password is to prevent users from accidentally + # deleting their data. + # The password should establish the identity of the server being. + port = 6379 + password = os.environ.get("REDIS_PASSWORD") or str(uuid.uuid4()) + password = None + client = redis.Redis(host="localhost", port=port, password=password, db=0) + try: + client.ping() + except redis.exceptions.ConnectionError: + pytest.skip( + "Redis server is not running or is not accessible. " + "Verify that credentials are correct. " + ) + # ATTENTION: This will delete all keys in the database! + client.flushdb() + return client + + +def test_mget(redis_client: Redis) -> None: + """Test mget method.""" + store = RedisStore(redis_client, ttl=None) + keys = ["key1", "key2"] + redis_client.mset({"key1": b"value1", "key2": b"value2"}) + result = store.mget(keys) + assert result == [b"value1", b"value2"] + + +def test_mset(redis_client: Redis) -> None: + """Test that multiple keys can be set.""" + store = RedisStore(redis_client, ttl=None) + key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] + store.mset(key_value_pairs) + result = redis_client.mget(["key1", "key2"]) + assert result == [b"value1", b"value2"] + + +def test_mdelete(redis_client: Redis) -> None: + """Test that deletion works as expected.""" + store = RedisStore(redis_client, ttl=None) + keys = ["key1", "key2"] + redis_client.mset({"key1": b"value1", "key2": b"value2"}) + store.mdelete(keys) + result = redis_client.mget(keys) + assert result == [None, None] + + +def test_yield_keys(redis_client: Redis) -> None: + store = RedisStore(redis_client, ttl=None) + redis_client.mset({"key1": b"value1", "key2": b"value2"}) + assert sorted(store.yield_keys()) == ["key1", "key2"] + assert sorted(store.yield_keys(prefix="key*")) == ["key1", "key2"] + assert sorted(store.yield_keys(prefix="lang*")) == [] + + +def test_namespace(redis_client: Redis) -> None: + """Test that a namespace is prepended to all keys properly.""" + store = RedisStore(redis_client, ttl=None, namespace="meow") + key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] + store.mset(key_value_pairs) + + assert sorted(redis_client.scan_iter("*")) == [ + b"meow/key1", + b"meow/key2", + ] + + store.mdelete(["key1"]) + + assert sorted(redis_client.scan_iter("*")) == [ + b"meow/key2", + ] + + assert list(store.yield_keys()) == ["key2"] + assert list(store.yield_keys(prefix="key*")) == ["key2"] + assert list(store.yield_keys(prefix="key1")) == [] diff --git a/libs/langchain/tests/unit_tests/storage/test_redis.py b/libs/langchain/tests/unit_tests/storage/test_redis.py new file mode 100644 index 0000000000000..ff882ed8df852 --- /dev/null +++ b/libs/langchain/tests/unit_tests/storage/test_redis.py @@ -0,0 +1,11 @@ +"""Light weight unit test that attempts to import RedisStore. + +The actual code is tested in integration tests. + +This test is intended to catch errors in the import process. +""" + + +def test_import_storage() -> None: + """Attempt to import storage modules.""" + from langchain.storage.redis import RedisStore # noqa From 4c45a8c48e44915f8fc9ce1d12982248ff725b7e Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 9 Aug 2023 13:04:35 -0400 Subject: [PATCH 2/2] x --- libs/langchain/langchain/storage/redis.py | 9 ++++++--- .../tests/integration_tests/storage/test_redis.py | 7 ++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/storage/redis.py b/libs/langchain/langchain/storage/redis.py index be8a49c7fc74b..900144aa2c161 100644 --- a/libs/langchain/langchain/storage/redis.py +++ b/libs/langchain/langchain/storage/redis.py @@ -1,4 +1,4 @@ -from typing import Any, Iterator, List, Optional, Sequence, Tuple +from typing import Any, Iterator, List, Optional, Sequence, Tuple, cast from langchain.schema import BaseStore @@ -81,7 +81,10 @@ def _get_prefixed_key(self, key: str) -> str: def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: """Get the values associated with the given keys.""" - return self.client.mget([self._get_prefixed_key(key) for key in keys]) + return cast( + List[Optional[bytes]], + self.client.mget([self._get_prefixed_key(key) for key in keys]), + ) def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: """Set the given key-value pairs.""" @@ -102,7 +105,7 @@ def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: pattern = self._get_prefixed_key(prefix) else: pattern = self._get_prefixed_key("*") - scan_iter = self.client.scan_iter(match=pattern) + scan_iter = cast(Iterator[bytes], self.client.scan_iter(match=pattern)) for key in scan_iter: decoded_key = key.decode("utf-8") if self.namespace: diff --git a/libs/langchain/tests/integration_tests/storage/test_redis.py b/libs/langchain/tests/integration_tests/storage/test_redis.py index cec22ae22624d..853028953f248 100644 --- a/libs/langchain/tests/integration_tests/storage/test_redis.py +++ b/libs/langchain/tests/integration_tests/storage/test_redis.py @@ -13,9 +13,11 @@ try: from redis import Redis except ImportError: - Redis = Any + # Ignoring mypy here to allow assignment of Any to Redis in the event + # that redis is not installed. + Redis = Any # type:ignore else: - Redis = Any + Redis = Any # type:ignore pytest.importorskip("redis") @@ -32,7 +34,6 @@ def redis_client() -> Redis: # The password should establish the identity of the server being. port = 6379 password = os.environ.get("REDIS_PASSWORD") or str(uuid.uuid4()) - password = None client = redis.Redis(host="localhost", port=port, password=password, db=0) try: client.ping()