Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add RAG LLMReranker #784

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
22 changes: 22 additions & 0 deletions presets/ragengine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

import os

"""
=========================================================================
"""

# Embedding configuration
EMBEDDING_SOURCE_TYPE = os.getenv("EMBEDDING_SOURCE_TYPE", "local") # Determines local or remote embedding source

Expand All @@ -19,11 +23,29 @@
REMOTE_EMBEDDING_URL = os.getenv("REMOTE_EMBEDDING_URL", "http://localhost:5000/embedding")
REMOTE_EMBEDDING_ACCESS_SECRET = os.getenv("REMOTE_EMBEDDING_ACCESS_SECRET", "default-access-secret")

"""
=========================================================================
"""

# Reranking Configuration
# For now we support simple LLMReranker, future additions would include
# FlagEmbeddingReranker, SentenceTransformerReranker, CohereReranker
LLM_RERANKER_BATCH_SIZE = int(os.getenv("LLM_RERANKER_BATCH_SIZE", 5)) # Default LLM batch size
LLM_RERANKER_TOP_N = int(os.getenv("LLM_RERANKER_TOP_N", 3)) # Default top 3 reranked nodes

"""
=========================================================================
"""

# LLM (Large Language Model) configuration
LLM_INFERENCE_URL = os.getenv("LLM_INFERENCE_URL", "http://localhost:5000/chat")
LLM_ACCESS_SECRET = os.getenv("LLM_ACCESS_SECRET", "default-access-secret")
# LLM_RESPONSE_FIELD = os.getenv("LLM_RESPONSE_FIELD", "result") # Uncomment if needed in the future

"""
=========================================================================
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low priority: we can reorg these params to cli args and using structured params instead of raw key-value dict.

# Vector database configuration
VECTOR_DB_IMPLEMENTATION = os.getenv("VECTOR_DB_IMPLEMENTATION", "faiss")
VECTOR_DB_PERSIST_DIR = os.getenv("VECTOR_DB_PERSIST_DIR", "storage")
3 changes: 2 additions & 1 deletion presets/ragengine/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ async def index_documents(request: IndexRequest): # TODO: Research async/sync wh
async def query_index(request: QueryRequest):
try:
llm_params = request.llm_params or {} # Default to empty dict if no params provided
return rag_ops.query(request.index_name, request.query, request.top_k, llm_params)
rerank_params = request.rerank_params or {} # Default to empty dict if no params provided
return rag_ops.query(request.index_name, request.query, request.top_k, llm_params, rerank_params)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

Expand Down
1 change: 1 addition & 0 deletions presets/ragengine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class QueryRequest(BaseModel):
query: str
top_k: int = 10
llm_params: Optional[Dict] = None # Accept a dictionary for parameters
rerank_params: Optional[Dict] = None # Accept a dictionary for parameters

class ListDocumentsResponse(BaseModel):
documents: Dict[str, Dict[str, Dict[str, str]]]
Expand Down
96 changes: 95 additions & 1 deletion presets/ragengine/tests/api/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_query_index_success(mock_post):
"result": "This is the completion from the API"
}
mock_post.return_value.json.return_value = mock_response
# Index
# Index
request_data = {
"index_name": "test_index",
"documents": [
Expand Down Expand Up @@ -74,6 +74,100 @@ def test_query_index_success(mock_post):
assert response.json()["source_nodes"][0]["metadata"] == {}
assert mock_post.call_count == 1


@patch('requests.post')
def test_reranker_and_query_with_index(mock_post):
"""
Test reranker and query functionality with indexed documents.

This test ensures the following:
1. The custom reranker returns a relevance-sorted list of documents.
2. The query response matches the expected format and contains the correct top results.

Template for reranker input:
A list of documents is shown below. Each document has a number next to it along with a summary of the document.
A question is also provided. Respond with the numbers of the documents you should consult to answer the question,
in order of relevance, as well as the relevance score. The relevance score is a number from 1-10 based on how
relevant you think the document is to the question. Do not include any documents that are not relevant.

Example format:
Document 1: <summary of document 1>
Document 2: <summary of document 2>
...
Document 10: <summary of document 10>

Question: <question>
Answer:
Doc: 9, Relevance: 7
Doc: 3, Relevance: 4
Doc: 7, Relevance: 3
"""
# Mock responses for the reranker and query API calls
reranker_mock_response = "Doc: 4, Relevance: 10\nDoc: 5, Relevance: 10"
query_mock_response = {"result": "This is the completion from the API"}
mock_http_responses = [reranker_mock_response, query_mock_response]

mock_post.return_value.json.side_effect = mock_http_responses

# Define input documents for indexing
documents = [
"The capital of France is great.",
"The capital of France is huge.",
"The capital of France is beautiful.",
"""Have you ever visited Paris? It is a beautiful city where you can eat delicious food and see the Eiffel Tower.
I really enjoyed all the cities in France, but its capital with the Eiffel Tower is my favorite city.""",
"I really enjoyed my trip to Paris, France. The city is beautiful and the food is delicious. I would love to visit again. "
"Such a great capital city."
]

# Indexing request payload
index_request_payload = {
"index_name": "test_index",
"documents": [{"text": doc} for doc in documents]
}

# Perform indexing
response = client.post("/index", json=index_request_payload)
assert response.status_code == 200

# Query request payload with reranking
top_n = len(reranker_mock_response.split("\n")) # Extract top_n from mock reranker response
query_request_payload = {
"index_name": "test_index",
"query": "what is the capital of france?",
"top_k": 5,
"llm_params": {"temperature": 0.7},
"rerank_params": {"top_n": top_n}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the list of parameters that we can change here?

}

# Perform query
response = client.post("/query", json=query_request_payload)
assert response.status_code == 200
query_response = response.json()

# Validate query response
assert query_response["response"] == query_response["result"]
assert len(query_response["source_nodes"]) == top_n

# Validate each source node in the query response
expected_source_nodes = [
{"text": "Have you ever visited Paris? It is a beautiful city where you can eat "
"delicious food and see the Eiffel Tower. I really enjoyed all the cities in "
"France, but its capital with the Eiffel Tower is my favorite city.",
"score": 10.0, "metadata": {}},
{"text": "I really enjoyed my trip to Paris, France. The city is beautiful and the "
"food is delicious. I would love to visit again. Such a great capital city.",
"score": 10.0, "metadata": {}},
]
for i, expected_node in enumerate(expected_source_nodes):
actual_node = query_response["source_nodes"][i]
assert actual_node["text"] == expected_node["text"]
assert actual_node["score"] == expected_node["score"]
assert actual_node["metadata"] == expected_node["metadata"]

# Verify the number of mock API calls
assert mock_post.call_count == len(mock_http_responses)

def test_query_index_failure():
# Prepare request data for querying.
request_data = {
Expand Down
3 changes: 2 additions & 1 deletion presets/ragengine/tests/vector_store/test_base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def test_query_documents(self, mock_post, vector_store_manager):
vector_store_manager.index_documents("test_index", documents)

params = {"temperature": 0.7}
query_result = vector_store_manager.query("test_index", "First", top_k=1, llm_params=params)
query_result = vector_store_manager.query("test_index", "First", top_k=1,
llm_params=params, rerank_params={})

assert query_result is not None
assert query_result["response"] == "{'result': 'This is the completion from the API'}"
Expand Down
50 changes: 45 additions & 5 deletions presets/ragengine/vector_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from llama_index.core import Document as LlamaDocument
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.core import (StorageContext, VectorStoreIndex)
from llama_index.core.postprocessor import LLMRerank # Query with LLM Reranking

from ragengine.models import Document
from ragengine.embedding.base import BaseEmbeddingModel
from ragengine.inference.inference import Inference
from ragengine.config import VECTOR_DB_PERSIST_DIR
from ragengine.config import (LLM_RERANKER_BATCH_SIZE, LLM_RERANKER_TOP_N, VECTOR_DB_PERSIST_DIR)

# Configure logging
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -86,15 +87,54 @@ def _create_index_common(self, index_name: str, documents: List[Document], vecto
self._persist(index_name)
return list(indexed_doc_ids)

def query(self, index_name: str, query: str, top_k: int, llm_params: dict):
"""Common query logic for all vector stores."""
def query(self,
index_name: str,
query: str,
top_k: int,
llm_params: dict,
rerank_params: dict
):
"""
Query the indexed documents

Args:
index_name (str): Name of the index to query
query (str): Query string
top_k (int): Number of initial top results to retrieve
llm_params (dict): Optional parameters for the language model
rerank_params (dict): Optional configuration for reranking
- 'top_n' (int): Number of top documents to return after reranking
- 'batch_size' (int): Number of documents to process in each batch

Returns:
dict: A dictionary containing the response and source nodes.
"""
if index_name not in self.index_map:
raise ValueError(f"No such index: '{index_name}' exists.")
self.llm.set_params(llm_params)

node_postprocessors = []
if rerank_params:
# Set default reranking parameters and merge with provided params
default_rerank_params = {
'choice_batch_size': LLM_RERANKER_BATCH_SIZE, # Default batch size
'top_n': min(LLM_RERANKER_TOP_N, top_k) # Limit top_n to top_k by default
}
rerank_params = {**default_rerank_params, **rerank_params}

# Add LLMRerank to postprocessors
node_postprocessors.append(
LLMRerank(
llm=self.llm,
choice_batch_size=rerank_params['choice_batch_size'],
top_n=rerank_params['top_n']
)
)

query_engine = self.index_map[index_name].as_query_engine(
llm=self.llm,
similarity_top_k=top_k
llm=self.llm,
similarity_top_k=top_k,
node_postprocessors=node_postprocessors
)
query_result = query_engine.query(query)
return {
Expand Down
10 changes: 8 additions & 2 deletions presets/ragengine/vector_store_manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@ def index(self, index_name: str, documents: List[Document]) -> List[str]:
"""Index new documents."""
return self.vector_store.index_documents(index_name, documents)

def query(self, index_name: str, query: str, top_k: int, llm_params: dict):
def query(self,
index_name: str,
query: str,
top_k: int,
llm_params: dict,
rerank_params: dict
):
"""Query the indexed documents."""
return self.vector_store.query(index_name, query, top_k, llm_params)
return self.vector_store.query(index_name, query, top_k, llm_params, rerank_params)

def list_all_indexed_documents(self) -> Dict[str, Dict[str, Dict[str, str]]]:
"""List all documents."""
Expand Down
Loading