Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reranker implementation #20

Merged
merged 20 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ Health Recommendations
"query": "I need mental health support",
"latitude": 43.6532,
"longitude": -79.3832,
"radius": 5000
"radius": 5000,
"rerank": false
}

:<json string query: The user's health-related query (required)
:<json number latitude: Optional latitude for location-based search
:<json number longitude: Optional longitude for location-based search
:<json number radius: Optional search radius in meters
:<json number radius: Optional search radius in meters (default: 5000)
:<json boolean rerank: Optional flag to enable/disable reranking of the services (default: false)
:>json string recommendation: Generated recommendation text
:>json array services: List of relevant health services

**Response Body**
a-kore marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
2 changes: 1 addition & 1 deletion eval/evaluate_topkacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import List, Dict, Any, DefaultDict


def load_embeddings(path: str) -> torch.Tensor:
def load_embeddings(path: str) -> Any:
return torch.load(path)


Expand Down
4 changes: 4 additions & 0 deletions health_rec/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@ class Config:
CHROMA_PORT: int = 8000
COLLECTION_NAME: str = getenv("COLLECTION_NAME", "211_gta")
RELEVANCY_WEIGHT: float = float(getenv("RELEVANCY_WEIGHT", "0.5"))
MAX_CONTEXT_LENGTH: int = 300
TOP_K: int = 5
RERANKER_MAX_CONTEXT_LENGTH: int = 150
RERANKER_MAX_SERVICES: int = 20
3 changes: 3 additions & 0 deletions health_rec/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,15 @@ class Query(BaseModel):
The latitude coordinate of the user.
radius : Optional[float]
The radius of the search.
rerank : Optional[bool]
Whether to use reranking for the recommendations.
"""

query: str
latitude: Optional[float] = Field(default=None)
longitude: Optional[float] = Field(default=None)
radius: Optional[float] = Field(default=None)
rerank: Optional[bool] = Field(default=False)


class RefineRequest(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion health_rec/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __call__(self, texts: Documents) -> Embeddings:
"""
try:
response = self.client.embeddings.create(input=texts, model=self.model)
return [data.embedding for data in response.data]
return [data.embedding for data in response.data] # type: ignore
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
raise
Expand Down
2 changes: 1 addition & 1 deletion health_rec/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion health_rec/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ python = "^3.11"
fastapi = "^0.115.2"
uvicorn = "^0.30.6"
openai = "^1.45.1"
chromadb = "^0.5.5"
chromadb = "0.5.15"
python-dotenv = "^1.0.1"

[tool.poetry.group.test]
Expand Down
Loading