Skip to content

Commit

Permalink
Sort ASC
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Dec 3, 2024
1 parent 5f4ea1c commit 275eb4f
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 86 deletions.
101 changes: 15 additions & 86 deletions libs/checkpoint-postgres/langgraph/store/postgres/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Migration(NamedTuple):
params: Optional[dict[str, Any]] = None


MIGRATIONS: Sequence[Union[str, Migration]] = [
MIGRATIONS: Sequence[str] = [
"""
CREATE TABLE IF NOT EXISTS store (
-- 'prefix' represents the doc's 'namespace'
Expand Down Expand Up @@ -104,26 +104,11 @@ class Migration(NamedTuple):
),
},
),
Migration(
"""
CREATE INDEX IF NOT EXISTS store_vectors_embedding_idx ON store_vectors
USING %(index_type)s (embedding %(ops)s)%(index_params)s;
""",
params={
"index_type": lambda store: _get_index_params(store)[0],
"ops": lambda store: _get_vector_type_ops(store),
"index_params": lambda store: (
" WITH ("
+ ", ".join(f"{k}={v}" for k, v in _get_index_params(store)[1].items())
+ ")"
if _get_index_params(store)[1]
else ""
),
},
),
# TODO: Add an HNSW or IVFFlat index depending on config
# First must improve the search query when filtering by
# namespace
]


C = TypeVar("C", bound=Union[_pg_internal.Conn, _ainternal.Conn])


Expand Down Expand Up @@ -155,8 +140,6 @@ class PoolConfig(TypedDict, total=False):
class ANNIndexConfig(TypedDict, total=False):
"""Configuration for vector index in PostgreSQL store."""

kind: Literal["hnsw", "ivfflat"]
"""Type of index to use: 'hnsw' for Hierarchical Navigable Small World, or 'ivfflat' for Inverted File Flat."""
vector_type: Literal["vector", "halfvec"]
"""Type of vector storage to use.
Options:
Expand All @@ -165,35 +148,6 @@ class ANNIndexConfig(TypedDict, total=False):
"""


class HNSWConfig(ANNIndexConfig, total=False):
"""Configuration for HNSW (Hierarchical Navigable Small World) index."""

kind: Literal["hnsw"] # type: ignore[misc]
m: int
"""Maximum number of connections per layer. Default is 16."""
ef_construction: int
"""Size of dynamic candidate list for index construction. Default is 64."""


class IVFFlatConfig(ANNIndexConfig, total=False):
"""IVFFlat index divides vectors into lists, and then searches a subset of those lists that are closest to the query vector. It has faster build times and uses less memory than HNSW, but has lower query performance (in terms of speed-recall tradeoff).
Three keys to achieving good recall are:
1. Create the index after the table has some data
2. Choose an appropriate number of lists - a good place to start is rows / 1000 for up to 1M rows and sqrt(rows) for over 1M rows
3. When querying, specify an appropriate number of probes (higher is better for recall, lower is better for speed) - a good place to start is sqrt(lists)
"""

kind: Literal["ivfflat"] # type: ignore[misc]
nlist: int
"""Number of inverted lists (clusters) for IVF index.
Determines the number of clusters used in the index structure.
Higher values can improve search speed but increase index size and build time.
Typically set to the square root of the number of vectors in the index.
"""


class PostgresIndexConfig(IndexConfig, total=False):
"""Configuration for vector embeddings in PostgreSQL store with pgvector-specific options.
Expand Down Expand Up @@ -367,7 +321,7 @@ def _prepare_batch_search_queries(
if op.query and self.index_config:
embedding_requests.append((idx, op.query))

score_operator, post_operator = _get_distance_operator(self)
score_operator = _get_distance_operator(self)
vector_type = (
cast(PostgresIndexConfig, self.index_config)
.get("ann_index_config", {})
Expand All @@ -390,17 +344,18 @@ def _prepare_batch_search_queries(

vectors_per_doc_estimate = self.index_config["__estimated_num_vectors"]
expanded_limit = (op.limit * vectors_per_doc_estimate * 2) + 1

# Vector search with CTE for proper score handling
filter_str = (
""
if not filter_conditions
else " AND " + " AND ".join(filter_conditions)
)
ns_args = []
if op.namespace_prefix:
prefix_filter_str = f"WHERE s.prefix = %s {filter_str} "
ns_args = [f"{_namespace_to_text(op.namespace_prefix)}"]
prefix_filter_str = f"WHERE s.prefix LIKE %s {filter_str} "
ns_args = (f"{_namespace_to_text(op.namespace_prefix)}%",)
else:
ns_args = ()
if filter_str:
prefix_filter_str = f"WHERE {filter_str} "
else:
Expand All @@ -417,7 +372,7 @@ def _prepare_batch_search_queries(
)
SELECT * FROM (
SELECT DISTINCT ON (prefix, key)
prefix, key, value, created_at, updated_at, {post_operator} as score
prefix, key, value, created_at, updated_at, -neg_score as score
FROM scored
ORDER BY prefix, key, score DESC
) AS unique_docs
Expand Down Expand Up @@ -757,20 +712,6 @@ def _batch_search_ops(
_paramslist[i] = embedding

for (idx, _), (query, params) in zip(search_ops, queries):
# Get and print pgvector version
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
version = cur.fetchone()
if version:
print(f"pgvector version: {list(version.values())[0]}", flush=True)

# Run EXPLAIN on the query, verbose to get the query plan
cur.execute(f"EXPLAIN {query}", params)
# Print the query plan line by line. Truncate at 300 chars per line
print("^" * 80, flush=True)
for line in cur.fetchall():
print(list(line.values())[0][:300], flush=True)
print("*" * 80, flush=True)
# Execute the actual query
cur.execute(query, params)
rows = cast(list[Row], cur.fetchall())
results[idx] = [
Expand Down Expand Up @@ -891,18 +832,6 @@ def _get_vector_type_ops(store: BasePostgresStore) -> str:
return f"{type_prefix}_{distance_suffix}"


def _get_index_params(store: Any) -> tuple[str, dict[str, Any]]:
"""Get the index type and configuration based on config."""
if not store.index_config:
return "hnsw", {}

config = cast(PostgresIndexConfig, store.index_config)
index_config = config.get("ann_index_config", _DEFAULT_ANN_CONFIG).copy()
kind = index_config.pop("kind", "hnsw")
index_config.pop("vector_type", None)
return kind, index_config


def _namespace_to_text(
namespace: tuple[str, ...], handle_wildcards: bool = False
) -> str:
Expand Down Expand Up @@ -995,7 +924,7 @@ def _decode_ns_bytes(namespace: Union[str, bytes, list]) -> tuple[str, ...]:
return tuple(namespace.split("."))


def _get_distance_operator(store: Any) -> tuple[str, str]:
def _get_distance_operator(store: Any) -> tuple[str, bool]:
"""Get the distance operator and score expression based on config."""
# Note: Today, we are not using ANN indices due to restrictions
# on PGVector's support for mixing vector and non-vector filters
Expand Down Expand Up @@ -1023,14 +952,14 @@ def _get_distance_operator(store: Any) -> tuple[str, str]:
# a DESCENDING ORDER sort clause and the user's expectations of what the similarity score
# should be.
if distance_type == "l2":
# Final: "1 - (sv.embedding <-> %s::%s)"
return "sv.embedding <-> %s::%s", "1 - (scored.neg_score)"
# Final: "sv.embedding <-> %s::%s"
return "sv.embedding <-> %s::%s"
elif distance_type == "inner_product":
# Final: "-(sv.embedding <#> %s::%s)"
return "sv.embedding <#> %s::%s", "-(scored.neg_score)"
return "sv.embedding <#> %s::%s"
else: # cosine
# Final: "1 - (sv.embedding <=> %s::%s)"
return "sv.embedding <=> %s::%s", "1 - (scored.neg_score)"
return "sv.embedding <=> %s::%s"


def _ensure_index_config(
Expand Down
87 changes: 87 additions & 0 deletions libs/checkpoint-postgres/tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,7 @@ def test_embed_with_path_operation_config(
distance_type: str,
) -> None:
"""Test operation-level field configuration for vector search."""

with _create_vector_store(
vector_type,
distance_type,
Expand Down Expand Up @@ -695,3 +696,89 @@ def test_embed_with_path_operation_config(
# assert len(results) == 3
# doc5_result = next(r for r in results if r.key == "doc5")
# assert doc5_result.score is None


def _cosine_similarity(X: list[float], Y: list[list[float]]) -> list[float]:
"""
Compute cosine similarity between a vector X and a matrix Y.
Lazy import numpy for efficiency.
"""

similarities = []
for y in Y:
dot_product = sum(a * b for a, b in zip(X, y))
norm1 = sum(a * a for a in X) ** 0.5
norm2 = sum(a * a for a in y) ** 0.5
similarity = dot_product / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0.0
similarities.append(similarity)

return similarities


def _inner_product(X: list[float], Y: list[list[float]]) -> list[float]:
"""
Compute inner product between a vector X and a matrix Y.
Lazy import numpy for efficiency.
"""

similarities = []
for y in Y:
similarity = sum(a * b for a, b in zip(X, y))
similarities.append(similarity)

return similarities


def _l2_distance(X: list[float], Y: list[list[float]]) -> list[float]:
"""
Compute l2 distance between a vector X and a matrix Y.
Lazy import numpy for efficiency.
"""

similarities = []
for y in Y:
similarity = sum((a - b) ** 2 for a, b in zip(X, y)) ** 0.5
similarities.append(similarity)

return similarities


@pytest.mark.parametrize(
"vector_type,distance_type",
[
("vector", "cosine"),
("vector", "inner_product"),
("halfvec", "l2"),
],
)
@pytest.mark.parametrize("query", ["aaa", "bbb", "ccc", "abcd", "poisson"])
def test_scores(
fake_embeddings: CharacterEmbeddings,
vector_type: str,
distance_type: str,
query: str,
) -> None:
"""Test operation-level field configuration for vector search."""
with _create_vector_store(
vector_type,
distance_type,
fake_embeddings,
text_fields=["key0"],
) as store:
doc = {
"key0": "aaa",
}
store.put(("test",), "doc", doc, index=["key0", "key1"])

results = store.search((), query=query)
vec0 = fake_embeddings.embed_query(doc["key0"])
vec1 = fake_embeddings.embed_query(query)
if distance_type == "cosine":
similarities = _cosine_similarity(vec1, [vec0])
elif distance_type == "inner_product":
similarities = _inner_product(vec1, [vec0])
elif distance_type == "l2":
similarities = _l2_distance(vec1, [vec0])

assert len(results) == 1
assert results[0].score == pytest.approx(similarities[0], abs=1e-3)

0 comments on commit 275eb4f

Please sign in to comment.