Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reranker implementation #20

Merged
merged 20 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eval/evaluate_topkacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import List, Dict, Any, DefaultDict


def load_embeddings(path: str) -> torch.Tensor:
def load_embeddings(path: str) -> Any:
return torch.load(path)


Expand Down
40 changes: 39 additions & 1 deletion health_rec/api/routes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Backend API routes."""

import logging
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

from fastapi import APIRouter, Body, Depends, HTTPException

Expand All @@ -10,10 +10,12 @@
RecommendationResponse,
RefineRequest,
Service,
ServiceDocument,
)
from services.dev.data import ChromaService
from services.rag import RAGService
from services.refine import RefineService
from services.rerank import RerankingConfig, ReRankingService


# Configure logging
Expand All @@ -25,6 +27,7 @@

rag_service = RAGService()
refine_service = RefineService()
rerank_service = ReRankingService(RerankingConfig())


@router.get("/questions", response_model=dict)
Expand Down Expand Up @@ -151,3 +154,38 @@ async def get_services_count(
The number of services.
"""
return await chroma_service.get_services_count()


@router.get("/rerank", response_model=List[Service])
async def rerank_recommendations(
query: str, retrieval_k: int = 20, output_k: int = 5
) -> Union[ServiceDocument | List[ServiceDocument]]:
"""
Generate re-ranked list of services based on the input query.

Parameters
----------
query : str
The user's input query.
retrieval_k : Optional[int]
Number of services to retrieve initially. Default is 10.
output_k : Optional[int]
Number of services to return after re-ranking. Default is 5.

Returns
-------
services: List[ServiceDocument]
A list of services ordered by relevance to the query.
"""
try:
config = RerankingConfig(
retrieval_k=min(max(1, retrieval_k), 10), # Limit between 1 and 20
a-kore marked this conversation as resolved.
Show resolved Hide resolved
output_k=min(
a-kore marked this conversation as resolved.
Show resolved Hide resolved
max(1, output_k), retrieval_k
), # Cannot be more than retrieval_k
)
rerank_service.config = config
return rerank_service.rerank(query)
except Exception as e:
logger.error(f"Error in rerank_recommendations: {str(e)}")
raise HTTPException(status_code=422, detail=str(e)) from e
2 changes: 1 addition & 1 deletion health_rec/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __call__(self, texts: Documents) -> Embeddings:
"""
try:
response = self.client.embeddings.create(input=texts, model=self.model)
return [data.embedding for data in response.data]
return [data.embedding for data in response.data] # type: ignore
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
raise
Expand Down
2 changes: 1 addition & 1 deletion health_rec/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion health_rec/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ python = "^3.11"
fastapi = "^0.115.2"
uvicorn = "^0.30.6"
openai = "^1.45.1"
chromadb = "^0.5.5"
chromadb = "0.5.15"
python-dotenv = "^1.0.1"

[tool.poetry.group.test]
Expand Down
10 changes: 5 additions & 5 deletions health_rec/services/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from api.data import Query, RecommendationResponse, Service, ServiceDocument
from services.emergency import get_emergency_services_message
from services.ranking import RankingService
from services.utils import _metadata_to_service, _parse_chroma_result
from services.rerank import RerankingConfig, ReRankingService
from services.utils import _metadata_to_service


logging.basicConfig(
Expand All @@ -34,6 +35,7 @@ def __init__(self) -> None:
name=Config.COLLECTION_NAME
)
self.ranking_service = RankingService(relevancy_weight=Config.RELEVANCY_WEIGHT)
self.reranking_service = ReRankingService(RerankingConfig())

def generate(self, query: Query) -> RecommendationResponse:
"""
Expand Down Expand Up @@ -95,10 +97,8 @@ def _retrieve_and_rank_services(
self, query: Query, query_embedding: List[float]
) -> List[ServiceDocument]:
"""Retrieve and rank services based on the query."""
chroma_results = self.services_collection.query(
query_embeddings=query_embedding, n_results=5
)
service_documents = _parse_chroma_result(chroma_results)
service_documents = self.reranking_service.rerank(query.query, query_embedding)

user_location = (
(query.latitude, query.longitude)
if query.latitude and query.longitude
Expand Down
161 changes: 161 additions & 0 deletions health_rec/services/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Service for re-ranking retrieved health services using LLM."""

import logging
from dataclasses import dataclass
from typing import Dict, List, Optional

import chromadb
import openai
from chromadb.api.models.Collection import Collection
from chromadb.api.types import QueryResult

from api.config import Config
from api.data import Service, ServiceDocument
from services.utils import _metadata_to_service, _parse_chroma_result


logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


@dataclass
class RerankingConfig:
"""Configuration for re-ranking parameters."""

retrieval_k: int = 20 # Number of services to retrieve initially
output_k: int = 5 # Number of services to return after re-ranking
max_content_length: int = 150 # Maximum length of service content to consider


class ReRankingService:
"""Service for re-ranking retrieved health services using LLM.

Based on RankGPT: https://arxiv.org/abs/2304.09542

"""

def __init__(self, config: Optional[RerankingConfig] = None) -> None:
"""Initialize the re-ranking service."""
self.client = openai.Client(api_key=Config.OPENAI_API_KEY)
self.embedding_model = Config.OPENAI_EMBEDDING
self.chroma_client = chromadb.HttpClient(
host=Config.CHROMA_HOST, port=Config.CHROMA_PORT
)
self.services_collection: Collection = self.chroma_client.get_collection(
name=Config.COLLECTION_NAME
)
self.config = config or RerankingConfig()

def _create_ranking_prompt(
self, query: str, services: List[Service]
) -> List[Dict[str, str]]:
"""Create the prompt for ranking services."""
messages = [
{
"role": "system",
"content": "You are a health service recommender that ranks services based on their relevance to a user's query. Respond only with the ranked service numbers in descending order of relevance, separated by '>'.",
},
{
"role": "user",
"content": f"I will provide you with {len(services)} services, each indicated by number identifier []. \nRank these services based on their relevance to query: {query}",
},
{
"role": "assistant",
"content": "I'll rank the services. Please provide them.",
},
]

# Add each service as a separate message
for i, service in enumerate(services, 1):
content = f"{service.public_name or ''}\n{service.description or ''}\n{service.eligibility or ''}"
content = " ".join(content.split()[: self.config.max_content_length])
messages.append({"role": "user", "content": f"[{i}] {content}"})
messages.append(
{"role": "assistant", "content": f"Received service [{i}]."}
)

# Add final ranking request
messages.append(
{
"role": "user",
"content": f"""For the query "{query}", rank the services from most to least relevant.
Respond only with service numbers in the format: [X] > [Y] > [Z]""",
}
)

return messages

def _process_ranking_response(
self, response: str, services: QueryResult
) -> List[ServiceDocument]:
"""Process the LLM's ranking response and return reordered services."""
try:
documents: List[ServiceDocument] = _parse_chroma_result(services)
# Extract numbers from the response
rankings = [
int(x) - 1
for x in "".join(c if c.isdigit() else " " for c in response).split()
]

# Remove any invalid indices
valid_rankings = [i for i in rankings if 0 <= i < len(documents)]

# Add any missing indices at the end
all_indices = set(range(len(documents)))
missing_indices = [i for i in all_indices if i not in valid_rankings]
rankings = valid_rankings + missing_indices

# Return reordered services limited to output_k
return [documents[i] for i in rankings[: self.config.output_k]]

except Exception as e:
logger.error(f"Error processing ranking response: {e}")
# Fall back to original order if parsing fails
return documents[: self.config.output_k]

def rerank(self, query: str, query_embedding: List[float]) -> List[ServiceDocument]:
"""
Generate re-ranked list of services for the query.

Parameters
----------
query : str
The user's input query.

Returns
-------
List[Service]
The re-ranked list of services.
"""
# Retrieve initial services
results = self.services_collection.query(
query_embeddings=query_embedding, n_results=self.config.retrieval_k
)

# Convert to Service objects
services = [
_metadata_to_service(meta)
for meta in (results["metadatas"][0] if results["metadatas"] else [])
]

if not services:
return []

# Create ranking prompt
messages = self._create_ranking_prompt(query, services)

# Get re-ranking from LLM
completion = self.client.chat.completions.create(
model="gpt-4",
messages=messages, # type: ignore
temperature=0,
)

response_content = completion.choices[0].message.content
if response_content is None:
raise ValueError("Received empty response from OpenAI API")

# Process and return results
return self._process_ranking_response(response_content, results)