Skip to content

Commit

Permalink
fix: Converting RAG Service to using Async
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaansehgal99 committed Jan 24, 2025
1 parent c111179 commit 4aa23db
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 54 deletions.
26 changes: 20 additions & 6 deletions presets/ragengine/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@
rag_ops = VectorStoreManager(vector_store_handler)

@app.get("/health", response_model=HealthStatus)
async def health_check():
def health_check():
try:

if embedding_manager is None:
raise HTTPException(status_code=500, detail="Embedding manager not initialized")

Expand Down Expand Up @@ -72,14 +71,29 @@ async def query_index(request: QueryRequest):
status_code=500, detail=f"An unexpected error occurred: {str(e)}"
)

@app.get("/indexed-documents", response_model=ListDocumentsResponse)
async def list_all_indexed_documents():
@app.get("/indexes", response_model=List[str])
def list_indexes():
try:
return rag_ops.list_indexes()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.get("/indexes/{index_name}/documents", response_model=ListDocumentsResponse)
async def list_documents_in_index(index_name: str):
try:
documents = await rag_ops.list_documents_in_index(index_name)
return ListDocumentsResponse(documents=documents)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.get("/documents", response_model=ListDocumentsResponse)
async def list_all_documents():
try:
documents = rag_ops.list_all_indexed_documents()
documents = await rag_ops.list_all_documents()
return ListDocumentsResponse(documents=documents)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)
uvicorn.run(app, host="0.0.0.0", port=8000)
2 changes: 0 additions & 2 deletions presets/ragengine/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,3 @@ llama-index-vector-stores-faiss
llama-index-vector-stores-chroma
llama-index-vector-stores-azurecosmosmongo
uvicorn
# For UTs
pytest
6 changes: 3 additions & 3 deletions presets/ragengine/tests/api/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ def test_query_index_failure():
assert response.json()["detail"] == "No such index: 'non_existent_index' exists."


def test_list_all_indexed_documents_success():
response = client.get("/indexed-documents")
def test_list_all_documents_success():
response = client.get("/documents")
assert response.status_code == 200
assert response.json() == {'documents': {}}

Expand All @@ -195,7 +195,7 @@ def test_list_all_indexed_documents_success():
response = client.post("/index", json=request_data)
assert response.status_code == 200

response = client.get("/indexed-documents")
response = client.get("/documents")
assert response.status_code == 200
assert "test_index" in response.json()["documents"]
response_idx = response.json()["documents"]["test_index"]
Expand Down
7 changes: 7 additions & 0 deletions presets/ragengine/tests/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Common dependencies
-r requirements.txt

# Test dependencies
pytest
pytest-asyncio

5 changes: 5 additions & 0 deletions presets/ragengine/tests/vector_store/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

import sys
import os
import nest_asyncio

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU-only execution for testing
os.environ["OMP_NUM_THREADS"] = "1" # Force single-threaded for testing to prevent segfault while loading embedding model
os.environ["MKL_NUM_THREADS"] = "1" # Force MKL to use a single thread

# Apply nest_asyncio to allow nested event loops
nest_asyncio.apply()
44 changes: 25 additions & 19 deletions presets/ragengine/tests/vector_store/test_base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,22 @@ def expected_query_score(self):
"""Override this in implementation-specific test classes."""
pass

def test_index_documents(self, vector_store_manager):
@pytest.mark.asyncio
async def test_index_documents(self, vector_store_manager):
first_doc_text, second_doc_text = "First document", "Second document"
documents = [
Document(text=first_doc_text, metadata={"type": "text"}),
Document(text=second_doc_text, metadata={"type": "text"})
]

doc_ids = vector_store_manager.index_documents("test_index", documents)
doc_ids = await vector_store_manager.index_documents("test_index", documents)

assert len(doc_ids) == 2
assert set(doc_ids) == {BaseVectorStore.generate_doc_id(first_doc_text),
BaseVectorStore.generate_doc_id(second_doc_text)}

def test_index_documents_isolation(self, vector_store_manager):
@pytest.mark.asyncio
async def test_index_documents_isolation(self, vector_store_manager):
documents1 = [
Document(text="First document in index1", metadata={"type": "text"}),
]
Expand All @@ -54,19 +56,20 @@ def test_index_documents_isolation(self, vector_store_manager):

# Index documents in separate indices
index_name_1, index_name_2 = "index1", "index2"
vector_store_manager.index_documents(index_name_1, documents1)
vector_store_manager.index_documents(index_name_2, documents2)
await vector_store_manager.index_documents(index_name_1, documents1)
await vector_store_manager.index_documents(index_name_2, documents2)

# Call the backend-specific check method
self.check_indexed_documents(vector_store_manager)
await self.check_indexed_documents(vector_store_manager)

@abstractmethod
def check_indexed_documents(self, vector_store_manager):
"""Abstract method to check indexed documents in backend-specific format."""
pass

@pytest.mark.asyncio
@patch('requests.post')
def test_query_documents(self, mock_post, vector_store_manager):
async def test_query_documents(self, mock_post, vector_store_manager):
mock_response = {
"result": "This is the completion from the API"
}
Expand All @@ -76,10 +79,10 @@ def test_query_documents(self, mock_post, vector_store_manager):
Document(text="First document", metadata={"type": "text"}),
Document(text="Second document", metadata={"type": "text"})
]
vector_store_manager.index_documents("test_index", documents)
await vector_store_manager.index_documents("test_index", documents)

params = {"temperature": 0.7}
query_result = vector_store_manager.query("test_index", "First", top_k=1,
query_result = await vector_store_manager.query("test_index", "First", top_k=1,
llm_params=params, rerank_params={})

assert query_result is not None
Expand All @@ -93,28 +96,31 @@ def test_query_documents(self, mock_post, vector_store_manager):
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}", 'Content-Type': 'application/json'}
)

def test_add_document(self, vector_store_manager):
@pytest.mark.asyncio
async def test_add_document(self, vector_store_manager):
documents = [Document(text="Third document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)
await vector_store_manager.index_documents("test_index", documents)

new_document = [Document(text="Fourth document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", new_document)
await vector_store_manager.index_documents("test_index", new_document)

assert vector_store_manager.document_exists("test_index", new_document[0],
BaseVectorStore.generate_doc_id("Fourth document"))

def test_persist_index_1(self, vector_store_manager):
@pytest.mark.asyncio
async def test_persist_index_1(self, vector_store_manager):
documents = [Document(text="Test document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)
vector_store_manager._persist("test_index")
await vector_store_manager.index_documents("test_index", documents)
await vector_store_manager._persist("test_index")
assert os.path.exists(VECTOR_DB_PERSIST_DIR)

def test_persist_index_2(self, vector_store_manager):
@pytest.mark.asyncio
async def test_persist_index_2(self, vector_store_manager):
documents = [Document(text="Test document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)
await vector_store_manager.index_documents("test_index", documents)

documents = [Document(text="Another Test document", metadata={"type": "text"})]
vector_store_manager.index_documents("another_test_index", documents)
await vector_store_manager.index_documents("another_test_index", documents)

vector_store_manager._persist_all()
await vector_store_manager._persist_all()
assert os.path.exists(VECTOR_DB_PERSIST_DIR)
5 changes: 3 additions & 2 deletions presets/ragengine/tests/vector_store/test_chromadb_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ def vector_store_manager(self, init_embed_manager):
manager._clear_collection_and_indexes()
yield manager

def check_indexed_documents(self, vector_store_manager):
indexed_docs = vector_store_manager.list_all_indexed_documents()
@pytest.mark.asyncio
async def check_indexed_documents(self, vector_store_manager):
indexed_docs = await vector_store_manager.list_all_documents()
assert len(indexed_docs) == 2
assert list(indexed_docs["index1"].values())[0]["text"] == "First document in index1"
assert list(indexed_docs["index2"].values())[0]["text"] == "First document in index2"
Expand Down
5 changes: 3 additions & 2 deletions presets/ragengine/tests/vector_store/test_faiss_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def vector_store_manager(self, init_embed_manager):
os.environ['PERSIST_DIR'] = temp_dir
yield FaissVectorStoreHandler(init_embed_manager)

def check_indexed_documents(self, vector_store_manager):
@pytest.mark.asyncio
async def check_indexed_documents(self, vector_store_manager):
expected_output = {
'index1': {"87117028123498eb7d757b1507aa3e840c63294f94c27cb5ec83c939dedb32fd": {
'hash': '1e64a170be48c45efeaa8667ab35919106da0489ec99a11d0029f2842db133aa',
Expand All @@ -29,7 +30,7 @@ def check_indexed_documents(self, vector_store_manager):
'text': 'First document in index2'
}}
}
assert vector_store_manager.list_all_indexed_documents() == expected_output
assert await vector_store_manager.list_all_documents() == expected_output

@property
def expected_query_score(self):
Expand Down
72 changes: 57 additions & 15 deletions presets/ragengine/vector_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from ragengine.inference.inference import Inference
from ragengine.config import (LLM_RERANKER_BATCH_SIZE, LLM_RERANKER_TOP_N, VECTOR_DB_PERSIST_DIR)

from llama_index.core.storage.docstore import SimpleDocumentStore

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -48,7 +50,8 @@ async def _append_documents_to_index(self, index_name: str, documents: List[Docu

for doc in documents:
doc_id = self.generate_doc_id(doc.text)
if not self.document_exists(index_name, doc, doc_id):
doc = await self.index_map[index_name].docstore.aget_document(doc_id)
if not doc:
await self.add_document_to_index(index_name, doc, doc_id)
indexed_doc_ids.add(doc_id)
else:
Expand Down Expand Up @@ -82,7 +85,7 @@ async def _create_index_common(self, index_name: str, documents: List[Document],
embed_model=self.embed_model,
use_async=True,
)
await index.set_index_id(index_name)
index.set_index_id(index_name)
self.index_map[index_name] = index
self.index_store.add_index_struct(index.index_struct)
await self._persist(index_name)
Expand Down Expand Up @@ -159,17 +162,56 @@ async def add_document_to_index(self, index_name: str, document: Document, doc_i
llama_doc = LlamaDocument(text=document.text, metadata=document.metadata, id_=doc_id)
await self.index_map[index_name].insert(llama_doc)

def list_all_indexed_documents(self) -> Dict[str, Dict[str, Dict[str, str]]]:
"""Common logic for listing all documents."""
return {
index_name: {
doc_info.ref_doc_id: {
"text": doc_info.text,
def list_indexes(self) -> List[str]:
return list(self.index_map.keys())

async def list_documents_in_index(self, index_name: str) -> Dict[str, Dict[str, str]]:
"""Return a dictionary of document metadata for the given index."""
vector_store_index = self.index_map[index_name]
doc_store = vector_store_index.docstore

is_simple_doc_store = isinstance(doc_store, SimpleDocumentStore)
doc_map: Dict[str, Dict[str, str]] = {}

for doc_id, doc_stub in doc_store.docs.items():
if is_simple_doc_store:
# Here 'doc_stub' should already be the full doc info
doc_map[doc_stub.ref_doc_id] = {
"text": doc_stub.text,
"hash": doc_stub.hash
}
else:
# Use async retrieval for non-simple doc_store
doc_info = await doc_store.aget_document(doc_id)
doc_map[doc_info.ref_doc_id] = {
"text": doc_info.text,
"hash": doc_info.hash
} for _, doc_info in vector_store_index.docstore.docs.items()
}
for index_name, vector_store_index in self.index_map.items()
}
}
return doc_map

async def list_all_documents(self) -> Dict[str, Dict[str, Dict[str, str]]]:
"""Common logic for listing all documents."""
indexes: Dict[str, Dict[str, Dict[str, str]]] = {}
for index_name, vector_store_index in self.index_map.items():
doc_store = vector_store_index.docstore
doc_map: Dict[str, Dict[str, str]] = {}

for doc_id, doc_stub in doc_store.docs.items():
if isinstance(doc_store, SimpleDocumentStore):
# Here 'doc_stub' should already be the full doc info
doc_map[doc_stub.ref_doc_id] = {
"text": doc_stub.text,
"hash": doc_stub.hash
}
else:
# Use async retrieval for non-simple doc_store
doc_info = await doc_store.aget_document(doc_id)
doc_map[doc_info.ref_doc_id] = {
"text": doc_info.text,
"hash": doc_info.hash
}
indexes[index_name] = doc_map
return indexes

def document_exists(self, index_name: str, doc: Document, doc_id: str) -> bool:
"""Common logic for checking document existence."""
Expand All @@ -178,12 +220,12 @@ def document_exists(self, index_name: str, doc: Document, doc_id: str) -> bool:
return False
return doc_id in self.index_map[index_name].ref_doc_info

def _persist_all(self):
async def _persist_all(self):
"""Common persistence logic."""
logger.info("Persisting all indexes.")
self.index_store.persist(os.path.join(VECTOR_DB_PERSIST_DIR, "store.json"))
for idx in self.index_store.index_structs():
self._persist(idx.index_id)
await self._persist(idx.index_id)

async def _persist(self, index_name: str):
"""Common persistence logic for individual index."""
Expand All @@ -193,7 +235,7 @@ async def _persist(self, index_name: str):
assert index_name in self.index_map, f"No such index: '{index_name}' exists."
storage_context = self.index_map[index_name].storage_context
# Persist the specific index
await storage_context.persist(persist_dir=os.path.join(VECTOR_DB_PERSIST_DIR, index_name))
storage_context.persist(persist_dir=os.path.join(VECTOR_DB_PERSIST_DIR, index_name))
logger.info(f"Successfully persisted index {index_name}.")
except Exception as e:
logger.error(f"Failed to persist index {index_name}. Error: {str(e)}")
19 changes: 16 additions & 3 deletions presets/ragengine/vector_store/chromadb_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,26 @@ def document_exists(self, index_name: str, doc: Document, doc_id: str) -> bool:
if index_name not in self.index_map:
logger.warning(f"No such index: '{index_name}' exists in vector store.")
return False
return doc.text in self.chroma_client.get_collection(name=index_name).get()["documents"]
return doc.text in self.chroma_client.get_collection(index_name).get()["documents"]

def list_all_indexed_documents(self) -> Dict[str, Dict[str, Dict[str, str]]]:
async def list_documents_in_index(self, index_name: str) -> Dict[str, Dict[str, str]]:
doc_map: Dict[str, Dict[str, str]] = {}
try:
collection_info = await self.chroma_client.get_collection(index_name).aget()
for doc in zip(collection_info["ids"], collection_info["documents"], collection_info["metadatas"]):
doc_map[doc[0]] = {
"text": doc[1],
"metadata": json.dumps(doc[2])
}
except Exception as e:
print(f"Failed to get documents from collection '{index_name}': {e}")
return doc_map

async def list_all_documents(self) -> Dict[str, Dict[str, Dict[str, str]]]:
indexed_docs = {} # Accumulate documents across all indexes
try:
for collection_name in self.chroma_client.list_collections():
collection_info = self.chroma_client.get_collection(collection_name).get()
collection_info = await self.chroma_client.get_collection(collection_name).aget()
for doc in zip(collection_info["ids"], collection_info["documents"], collection_info["metadatas"]):
indexed_docs.setdefault(collection_name, {})[doc[0]] = {
"text": doc[1],
Expand Down
12 changes: 10 additions & 2 deletions presets/ragengine/vector_store_manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ async def query(self,
"""Query the indexed documents."""
return await self.vector_store.query(index_name, query, top_k, llm_params, rerank_params)

def list_all_indexed_documents(self) -> Dict[str, Dict[str, Dict[str, str]]]:
def list_indexes(self):
"""List all indexes."""
return self.vector_store.list_indexes()

async def list_documents_in_index(self, index_name: str):
"""List all documents in index."""
return await self.vector_store.list_documents_in_index(index_name)

async def list_all_documents(self) -> Dict[str, Dict[str, Dict[str, str]]]:
"""List all documents."""
return self.vector_store.list_all_indexed_documents()
return await self.vector_store.list_all_documents()

0 comments on commit 4aa23db

Please sign in to comment.