diff --git a/buster/busterbot.py b/buster/busterbot.py index 90a80cd..51d8d2a 100644 --- a/buster/busterbot.py +++ b/buster/busterbot.py @@ -70,7 +70,7 @@ def __init__(self, retriever: Retriever, completer: Completer, validator: Valida self.retriever = retriever self.validator = validator - def process_input(self, user_input: str, source: str = "") -> Completion: + def process_input(self, user_input: str, source: str = None) -> Completion: """ Main function to process the input question and generate a formatted output. """ diff --git a/buster/retriever/base.py b/buster/retriever/base.py index e2d89a9..7a8f03d 100644 --- a/buster/retriever/base.py +++ b/buster/retriever/base.py @@ -21,7 +21,7 @@ def __init__(self, top_k, thresh, max_tokens, embedding_model, *args, **kwargs): self.embedding_model = embedding_model @abstractmethod - def get_documents(self, source: str) -> pd.DataFrame: + def get_documents(self, source: str = None) -> pd.DataFrame: """Get all current documents from a given source.""" ... diff --git a/buster/retriever/service.py b/buster/retriever/service.py index 6cdc51f..4a0d83b 100644 --- a/buster/retriever/service.py +++ b/buster/retriever/service.py @@ -32,13 +32,28 @@ def __init__( self.db = self.client[mongo_db_name] def get_source_id(self, source: str) -> str: - """Get the id of a source.""" - return str(self.db.sources.find_one({"name": source})["_id"]) + """Get the id of a source. Returns empty string if the source does not exist.""" + source_pointer = self.db.sources.find_one({"name": source}) + return "" if source_pointer is None else str(source_pointer["_id"]) - def get_documents(self, source: str) -> pd.DataFrame: - """Get all current documents from a given source.""" - source_id = self.get_source_id(source) - return self.db.documents.find({"source_id": source_id}) + def get_documents(self, source: str = None) -> pd.DataFrame: + """Get all current documents from a given source. + + If source is None, returns all documents. If source does not exist, returns empty dataframe.""" + + if source is None: + # No source specified, return all documents + documents = self.db.documents.find() + else: + assert isinstance(source, str), "source must be a valid string." + source_id = self.get_source_id(source) + + if source_id == "": + logger.warning(f"{source=} not found.") + + documents = self.db.documents.find({"source_id": source_id}) + + return pd.DataFrame(list(documents)) def get_source_display_name(self, source: str) -> str: """Get the display name of a source.""" @@ -52,7 +67,7 @@ def retrieve(self, query: str, top_k: int = None, source: str = None) -> pd.Data if top_k is None: # use default top_k value top_k = self.top_k - if source is "" or source is None: + if source == "" or source is None: filter = None else: filter = {"source": {"$eq": source}} diff --git a/buster/retriever/sqlite.py b/buster/retriever/sqlite.py index fbcec3c..c2f3996 100644 --- a/buster/retriever/sqlite.py +++ b/buster/retriever/sqlite.py @@ -44,10 +44,10 @@ def __del__(self): if self.db_path is not None: self.conn.close() - def get_documents(self, source: str) -> pd.DataFrame: + def get_documents(self, source: str = None) -> pd.DataFrame: """Get all current documents from a given source.""" # Execute the SQL statement and fetch the results. - if source == "": + if source is None: results = self.conn.execute("SELECT * FROM documents") else: results = self.conn.execute("SELECT * FROM documents WHERE source = ?", (source,))