Skip to content

Commit

Permalink
Merge pull request #1545 from vespa-engine/thomasht86/fix-sim-maps
Browse files Browse the repository at this point in the history
(colpalidemo) fix sim maps
  • Loading branch information
andreer authored Oct 28, 2024
2 parents 4589eb4 + 596da92 commit 9570d72
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 220 deletions.
4 changes: 3 additions & 1 deletion visual-retrieval-colpali/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ template/
*.json
output/
pdfs/
static/saved/
static/full_images/
static/sim_maps/
embeddings/
60 changes: 4 additions & 56 deletions visual-retrieval-colpali/backend/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pathlib import Path
import base64
from io import BytesIO
from typing import Union, Tuple, List, Dict, Any
from typing import Union, Tuple, List
import matplotlib
import matplotlib.cm as cm
import re
Expand Down Expand Up @@ -49,7 +49,7 @@ def load_model() -> Tuple[ColPali, ColPaliProcessor]:

# Load the processor
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model, processor
return model, processor, device


def load_vit_config(model):
Expand All @@ -63,7 +63,6 @@ def gen_similarity_maps(
model: ColPali,
processor: ColPaliProcessor,
device,
vit_config,
query: str,
query_embs: torch.Tensor,
token_idx_map: dict,
Expand All @@ -88,7 +87,7 @@ def gen_similarity_maps(
Tuple[int, str, str]: A tuple containing the image index, the selected token, and the base64-encoded image.
"""

vit_config = load_vit_config(model)
# Process images and store original images and sizes
processed_images = []
original_images = []
Expand Down Expand Up @@ -254,7 +253,7 @@ def gen_similarity_maps(

# Store the base64-encoded image
result_per_image[token] = blended_img_base64
yield idx, token, blended_img_base64
yield idx, token, token_idx, blended_img_base64
end3 = time.perf_counter()
print(f"Blending images took: {end3 - start3} s")

Expand Down Expand Up @@ -287,54 +286,3 @@ def is_special_token(token: str) -> bool:
if (len(token) < 3) or pattern.match(token):
return True
return False


def add_sim_maps_to_result(
result: Dict[str, Any],
model: ColPali,
processor: ColPaliProcessor,
query: str,
q_embs: Any,
token_to_idx: Dict[str, int],
query_id: str,
result_cache,
) -> Dict[str, Any]:
vit_config = load_vit_config(model)
imgs: List[str] = []
vespa_sim_maps: List[str] = []
for single_result in result["root"]["children"]:
img = single_result["fields"]["blur_image"]
if img:
imgs.append(img)
vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
if vespa_sim_map:
vespa_sim_maps.append(vespa_sim_map)
if not imgs:
return result
sim_map_imgs_generator = gen_similarity_maps(
model=model,
processor=processor,
device=model.device if hasattr(model, "device") else "cpu",
vit_config=vit_config,
query=query,
query_embs=q_embs,
token_idx_map=token_to_idx,
images=imgs,
vespa_sim_maps=vespa_sim_maps,
)
for img_idx, token, sim_mapb64 in sim_map_imgs_generator:
print(f"Created sim map for image {img_idx} and token {token}")
if (
len(result["root"]["children"]) > img_idx
and "fields" in result["root"]["children"][img_idx]
and "sim_map" in result["root"]["children"][img_idx]["fields"]
):
result["root"]["children"][img_idx]["fields"][f"sim_map_{token}"] = (
sim_mapb64
)
# Update result_cache with the new sim_map
result_cache.set(query_id, result)
# for single_result, sim_map_dict in zip(result["root"]["children"], sim_map_imgs):
# for token, sim_mapb64 in sim_map_dict.items():
# single_result["fields"][f"sim_map_{token}"] = sim_mapb64
return result
2 changes: 1 addition & 1 deletion visual-retrieval-colpali/backend/modelmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_instance():

def initialize_model_and_processor(self):
if self.model is None or self.processor is None: # Ensure no reinitialization
self.model, self.processor = load_model()
self.model, self.processor, self.device = load_model()
if self.model is None or self.processor is None:
print("Failed to initialize model or processor at startup")
else:
Expand Down
106 changes: 82 additions & 24 deletions visual-retrieval-colpali/backend/vespa_app.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import os
import time
from typing import Any, Dict, Tuple

import asyncio
import numpy as np
import torch
from dotenv import load_dotenv
from vespa.application import Vespa
from vespa.io import VespaQueryResponse
from .colpali import is_special_token


class VespaQueryClient:
MAX_QUERY_TERMS = 64
VESPA_SCHEMA_NAME = "pdf_page"
SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text,summaryfeatures"
SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text"

def __init__(self):
"""
Expand Down Expand Up @@ -73,6 +74,12 @@ def __init__(self):
self.app.wait_for_application_up()
print(f"Connected to Vespa at {self.vespa_app_url}")

def get_fields(self, sim_map: bool = False):
if not sim_map:
return self.SELECT_FIELDS
else:
return "summaryfeatures"

def format_query_results(
self, query: str, response: VespaQueryResponse, hits: int = 5
) -> dict:
Expand Down Expand Up @@ -100,6 +107,7 @@ async def query_vespa_default(
q_emb: torch.Tensor,
hits: int = 3,
timeout: str = "10s",
sim_map: bool = False,
**kwargs,
) -> dict:
"""
Expand All @@ -121,9 +129,9 @@ async def query_vespa_default(
response: VespaQueryResponse = await session.query(
body={
"yql": (
f"select {self.SELECT_FIELDS} from {self.VESPA_SCHEMA_NAME} where userQuery();"
f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where userQuery();"
),
"ranking": "default",
"ranking": self.get_rank_profile("default", sim_map),
"query": query,
"timeout": timeout,
"hits": hits,
Expand All @@ -146,6 +154,7 @@ async def query_vespa_bm25(
q_emb: torch.Tensor,
hits: int = 3,
timeout: str = "10s",
sim_map: bool = False,
**kwargs,
) -> dict:
"""
Expand All @@ -167,9 +176,9 @@ async def query_vespa_bm25(
response: VespaQueryResponse = await session.query(
body={
"yql": (
f"select {self.SELECT_FIELDS} from {self.VESPA_SCHEMA_NAME} where userQuery();"
f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where userQuery();"
),
"ranking": "bm25",
"ranking": self.get_rank_profile("bm25", sim_map),
"query": query,
"timeout": timeout,
"hits": hits,
Expand Down Expand Up @@ -266,30 +275,54 @@ async def get_result_from_query(
Returns:
Dict[str, Any]: The query results.
"""
print(query)
print(token_to_idx)

if ranking == "nn+colpali":
result = await self.query_vespa_nearest_neighbor(query, q_embs)
elif ranking == "bm25+colpali":
result = await self.query_vespa_default(query, q_embs)
elif ranking == "bm25":
result = await self.query_vespa_bm25(query, q_embs)
rank_method = ranking.split("_")[0]
sim_map: bool = len(ranking.split("_")) > 1 and ranking.split("_")[1] == "sim"
if rank_method == "nn+colpali":
result = await self.query_vespa_nearest_neighbor(
query, q_embs, sim_map=sim_map
)
elif rank_method == "bm25+colpali":
result = await self.query_vespa_default(query, q_embs, sim_map=sim_map)
elif rank_method == "bm25":
result = await self.query_vespa_bm25(query, q_embs, sim_map=sim_map)
else:
raise ValueError(f"Unsupported ranking: {ranking}")

raise ValueError(f"Unsupported ranking: {rank_method}")
# Print score, title id, and text of the results
if "root" not in result or "children" not in result["root"]:
result["root"] = {"children": []}
return result
for idx, child in enumerate(result["root"]["children"]):
print(
f"Result {idx+1}: {child['relevance']}, {child['fields']['title']}, {child['fields']['id']}"
)
for single_result in result["root"]["children"]:
print(single_result["fields"].keys())
return result

def get_sim_maps_from_query(
self, query: str, q_embs: torch.Tensor, ranking: str, token_to_idx: dict
):
"""
Get similarity maps from Vespa based on the ranking method.
Args:
query (str): The query text.
q_embs (torch.Tensor): Query embeddings.
ranking (str): The ranking method to use.
token_to_idx (dict): Token to index mapping.
Returns:
Dict[str, Any]: The query results.
"""
# Get the result by calling asyncio.run
result = asyncio.run(
self.get_result_from_query(query, q_embs, ranking, token_to_idx)
)
vespa_sim_maps = []
for single_result in result["root"]["children"]:
vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
if vespa_sim_map is not None:
vespa_sim_maps.append(vespa_sim_map)
else:
raise ValueError("No sim_map found in Vespa response")
return vespa_sim_maps

async def get_full_image_from_vespa(self, doc_id: str) -> str:
"""
Retrieve the full image from Vespa for a given document ID.
Expand Down Expand Up @@ -317,6 +350,23 @@ async def get_full_image_from_vespa(self, doc_id: str) -> str:
)
return response.json["root"]["children"][0]["fields"]["full_image"]

def get_results_children(self, result: VespaQueryResponse) -> list:
return result["root"]["children"]

def results_to_search_results(
self, result: VespaQueryResponse, token_to_idx: dict
) -> list:
# Initialize sim_map_ fields in the result
fields_to_add = [
f"sim_map_{token}_{idx}"
for idx, token in enumerate(token_to_idx.keys())
if not is_special_token(token)
]
for child in result["root"]["children"]:
for sim_map_key in fields_to_add:
child["fields"][sim_map_key] = None
return self.get_results_children(result)

async def get_suggestions(self, query: str) -> list:
async with self.app.asyncio(connections=1) as session:
start = time.perf_counter()
Expand Down Expand Up @@ -348,13 +398,20 @@ async def get_suggestions(self, query: str) -> list:
flat_questions = [item for sublist in questions for item in sublist]
return flat_questions

def get_rank_profile(self, ranking: str, sim_map: bool) -> str:
if sim_map:
return f"{ranking}_sim"
else:
return ranking

async def query_vespa_nearest_neighbor(
self,
query: str,
q_emb: torch.Tensor,
target_hits_per_query_tensor: int = 20,
hits: int = 3,
timeout: str = "10s",
sim_map: bool = False,
**kwargs,
) -> dict:
"""
Expand Down Expand Up @@ -385,15 +442,16 @@ async def query_vespa_nearest_neighbor(
binary_query_embeddings, target_hits_per_query_tensor
)
query_tensors.update(nn_query_dict)

response: VespaQueryResponse = await session.query(
body={
**query_tensors,
"presentation.timing": True,
"yql": (
f"select {self.SELECT_FIELDS} from {self.VESPA_SCHEMA_NAME} where {nn_string} or userQuery()"
f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where {nn_string} or userQuery()"
),
"ranking.profile": self.get_rank_profile(
"retrieval-and-rerank", sim_map
),
"ranking.profile": "retrieval-and-rerank",
"timeout": timeout,
"hits": hits,
"query": query,
Expand Down
15 changes: 8 additions & 7 deletions visual-retrieval-colpali/frontend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,14 +323,13 @@ def SimMapButtonReady(query_id, idx, token, img_src):
)


def SimMapButtonPoll(query_id, idx, token):
def SimMapButtonPoll(query_id, idx, token, token_idx):
return Button(
Lucide(icon="loader-circle", size="15", cls="animate-spin"),
size="sm",
disabled=True,
hx_get=f"/get_sim_map?query_id={query_id}&idx={idx}&token={token}",
# Poll every x seconds, where x is 0.3 x idx, formatted to 2 decimals
hx_trigger=f"every {(idx+1)*0.3:.2f}s",
hx_get=f"/get_sim_map?query_id={query_id}&idx={idx}&token={token}&token_idx={token_idx}",
hx_trigger="every 0.5s",
hx_swap="outerHTML",
cls="pointer-events-auto text-xs h-5 rounded-none px-2",
)
Expand All @@ -352,7 +351,6 @@ def SearchResult(results: list, query_id: Optional[str] = None):
fields = result["fields"] # Extract the 'fields' part of each result
blur_image_base64 = f"data:image/jpeg;base64,{fields['blur_image']}"

# Filter sim_map fields that are words with 4 or more characters
sim_map_fields = {
key: value
for key, value in fields.items()
Expand All @@ -370,14 +368,17 @@ def SearchResult(results: list, query_id: Optional[str] = None):
SimMapButtonReady(
query_id=query_id,
idx=idx,
token=key.split("_")[-1],
token=key.split("_")[-2],
img_src=sim_map_base64,
)
)
else:
sim_map_buttons.append(
SimMapButtonPoll(
query_id=query_id, idx=idx, token=key.split("_")[-1]
query_id=query_id,
idx=idx,
token=key.split("_")[-2],
token_idx=int(key.split("_")[-1]),
)
)

Expand Down
Loading

0 comments on commit 9570d72

Please sign in to comment.