Skip to content

Commit

Permalink
Back to the basics
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Nov 28, 2024
1 parent 04dbb5e commit 5101d1d
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 179 deletions.
6 changes: 5 additions & 1 deletion libs/checkpoint-postgres/langgraph/store/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from langgraph.store.base.batch import AsyncBatchedBaseStore
from langgraph.store.postgres.base import (
_PLACEHOLDER,
BasePostgresStore,
PoolConfig,
PostgresIndexConfig,
Expand Down Expand Up @@ -292,7 +293,10 @@ async def _batch_search_ops(
[query for _, query in embedding_requests]
)
for (idx, _), vector in zip(embedding_requests, vectors):
queries[idx][1][0] = vector
_paramslist = queries[idx][1]
for i in range(len(_paramslist)):
if _paramslist[i] is _PLACEHOLDER:
_paramslist[i] = vector

for (idx, _), (query, params) in zip(search_ops, queries):
await cur.execute(query, params)
Expand Down
210 changes: 87 additions & 123 deletions libs/checkpoint-postgres/langgraph/store/postgres/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,28 +99,14 @@ class Migration(NamedTuple):
"dims": lambda store: store.index_config["dims"],
"vector_type": lambda store: (
cast(PostgresIndexConfig, store.index_config)
.get("db_index_config", {})
.get("ann_index_config", {})
.get("vector_type", "vector")
),
},
),
Migration(
"""
CREATE INDEX IF NOT EXISTS store_vectors_embedding_idx ON store_vectors
USING %(index_type)s (embedding %(ops)s)%(index_params)s;
""",
params={
"index_type": lambda store: _get_index_params(store)[0],
"ops": lambda store: _get_vector_type_ops(store),
"index_params": lambda store: (
" WITH ("
+ ", ".join(f"{k}={v}" for k, v in _get_index_params(store)[1].items())
+ ")"
if _get_index_params(store)[1]
else ""
),
},
),
# TODO: Add an HNSW or IVFFlat index depending on config
# First must improve the search query when filtering by
# namespace
]

C = TypeVar("C", bound=Union[_pg_internal.Conn, _ainternal.Conn])
Expand Down Expand Up @@ -151,22 +137,9 @@ class PoolConfig(TypedDict, total=False):
"""


class DBIndexConfig(TypedDict, total=False):
class ANNIndexConfig(TypedDict, total=False):
"""Configuration for vector index in PostgreSQL store."""

kind: Literal["hnsw", "ivfflat"]
"""Type of index to use.
'hnsw': Hierarchical Navigable Small World index.
'ivfflat': Inverted File Flat index.
HNSW has slower build times and uses more memory than IVFFlat, but has better query performance
(in terms of speed-recall tradeoff).
IVFFlat divides vectors into lists, then searches a subset closest to the query vector.
It has faster build times and uses less memory than HNSW, but lower query performance.
"""
vector_type: Literal["vector", "halfvec"]
"""Type of vector storage to use.
Options:
Expand All @@ -175,42 +148,13 @@ class DBIndexConfig(TypedDict, total=False):
"""


class HNSWConfig(DBIndexConfig, total=False):
"""Configuration for HNSW (Hierarchical Navigable Small World) index."""

kind: Literal["hnsw"] # type: ignore[misc]
m: int
"""Maximum number of connections per layer. Default is 16."""
ef_construction: int
"""Size of dynamic candidate list for index construction. Default is 64."""


class IVFFlatConfig(DBIndexConfig, total=False):
"""IVFFlat index divides vectors into lists, and then searches a subset of those lists that are closest to the query vector. It has faster build times and uses less memory than HNSW, but has lower query performance (in terms of speed-recall tradeoff).
Three keys to achieving good recall are:
1. Create the index after the table has some data
2. Choose an appropriate number of lists - a good place to start is rows / 1000 for up to 1M rows and sqrt(rows) for over 1M rows
3. When querying, specify an appropriate number of probes (higher is better for recall, lower is better for speed) - a good place to start is sqrt(lists)
"""

kind: Literal["ivfflat"] # type: ignore[misc]
nlist: int
"""Number of inverted lists (clusters) for IVF index.
Determines the number of clusters used in the index structure.
Higher values can improve search speed but increase index size and build time.
Typically set to the square root of the number of vectors in the index.
"""


class PostgresIndexConfig(IndexConfig, total=False):
"""Configuration for vector embeddings in PostgreSQL store with pgvector-specific options.
Extends EmbeddingConfig with additional configuration for pgvector index and vector types.
"""

db_index_config: Union[HNSWConfig, IVFFlatConfig]
ann_index_config: ANNIndexConfig
"""Specific configuration for the chosen index type (HNSW or IVF Flat)."""
distance_type: Literal["l2", "inner_product", "cosine"]
"""Distance metric to use for vector similarity search:
Expand Down Expand Up @@ -357,84 +301,102 @@ def _prepare_batch_search_queries(
embedding_requests = []

for idx, (_, op) in enumerate(search_ops):
base_query = """
SELECT prefix, key, value, created_at, updated_at
FROM store
WHERE prefix LIKE %s
"""
params: list = [f"{_namespace_to_text(op.namespace_prefix)}%"]
needs_vector_search = False
# Build filter conditions first
filter_params = []
filter_conditions = []
if op.filter:
for key, value in op.filter.items():
if isinstance(value, dict):
for op_name, val in value.items():
condition, filter_params_ = self._get_filter_condition(
key, op_name, val
)
filter_conditions.append(condition)
filter_params.extend(filter_params_)
else:
filter_conditions.append("value->%s = %s::jsonb")
filter_params.extend([key, json.dumps(value)])

# Vector search branch
if op.query and self.index_config:
needs_vector_search = True
embedding_requests.append((idx, op.query))

score_expr = _get_distance_operator(self)
score_operator = _get_distance_operator(self)
vector_type = (
cast(PostgresIndexConfig, self.index_config)
.get("db_index_config", {})
.get("ann_index_config", {})
.get("vector_type", "vector")
)

if (
vector_type == "bit"
and self.index_config.get("distance_type") == "hamming"
):
score_expr = score_expr % ("%s", self.index_config["dims"])
score_operator = score_operator % (
"%s",
self.index_config["dims"],
)
else:
score_expr = score_expr % ("%s", vector_type)
score_operator = score_operator % (
"%s",
vector_type,
)

vectors_per_doc_estimate = self.index_config["__estimated_num_vectors"]
expanded_limit = (op.limit * vectors_per_doc_estimate * 2) + 1

# Direct query with DISTINCT ON to get best score per document
# Vector search with CTE for proper score handling
filter_str = (
""
if not filter_conditions
else " AND " + " AND ".join(filter_conditions)
)
base_query = f"""
with scored as (
SELECT DISTINCT ON (s.prefix, s.key)
s.prefix, s.key, s.value, s.created_at, s.updated_at,
{score_expr} as score
WITH scored AS (
SELECT s.prefix, s.key, s.value, s.created_at, s.updated_at, {score_operator} AS score
FROM store s
JOIN store_vectors sv ON s.prefix = sv.prefix AND s.key = sv.key
WHERE s.prefix LIKE %s
ORDER BY s.prefix, s.key, score DESC
WHERE s.prefix LIKE %s {filter_str}
ORDER BY {score_operator} DESC
LIMIT %s
)
SELECT * FROM scored
SELECT * FROM (
SELECT DISTINCT ON (prefix, key)
prefix, key, value, created_at, updated_at, score
FROM scored
ORDER BY prefix, key, score DESC
) AS unique_docs
ORDER BY score DESC
LIMIT %s
OFFSET %s
"""
params = [
None, # Vector placeholder
_PLACEHOLDER, # Vector placeholder
f"{_namespace_to_text(op.namespace_prefix)}%",
*filter_params,
_PLACEHOLDER,
expanded_limit,
op.limit,
op.offset,
]

if op.filter:
filter_conditions = []
for key, value in op.filter.items():
if isinstance(value, dict):
for op_name, val in value.items():
condition, filter_params = self._get_filter_condition(
key, op_name, val
)
filter_conditions.append(condition)
params.extend(filter_params)
else:
filter_conditions.append("value->%s = %s::jsonb")
params.extend([key, json.dumps(value)])
# Regular search branch
else:
base_query = """
SELECT prefix, key, value, created_at, updated_at
FROM store
WHERE prefix LIKE %s
"""
params = [f"{_namespace_to_text(op.namespace_prefix)}%"]

if filter_conditions:
if needs_vector_search:
base_query += " WHERE " + " AND ".join(filter_conditions)
else:
base_query += " AND " + " AND ".join(filter_conditions)
params.extend(filter_params)
base_query += " AND " + " AND ".join(filter_conditions)

if needs_vector_search:
base_query += " ORDER BY score DESC"
else:
base_query += " ORDER BY updated_at DESC"
base_query += " LIMIT %s OFFSET %s"
params.extend([op.limit, op.offset])

base_query += " LIMIT %s OFFSET %s"
params.extend([op.limit, op.offset])
queries.append((base_query, params))

return queries, embedding_requests
Expand Down Expand Up @@ -707,7 +669,6 @@ def _batch_put_ops(
vectors = self.embeddings.embed_documents(
[param[-1] for param in txt_params]
)

queries.append(
(
query,
Expand Down Expand Up @@ -735,9 +696,13 @@ def _batch_search_ops(
[query for _, query in embedding_requests]
)
for (idx, _), embedding in zip(embedding_requests, embeddings):
queries[idx][1][0] = embedding
_paramslist = queries[idx][1]
for i in range(len(_paramslist)):
if _paramslist[i] is _PLACEHOLDER:
_paramslist[i] = embedding

for (idx, _), (query, params) in zip(search_ops, queries):
# Execute the actual query
cur.execute(query, params)
rows = cast(list[Row], cur.fetchall())
results[idx] = [
Expand Down Expand Up @@ -821,8 +786,7 @@ class Row(TypedDict):

# Private utilities

_DEFAULT_ANN_CONFIG = HNSWConfig(
kind="hnsw",
_DEFAULT_ANN_CONFIG = ANNIndexConfig(
vector_type="vector",
)

Expand All @@ -833,7 +797,7 @@ def _get_vector_type_ops(store: BasePostgresStore) -> str:
return "vector_cosine_ops"

config = cast(PostgresIndexConfig, store.index_config)
index_config = config.get("db_index_config", _DEFAULT_ANN_CONFIG).copy()
index_config = config.get("ann_index_config", _DEFAULT_ANN_CONFIG).copy()
vector_type = cast(str, index_config.get("vector_type", "vector"))
if vector_type not in ("vector", "halfvec"):
raise ValueError(
Expand All @@ -859,18 +823,6 @@ def _get_vector_type_ops(store: BasePostgresStore) -> str:
return f"{type_prefix}_{distance_suffix}"


def _get_index_params(store: Any) -> tuple[str, dict[str, Any]]:
"""Get the index type and configuration based on config."""
if not store.index_config:
return "hnsw", {}

config = cast(PostgresIndexConfig, store.index_config)
index_config = config.get("db_index_config", _DEFAULT_ANN_CONFIG).copy()
kind = index_config.pop("kind", "hnsw")
index_config.pop("vector_type", None)
return kind, index_config


def _namespace_to_text(
namespace: tuple[str, ...], handle_wildcards: bool = False
) -> str:
Expand Down Expand Up @@ -965,6 +917,15 @@ def _decode_ns_bytes(namespace: Union[str, bytes, list]) -> tuple[str, ...]:

def _get_distance_operator(store: Any) -> str:
"""Get the distance operator and score expression based on config."""
# Note: Today, we are not using ANN indices due to restrictions
# on PGVector's support for mixing vector and non-vector filters
# To use the index, PGVector expects:
# - ORDER BY the operator NOT an expression (even negation blocks it)
# - ASCENDING order
# - Any WHERE clause should be over a partial index.
# If we violate any of these, it will use a sequential scan
# See https://github.com/pgvector/pgvector/issues/216 and the
# pgvector documentation for more details.
if not store.index_config:
raise ValueError(
"Embedding configuration is required for vector operations "
Expand Down Expand Up @@ -1008,3 +969,6 @@ def _ensure_index_config(
index_config.get("embed"),
)
return embeddings, index_config


_PLACEHOLDER = object()
1 change: 0 additions & 1 deletion libs/checkpoint-postgres/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,4 @@ def fake_embeddings() -> CharacterEmbeddings:
return CharacterEmbeddings(dims=500)


INDEX_TYPES = ["hnsw", "ivfflat"]
VECTOR_TYPES = ["vector", "halfvec"]
Loading

0 comments on commit 5101d1d

Please sign in to comment.