Skip to content

Commit

Permalink
Loading medical notes work
Browse files Browse the repository at this point in the history
  • Loading branch information
actions-user committed Aug 28, 2024
1 parent faaca0f commit 74cb674
Show file tree
Hide file tree
Showing 15 changed files with 458 additions and 896 deletions.
12 changes: 9 additions & 3 deletions backend/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from api.notes.db import check_database_connection
from api.routes import router as api_router
from api.users.crud import create_initial_admin
from api.users.db import get_async_session, init_db
Expand Down Expand Up @@ -33,9 +34,14 @@ async def startup_event() -> None:
This function is called when the FastAPI application starts up. It initializes
the database and creates an initial admin user if one doesn't already exist.
"""
await init_db()
async for session in get_async_session():
await create_initial_admin(session)
try:
await check_database_connection()
await init_db()
async for session in get_async_session():
await create_initial_admin(session)
except Exception as e:
logger.error(f"Startup failed: {str(e)}")
raise


@app.get("/")
Expand Down
24 changes: 24 additions & 0 deletions backend/api/notes/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Data models for medical notes."""

from pydantic import BaseModel, Field

class MedicalNote(BaseModel):
"""
Represents a medical note.
Attributes
----------
note_id : str
The unique identifier for the note.
subject_id : int
The subject (patient) identifier.
hadm_id : str
The hospital admission identifier.
text : str
The content of the medical note.
"""

note_id: str = Field(..., description="Unique identifier for the note")
subject_id: int = Field(..., description="Subject (patient) identifier")
hadm_id: str = Field(..., description="Hospital admission identifier")
text: str = Field(..., description="Content of the medical note")
72 changes: 72 additions & 0 deletions backend/api/notes/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Database module for medical notes."""

import logging
import os
from typing import AsyncGenerator

from motor.motor_asyncio import AsyncIOMotorClient

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

MONGO_USERNAME = os.getenv("MONGO_USERNAME")
MONGO_PASSWORD = os.getenv("MONGO_PASSWORD")
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb")
MONGO_PORT = os.getenv("MONGO_PORT", "27017")
MONGO_URL = f"mongodb://{MONGO_USERNAME}:{MONGO_PASSWORD}@{MONGO_HOST}:{MONGO_PORT}"
DB_NAME = "medical_db"

async def get_database() -> AsyncGenerator[AsyncIOMotorClient, None]:
"""
Create and yield a database client.
Yields
------
AsyncIOMotorClient
An asynchronous MongoDB client.
Raises
------
ConnectionError
If unable to connect to the database.
"""
client = AsyncIOMotorClient(MONGO_URL)
try:
# Check the connection
await client.admin.command("ismaster")
logger.info("Successfully connected to the database")
yield client
except Exception as e:
logger.error(f"Unable to connect to the database: {str(e)}")
raise ConnectionError(f"Database connection failed: {str(e)}")
finally:
client.close()
logger.info("Database connection closed")

async def check_database_connection():
"""
Check the database connection on startup.
Raises
------
ConnectionError
If unable to connect to the database.
"""
client = AsyncIOMotorClient(MONGO_URL)
try:
await client.admin.command("ismaster")
db = client[DB_NAME]
collections = await db.list_collection_names()
if "medical_notes" in collections:
logger.info(
f"Database connection check passed. Found 'medical_notes' collection in {DB_NAME}"
)
else:
logger.warning(f"'medical_notes' collection not found in {DB_NAME}")
logger.info("Database connection check passed")
except Exception as e:
logger.error(f"Database connection check failed: {str(e)}")
raise ConnectionError(f"Database connection check failed: {str(e)}")
finally:
client.close()
67 changes: 67 additions & 0 deletions backend/api/routes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Backend API routes."""

import logging
from datetime import timedelta
from typing import Any, Dict, List

from fastapi import APIRouter, Depends, HTTPException, Request, status
from motor.motor_asyncio import AsyncIOMotorClient
from sqlalchemy.ext.asyncio import AsyncSession

from api.notes.data import MedicalNote
from api.notes.db import DB_NAME, get_database
from api.users.auth import (
ACCESS_TOKEN_EXPIRE_MINUTES,
authenticate_user,
Expand All @@ -24,9 +28,72 @@
from api.users.utils import verify_password


# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


router = APIRouter()


@router.get("/medical_notes/{patient_id}", response_model=List[MedicalNote])
async def get_medical_notes(
patient_id: str,
db: AsyncIOMotorClient = Depends(get_database),
current_user: User = Depends(get_current_active_user),
) -> List[MedicalNote]:
"""
Retrieve medical notes for a specific patient.
Parameters
----------
patient_id : str
The ID of the patient to fetch medical notes for.
db : AsyncIOMotorClient
The database client.
current_user : User
The authenticated user making the request.
Returns
-------
List[MedicalNote]
A list of medical notes for the specified patient.
Raises
------
HTTPException
If no medical notes are found for the patient or if there's a database error.
"""
try:
patient_id_int = int(patient_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid patient ID format. Must be an integer."
)

try:
collection = db[DB_NAME].medical_notes
cursor = collection.find({"subject_id": patient_id_int})
notes = await cursor.to_list(length=None)

if not notes:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No medical notes found for this patient"
)

return [MedicalNote(**note) for note in notes]
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving medical notes for patient ID {patient_id_int}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while retrieving medical notes"
)


@router.post("/auth/signin")
async def signin(
request: Request,
Expand Down
22 changes: 18 additions & 4 deletions backend/clinical_llm_service/api/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import gc

import torch
import transformers
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import gc


# Initialize FastAPI app
app = FastAPI()
Expand All @@ -18,7 +20,7 @@
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True,
offload_folder="offload"
offload_folder="offload",
)

# Enable gradient checkpointing for memory efficiency
Expand All @@ -34,11 +36,13 @@
# System prompt
system_prompt = """You are an expert and experienced from the healthcare and biomedical domain with extensive medical knowledge and practical experience. Your name is adrenaline AI. who's willing to help answer the user's query with explanation. In your explanation, leverage your deep medical expertise such as relevant anatomical structures, physiological processes, diagnostic criteria, treatment guidelines, or other pertinent medical concepts. Use precise medical terminology while still aiming to make the explanation clear and accessible to a general audience."""


# Pydantic model for request body
class Query(BaseModel):
prompt: str
context: str = ""


@app.post("/generate")
async def generate_text(query: Query):
try:
Expand All @@ -47,11 +51,19 @@ async def generate_text(query: Query):

# Create a generator function for streaming
def generate():
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=1024, padding=True)
inputs = tokenizer(
full_prompt,
return_tensors="pt",
truncation=True,
max_length=1024,
padding=True,
)
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)

streamer = transformers.TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=10.0)
streamer = transformers.TextIteratorStreamer(
tokenizer, skip_prompt=True, timeout=10.0
)
generation_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
Expand All @@ -66,6 +78,7 @@ def generate():

# Start the generation in a separate thread
from threading import Thread

thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

Expand All @@ -80,5 +93,6 @@ def generate():
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8003)
Loading

0 comments on commit 74cb674

Please sign in to comment.