Skip to content

Commit

Permalink
return all documents when no source specified (#104)
Browse files Browse the repository at this point in the history
* return all documents when no source specified
  • Loading branch information
jerpint authored Jun 1, 2023
1 parent 59ba24f commit a9c1cb4
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
2 changes: 1 addition & 1 deletion buster/busterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, retriever: Retriever, completer: Completer, validator: Valida
self.retriever = retriever
self.validator = validator

def process_input(self, user_input: str, source: str = "") -> Completion:
def process_input(self, user_input: str, source: str = None) -> Completion:
"""
Main function to process the input question and generate a formatted output.
"""
Expand Down
2 changes: 1 addition & 1 deletion buster/retriever/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, top_k, thresh, max_tokens, embedding_model, *args, **kwargs):
self.embedding_model = embedding_model

@abstractmethod
def get_documents(self, source: str) -> pd.DataFrame:
def get_documents(self, source: str = None) -> pd.DataFrame:
"""Get all current documents from a given source."""
...

Expand Down
29 changes: 22 additions & 7 deletions buster/retriever/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,28 @@ def __init__(
self.db = self.client[mongo_db_name]

def get_source_id(self, source: str) -> str:
"""Get the id of a source."""
return str(self.db.sources.find_one({"name": source})["_id"])
"""Get the id of a source. Returns empty string if the source does not exist."""
source_pointer = self.db.sources.find_one({"name": source})
return "" if source_pointer is None else str(source_pointer["_id"])

def get_documents(self, source: str) -> pd.DataFrame:
"""Get all current documents from a given source."""
source_id = self.get_source_id(source)
return self.db.documents.find({"source_id": source_id})
def get_documents(self, source: str = None) -> pd.DataFrame:
"""Get all current documents from a given source.
If source is None, returns all documents. If source does not exist, returns empty dataframe."""

if source is None:
# No source specified, return all documents
documents = self.db.documents.find()
else:
assert isinstance(source, str), "source must be a valid string."
source_id = self.get_source_id(source)

if source_id == "":
logger.warning(f"{source=} not found.")

documents = self.db.documents.find({"source_id": source_id})

return pd.DataFrame(list(documents))

def get_source_display_name(self, source: str) -> str:
"""Get the display name of a source."""
Expand All @@ -52,7 +67,7 @@ def retrieve(self, query: str, top_k: int = None, source: str = None) -> pd.Data
if top_k is None:
# use default top_k value
top_k = self.top_k
if source is "" or source is None:
if source == "" or source is None:
filter = None
else:
filter = {"source": {"$eq": source}}
Expand Down
4 changes: 2 additions & 2 deletions buster/retriever/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def __del__(self):
if self.db_path is not None:
self.conn.close()

def get_documents(self, source: str) -> pd.DataFrame:
def get_documents(self, source: str = None) -> pd.DataFrame:
"""Get all current documents from a given source."""
# Execute the SQL statement and fetch the results.
if source == "":
if source is None:
results = self.conn.execute("SELECT * FROM documents")
else:
results = self.conn.execute("SELECT * FROM documents WHERE source = ?", (source,))
Expand Down

0 comments on commit a9c1cb4

Please sign in to comment.