Skip to content

Commit

Permalink
feat: add default cache_knowledge to google bucket urls, updated vesp…
Browse files Browse the repository at this point in the history
…a_id lookup
  • Loading branch information
milovate authored and devxpy committed Dec 31, 2024
1 parent 2f34064 commit 68bb697
Showing 1 changed file with 61 additions and 44 deletions.
105 changes: 61 additions & 44 deletions daras_ai_v2/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from daras_ai_v2.embedding_model import create_embeddings_cached, EmbeddingModels
from daras_ai_v2.exceptions import raise_for_status, call_cmd, UserError
from daras_ai_v2.functional import (
flatmap_parallel,
map_parallel,
flatmap_parallel_ascompleted,
)
Expand Down Expand Up @@ -146,9 +145,6 @@ def get_top_k_references(
else:
selected_asr_model = google_translate_target = None

file_url_metas = flatmap_parallel(doc_or_yt_url_to_metadatas, input_docs)
file_urls, file_metas = zip(*file_url_metas)

yield "Creating knowledge embeddings..."

embedding_model = EmbeddingModels.get(
Expand All @@ -158,9 +154,8 @@ def get_top_k_references(
),
)
embedded_files: list[EmbeddedFile] = map_parallel(
lambda f_url, file_meta: get_or_create_embedded_file(
lambda f_url: get_or_create_embedded_file(
f_url=f_url,
file_meta=file_meta,
max_context_words=request.max_context_words,
scroll_jump=request.scroll_jump,
google_translate_target=google_translate_target,
Expand All @@ -169,8 +164,7 @@ def get_top_k_references(
is_user_url=is_user_url,
current_user=current_user,
),
file_urls,
file_metas,
input_docs,
max_workers=4,
)
if not embedded_files:
Expand Down Expand Up @@ -314,7 +308,13 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata:
else:
try:
if is_user_uploaded_url(f_url):
r = requests.head(f_url)
name = f.path.segments[-1]
return FileMetadata(
name=name,
etag=None,
mime_type=mimetypes.guess_type(name)[0],
total_bytes=0,
)
else:
r = requests.head(
f_url,
Expand Down Expand Up @@ -387,7 +387,6 @@ def yt_dlp_extract_info(url: str) -> dict:
def get_or_create_embedded_file(
*,
f_url: str,
file_meta: FileMetadata,
max_context_words: int,
scroll_jump: int,
google_translate_target: str | None,
Expand All @@ -402,51 +401,69 @@ def get_or_create_embedded_file(
"""
lookup = dict(
url=f_url,
metadata__name=file_meta.name,
metadata__etag=file_meta.etag,
metadata__mime_type=file_meta.mime_type,
metadata__total_bytes=file_meta.total_bytes,
max_context_words=max_context_words,
scroll_jump=scroll_jump,
google_translate_target=google_translate_target or "",
selected_asr_model=selected_asr_model or "",
embedding_model=embedding_model.name,
)

file_id = hashlib.sha256(str(lookup).encode()).hexdigest()
with redis_lock(f"gooey/get_or_create_embeddings/v1/{file_id}"):
try:
return EmbeddedFile.objects.filter(**lookup).order_by("-updated_at")[0]
embedded_file = EmbeddedFile.objects.filter(**lookup).order_by(
"-updated_at"
)[0]
if is_user_uploaded_url(f_url):
return embedded_file

if not is_yt_dlp_able_url(f_url):
file_meta = doc_or_yt_url_to_metadatas(f_url)[0][1]
if file_meta == embedded_file.metadata:
return embedded_file
pass

except IndexError:
refs = create_embeddings_in_search_db(
f_url=f_url,
file_meta=file_meta,
file_id=file_id,
max_context_words=max_context_words,
scroll_jump=scroll_jump,
google_translate_target=google_translate_target or "",
selected_asr_model=selected_asr_model or "",
embedding_model=embedding_model,
is_user_url=is_user_url,
)
with transaction.atomic():
file_meta.save()
embedded_file = EmbeddedFile.objects.get_or_create(
**lookup,
defaults=dict(
metadata=file_meta,
vespa_file_id=file_id,
created_by=current_user,
),
)[0]
for ref in refs:
ref.embedded_file = embedded_file
EmbeddingsReference.objects.bulk_create(
refs,
update_conflicts=True,
update_fields=["url", "title", "snippet", "updated_at"],
unique_fields=["vespa_doc_id"],

for leaf_url, file_meta in doc_or_yt_url_to_metadatas(f_url):
lookup.update(
metadata__name=file_meta.name,
metadata__etag=file_meta.etag,
metadata__mime_type=file_meta.mime_type,
metadata__total_bytes=file_meta.total_bytes,
)
file_id = hashlib.sha256(str(lookup).encode()).hexdigest()

refs = create_embeddings_in_search_db(
f_url=f_url,
file_meta=file_meta,
file_id=file_id,
max_context_words=max_context_words,
scroll_jump=scroll_jump,
google_translate_target=google_translate_target or "",
selected_asr_model=selected_asr_model or "",
embedding_model=embedding_model,
is_user_url=is_user_url,
)
return embedded_file
with transaction.atomic():
file_meta.save()
embedded_file = EmbeddedFile.objects.get_or_create(
**lookup,
defaults=dict(
metadata=file_meta,
vespa_file_id=file_id,
created_by=current_user,
),
)[0]
for ref in refs:
ref.embedded_file = embedded_file
EmbeddingsReference.objects.bulk_create(
refs,
update_conflicts=True,
update_fields=["url", "title", "snippet", "updated_at"],
unique_fields=["vespa_doc_id"],
)
return embedded_file


def create_embeddings_in_search_db(
Expand Down

0 comments on commit 68bb697

Please sign in to comment.