Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bm25 & keyword search #564

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,7 +1815,9 @@ def render_variables(self):
if not self.functions_in_settings:
functions_input(self.request.user)
variables_input(
template_keys=self.template_keys, allow_add=is_functions_enabled()
template_keys=self.template_keys,
allow_add=is_functions_enabled(),
exclude=self.fields_to_save(),
)

@classmethod
Expand Down
11 changes: 6 additions & 5 deletions daras_ai_v2/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ def generate_final_search_query(
context: dict = None,
response_format_type: typing.Literal["text", "json_object"] = None,
):
if context is None:
context = request.dict()
if response:
context |= response.dict()
instructions = render_prompt_vars(instructions, context).strip()
state = request.dict()
if response:
state |= response.dict()
if context:
state |= context
instructions = render_prompt_vars(instructions, state).strip()
if not instructions:
return ""
return run_language_model(
Expand Down
3 changes: 2 additions & 1 deletion daras_ai_v2/variables_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def variables_input(
description: str = "Variables let you pass custom parameters to your workflow. Access a variable in your instruction prompt with <a href='https://jinja.palletsprojects.com/en/3.1.x/templates/' target='_blank'>Jinja</a>, e.g. `{{ my_variable }}`\n ",
key: str = "variables",
allow_add: bool = False,
exclude: typing.Iterable[str] = (),
):
from recipes.BulkRunner import list_view_editor

Expand All @@ -45,7 +46,7 @@ def variables_input(
var_names = (
(template_var_names | set(variables.keys()))
- set(context_globals().keys()) # dont show global context variables
- set(gui.session_state.keys()) # dont show other session state variables
- set(exclude) # used for hiding request/response fields
)
pressed_add = False
if var_names or allow_add:
Expand Down
103 changes: 69 additions & 34 deletions daras_ai_v2/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import tempfile
import typing
import unicodedata
from functools import partial
from time import time

Expand Down Expand Up @@ -56,6 +57,7 @@
url_to_gdrive_file_id,
gdrive_metadata,
)
from daras_ai_v2.office_utils_pptx import pptx_to_text_pages
from daras_ai_v2.redis_cache import redis_lock
from daras_ai_v2.scraping_proxy import (
get_scraping_proxy_cert_path,
Expand All @@ -67,7 +69,6 @@
remove_quotes,
generate_text_fragment_url,
)
from daras_ai_v2.office_utils_pptx import pptx_to_text_pages
from daras_ai_v2.text_splitter import text_splitter, Document
from embeddings.models import EmbeddedFile, EmbeddingsReference
from files.models import FileMetadata
Expand Down Expand Up @@ -190,6 +191,7 @@ def get_top_k_references(
s = time()
search_result = query_vespa(
request.search_query,
request.keyword_query,
file_ids=vespa_file_ids,
limit=request.max_references or 100,
embedding_model=embedding_model,
Expand Down Expand Up @@ -232,34 +234,63 @@ def vespa_search_results_to_refs(

def query_vespa(
search_query: str,
keyword_query: str | list[str] | None,
file_ids: list[str],
limit: int,
embedding_model: EmbeddingModels,
semantic_weight: float = 1.0,
threshold: float = 0.7,
rerank_count: float = 1000,
) -> dict:
Comment on lines 235 to 244
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The rerank_count parameter is defined as float but used for integer operations. Should be typed as int.

query_embedding = create_embeddings_cached([search_query], model=embedding_model)[0]
if query_embedding is None or not file_ids:
if not file_ids:
return {"root": {"children": []}}
file_ids_str = ", ".join(map(repr, file_ids))
query = f"select * from {settings.VESPA_SCHEMA} where file_id in (@fileIds) and (userQuery() or ({{targetHits: {limit}}}nearestNeighbor(embedding, q))) limit {limit}"
logger.debug(f"Vespa query: {'-'*80}\n{query}\n{'-'*80}")
if semantic_weight == 1.0:
ranking = "semantic"
elif semantic_weight == 0.0:

yql = "select * from %(schema)s where file_id in (@fileIds) and " % dict(
schema=settings.VESPA_SCHEMA
)
bm25_yql = "( {targetHits: %(hits)i} userInput(@bm25Query) )"
semantic_yql = "( {targetHits: %(hits)i, distanceThreshold: %(threshold)f} nearestNeighbor(embedding, queryEmbedding) )"

if semantic_weight == 0.0:
yql += bm25_yql % dict(hits=limit)
ranking = "bm25"
elif semantic_weight == 1.0:
yql += semantic_yql % dict(hits=limit, threshold=threshold)
ranking = "semantic"
else:
yql += (
"( "
+ bm25_yql % dict(hits=rerank_count)
+ " or "
+ semantic_yql % dict(hits=rerank_count, threshold=threshold)
+ " )"
)
ranking = "fusion"
response = get_vespa_app().query(
yql=query,
query=search_query,
ranking=ranking,
body={
"ranking.features.query(q)": padded_embedding(query_embedding),
"ranking.features.query(semanticWeight)": semantic_weight,
"fileIds": file_ids_str,
},

body = {"yql": yql, "ranking": ranking, "hits": limit}

if ranking in ("bm25", "fusion"):
if isinstance(keyword_query, list):
keyword_query = " ".join(keyword_query)
body["bm25Query"] = remove_control_characters(keyword_query or search_query)

logger.debug(
"vespa query " + " ".join(repr(f"{k}={v}") for k, v in body.items()) + " ..."
)

if ranking in ("semantic", "fusion"):
query_embedding = create_embeddings_cached(
[search_query], model=embedding_model
)[0]
if query_embedding is None:
return {"root": {"children": []}}
body["input.query(queryEmbedding)"] = padded_embedding(query_embedding)

body["fileIds"] = ", ".join(map(repr, file_ids))

response = get_vespa_app().query(body)
assert response.is_successful()

return response.get_json()


Expand Down Expand Up @@ -485,6 +516,23 @@ def create_embeddings_in_search_db(
return refs


def format_embedding_row(
doc_id: str,
file_id: str,
ref: SearchReference,
embedding: np.ndarray,
created_at: datetime.datetime,
):
return dict(
id=doc_id,
file_id=file_id,
embedding=padded_embedding(embedding),
created_at=int(created_at.timestamp() * 1000),
title=remove_control_characters(ref["title"]),
snippet=remove_control_characters(ref["snippet"]),
)


def get_embeds_for_doc(
*,
f_url: str,
Expand Down Expand Up @@ -940,22 +988,9 @@ def render_sources_widget(refs: list[SearchReference]):
)


def format_embedding_row(
doc_id: str,
file_id: str,
ref: SearchReference,
embedding: np.ndarray,
created_at: datetime.datetime,
):
return dict(
id=doc_id,
file_id=file_id,
embedding=padded_embedding(embedding),
created_at=int(created_at.timestamp() * 1000),
# url=ref["url"].encode("unicode-escape").decode(),
# title=ref["title"].encode("unicode-escape").decode(),
# snippet=ref["snippet"].encode("unicode-escape").decode(),
)
def remove_control_characters(s):
# from https://docs.vespa.ai/en/troubleshooting-encoding.html
return "".join(ch for ch in s if unicodedata.category(ch)[0] != "C")
Comment on lines +991 to +993
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The remove_control_characters function could be more efficient using str.translate() with a translation table



EMBEDDING_SIZE = 3072
Expand Down
Loading
Loading