Skip to content

Commit

Permalink
PR: retrieve all sources (#100)
Browse files Browse the repository at this point in the history
* retrieve all sources

* warning no source
  • Loading branch information
hbertrand authored May 19, 2023
1 parent 78d65d2 commit b436fee
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion buster/retriever/service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import pandas as pd
import pinecone
from bson.objectid import ObjectId
Expand All @@ -6,6 +8,9 @@

from buster.retriever.base import ALL_SOURCES, Retriever

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class ServiceRetriever(Retriever):
def __init__(
Expand Down Expand Up @@ -44,8 +49,17 @@ def get_source_display_name(self, source: str) -> str:
return display_name

def retrieve(self, query_embedding: list[float], top_k: int, source: str = None) -> pd.DataFrame:
if source is "" or source is None:
filter = None
else:
filter = {"source": {"$eq": source}}
source_exists = self.db.sources.find_one({"name": source})
if source_exists is None:
logger.warning(f"Source {source} does not exist. Returning empty dataframe.")
return pd.DataFrame()

# Pinecone retrieval
matches = self.index.query(query_embedding, top_k=top_k, filter={"source": {"$eq": source}})["matches"]
matches = self.index.query(query_embedding, top_k=top_k, filter=filter)["matches"]
matching_ids = [ObjectId(match.id) for match in matches]
matching_scores = {match.id: match.score for match in matches}

Expand Down

0 comments on commit b436fee

Please sign in to comment.