From a3779495bfc4c75d61b24ff586a9d492462d6669 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Sun, 29 Dec 2024 22:38:10 +0530 Subject: [PATCH] Refactor `vector_search.py` and `DocExtract.py` to improve YouTube handling and embedding process - Introduce `astuple` method in `FileMetadata` for tuple representation and correct comparison - Remove null metadata for bucket urls - Store playlists as a single `EmbeddedFile` & compare playlist level metadata when doing search - Make fetching playlist metadata faster by using `--flat-playlist` - Delete previous `EmbeddedFile` with same lookup when creating embeddings - Display url level progress when creating embeddings Enhances handling of YouTube playlists/videos and optimizes embedding management for more accurate results --- daras_ai_v2/vector_search.py | 215 ++++++++++++++++++----------------- files/models.py | 4 + recipes/DocExtract.py | 39 +++---- 3 files changed, 136 insertions(+), 122 deletions(-) diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 806a50c8a..ad322bc91 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -14,62 +14,63 @@ import gooey_gui as gui import numpy as np import requests +from app_users.models import AppUser +from daras_ai.image_input import ( + get_mimetype_from_response, + safe_filename, + upload_file_from_bytes, +) from django.db import transaction from django.db.models import F from django.utils import timezone +from embeddings.models import EmbeddedFile, EmbeddingsReference +from files.models import FileMetadata from furl import furl from loguru import logger from pydantic import BaseModel, Field -from app_users.models import AppUser -from daras_ai.image_input import ( - upload_file_from_bytes, - safe_filename, - get_mimetype_from_response, -) from daras_ai_v2 import settings from daras_ai_v2.asr import ( AsrModels, + download_youtube_to_wav, run_asr, run_google_translate, - download_youtube_to_wav, ) from daras_ai_v2.azure_doc_extract import ( - table_arr_to_prompt_chunked, THEAD, azure_doc_extract_page_num, + table_arr_to_prompt_chunked, ) from daras_ai_v2.doc_search_settings_widgets import ( is_user_uploaded_url, ) -from daras_ai_v2.embedding_model import create_embeddings_cached, EmbeddingModels -from daras_ai_v2.exceptions import raise_for_status, call_cmd, UserError +from daras_ai_v2.embedding_model import EmbeddingModels, create_embeddings_cached +from daras_ai_v2.exceptions import UserError, call_cmd, raise_for_status from daras_ai_v2.functional import ( - map_parallel, + apply_parallel, flatmap_parallel_ascompleted, + map_parallel, ) from daras_ai_v2.gdrive_downloader import ( gdrive_download, - is_gdrive_url, + gdrive_metadata, is_gdrive_presentation_url, + is_gdrive_url, url_to_gdrive_file_id, - gdrive_metadata, ) +from daras_ai_v2.office_utils_pptx import pptx_to_text_pages from daras_ai_v2.redis_cache import redis_lock from daras_ai_v2.scraping_proxy import ( + SCRAPING_PROXIES, get_scraping_proxy_cert_path, requests_scraping_kwargs, - SCRAPING_PROXIES, ) from daras_ai_v2.search_ref import ( SearchReference, - remove_quotes, generate_text_fragment_url, + remove_quotes, ) -from daras_ai_v2.office_utils_pptx import pptx_to_text_pages -from daras_ai_v2.text_splitter import text_splitter, Document -from embeddings.models import EmbeddedFile, EmbeddingsReference -from files.models import FileMetadata +from daras_ai_v2.text_splitter import Document, text_splitter class DocSearchRequest(BaseModel): @@ -145,15 +146,13 @@ def get_top_k_references( else: selected_asr_model = google_translate_target = None - yield "Creating knowledge embeddings..." - embedding_model = EmbeddingModels.get( request.embedding_model, default=EmbeddingModels.get( EmbeddedFile._meta.get_field("embedding_model").default ), ) - embedded_files: list[EmbeddedFile] = map_parallel( + embedded_files: list[EmbeddedFile] = yield from apply_parallel( lambda f_url: get_or_create_embedded_file( f_url=f_url, max_context_words=request.max_context_words, @@ -166,6 +165,7 @@ def get_top_k_references( ), input_docs, max_workers=4, + message="Fetching latest knowledge docs & Embeddings...", ) if not embedded_files: yield "No embeddings found - skipping search" @@ -263,25 +263,44 @@ def get_vespa_app(): return Vespa(url=settings.VESPA_URL) -def doc_or_yt_url_to_metadatas(f_url: str) -> list[tuple[str, FileMetadata]]: +def doc_or_yt_url_to_file_metas( + f_url: str, +) -> tuple[FileMetadata, list[tuple[str, FileMetadata]]]: if is_yt_dlp_able_url(f_url): - entries = yt_dlp_get_video_entries(f_url) - return [ - ( - entry["webpage_url"], - FileMetadata( - name=entry.get("title", "YouTube Video"), - # youtube doesn't provide etag, so we use filesize_approx or upload_date - etag=entry.get("filesize_approx") or entry.get("upload_date"), - # we will later convert & save as wav - mime_type="audio/wav", - total_bytes=entry.get("filesize_approx", 0), - ), - ) - for entry in entries - ] + data = yt_dlp_extract_info(f_url) + if data.get("_type") == "playlist": + file_meta = yt_info_to_playlist_metadata(data) + return file_meta, [ + (entry["url"], yt_info_to_video_metadata(entry)) + for entry in yt_dlp_info_to_entries(data) + ] + else: + file_meta = yt_info_to_video_metadata(data) + return file_meta, [(f_url, file_meta)] else: - return [(f_url, doc_url_to_file_metadata(f_url))] + file_meta = doc_url_to_file_metadata(f_url) + return file_meta, [(f_url, file_meta)] + + +def yt_info_to_playlist_metadata(data: dict) -> FileMetadata: + return FileMetadata( + name=data.get("title", "YouTube Playlist"), + # youtube doesn't provide etag, so we use modified_date / playlist_count + etag=data.get("modified_date") or data.get("playlist_count"), + # will be converted later & saved as wav + mime_type="audio/wav", + ) + + +def yt_info_to_video_metadata(data: dict) -> FileMetadata: + return FileMetadata( + name=data.get("title", "YouTube Video"), + # youtube doesn't provide etag, so we use filesize_approx or upload_date + etag=data.get("filesize_approx") or data.get("upload_date"), + # we will later convert & save as wav + mime_type="audio/wav", + total_bytes=data.get("filesize_approx", 0), + ) def doc_url_to_file_metadata(f_url: str) -> FileMetadata: @@ -308,13 +327,7 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: else: try: if is_user_uploaded_url(f_url): - name = f.path.segments[-1] - return FileMetadata( - name=name, - etag=None, - mime_type=mimetypes.guess_type(name)[0], - total_bytes=0, - ) + r = requests.head(f_url) else: r = requests.head( f_url, @@ -358,21 +371,24 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: return file_metadata -def yt_dlp_get_video_entries(url: str) -> list[dict]: - data = yt_dlp_extract_info(url) - entries = data.get("entries", [data]) +def yt_dlp_info_to_entries(data: dict) -> list[dict]: + entries = data.pop("entries", [data]) return [e for e in entries if e] -def yt_dlp_extract_info(url: str) -> dict: +def yt_dlp_extract_info(url: str, **params) -> dict: import yt_dlp # https://github.com/yt-dlp/yt-dlp/blob/master/yt_dlp/options.py - params = dict( - ignoreerrors=True, - check_formats=False, - proxy=SCRAPING_PROXIES.get("https"), - client_certificate=get_scraping_proxy_cert_path(), + params = ( + dict( + ignoreerrors=True, + check_formats=False, + extract_flat="in_playlist", + proxy=SCRAPING_PROXIES.get("https"), + client_certificate=get_scraping_proxy_cert_path(), + ) + | params ) with yt_dlp.YoutubeDL(params) as ydl: data = ydl.extract_info(url, download=False) @@ -407,63 +423,56 @@ def get_or_create_embedded_file( selected_asr_model=selected_asr_model or "", embedding_model=embedding_model.name, ) - - file_id = hashlib.sha256(str(lookup).encode()).hexdigest() - with redis_lock(f"gooey/get_or_create_embeddings/v1/{file_id}"): + lock_id = hashlib.sha256(str(lookup).encode()).hexdigest() + with redis_lock(f"gooey/get_or_create_embeddings/v1/{lock_id}"): try: embedded_file = EmbeddedFile.objects.filter(**lookup).order_by( "-updated_at" )[0] + except IndexError: + embedded_file = None + else: + # skip metadata check for bucket urls (since they are unique & static) if is_user_uploaded_url(f_url): return embedded_file - if not is_yt_dlp_able_url(f_url): - file_meta = doc_or_yt_url_to_metadatas(f_url)[0][1] - if file_meta == embedded_file.metadata: - return embedded_file - pass + file_meta, leaf_url_metas = doc_or_yt_url_to_file_metas(f_url) + if embedded_file and embedded_file.metadata.astuple() == file_meta.astuple(): + # metadata hasn't changed, return existing file + return embedded_file - except IndexError: + file_id_fields = lookup | dict(metadata=file_meta.astuple()) + file_id = hashlib.sha256(str(file_id_fields).encode()).hexdigest() - for leaf_url, file_meta in doc_or_yt_url_to_metadatas(f_url): - lookup.update( - metadata__name=file_meta.name, - metadata__etag=file_meta.etag, - metadata__mime_type=file_meta.mime_type, - metadata__total_bytes=file_meta.total_bytes, - ) - file_id = hashlib.sha256(str(lookup).encode()).hexdigest() - - refs = create_embeddings_in_search_db( - f_url=f_url, - file_meta=file_meta, - file_id=file_id, - max_context_words=max_context_words, - scroll_jump=scroll_jump, - google_translate_target=google_translate_target or "", - selected_asr_model=selected_asr_model or "", - embedding_model=embedding_model, - is_user_url=is_user_url, - ) - with transaction.atomic(): - file_meta.save() - embedded_file = EmbeddedFile.objects.get_or_create( - **lookup, - defaults=dict( - metadata=file_meta, - vespa_file_id=file_id, - created_by=current_user, - ), - )[0] - for ref in refs: - ref.embedded_file = embedded_file - EmbeddingsReference.objects.bulk_create( - refs, - update_conflicts=True, - update_fields=["url", "title", "snippet", "updated_at"], - unique_fields=["vespa_doc_id"], - ) - return embedded_file + # create fresh embeddings + for leaf_url, leaf_meta in leaf_url_metas: + refs = create_embeddings_in_search_db( + f_url=leaf_url, + file_meta=leaf_meta, + file_id=file_id, + max_context_words=max_context_words, + scroll_jump=scroll_jump, + google_translate_target=google_translate_target or "", + selected_asr_model=selected_asr_model or "", + embedding_model=embedding_model, + is_user_url=is_user_url, + ) + with transaction.atomic(): + EmbeddedFile.objects.filter(**lookup).delete() + file_meta.save() + embedded_file = EmbeddedFile.objects.get_or_create( + vespa_file_id=file_id, + defaults=lookup | dict(metadata=file_meta, created_by=current_user), + )[0] + for ref in refs: + ref.embedded_file = embedded_file + EmbeddingsReference.objects.bulk_create( + refs, + update_conflicts=True, + update_fields=["url", "title", "snippet", "updated_at"], + unique_fields=["vespa_doc_id"], + ) + return embedded_file def create_embeddings_in_search_db( diff --git a/files/models.py b/files/models.py index b03a598b3..a84a5303a 100644 --- a/files/models.py +++ b/files/models.py @@ -7,6 +7,7 @@ class FileMetadata(models.Model): etag = models.CharField(max_length=255, null=True) mime_type = models.CharField(max_length=255, default="", blank=True) total_bytes = models.PositiveIntegerField(default=0, blank=True) + export_links: dict[str, str] | None = None def __str__(self): @@ -17,6 +18,9 @@ def __str__(self): ret += f" - {self.etag}" return ret + def astuple(self) -> tuple: + return self.name, self.etag, self.mime_type, self.total_bytes + class Meta: indexes = [ models.Index(fields=["name", "etag", "mime_type", "total_bytes"]), diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 0fa063379..35a73e9e9 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -5,20 +5,16 @@ import gooey_gui as gui import requests from aifail import retry_if -from django.db.models import IntegerChoices -from furl import furl -from pydantic import BaseModel, Field - from bots.models import Workflow from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import settings from daras_ai_v2.asr import ( - run_translate, AsrModels, - run_asr, - download_youtube_to_wav_url, - audio_url_to_wav, TranslationModels, + audio_url_to_wav, + download_youtube_to_wav_url, + run_asr, + run_translate, ) from daras_ai_v2.azure_doc_extract import ( azure_doc_extract_page_num, @@ -33,15 +29,15 @@ apply_parallel, flatapply_parallel, ) -from daras_ai_v2.gdrive_downloader import is_gdrive_url, gdrive_download +from daras_ai_v2.gdrive_downloader import gdrive_download, is_gdrive_url from daras_ai_v2.language_model import ( - run_language_model, LargeLanguageModels, + run_language_model, ) from daras_ai_v2.language_model_settings_widgets import ( - language_model_settings, - language_model_selector, LanguageModelSettings, + language_model_selector, + language_model_settings, ) from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.pydantic_validation import FieldHttpUrl @@ -49,17 +45,22 @@ from daras_ai_v2.settings import service_account_key_path from daras_ai_v2.vector_search import ( add_page_number_to_pdf, - yt_dlp_get_video_entries, + doc_or_yt_url_to_file_metas, doc_url_to_file_metadata, - get_pdf_num_pages, doc_url_to_text_pages, - doc_or_yt_url_to_metadatas, + get_pdf_num_pages, is_yt_dlp_able_url, + yt_dlp_extract_info, + yt_dlp_info_to_entries, ) +from django.db.models import IntegerChoices from files.models import FileMetadata +from furl import furl +from pydantic import BaseModel, Field + +from recipes.asr_page import AsrPage from recipes.DocSearch import render_documents from recipes.Translation import TranslationOptions -from recipes.asr_page import AsrPage DEFAULT_YOUTUBE_BOT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ddc8ffac-93fb-11ee-89fb-02420a0001cb/Youtube%20transcripts.jpg.png" @@ -226,7 +227,7 @@ def run_v2( ) else: file_url_metas = yield from flatapply_parallel( - doc_or_yt_url_to_metadatas, + lambda f_url: doc_or_yt_url_to_file_metas(f_url)[1], request.documents, message="Extracting metadata...", ) @@ -367,7 +368,7 @@ def col_i2a(col: int) -> str: def extract_info(url: str) -> list[dict | None]: if is_yt_dlp_able_url(url): - return yt_dlp_get_video_entries(url) + return yt_dlp_info_to_entries(yt_dlp_extract_info(url)) # assume it's a direct link file_meta = doc_url_to_file_metadata(url) @@ -582,8 +583,8 @@ def update_cell(spreadsheet_id: str, row: int, col: int, value: str): def get_spreadsheet_service(): - from oauth2client.service_account import ServiceAccountCredentials from googleapiclient.discovery import build + from oauth2client.service_account import ServiceAccountCredentials try: return threadlocal.spreadsheets