Skip to content

Commit

Permalink
Add test that sorting is correct
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Nov 28, 2024
1 parent 5101d1d commit c25e383
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 65 deletions.
2 changes: 1 addition & 1 deletion libs/checkpoint-postgres/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph-checkpoint-postgres"
version = "2.0.4"
version = "2.0.5"
description = "Library with a Postgres implementation of LangGraph checkpoint saver."
authors = []
license = "MIT"
Expand Down
81 changes: 18 additions & 63 deletions libs/checkpoint-postgres/tests/test_async_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# type: ignore
import itertools
import sys
import uuid
from collections.abc import AsyncIterator
Expand Down Expand Up @@ -413,12 +414,7 @@ async def test_vector_search_edge_cases(vector_store: AsyncPostgresStore) -> Non
@pytest.mark.parametrize(
"vector_type,distance_type",
[
("vector", "cosine"),
("vector", "inner_product"),
("vector", "l2"),
("halfvec", "cosine"),
("halfvec", "inner_product"),
("halfvec", "l2"),
*itertools.product(["vector", "halfvec"], ["cosine", "inner_product", "l2"]),
],
)
async def test_embed_with_path(
Expand Down Expand Up @@ -477,13 +473,10 @@ async def test_embed_with_path(
@pytest.mark.parametrize(
"vector_type,distance_type",
[
("vector", "cosine"),
("vector", "cosine"),
("halfvec", "cosine"),
("halfvec", "inner_product"),
*itertools.product(["vector", "halfvec"], ["cosine", "inner_product", "l2"]),
],
)
async def test_embed_with_path_operation_config(
async def test_search_sorting(
request: Any,
fake_embeddings: CharacterEmbeddings,
vector_type: str,
Expand All @@ -494,59 +487,21 @@ async def test_embed_with_path_operation_config(
vector_type,
distance_type,
fake_embeddings,
text_fields=["key17"], # Default fields that won't match our test data
text_fields=["key1"], # Default fields that won't match our test data
) as store:
doc3 = {
"key0": "aaa",
"key1": "bbb",
"key2": "ccc",
"key3": "ddd",
amatch = {
"key1": "mmm",
}
doc4 = {
"key0": "eee",
"key1": "bbb", # Same as doc3.key1
"key2": "fff",
"key3": "ggg",
}

await store.aput(("test",), "doc3", doc3, index=["key0", "key1"])
await store.aput(("test",), "doc4", doc4, index=["key1", "key3"])

results = await store.asearch(("test",), query="aaa")
assert len(results) == 2
assert results[0].key == "doc3"
assert results[0].score > results[1].score

results = await store.asearch(("test",), query="ggg")
assert len(results) == 2
assert results[0].key == "doc4"
await store.aput(("test", "M"), "M", amatch)
N = 100
for i in range(N):
await store.aput(("test", "A"), f"A{i}", {"key1": "no"})
for i in range(N):
await store.aput(("test", "Z"), f"Z{i}", {"key1": "no"})

results = await store.asearch(("test",), query="mmm", limit=10)
assert len(results) == 10
assert len(set(r.key for r in results)) == 10
assert results[0].key == "M"
assert results[0].score > results[1].score

results = await store.asearch(("test",), query="bbb")
assert len(results) == 2
assert results[0].key != results[1].key
assert results[0].score == pytest.approx(results[1].score, abs=1e-3)

results = await store.asearch(("test",), query="ccc")
assert len(results) == 2
assert all(
r.score < 0.9 for r in results
) # Unindexed field should have low scores

# Test index=False behavior
doc5 = {
"key0": "hhh",
"key1": "iii",
}
await store.aput(("test",), "doc5", doc5, index=False)
results = await store.asearch(("test",))
assert len(results) == 3
assert all(r.score is None for r in results)
assert any(r.key == "doc5" for r in results)

results = await store.asearch(("test",), query="hhh")
# TODO: We don't currently fill in additional results if there are not enough
# returned during vector search.
# assert len(results) == 3
# doc5_result = next(r for r in results if r.key == "doc5")
# assert doc5_result.score is None
2 changes: 1 addition & 1 deletion libs/checkpoint/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph-checkpoint"
version = "2.0.5"
version = "2.0.6"
description = "Library with base interfaces for LangGraph checkpoint savers."
authors = []
license = "MIT"
Expand Down

0 comments on commit c25e383

Please sign in to comment.