Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[postgres] Sort Ascending #2594

Merged
merged 2 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions libs/checkpoint-postgres/langgraph/store/postgres/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio

Check notice on line 1 in libs/checkpoint-postgres/langgraph/store/postgres/base.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 61.3 ms +- 1.8 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 51.5 ms +- 1.1 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 91.4 ms +- 7.6 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 93.2 ms +- 0.9 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 588 ms +- 24 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 501 ms +- 5 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 919 ms +- 49 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 919 ms +- 16 ms ......................................... react_agent_10x: Mean +- std dev: 30.7 ms +- 0.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.4 ms +- 0.3 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 47.0 ms +- 1.0 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.7 ms +- 0.5 ms ......................................... react_agent_100x: Mean +- std dev: 348 ms +- 7 ms ......................................... react_agent_100x_sync: Mean +- std dev: 272 ms +- 4 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 940 ms +- 10 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 852 ms +- 17 ms ......................................... wide_state_25x300: Mean +- std dev: 24.3 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.6 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 281 ms +- 5 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 269 ms +- 3 ms ......................................... wide_state_15x600: Mean +- std dev: 28.4 ms +- 0.5 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.9 ms +- 0.2 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 482 ms +- 6 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 466 ms +- 7 ms ......................................... wide_state_9x1200: Mean +- std dev: 28.2 ms +- 0.8 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.8 ms +- 0.3 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 312 ms +- 3 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 299 ms +- 4 ms

Check notice on line 1 in libs/checkpoint-postgres/langgraph/store/postgres/base.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +====================================+=========+=======================+ | fanout_to_subgraph_100x_checkpoint | 940 ms | 919 ms: 1.02x faster | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 594 ms | 588 ms: 1.01x faster | +------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint | 314 ms | 312 ms: 1.01x faster | +------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 17.7 ms | 17.8 ms: 1.00x slower | +------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 480 ms | 482 ms: 1.00x slower | +------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint_sync | 298 ms | 299 ms: 1.00x slower | +------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 271 ms | 272 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 46.7 ms | 47.0 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 17.8 ms | 17.9 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint_sync | 266 ms | 269 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint_sync | 36.4 ms | 36.7 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 279 ms | 281 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_25x300 | 24.1 ms | 24.3 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 15.4 ms | 15.6 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_15x600 | 28.0 ms | 28.4 ms: 1.02x slower | +------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 837 ms | 852 ms: 1.02x slower | +------------------------------------+---------+-----------------------+ | react_agent_100x | 340 ms | 348 ms: 1.02x slower | +------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.00x slower | +------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (11): react_agent_10x, fanout_to_subgraph_10x_checkpoint, fanout_to_subgraph_10x, fanout_to_subgraph_10x_checkpoint_sync, fanout_to_subgraph_100x_checkpoint_sync, fanout_to_subgraph_100x_sync, react_agent_10x_sync, wide_state_15x600_checkpoint_sync, wide_state_9x1200, react_agent_100x_checkpoint, fanout_to_subgraph_10x_sync
import json
import logging
import threading
Expand Down Expand Up @@ -321,7 +321,7 @@
if op.query and self.index_config:
embedding_requests.append((idx, op.query))

score_operator = _get_distance_operator(self)
score_operator, post_operator = _get_distance_operator(self)
vector_type = (
cast(PostgresIndexConfig, self.index_config)
.get("ann_index_config", {})
Expand Down Expand Up @@ -351,18 +351,28 @@
if not filter_conditions
else " AND " + " AND ".join(filter_conditions)
)
if op.namespace_prefix:
prefix_filter_str = f"WHERE s.prefix LIKE %s {filter_str} "
ns_args: Sequence = (f"{_namespace_to_text(op.namespace_prefix)}%",)
else:
ns_args = ()
if filter_str:
prefix_filter_str = f"WHERE {filter_str} "
else:
prefix_filter_str = ""

base_query = f"""
WITH scored AS (
SELECT s.prefix, s.key, s.value, s.created_at, s.updated_at, {score_operator} AS score
SELECT s.prefix, s.key, s.value, s.created_at, s.updated_at, {score_operator} AS neg_score
FROM store s
JOIN store_vectors sv ON s.prefix = sv.prefix AND s.key = sv.key
WHERE s.prefix LIKE %s {filter_str}
ORDER BY {score_operator} DESC
{prefix_filter_str}
ORDER BY {score_operator} ASC
LIMIT %s
)
SELECT * FROM (
SELECT DISTINCT ON (prefix, key)
prefix, key, value, created_at, updated_at, score
prefix, key, value, created_at, updated_at, {post_operator} as score
FROM scored
ORDER BY prefix, key, score DESC
) AS unique_docs
Expand All @@ -372,7 +382,7 @@
"""
params = [
_PLACEHOLDER, # Vector placeholder
f"{_namespace_to_text(op.namespace_prefix)}%",
*ns_args,
*filter_params,
_PLACEHOLDER,
expanded_limit,
Expand Down Expand Up @@ -702,7 +712,6 @@
_paramslist[i] = embedding

for (idx, _), (query, params) in zip(search_ops, queries):
# Execute the actual query
cur.execute(query, params)
rows = cast(list[Row], cur.fetchall())
results[idx] = [
Expand Down Expand Up @@ -915,7 +924,7 @@
return tuple(namespace.split("."))


def _get_distance_operator(store: Any) -> str:
def _get_distance_operator(store: Any) -> tuple[str, str]:
"""Get the distance operator and score expression based on config."""
# Note: Today, we are not using ANN indices due to restrictions
# on PGVector's support for mixing vector and non-vector filters
Expand All @@ -936,12 +945,22 @@
config = cast(PostgresIndexConfig, store.index_config)
distance_type = config.get("distance_type", "cosine")

# Return the operator and the score expression
# The operator is used in the CTE and will be compatible with an ASCENDING ORDER
# sort clause.
# The score expression is used in the final query and will be compatible with
# a DESCENDING ORDER sort clause and the user's expectations of what the similarity score
# should be.
if distance_type == "l2":
return "1 - (sv.embedding <-> %s::%s)"
# Final: "-(sv.embedding <-> %s::%s)"
# We return the "l2 similarity" so that the sorting order is the same
return "sv.embedding <-> %s::%s", "-scored.neg_score"
elif distance_type == "inner_product":
return "-(sv.embedding <#> %s::%s)"
else: # cosine
return "1 - (sv.embedding <=> %s::%s)"
# Final: "-(sv.embedding <#> %s::%s)"
return "sv.embedding <#> %s::%s", "-(scored.neg_score)"
else: # cosine similarity
# Final: "1 - (sv.embedding <=> %s::%s)"
return "sv.embedding <=> %s::%s", "1 - scored.neg_score"


def _ensure_index_config(
Expand Down
87 changes: 87 additions & 0 deletions libs/checkpoint-postgres/tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,7 @@ def test_embed_with_path_operation_config(
distance_type: str,
) -> None:
"""Test operation-level field configuration for vector search."""

with _create_vector_store(
vector_type,
distance_type,
Expand Down Expand Up @@ -695,3 +696,89 @@ def test_embed_with_path_operation_config(
# assert len(results) == 3
# doc5_result = next(r for r in results if r.key == "doc5")
# assert doc5_result.score is None


def _cosine_similarity(X: list[float], Y: list[list[float]]) -> list[float]:
"""
Compute cosine similarity between a vector X and a matrix Y.
Lazy import numpy for efficiency.
"""

similarities = []
for y in Y:
dot_product = sum(a * b for a, b in zip(X, y))
norm1 = sum(a * a for a in X) ** 0.5
norm2 = sum(a * a for a in y) ** 0.5
similarity = dot_product / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0.0
similarities.append(similarity)

return similarities


def _inner_product(X: list[float], Y: list[list[float]]) -> list[float]:
"""
Compute inner product between a vector X and a matrix Y.
Lazy import numpy for efficiency.
"""

similarities = []
for y in Y:
similarity = sum(a * b for a, b in zip(X, y))
similarities.append(similarity)

return similarities


def _neg_l2_distance(X: list[float], Y: list[list[float]]) -> list[float]:
"""
Compute l2 distance between a vector X and a matrix Y.
Lazy import numpy for efficiency.
"""

similarities = []
for y in Y:
similarity = sum((a - b) ** 2 for a, b in zip(X, y)) ** 0.5
similarities.append(-similarity)

return similarities


@pytest.mark.parametrize(
"vector_type,distance_type",
[
("vector", "cosine"),
("vector", "inner_product"),
("halfvec", "l2"),
],
)
@pytest.mark.parametrize("query", ["aaa", "bbb", "ccc", "abcd", "poisson"])
def test_scores(
fake_embeddings: CharacterEmbeddings,
vector_type: str,
distance_type: str,
query: str,
) -> None:
"""Test operation-level field configuration for vector search."""
with _create_vector_store(
vector_type,
distance_type,
fake_embeddings,
text_fields=["key0"],
) as store:
doc = {
"key0": "aaa",
}
store.put(("test",), "doc", doc, index=["key0", "key1"])

results = store.search((), query=query)
vec0 = fake_embeddings.embed_query(doc["key0"])
vec1 = fake_embeddings.embed_query(query)
if distance_type == "cosine":
similarities = _cosine_similarity(vec1, [vec0])
elif distance_type == "inner_product":
similarities = _inner_product(vec1, [vec0])
elif distance_type == "l2":
similarities = _neg_l2_distance(vec1, [vec0])

assert len(results) == 1
assert results[0].score == pytest.approx(similarities[0], abs=1e-3)
Loading