diff --git a/visual-retrieval-colpali/.gitignore b/visual-retrieval-colpali/.gitignore index 0e95ca24d..d48359af6 100644 --- a/visual-retrieval-colpali/.gitignore +++ b/visual-retrieval-colpali/.gitignore @@ -8,4 +8,6 @@ template/ *.json output/ pdfs/ -static/saved/ \ No newline at end of file +static/full_images/ +static/sim_maps/ +embeddings/ \ No newline at end of file diff --git a/visual-retrieval-colpali/backend/colpali.py b/visual-retrieval-colpali/backend/colpali.py index 8c337bcd9..30b0d1f17 100644 --- a/visual-retrieval-colpali/backend/colpali.py +++ b/visual-retrieval-colpali/backend/colpali.py @@ -7,7 +7,7 @@ from pathlib import Path import base64 from io import BytesIO -from typing import Union, Tuple, List, Dict, Any +from typing import Union, Tuple, List import matplotlib import matplotlib.cm as cm import re @@ -49,7 +49,7 @@ def load_model() -> Tuple[ColPali, ColPaliProcessor]: # Load the processor processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) - return model, processor + return model, processor, device def load_vit_config(model): @@ -63,7 +63,6 @@ def gen_similarity_maps( model: ColPali, processor: ColPaliProcessor, device, - vit_config, query: str, query_embs: torch.Tensor, token_idx_map: dict, @@ -88,7 +87,7 @@ def gen_similarity_maps( Tuple[int, str, str]: A tuple containing the image index, the selected token, and the base64-encoded image. """ - + vit_config = load_vit_config(model) # Process images and store original images and sizes processed_images = [] original_images = [] @@ -254,7 +253,7 @@ def gen_similarity_maps( # Store the base64-encoded image result_per_image[token] = blended_img_base64 - yield idx, token, blended_img_base64 + yield idx, token, token_idx, blended_img_base64 end3 = time.perf_counter() print(f"Blending images took: {end3 - start3} s") @@ -287,54 +286,3 @@ def is_special_token(token: str) -> bool: if (len(token) < 3) or pattern.match(token): return True return False - - -def add_sim_maps_to_result( - result: Dict[str, Any], - model: ColPali, - processor: ColPaliProcessor, - query: str, - q_embs: Any, - token_to_idx: Dict[str, int], - query_id: str, - result_cache, -) -> Dict[str, Any]: - vit_config = load_vit_config(model) - imgs: List[str] = [] - vespa_sim_maps: List[str] = [] - for single_result in result["root"]["children"]: - img = single_result["fields"]["blur_image"] - if img: - imgs.append(img) - vespa_sim_map = single_result["fields"].get("summaryfeatures", None) - if vespa_sim_map: - vespa_sim_maps.append(vespa_sim_map) - if not imgs: - return result - sim_map_imgs_generator = gen_similarity_maps( - model=model, - processor=processor, - device=model.device if hasattr(model, "device") else "cpu", - vit_config=vit_config, - query=query, - query_embs=q_embs, - token_idx_map=token_to_idx, - images=imgs, - vespa_sim_maps=vespa_sim_maps, - ) - for img_idx, token, sim_mapb64 in sim_map_imgs_generator: - print(f"Created sim map for image {img_idx} and token {token}") - if ( - len(result["root"]["children"]) > img_idx - and "fields" in result["root"]["children"][img_idx] - and "sim_map" in result["root"]["children"][img_idx]["fields"] - ): - result["root"]["children"][img_idx]["fields"][f"sim_map_{token}"] = ( - sim_mapb64 - ) - # Update result_cache with the new sim_map - result_cache.set(query_id, result) - # for single_result, sim_map_dict in zip(result["root"]["children"], sim_map_imgs): - # for token, sim_mapb64 in sim_map_dict.items(): - # single_result["fields"][f"sim_map_{token}"] = sim_mapb64 - return result diff --git a/visual-retrieval-colpali/backend/modelmanager.py b/visual-retrieval-colpali/backend/modelmanager.py index 2b0314c24..fee8d1d25 100644 --- a/visual-retrieval-colpali/backend/modelmanager.py +++ b/visual-retrieval-colpali/backend/modelmanager.py @@ -17,7 +17,7 @@ def get_instance(): def initialize_model_and_processor(self): if self.model is None or self.processor is None: # Ensure no reinitialization - self.model, self.processor = load_model() + self.model, self.processor, self.device = load_model() if self.model is None or self.processor is None: print("Failed to initialize model or processor at startup") else: diff --git a/visual-retrieval-colpali/backend/vespa_app.py b/visual-retrieval-colpali/backend/vespa_app.py index c9e01df15..37d1fab99 100644 --- a/visual-retrieval-colpali/backend/vespa_app.py +++ b/visual-retrieval-colpali/backend/vespa_app.py @@ -1,18 +1,19 @@ import os import time from typing import Any, Dict, Tuple - +import asyncio import numpy as np import torch from dotenv import load_dotenv from vespa.application import Vespa from vespa.io import VespaQueryResponse +from .colpali import is_special_token class VespaQueryClient: MAX_QUERY_TERMS = 64 VESPA_SCHEMA_NAME = "pdf_page" - SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text,summaryfeatures" + SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text" def __init__(self): """ @@ -73,6 +74,12 @@ def __init__(self): self.app.wait_for_application_up() print(f"Connected to Vespa at {self.vespa_app_url}") + def get_fields(self, sim_map: bool = False): + if not sim_map: + return self.SELECT_FIELDS + else: + return "summaryfeatures" + def format_query_results( self, query: str, response: VespaQueryResponse, hits: int = 5 ) -> dict: @@ -100,6 +107,7 @@ async def query_vespa_default( q_emb: torch.Tensor, hits: int = 3, timeout: str = "10s", + sim_map: bool = False, **kwargs, ) -> dict: """ @@ -121,9 +129,9 @@ async def query_vespa_default( response: VespaQueryResponse = await session.query( body={ "yql": ( - f"select {self.SELECT_FIELDS} from {self.VESPA_SCHEMA_NAME} where userQuery();" + f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where userQuery();" ), - "ranking": "default", + "ranking": self.get_rank_profile("default", sim_map), "query": query, "timeout": timeout, "hits": hits, @@ -146,6 +154,7 @@ async def query_vespa_bm25( q_emb: torch.Tensor, hits: int = 3, timeout: str = "10s", + sim_map: bool = False, **kwargs, ) -> dict: """ @@ -167,9 +176,9 @@ async def query_vespa_bm25( response: VespaQueryResponse = await session.query( body={ "yql": ( - f"select {self.SELECT_FIELDS} from {self.VESPA_SCHEMA_NAME} where userQuery();" + f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where userQuery();" ), - "ranking": "bm25", + "ranking": self.get_rank_profile("bm25", sim_map), "query": query, "timeout": timeout, "hits": hits, @@ -266,30 +275,54 @@ async def get_result_from_query( Returns: Dict[str, Any]: The query results. """ - print(query) - print(token_to_idx) - - if ranking == "nn+colpali": - result = await self.query_vespa_nearest_neighbor(query, q_embs) - elif ranking == "bm25+colpali": - result = await self.query_vespa_default(query, q_embs) - elif ranking == "bm25": - result = await self.query_vespa_bm25(query, q_embs) + rank_method = ranking.split("_")[0] + sim_map: bool = len(ranking.split("_")) > 1 and ranking.split("_")[1] == "sim" + if rank_method == "nn+colpali": + result = await self.query_vespa_nearest_neighbor( + query, q_embs, sim_map=sim_map + ) + elif rank_method == "bm25+colpali": + result = await self.query_vespa_default(query, q_embs, sim_map=sim_map) + elif rank_method == "bm25": + result = await self.query_vespa_bm25(query, q_embs, sim_map=sim_map) else: - raise ValueError(f"Unsupported ranking: {ranking}") - + raise ValueError(f"Unsupported ranking: {rank_method}") # Print score, title id, and text of the results if "root" not in result or "children" not in result["root"]: result["root"] = {"children": []} return result - for idx, child in enumerate(result["root"]["children"]): - print( - f"Result {idx+1}: {child['relevance']}, {child['fields']['title']}, {child['fields']['id']}" - ) for single_result in result["root"]["children"]: print(single_result["fields"].keys()) return result + def get_sim_maps_from_query( + self, query: str, q_embs: torch.Tensor, ranking: str, token_to_idx: dict + ): + """ + Get similarity maps from Vespa based on the ranking method. + + Args: + query (str): The query text. + q_embs (torch.Tensor): Query embeddings. + ranking (str): The ranking method to use. + token_to_idx (dict): Token to index mapping. + + Returns: + Dict[str, Any]: The query results. + """ + # Get the result by calling asyncio.run + result = asyncio.run( + self.get_result_from_query(query, q_embs, ranking, token_to_idx) + ) + vespa_sim_maps = [] + for single_result in result["root"]["children"]: + vespa_sim_map = single_result["fields"].get("summaryfeatures", None) + if vespa_sim_map is not None: + vespa_sim_maps.append(vespa_sim_map) + else: + raise ValueError("No sim_map found in Vespa response") + return vespa_sim_maps + async def get_full_image_from_vespa(self, doc_id: str) -> str: """ Retrieve the full image from Vespa for a given document ID. @@ -317,6 +350,23 @@ async def get_full_image_from_vespa(self, doc_id: str) -> str: ) return response.json["root"]["children"][0]["fields"]["full_image"] + def get_results_children(self, result: VespaQueryResponse) -> list: + return result["root"]["children"] + + def results_to_search_results( + self, result: VespaQueryResponse, token_to_idx: dict + ) -> list: + # Initialize sim_map_ fields in the result + fields_to_add = [ + f"sim_map_{token}_{idx}" + for idx, token in enumerate(token_to_idx.keys()) + if not is_special_token(token) + ] + for child in result["root"]["children"]: + for sim_map_key in fields_to_add: + child["fields"][sim_map_key] = None + return self.get_results_children(result) + async def get_suggestions(self, query: str) -> list: async with self.app.asyncio(connections=1) as session: start = time.perf_counter() @@ -348,6 +398,12 @@ async def get_suggestions(self, query: str) -> list: flat_questions = [item for sublist in questions for item in sublist] return flat_questions + def get_rank_profile(self, ranking: str, sim_map: bool) -> str: + if sim_map: + return f"{ranking}_sim" + else: + return ranking + async def query_vespa_nearest_neighbor( self, query: str, @@ -355,6 +411,7 @@ async def query_vespa_nearest_neighbor( target_hits_per_query_tensor: int = 20, hits: int = 3, timeout: str = "10s", + sim_map: bool = False, **kwargs, ) -> dict: """ @@ -385,15 +442,16 @@ async def query_vespa_nearest_neighbor( binary_query_embeddings, target_hits_per_query_tensor ) query_tensors.update(nn_query_dict) - response: VespaQueryResponse = await session.query( body={ **query_tensors, "presentation.timing": True, "yql": ( - f"select {self.SELECT_FIELDS} from {self.VESPA_SCHEMA_NAME} where {nn_string} or userQuery()" + f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where {nn_string} or userQuery()" + ), + "ranking.profile": self.get_rank_profile( + "retrieval-and-rerank", sim_map ), - "ranking.profile": "retrieval-and-rerank", "timeout": timeout, "hits": hits, "query": query, diff --git a/visual-retrieval-colpali/frontend/app.py b/visual-retrieval-colpali/frontend/app.py index ebe239db4..6d60163bb 100644 --- a/visual-retrieval-colpali/frontend/app.py +++ b/visual-retrieval-colpali/frontend/app.py @@ -323,14 +323,13 @@ def SimMapButtonReady(query_id, idx, token, img_src): ) -def SimMapButtonPoll(query_id, idx, token): +def SimMapButtonPoll(query_id, idx, token, token_idx): return Button( Lucide(icon="loader-circle", size="15", cls="animate-spin"), size="sm", disabled=True, - hx_get=f"/get_sim_map?query_id={query_id}&idx={idx}&token={token}", - # Poll every x seconds, where x is 0.3 x idx, formatted to 2 decimals - hx_trigger=f"every {(idx+1)*0.3:.2f}s", + hx_get=f"/get_sim_map?query_id={query_id}&idx={idx}&token={token}&token_idx={token_idx}", + hx_trigger="every 0.5s", hx_swap="outerHTML", cls="pointer-events-auto text-xs h-5 rounded-none px-2", ) @@ -352,7 +351,6 @@ def SearchResult(results: list, query_id: Optional[str] = None): fields = result["fields"] # Extract the 'fields' part of each result blur_image_base64 = f"data:image/jpeg;base64,{fields['blur_image']}" - # Filter sim_map fields that are words with 4 or more characters sim_map_fields = { key: value for key, value in fields.items() @@ -370,14 +368,17 @@ def SearchResult(results: list, query_id: Optional[str] = None): SimMapButtonReady( query_id=query_id, idx=idx, - token=key.split("_")[-1], + token=key.split("_")[-2], img_src=sim_map_base64, ) ) else: sim_map_buttons.append( SimMapButtonPoll( - query_id=query_id, idx=idx, token=key.split("_")[-1] + query_id=query_id, + idx=idx, + token=key.split("_")[-2], + token_idx=int(key.split("_")[-1]), ) ) diff --git a/visual-retrieval-colpali/main.py b/visual-retrieval-colpali/main.py index ec87bddd9..27720f913 100644 --- a/visual-retrieval-colpali/main.py +++ b/visual-retrieval-colpali/main.py @@ -1,25 +1,33 @@ import asyncio -import base64 -import io import os import time -from concurrent.futures import ThreadPoolExecutor -from functools import partial from pathlib import Path +from concurrent.futures import ThreadPoolExecutor import uuid - import google.generativeai as genai -from fasthtml.common import * -from PIL import Image -from shad4fast import * +from fasthtml.common import ( + Div, + Img, + Main, + P, + Script, + Link, + fast_app, + HighlightJS, + FileResponse, + RedirectResponse, + Aside, + StreamingResponse, + JSONResponse, + serve, +) +from shad4fast import ShadHead from vespa.application import Vespa +import base64 +from fastcore.parallel import threaded +from PIL import Image -from backend.cache import LRUCache -from backend.colpali import ( - add_sim_maps_to_result, - get_query_embeddings_and_token_map, - is_special_token, -) +from backend.colpali import get_query_embeddings_and_token_map, gen_similarity_maps from backend.modelmanager import ModelManager from backend.vespa_app import VespaQueryClient from frontend.app import ( @@ -76,10 +84,6 @@ ), ) vespa_app: Vespa = VespaQueryClient() -result_cache = LRUCache(max_size=20) # Each result can be ~10MB -task_cache = LRUCache( - max_size=1000 -) # Map from query_id to boolean value - False if not all results are ready. thread_pool = ThreadPoolExecutor() # Gemini config @@ -94,9 +98,11 @@ gemini_model = genai.GenerativeModel( "gemini-1.5-flash-8b", system_instruction=GEMINI_SYSTEM_PROMPT ) -STATIC_DIR = Path(__file__).parent / "static" -IMG_DIR = STATIC_DIR / "saved" -os.makedirs(STATIC_DIR, exist_ok=True) +STATIC_DIR = Path("static") +IMG_DIR = STATIC_DIR / "full_images" +SIM_MAP_DIR = STATIC_DIR / "sim_maps" +os.makedirs(IMG_DIR, exist_ok=True) +os.makedirs(SIM_MAP_DIR, exist_ok=True) @app.on_event("startup") @@ -111,8 +117,9 @@ async def keepalive(): return -def generate_query_id(query): - return uuid.uuid4().hex +def generate_query_id(query, ranking_value): + hash_input = (query + ranking_value).encode("utf-8") + return hash(hash_input) @rt("/static/{filepath:path}") @@ -121,7 +128,9 @@ def serve_static(filepath: str): @rt("/") -def get(): +def get(session): + if "session_id" not in session: + session["session_id"] = str(uuid.uuid4()) return Layout(Main(Home())) @@ -156,13 +165,7 @@ def get(request): ) ) # Generate a unique query_id based on the query and ranking value - query_id = generate_query_id(query_value + ranking_value) - # See if results are already in cache - # if result_cache.get(query_id) is not None: - # print(f"Results for query_id {query_id} already in cache") - # result = result_cache.get(query_id) - # search_results = get_results_children(result) - # return Layout(Search(request, search_results)) + query_id = generate_query_id(query_value, ranking_value) # Show the loading message if a query is provided return Layout( Main(Search(request), data_overlayscrollbars_initialize=True, cls="border-t"), @@ -173,26 +176,33 @@ def get(request): ) # Show SearchBox and Loading message initially +@rt("/fetch_results2") +def get(query: str, ranking: str): + # 1. Get the results from Vespa (without sim_maps and full_images) + # Call search-endpoint in Vespa sync. + + # 2. Kick off tasks to fetch sim_maps and full_images + # Sim maps - call search endpoint async. + # (A) New rank_profile that does not calculate sim_maps. + # (A) Make vespa endpoints take select_fields as a parameter. + # One sim map per image per token. + # the filename query_id_result_idx_token_idx.png + # Full image. based on the doc_id. + # Each of these tasks saves to disk. + # Need a cleanup task to delete old files. + # Polling endpoints for sim_maps and full_images checks if file exists and returns it. + pass + + @rt("/fetch_results") -async def get(request, query: str, nn: bool = True): +async def get(session, request, query: str, ranking: str): if "hx-request" not in request.headers: return RedirectResponse("/search") - # Extract ranking option from the request - ranking_value = request.query_params.get("ranking") - print( - f"/fetch_results: Fetching results for query: {query}, ranking: {ranking_value}" - ) - # Generate a unique query_id based on the query and ranking value - query_id = generate_query_id(query + ranking_value) - # See if results are already in cache - # if result_cache.get(query_id) is not None: - # print(f"Results for query_id {query_id} already in cache") - # result = result_cache.get(query_id) - # search_results = get_results_children(result) - # return SearchResult(search_results, query_id) + # Get the hash of the query and ranking value + query_id = generate_query_id(query, ranking) + print(f"Query id in /fetch_results: {query_id}") # Run the embedding and query against Vespa app - task_cache.set(query_id, False) model = app.manager.model processor = app.manager.processor q_embs, token_to_idx = get_query_embeddings_and_token_map(processor, model, query) @@ -202,30 +212,21 @@ async def get(request, query: str, nn: bool = True): result = await vespa_app.get_result_from_query( query=query, q_embs=q_embs, - ranking=ranking_value, + ranking=ranking, token_to_idx=token_to_idx, ) end = time.perf_counter() print( f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds" ) - # Add result to cache - result_cache.set(query_id, result) - # Start generating the similarity map in the background - asyncio.create_task( - generate_similarity_map( - model, processor, query, q_embs, token_to_idx, result, query_id - ) + search_results = vespa_app.results_to_search_results(result, token_to_idx) + get_and_store_sim_maps( + query_id=query_id, + query=query, + q_embs=q_embs, + ranking=ranking, + token_to_idx=token_to_idx, ) - fields_to_add = [ - f"sim_map_{token}" - for token in token_to_idx.keys() - if not is_special_token(token) - ] - search_results = get_results_children(result) - for result in search_results: - for sim_map_key in fields_to_add: - result["fields"][sim_map_key] = None return SearchResult(search_results, query_id) @@ -245,78 +246,84 @@ async def poll_vespa_keepalive(): print(f"Vespa keepalive: {time.time()}") -async def generate_similarity_map( - model, processor, query, q_embs, token_to_idx, result, query_id -): - loop = asyncio.get_event_loop() - sim_map_task = partial( - add_sim_maps_to_result, - result=result, - model=model, - processor=processor, +@threaded +def get_and_store_sim_maps(query_id, query: str, q_embs, ranking, token_to_idx): + ranking_sim = ranking + "_sim" + vespa_sim_maps = vespa_app.get_sim_maps_from_query( query=query, q_embs=q_embs, + ranking=ranking_sim, token_to_idx=token_to_idx, - query_id=query_id, - result_cache=result_cache, ) - sim_map_result = await loop.run_in_executor(thread_pool, sim_map_task) - result_cache.set(query_id, sim_map_result) - task_cache.set(query_id, True) + img_paths = [ + IMG_DIR / f"{query_id}_{idx}.jpg" for idx in range(len(vespa_sim_maps)) + ] + # All images should be downloaded, but best to wait 5 secs + max_wait = 5 + start_time = time.time() + while ( + not all([os.path.exists(img_path) for img_path in img_paths]) + and time.time() - start_time < max_wait + ): + time.sleep(0.2) + if not all([os.path.exists(img_path) for img_path in img_paths]): + print(f"Images not ready in 5 seconds for query_id: {query_id}") + return False + sim_map_generator = gen_similarity_maps( + model=app.manager.model, + processor=app.manager.processor, + device=app.manager.device, + query=query, + query_embs=q_embs, + token_idx_map=token_to_idx, + images=img_paths, + vespa_sim_maps=vespa_sim_maps, + ) + for idx, token, token_idx, blended_img_base64 in sim_map_generator: + with open(SIM_MAP_DIR / f"{query_id}_{idx}_{token_idx}.png", "wb") as f: + f.write(base64.b64decode(blended_img_base64)) + print( + f"Sim map saved to disk for query_id: {query_id}, idx: {idx}, token: {token}" + ) + return True @app.get("/get_sim_map") -async def get_sim_map(query_id: str, idx: int, token: str): +async def get_sim_map(query_id: str, idx: int, token: str, token_idx: int): """ Endpoint that each of the sim map button polls to get the sim map image when it is ready. If it is not ready, returns a SimMapButtonPoll, that continues to poll every 1 second. """ - result = result_cache.get(query_id) - if result is None: - return SimMapButtonPoll(query_id=query_id, idx=idx, token=token) - search_results = get_results_children(result) - # Check if idx exists in list of children - if idx >= len(search_results): - return SimMapButtonPoll(query_id=query_id, idx=idx, token=token) + sim_map_path = SIM_MAP_DIR / f"{query_id}_{idx}_{token_idx}.png" + if not os.path.exists(sim_map_path): + print(f"Sim map not ready for query_id: {query_id}, idx: {idx}, token: {token}") + return SimMapButtonPoll( + query_id=query_id, idx=idx, token=token, token_idx=token_idx + ) else: - sim_map_key = f"sim_map_{token}" - sim_map_b64 = search_results[idx]["fields"].get(sim_map_key, None) - if sim_map_b64 is None: - return SimMapButtonPoll(query_id=query_id, idx=idx, token=token) - sim_map_img_src = f"data:image/png;base64,{sim_map_b64}" return SimMapButtonReady( - query_id=query_id, idx=idx, token=token, img_src=sim_map_img_src + query_id=query_id, idx=idx, token=token, img_src=sim_map_path ) -async def update_full_image_cache(docid: str, query_id: str, idx: int, image_data: str): - result = None - max_wait = 20 # seconds. If horribly slow network latency. - start_time = time.time() - while result is None and time.time() - start_time < max_wait: - result = result_cache.get(query_id) - if result is None: - await asyncio.sleep(0.1) - try: - result["root"]["children"][idx]["fields"]["full_image"] = image_data - except KeyError as err: - print(f"Error updating full image cache: {err}") - result_cache.set(query_id, result) - print(f"Full image cache updated for query_id {query_id}") - return - - @app.get("/full_image") async def full_image(docid: str, query_id: str, idx: int): """ Endpoint to get the full quality image for a given result id. """ - image_data = await vespa_app.get_full_image_from_vespa(docid) - # Update the cache with the full image data - asyncio.create_task(update_full_image_cache(docid, query_id, idx, image_data)) + img_path = IMG_DIR / f"{query_id}_{idx}.jpg" + if not os.path.exists(img_path): + image_data = await vespa_app.get_full_image_from_vespa(docid) + # image data is base 64 encoded string. Save it to disk as jpg. + with open(img_path, "wb") as f: + f.write(base64.b64decode(image_data)) + print(f"Full image saved to disk for query_id: {query_id}, idx: {idx}") + else: + with open(img_path, "rb") as f: + image_data = base64.b64encode(f.read()).decode("utf-8") return Img( - src=f"data:image/png;base64,{image_data}", + src=f"data:image/jpeg;base64,{image_data}", alt="something", cls="result-image w-full h-full object-contain", ) @@ -336,28 +343,25 @@ async def get_suggestions(request): async def message_generator(query_id: str, query: str): images = [] - result = None - all_images_ready = False + num_images = 3 # Number of images before firing chat request max_wait = 10 # seconds start_time = time.time() - while not all_images_ready and time.time() - start_time < max_wait: - result = result_cache.get(query_id) - if result is None: - await asyncio.sleep(0.1) - continue - search_results = get_results_children(result) - for single_result in search_results: - img = single_result["fields"].get("full_image", None) - if img is not None: - images.append(img) - if len(images) == len(search_results): - all_images_ready = True - break + # Check if full images are ready on disk + while len(images) < num_images and time.time() - start_time < max_wait: + for idx in range(num_images): + if not os.path.exists(IMG_DIR / f"{query_id}_{idx}.jpg"): + print( + f"Message generator: Full image not ready for query_id: {query_id}, idx: {idx}" + ) + continue else: - await asyncio.sleep(0.1) - - # from b64 to PIL image - images = [Image.open(io.BytesIO(base64.b64decode(img))) for img in images] + print( + f"Message generator: image ready for query_id: {query_id}, idx: {idx}" + ) + images.append(Image.open(IMG_DIR / f"{query_id}_{idx}.jpg")) + await asyncio.sleep(0.2) + # yield message with number of images ready + yield f"event: message\ndata: Generating response based on {len(images)} images.\n\n" if not images: yield "event: message\ndata: I am sorry, I do not have enough information in the image to answer your question.\n\n" yield "event: close\ndata: \n\n"