diff --git a/health_rec/services/retriever.py b/health_rec/services/retriever.py index 771caac..9262814 100644 --- a/health_rec/services/retriever.py +++ b/health_rec/services/retriever.py @@ -1,5 +1,6 @@ +"""Service for retrieving documents using ChromaDB vector similarity search.""" + import logging -from dataclasses import dataclass from typing import List import chromadb @@ -26,8 +27,7 @@ def __init__(self) -> None: self.client = openai.Client(api_key=Config.OPENAI_API_KEY) self.embedding_model = Config.OPENAI_EMBEDDING self.chroma_client = chromadb.HttpClient( - host=Config.CHROMA_HOST, - port=Config.CHROMA_PORT + host=Config.CHROMA_HOST, port=Config.CHROMA_PORT ) self.collection: Collection = self.chroma_client.get_collection( name=Config.COLLECTION_NAME @@ -54,10 +54,7 @@ def _generate_embedding(self, text: str) -> List[float]: """ try: return ( - self.client.embeddings.create( - input=[text], - model=self.embedding_model - ) + self.client.embeddings.create(input=[text], model=self.embedding_model) .data[0] .embedding ) @@ -73,6 +70,8 @@ def retrieve(self, query: str, n_results: int = 5) -> List[ServiceDocument]: ---------- query : str The search query. + n_results : int, optional + The number of results to retrieve, by default 5. Returns ------- @@ -90,13 +89,12 @@ def retrieve(self, query: str, n_results: int = 5) -> List[ServiceDocument]: # Retrieve documents from ChromaDB results: QueryResult = self.collection.query( - query_embeddings=query_embedding, - n_results=n_results + query_embeddings=query_embedding, n_results=n_results ) # Parse and return results - return _parse_chroma_result(results) + return _parse_chroma_result(results) # type: ignore except Exception as e: logger.error(f"Error in Retriever.retrieve: {e}") - raise \ No newline at end of file + raise