-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
hybrid search example code (for docs) #30
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import turbopuffer as tpuf | ||
|
||
|
||
# Helper to convert the results to a dictionary with ranks | ||
def results_to_ranks(results, reverse=False): | ||
if reverse: | ||
results = sorted(results, key=lambda item: item.dist, reverse=True) | ||
return {item.id: rank for rank, item in enumerate(results, start=1)} | ||
|
||
|
||
# Fuses two search result sets together into one | ||
# Uses reciprocal rank fusion | ||
def rank_fusion(bm25_results, vector_results, k=60): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we may want to inline/copy-paste this in a few places, so I think we need 1 top-level comments and then the function to be as small as possible, e.g. something alone the lines of this (not tested): def results_to_ranks(results, reverse=False):
return {item.id: rank for rank, item in enumerate(sorted(results, key=lambda item: item.dist, reverse=reverse), start=1)}
def rank_fusion(bm25_results, vector_results, k=60):
bm25_ranks, vector_ranks = results_to_ranks(bm25_results), results_to_ranks(vector_results, reverse=True)
scores = {doc_id: (1 / (k + bm25_ranks[doc_id]) if doc_id in bm25_ranks else 0) +
(1 / (k + vector_ranks[doc_id]) if doc_id in vector_ranks else 0)
for doc_id in set(bm25_ranks) | set(vector_ranks)}
return [{"id": doc_id, "score": score} for doc_id, score in sorted(scores.items(), key=lambda item: item[1], reverse=True)] The more code you have to paste, the more you'll question why we don't build this into this library (good reasons). The shorter it is, the more you'll want to tweak it (good). |
||
# Compute the ranks of all docs | ||
# Note: a higher score in bm25 denotes a better result, whereas a higher score | ||
# with vector euclidean distance indicates a worse result. When computing RRF, | ||
# we reverse the order of vector results so that higher = better. | ||
bm25_ranks = results_to_ranks(bm25_results) | ||
vector_ranks = results_to_ranks(vector_results, reverse=True) | ||
|
||
print("bm25_ranks:", bm25_ranks) | ||
print("vector ranks", vector_ranks) | ||
|
||
# Get all the unique document IDs | ||
doc_ids = set(bm25_ranks.keys()).union(set(vector_ranks.keys())) | ||
|
||
# Calculate the RRF score for each document | ||
scores = {} | ||
for doc_id in doc_ids: | ||
score = 0.0 | ||
if doc_id in bm25_ranks: | ||
score += 1.0 / (k + bm25_ranks[doc_id]) | ||
if doc_id in vector_ranks: | ||
score += 1.0 / (k + vector_ranks[doc_id]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, is classic RRF only rank-based? Nothing about scores? Probably fine to begin with |
||
scores[doc_id] = score | ||
|
||
# Fuse the results together and return | ||
fused_results = sorted(scores.items(), key=lambda item: item[1], reverse=True) | ||
return [{"id": doc_id, "score": score} for doc_id, score in fused_results] | ||
|
||
|
||
def main(): | ||
ns = tpuf.Namespace("tpuf_python_hybrid_search") | ||
try: | ||
ns.delete_all() # For cleaning up from previous runs | ||
except tpuf.NotFoundError: | ||
pass | ||
|
||
# Upsert 10 documents, containing both vectors and text | ||
# Whale facts from: https://www.natgeokids.com/uk/discover/animals/sea-life/10-blue-whale-facts/ | ||
ns.upsert( | ||
{ | ||
"ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], | ||
"vectors": [ | ||
[0.1, 0.1], | ||
[0.2, 0.2], | ||
[0.3, 0.3], | ||
[0.4, 0.4], | ||
[0.5, 0.5], | ||
[0.6, 0.6], | ||
[0.7, 0.7], | ||
[0.8, 0.8], | ||
[0.9, 0.9], | ||
[1.0, 1.0], | ||
], | ||
"attributes": { | ||
"text": [ | ||
"blue whales can grow to over 30m long", | ||
"pretty much everything about the blue whale is massive", | ||
"blue whales can be found in all oceans, except for the arctic", | ||
"blue whales eat tiny shrimp-like crustaceans called krill", | ||
"to communicate with each other, blue whales make a series of super loud vocal sounds", | ||
"blah", | ||
"blah", | ||
"blah", | ||
"blah", | ||
"blah", | ||
] | ||
}, | ||
"schema": { | ||
"text": { | ||
"type": "?string", | ||
"bm25": { | ||
"language": "english", | ||
"stemming": True, | ||
"case_sensitive": False, | ||
"remove_stopwords": True, | ||
}, | ||
}, | ||
}, | ||
"distance_metric": "euclidean_squared", | ||
} | ||
) | ||
print("upsert documents successfully") | ||
print("") | ||
|
||
# First, do query w/ BM25 full-text search | ||
bm25_results = ns.query({"top_k": 10, "rank_by": ("text", "BM25", "blue whale")}) | ||
print( | ||
"search results for 'blue whale':", | ||
[(item.id, item.dist) for item in bm25_results], | ||
) | ||
print("") | ||
|
||
# Now, we also do a query with a vector | ||
vector_results = ns.query( | ||
top_k=10, vector=[1.0, 1.0], distance_metric="euclidean_squared" | ||
) | ||
print( | ||
"search results for [1.0, 1.0]:", | ||
[(item.id, item.dist) for item in vector_results], | ||
) | ||
print("") | ||
|
||
# Fuse the results with Reciprocal rank fusion (RRF) | ||
fused_results = rank_fusion(bm25_results, vector_results) | ||
print("") | ||
print("fused results:", fused_results) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this run as part of the test suite?