Skip to content

Commit

Permalink
Improvements and type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Nov 21, 2024
1 parent bdf5646 commit 77fef4c
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 40 deletions.
36 changes: 7 additions & 29 deletions adrenaline/api/patients/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

import json
import logging
import os
from typing import Tuple

from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence
from langchain_openai import ChatOpenAI

from api.pages.data import Answer
from api.patients.llm import LLM
from api.patients.prompts import (
general_answer_template,
patient_answer_template,
Expand All @@ -21,27 +20,6 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set up OpenAI client with custom endpoint
LLM_SERVICE_URL = os.getenv("LLM_SERVICE_URL")
if not LLM_SERVICE_URL:
raise ValueError("LLM_SERVICE_URL is not set")
logger.info(f"LLM_SERVICE_URL is set to: {LLM_SERVICE_URL}")

os.environ["OPENAI_API_KEY"] = "EMPTY"

# Initialize LLM with increased timeout
try:
llm = ChatOpenAI(
base_url=LLM_SERVICE_URL,
model_name="Meta-Llama-3.1-70B-Instruct",
temperature=0.3,
max_tokens=4096,
request_timeout=60,
)
logger.info("ChatOpenAI initialized successfully")
except Exception as e:
logger.error(f"Error initializing ChatOpenAI: {str(e)}")
raise

answer_parser = PydanticOutputParser(pydantic_object=Answer)
patient_answer_prompt = PromptTemplate(
Expand All @@ -54,8 +32,8 @@
)

# Initialize the LLMChains
patient_answer_chain = RunnableSequence(patient_answer_prompt | llm)
general_answer_chain = RunnableSequence(general_answer_prompt | llm)
patient_answer_chain = RunnableSequence(patient_answer_prompt | LLM)
general_answer_chain = RunnableSequence(general_answer_prompt | LLM)


def parse_llm_output_answer(output: str) -> Tuple[str, str]:
Expand Down Expand Up @@ -139,7 +117,7 @@ async def generate_answer(
raise


async def test_llm_connection():
async def test_llm_connection() -> bool:
"""Test the connection to the LLM.
Returns
Expand All @@ -158,12 +136,12 @@ async def test_llm_connection():
return False


async def initialize_llm():
async def initialize_llm() -> bool:
"""Initialize the LLM.
Returns
-------
bool
True if the connection is successful, False otherwise.
True if the LLM is initialized successfully, False otherwise.
"""
await test_llm_connection()
return await test_llm_connection()
2 changes: 1 addition & 1 deletion adrenaline/api/patients/ehr.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def fetch_latest_medications(self, patient_id: int) -> str:
for event in processed_events
if (
event["event_type"] == "MEDICATION"
and event["encounter_id"] == latest_encounter
and event["timestamp"] > latest_timestamp
)
}

Expand Down
33 changes: 33 additions & 0 deletions adrenaline/api/patients/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""LLM module for patients API."""

import logging
import os

from langchain_openai import ChatOpenAI


# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set up OpenAI client with custom endpoint
LLM_SERVICE_URL = os.getenv("LLM_SERVICE_URL")
if not LLM_SERVICE_URL:
raise ValueError("LLM_SERVICE_URL is not set")
logger.info(f"LLM_SERVICE_URL is set to: {LLM_SERVICE_URL}")

os.environ["OPENAI_API_KEY"] = "EMPTY"

# Initialize LLM with increased timeout
try:
LLM = ChatOpenAI(
base_url=LLM_SERVICE_URL,
model_name="Meta-Llama-3.1-70B-Instruct",
temperature=0.3,
max_tokens=4096,
request_timeout=60,
)
logger.info("ChatOpenAI initialized successfully")
except Exception as e:
logger.error(f"Error initializing ChatOpenAI: {str(e)}")
raise
3 changes: 3 additions & 0 deletions adrenaline/api/patients/workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""EHR workflow functions."""


4 changes: 2 additions & 2 deletions adrenaline/api/routes/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
from datetime import datetime
from typing import Dict, List
from typing import Any, Dict, List

from fastapi import APIRouter, Body, Depends, HTTPException
from motor.motor_asyncio import AsyncIOMotorDatabase
Expand Down Expand Up @@ -109,7 +109,7 @@ async def format_medications(
@router.post("/generate_answer")
async def generate_answer_endpoint(
query: Query = Body(...), # noqa: B008
db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008
db: AsyncIOMotorDatabase[Any] = Depends(get_database), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> Dict[str, str]:
"""Generate an answer using RAG."""
Expand Down
12 changes: 6 additions & 6 deletions adrenaline/api/routes/pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from datetime import UTC, datetime
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from bson import ObjectId
from fastapi import APIRouter, Body, Depends, HTTPException
Expand Down Expand Up @@ -32,7 +32,7 @@ class CreatePageRequest(BaseModel):
@router.post("/pages/create")
async def create_page(
request: CreatePageRequest,
db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008
db: AsyncIOMotorDatabase[Any] = Depends(get_database), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> Dict[str, str]:
"""Create a new page for a user."""
Expand Down Expand Up @@ -63,9 +63,9 @@ async def append_to_page(
page_id: str,
question: str = Body(...),
answer: str = Body(...),
db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008
db: AsyncIOMotorDatabase[Any] = Depends(get_database), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> dict:
) -> Dict[str, str]:
"""Append a follow-up question and answer to an existing page."""
existing_page = await db.pages.find_one(
{"_id": ObjectId(page_id), "user_id": str(current_user.id)}
Expand All @@ -89,7 +89,7 @@ async def append_to_page(

@router.get("/pages/history", response_model=List[Page])
async def get_user_page_history(
db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008
db: AsyncIOMotorDatabase[Any] = Depends(get_database), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> List[Page]:
"""Retrieve all pages for the current user.
Expand All @@ -114,7 +114,7 @@ async def get_user_page_history(
@router.get("/pages/{page_id}", response_model=Page)
async def get_page(
page_id: str,
db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008
db: AsyncIOMotorDatabase[Any] = Depends(get_database), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> Page:
"""Retrieve a specific page.
Expand Down
5 changes: 4 additions & 1 deletion scripts/cot_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Test a chain of thought endpoint."""

import os
from typing import Dict
from langchain_openai import OpenAI
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
Expand Down Expand Up @@ -47,7 +50,7 @@ class Query(BaseModel):

# Define endpoint
@app.post("/cot")
async def chain_of_thought(query: Query):
async def chain_of_thought(query: Query) -> Dict[str, str]:
try:
result = chain.invoke({"query": query.text})

Expand Down
2 changes: 1 addition & 1 deletion services/embedding/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def create_embeddings(request: EmbeddingRequest) -> Dict[str, List[List[fl
raise HTTPException(status_code=500, detail=str(e)) from e


def initialize_model():
def initialize_model() -> None:
"""Initialize the model."""
global model
model = load_model()

0 comments on commit 77fef4c

Please sign in to comment.