Skip to content

Commit

Permalink
Refactor vector_search.py and DocExtract.py to improve YouTube ha…
Browse files Browse the repository at this point in the history
…ndling and embedding process

- Introduce `astuple` method in `FileMetadata` for tuple representation and correct comparison
- Remove null metadata for bucket urls
- Store playlists as a single `EmbeddedFile` & compare playlist level metadata when doing search
- Make fetching playlist metadata faster by using `--flat-playlist`
- Delete previous `EmbeddedFile` with same lookup when creating embeddings
- Display url level progress when creating embeddings

Enhances handling of YouTube playlists/videos and optimizes embedding management for more accurate results
  • Loading branch information
devxpy committed Dec 31, 2024
1 parent 68bb697 commit cc1cb0a
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 122 deletions.
215 changes: 112 additions & 103 deletions daras_ai_v2/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,62 +14,63 @@
import gooey_gui as gui
import numpy as np
import requests
from app_users.models import AppUser
from daras_ai.image_input import (
get_mimetype_from_response,
safe_filename,
upload_file_from_bytes,
)
from django.db import transaction
from django.db.models import F
from django.utils import timezone
from embeddings.models import EmbeddedFile, EmbeddingsReference
from files.models import FileMetadata
from furl import furl
from loguru import logger
from pydantic import BaseModel, Field

from app_users.models import AppUser
from daras_ai.image_input import (
upload_file_from_bytes,
safe_filename,
get_mimetype_from_response,
)
from daras_ai_v2 import settings
from daras_ai_v2.asr import (
AsrModels,
download_youtube_to_wav,
run_asr,
run_google_translate,
download_youtube_to_wav,
)
from daras_ai_v2.azure_doc_extract import (
table_arr_to_prompt_chunked,
THEAD,
azure_doc_extract_page_num,
table_arr_to_prompt_chunked,
)
from daras_ai_v2.doc_search_settings_widgets import (
is_user_uploaded_url,
)
from daras_ai_v2.embedding_model import create_embeddings_cached, EmbeddingModels
from daras_ai_v2.exceptions import raise_for_status, call_cmd, UserError
from daras_ai_v2.embedding_model import EmbeddingModels, create_embeddings_cached
from daras_ai_v2.exceptions import UserError, call_cmd, raise_for_status
from daras_ai_v2.functional import (
map_parallel,
apply_parallel,
flatmap_parallel_ascompleted,
map_parallel,
)
from daras_ai_v2.gdrive_downloader import (
gdrive_download,
is_gdrive_url,
gdrive_metadata,
is_gdrive_presentation_url,
is_gdrive_url,
url_to_gdrive_file_id,
gdrive_metadata,
)
from daras_ai_v2.office_utils_pptx import pptx_to_text_pages
from daras_ai_v2.redis_cache import redis_lock
from daras_ai_v2.scraping_proxy import (
SCRAPING_PROXIES,
get_scraping_proxy_cert_path,
requests_scraping_kwargs,
SCRAPING_PROXIES,
)
from daras_ai_v2.search_ref import (
SearchReference,
remove_quotes,
generate_text_fragment_url,
remove_quotes,
)
from daras_ai_v2.office_utils_pptx import pptx_to_text_pages
from daras_ai_v2.text_splitter import text_splitter, Document
from embeddings.models import EmbeddedFile, EmbeddingsReference
from files.models import FileMetadata
from daras_ai_v2.text_splitter import Document, text_splitter


class DocSearchRequest(BaseModel):
Expand Down Expand Up @@ -145,15 +146,13 @@ def get_top_k_references(
else:
selected_asr_model = google_translate_target = None

yield "Creating knowledge embeddings..."

embedding_model = EmbeddingModels.get(
request.embedding_model,
default=EmbeddingModels.get(
EmbeddedFile._meta.get_field("embedding_model").default
),
)
embedded_files: list[EmbeddedFile] = map_parallel(
embedded_files: list[EmbeddedFile] = yield from apply_parallel(
lambda f_url: get_or_create_embedded_file(
f_url=f_url,
max_context_words=request.max_context_words,
Expand All @@ -166,6 +165,7 @@ def get_top_k_references(
),
input_docs,
max_workers=4,
message="Fetching latest knowledge docs & Embeddings...",
)
if not embedded_files:
yield "No embeddings found - skipping search"
Expand Down Expand Up @@ -263,25 +263,44 @@ def get_vespa_app():
return Vespa(url=settings.VESPA_URL)


def doc_or_yt_url_to_metadatas(f_url: str) -> list[tuple[str, FileMetadata]]:
def doc_or_yt_url_to_file_metas(
f_url: str,
) -> tuple[FileMetadata, list[tuple[str, FileMetadata]]]:
if is_yt_dlp_able_url(f_url):
entries = yt_dlp_get_video_entries(f_url)
return [
(
entry["webpage_url"],
FileMetadata(
name=entry.get("title", "YouTube Video"),
# youtube doesn't provide etag, so we use filesize_approx or upload_date
etag=entry.get("filesize_approx") or entry.get("upload_date"),
# we will later convert & save as wav
mime_type="audio/wav",
total_bytes=entry.get("filesize_approx", 0),
),
)
for entry in entries
]
data = yt_dlp_extract_info(f_url)
if data.get("_type") == "playlist":
file_meta = yt_info_to_playlist_metadata(data)
return file_meta, [
(entry["url"], yt_info_to_video_metadata(entry))
for entry in yt_dlp_info_to_entries(data)
]
else:
file_meta = yt_info_to_video_metadata(data)
return file_meta, [(f_url, file_meta)]
else:
return [(f_url, doc_url_to_file_metadata(f_url))]
file_meta = doc_url_to_file_metadata(f_url)
return file_meta, [(f_url, file_meta)]


def yt_info_to_playlist_metadata(data: dict) -> FileMetadata:
return FileMetadata(
name=data.get("title", "YouTube Playlist"),
# youtube doesn't provide etag, so we use modified_date / playlist_count
etag=data.get("modified_date") or data.get("playlist_count"),
# will be converted later & saved as wav
mime_type="audio/wav",
)


def yt_info_to_video_metadata(data: dict) -> FileMetadata:
return FileMetadata(
name=data.get("title", "YouTube Video"),
# youtube doesn't provide etag, so we use filesize_approx or upload_date
etag=data.get("filesize_approx") or data.get("upload_date"),
# we will later convert & save as wav
mime_type="audio/wav",
total_bytes=data.get("filesize_approx", 0),
)


def doc_url_to_file_metadata(f_url: str) -> FileMetadata:
Expand All @@ -308,13 +327,7 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata:
else:
try:
if is_user_uploaded_url(f_url):
name = f.path.segments[-1]
return FileMetadata(
name=name,
etag=None,
mime_type=mimetypes.guess_type(name)[0],
total_bytes=0,
)
r = requests.head(f_url)
else:
r = requests.head(
f_url,
Expand Down Expand Up @@ -358,21 +371,24 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata:
return file_metadata


def yt_dlp_get_video_entries(url: str) -> list[dict]:
data = yt_dlp_extract_info(url)
entries = data.get("entries", [data])
def yt_dlp_info_to_entries(data: dict) -> list[dict]:
entries = data.pop("entries", [data])
return [e for e in entries if e]


def yt_dlp_extract_info(url: str) -> dict:
def yt_dlp_extract_info(url: str, **params) -> dict:
import yt_dlp

# https://github.com/yt-dlp/yt-dlp/blob/master/yt_dlp/options.py
params = dict(
ignoreerrors=True,
check_formats=False,
proxy=SCRAPING_PROXIES.get("https"),
client_certificate=get_scraping_proxy_cert_path(),
params = (
dict(
ignoreerrors=True,
check_formats=False,
extract_flat="in_playlist",
proxy=SCRAPING_PROXIES.get("https"),
client_certificate=get_scraping_proxy_cert_path(),
)
| params
)
with yt_dlp.YoutubeDL(params) as ydl:
data = ydl.extract_info(url, download=False)
Expand Down Expand Up @@ -407,63 +423,56 @@ def get_or_create_embedded_file(
selected_asr_model=selected_asr_model or "",
embedding_model=embedding_model.name,
)

file_id = hashlib.sha256(str(lookup).encode()).hexdigest()
with redis_lock(f"gooey/get_or_create_embeddings/v1/{file_id}"):
lock_id = hashlib.sha256(str(lookup).encode()).hexdigest()
with redis_lock(f"gooey/get_or_create_embeddings/v1/{lock_id}"):
try:
embedded_file = EmbeddedFile.objects.filter(**lookup).order_by(
"-updated_at"
)[0]
except IndexError:
embedded_file = None
else:
# skip metadata check for bucket urls (since they are unique & static)
if is_user_uploaded_url(f_url):
return embedded_file

if not is_yt_dlp_able_url(f_url):
file_meta = doc_or_yt_url_to_metadatas(f_url)[0][1]
if file_meta == embedded_file.metadata:
return embedded_file
pass
file_meta, leaf_url_metas = doc_or_yt_url_to_file_metas(f_url)
if embedded_file and embedded_file.metadata.astuple() == file_meta.astuple():
# metadata hasn't changed, return existing file
return embedded_file

except IndexError:
file_id_fields = lookup | dict(metadata=file_meta.astuple())
file_id = hashlib.sha256(str(file_id_fields).encode()).hexdigest()

for leaf_url, file_meta in doc_or_yt_url_to_metadatas(f_url):
lookup.update(
metadata__name=file_meta.name,
metadata__etag=file_meta.etag,
metadata__mime_type=file_meta.mime_type,
metadata__total_bytes=file_meta.total_bytes,
)
file_id = hashlib.sha256(str(lookup).encode()).hexdigest()

refs = create_embeddings_in_search_db(
f_url=f_url,
file_meta=file_meta,
file_id=file_id,
max_context_words=max_context_words,
scroll_jump=scroll_jump,
google_translate_target=google_translate_target or "",
selected_asr_model=selected_asr_model or "",
embedding_model=embedding_model,
is_user_url=is_user_url,
)
with transaction.atomic():
file_meta.save()
embedded_file = EmbeddedFile.objects.get_or_create(
**lookup,
defaults=dict(
metadata=file_meta,
vespa_file_id=file_id,
created_by=current_user,
),
)[0]
for ref in refs:
ref.embedded_file = embedded_file
EmbeddingsReference.objects.bulk_create(
refs,
update_conflicts=True,
update_fields=["url", "title", "snippet", "updated_at"],
unique_fields=["vespa_doc_id"],
)
return embedded_file
# create fresh embeddings
for leaf_url, leaf_meta in leaf_url_metas:
refs = create_embeddings_in_search_db(
f_url=leaf_url,
file_meta=leaf_meta,
file_id=file_id,
max_context_words=max_context_words,
scroll_jump=scroll_jump,
google_translate_target=google_translate_target or "",
selected_asr_model=selected_asr_model or "",
embedding_model=embedding_model,
is_user_url=is_user_url,
)
with transaction.atomic():
EmbeddedFile.objects.filter(**lookup).delete()
file_meta.save()
embedded_file = EmbeddedFile.objects.get_or_create(
vespa_file_id=file_id,
defaults=lookup | dict(metadata=file_meta, created_by=current_user),
)[0]
for ref in refs:
ref.embedded_file = embedded_file
EmbeddingsReference.objects.bulk_create(
refs,
update_conflicts=True,
update_fields=["url", "title", "snippet", "updated_at"],
unique_fields=["vespa_doc_id"],
)
return embedded_file


def create_embeddings_in_search_db(
Expand Down
4 changes: 4 additions & 0 deletions files/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class FileMetadata(models.Model):
etag = models.CharField(max_length=255, null=True)
mime_type = models.CharField(max_length=255, default="", blank=True)
total_bytes = models.PositiveIntegerField(default=0, blank=True)

export_links: dict[str, str] | None = None

def __str__(self):
Expand All @@ -17,6 +18,9 @@ def __str__(self):
ret += f" - {self.etag}"
return ret

def astuple(self) -> tuple:
return self.name, self.etag, self.mime_type, self.total_bytes

class Meta:
indexes = [
models.Index(fields=["name", "etag", "mime_type", "total_bytes"]),
Expand Down
Loading

0 comments on commit cc1cb0a

Please sign in to comment.