Skip to content

Commit

Permalink
Move embedding config to base
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Nov 26, 2024
1 parent b1fe171 commit 9c25bd1
Show file tree
Hide file tree
Showing 7 changed files with 570 additions and 75 deletions.
29 changes: 20 additions & 9 deletions libs/checkpoint-postgres/langgraph/store/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections.abc import AsyncIterator, Iterable, Sequence
from contextlib import asynccontextmanager
from typing import Any, Callable, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast

import orjson
from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline, Capabilities
Expand All @@ -19,18 +19,22 @@
Result,
SearchItem,
SearchOp,
ensure_embeddings,
)
from langgraph.store.base.batch import AsyncBatchedBaseStore
from langgraph.store.postgres.base import (
BasePostgresStore,
EmbeddingConfig,
PoolConfig,
PostgresEmbeddingConfig,
Row,
_decode_ns_bytes,
_group_ops,
_row_to_item,
)

if TYPE_CHECKING:
from langchain_core.embeddings import Embeddings

logger = logging.getLogger(__name__)


Expand All @@ -51,7 +55,7 @@ def __init__(
deserializer: Optional[
Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]
] = None,
embedding: Optional[EmbeddingConfig] = None,
embedding: Optional[PostgresEmbeddingConfig] = None,
) -> None:
if isinstance(conn, AsyncConnectionPool) and pipe is not None:
raise ValueError(
Expand All @@ -65,6 +69,13 @@ def __init__(
self.loop = asyncio.get_running_loop()
self.supports_pipeline = Capabilities().has_pipeline()
self.embedding_config = embedding
if self.embedding_config:
self.embeddings: Optional[Embeddings] = ensure_embeddings(
self.embedding_config.get("embed"),
aembed=self.embedding_config.get("aembed"),
)
else:
self.embeddings = None

async def abatch(self, ops: Iterable[Op]) -> list[Result]:
grouped_ops, num_ops = _group_ops(ops)
Expand Down Expand Up @@ -142,7 +153,7 @@ async def _batch_put_ops(
) -> None:
queries, embedding_request = self._prepare_batch_PUT_queries(put_ops)
if embedding_request:
if self.embedding_config is None:
if self.embeddings is None:
# Should not get here since the embedding config is required
# to return an embedding_request above
raise ValueError(
Expand All @@ -152,7 +163,7 @@ async def _batch_put_ops(
)
query, txt_params = embedding_request
# Update the params to replace the raw text with the vectors
vectors = await self.embedding_config["embed"].aembed_documents(
vectors = await self.embeddings.aembed_documents(
[param[-1] for param in txt_params]
)
queries.extend(
Expand All @@ -173,8 +184,8 @@ async def _batch_search_ops(
) -> None:
queries, embedding_requests = self._prepare_batch_search_queries(search_ops)

if embedding_requests and self.embedding_config:
embeddings = await self.embedding_config["embed"].aembed_documents(
if embedding_requests and self.embeddings:
embeddings = await self.embeddings.aembed_documents(
[query for _, query in embedding_requests]
)
for (idx, _), embedding in zip(embedding_requests, embeddings):
Expand Down Expand Up @@ -261,7 +272,7 @@ async def from_conn_string(
*,
pipeline: bool = False,
pool_config: Optional[PoolConfig] = None,
embedding: Optional[EmbeddingConfig] = None,
embedding: Optional[PostgresEmbeddingConfig] = None,
) -> AsyncIterator["AsyncPostgresStore"]:
"""Create a new AsyncPostgresStore instance from a connection string.
Expand All @@ -271,7 +282,7 @@ async def from_conn_string(
pool_config (Optional[PoolConfig]): Configuration for the connection pool.
If provided, will create a connection pool and use it instead of a single connection.
This overrides the `pipeline` argument.
embedding (Optional[EmbeddingConfig]): The embedding config.
embedding (Optional[PostgresEmbeddingConfig]): The embedding config.
Returns:
AsyncPostgresStore: A new AsyncPostgresStore instance.
Expand Down
Loading

0 comments on commit 9c25bd1

Please sign in to comment.