From 74cb67499b649197a1c54e918f178125257abd61 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Wed, 28 Aug 2024 16:51:07 -0400 Subject: [PATCH] Loading medical notes work --- backend/api/main.py | 12 +- backend/api/notes/data.py | 24 ++ backend/api/notes/db.py | 72 ++++ backend/api/routes.py | 67 ++++ backend/clinical_llm_service/api/main.py | 22 +- backend/instruction_tune_dataset.ipynb | 488 ----------------------- backend/poetry.lock | 98 ++++- backend/pyproject.toml | 1 + docker-compose.dev.yml | 5 +- frontend/src/app/context/endpoint.tsx | 178 --------- frontend/src/app/context/model.tsx | 161 -------- frontend/src/app/home/page.tsx | 166 +++++--- frontend/src/app/layout.tsx | 8 +- frontend/src/app/login/page.tsx | 2 +- scripts/load_mimic_data.py | 50 +++ 15 files changed, 458 insertions(+), 896 deletions(-) create mode 100644 backend/api/notes/data.py create mode 100644 backend/api/notes/db.py delete mode 100644 backend/instruction_tune_dataset.ipynb delete mode 100644 frontend/src/app/context/endpoint.tsx delete mode 100644 frontend/src/app/context/model.tsx create mode 100644 scripts/load_mimic_data.py diff --git a/backend/api/main.py b/backend/api/main.py index 83e31b3..3943d86 100644 --- a/backend/api/main.py +++ b/backend/api/main.py @@ -6,6 +6,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from api.notes.db import check_database_connection from api.routes import router as api_router from api.users.crud import create_initial_admin from api.users.db import get_async_session, init_db @@ -33,9 +34,14 @@ async def startup_event() -> None: This function is called when the FastAPI application starts up. It initializes the database and creates an initial admin user if one doesn't already exist. """ - await init_db() - async for session in get_async_session(): - await create_initial_admin(session) + try: + await check_database_connection() + await init_db() + async for session in get_async_session(): + await create_initial_admin(session) + except Exception as e: + logger.error(f"Startup failed: {str(e)}") + raise @app.get("/") diff --git a/backend/api/notes/data.py b/backend/api/notes/data.py new file mode 100644 index 0000000..68c56bf --- /dev/null +++ b/backend/api/notes/data.py @@ -0,0 +1,24 @@ +"""Data models for medical notes.""" + +from pydantic import BaseModel, Field + +class MedicalNote(BaseModel): + """ + Represents a medical note. + + Attributes + ---------- + note_id : str + The unique identifier for the note. + subject_id : int + The subject (patient) identifier. + hadm_id : str + The hospital admission identifier. + text : str + The content of the medical note. + """ + + note_id: str = Field(..., description="Unique identifier for the note") + subject_id: int = Field(..., description="Subject (patient) identifier") + hadm_id: str = Field(..., description="Hospital admission identifier") + text: str = Field(..., description="Content of the medical note") \ No newline at end of file diff --git a/backend/api/notes/db.py b/backend/api/notes/db.py new file mode 100644 index 0000000..d49c004 --- /dev/null +++ b/backend/api/notes/db.py @@ -0,0 +1,72 @@ +"""Database module for medical notes.""" + +import logging +import os +from typing import AsyncGenerator + +from motor.motor_asyncio import AsyncIOMotorClient + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +MONGO_USERNAME = os.getenv("MONGO_USERNAME") +MONGO_PASSWORD = os.getenv("MONGO_PASSWORD") +MONGO_HOST = os.getenv("MONGO_HOST", "mongodb") +MONGO_PORT = os.getenv("MONGO_PORT", "27017") +MONGO_URL = f"mongodb://{MONGO_USERNAME}:{MONGO_PASSWORD}@{MONGO_HOST}:{MONGO_PORT}" +DB_NAME = "medical_db" + +async def get_database() -> AsyncGenerator[AsyncIOMotorClient, None]: + """ + Create and yield a database client. + + Yields + ------ + AsyncIOMotorClient + An asynchronous MongoDB client. + + Raises + ------ + ConnectionError + If unable to connect to the database. + """ + client = AsyncIOMotorClient(MONGO_URL) + try: + # Check the connection + await client.admin.command("ismaster") + logger.info("Successfully connected to the database") + yield client + except Exception as e: + logger.error(f"Unable to connect to the database: {str(e)}") + raise ConnectionError(f"Database connection failed: {str(e)}") + finally: + client.close() + logger.info("Database connection closed") + +async def check_database_connection(): + """ + Check the database connection on startup. + + Raises + ------ + ConnectionError + If unable to connect to the database. + """ + client = AsyncIOMotorClient(MONGO_URL) + try: + await client.admin.command("ismaster") + db = client[DB_NAME] + collections = await db.list_collection_names() + if "medical_notes" in collections: + logger.info( + f"Database connection check passed. Found 'medical_notes' collection in {DB_NAME}" + ) + else: + logger.warning(f"'medical_notes' collection not found in {DB_NAME}") + logger.info("Database connection check passed") + except Exception as e: + logger.error(f"Database connection check failed: {str(e)}") + raise ConnectionError(f"Database connection check failed: {str(e)}") + finally: + client.close() \ No newline at end of file diff --git a/backend/api/routes.py b/backend/api/routes.py index a920726..6f3c560 100644 --- a/backend/api/routes.py +++ b/backend/api/routes.py @@ -1,11 +1,15 @@ """Backend API routes.""" +import logging from datetime import timedelta from typing import Any, Dict, List from fastapi import APIRouter, Depends, HTTPException, Request, status +from motor.motor_asyncio import AsyncIOMotorClient from sqlalchemy.ext.asyncio import AsyncSession +from api.notes.data import MedicalNote +from api.notes.db import DB_NAME, get_database from api.users.auth import ( ACCESS_TOKEN_EXPIRE_MINUTES, authenticate_user, @@ -24,9 +28,72 @@ from api.users.utils import verify_password +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + router = APIRouter() +@router.get("/medical_notes/{patient_id}", response_model=List[MedicalNote]) +async def get_medical_notes( + patient_id: str, + db: AsyncIOMotorClient = Depends(get_database), + current_user: User = Depends(get_current_active_user), +) -> List[MedicalNote]: + """ + Retrieve medical notes for a specific patient. + + Parameters + ---------- + patient_id : str + The ID of the patient to fetch medical notes for. + db : AsyncIOMotorClient + The database client. + current_user : User + The authenticated user making the request. + + Returns + ------- + List[MedicalNote] + A list of medical notes for the specified patient. + + Raises + ------ + HTTPException + If no medical notes are found for the patient or if there's a database error. + """ + try: + patient_id_int = int(patient_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid patient ID format. Must be an integer." + ) + + try: + collection = db[DB_NAME].medical_notes + cursor = collection.find({"subject_id": patient_id_int}) + notes = await cursor.to_list(length=None) + + if not notes: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No medical notes found for this patient" + ) + + return [MedicalNote(**note) for note in notes] + except HTTPException: + raise + except Exception as e: + logger.error(f"Error retrieving medical notes for patient ID {patient_id_int}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An error occurred while retrieving medical notes" + ) + + @router.post("/auth/signin") async def signin( request: Request, diff --git a/backend/clinical_llm_service/api/main.py b/backend/clinical_llm_service/api/main.py index 8d5ad0a..d15fe0c 100644 --- a/backend/clinical_llm_service/api/main.py +++ b/backend/clinical_llm_service/api/main.py @@ -1,10 +1,12 @@ +import gc + import torch import transformers import uvicorn from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel -import gc + # Initialize FastAPI app app = FastAPI() @@ -18,7 +20,7 @@ torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True, - offload_folder="offload" + offload_folder="offload", ) # Enable gradient checkpointing for memory efficiency @@ -34,11 +36,13 @@ # System prompt system_prompt = """You are an expert and experienced from the healthcare and biomedical domain with extensive medical knowledge and practical experience. Your name is adrenaline AI. who's willing to help answer the user's query with explanation. In your explanation, leverage your deep medical expertise such as relevant anatomical structures, physiological processes, diagnostic criteria, treatment guidelines, or other pertinent medical concepts. Use precise medical terminology while still aiming to make the explanation clear and accessible to a general audience.""" + # Pydantic model for request body class Query(BaseModel): prompt: str context: str = "" + @app.post("/generate") async def generate_text(query: Query): try: @@ -47,11 +51,19 @@ async def generate_text(query: Query): # Create a generator function for streaming def generate(): - inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=1024, padding=True) + inputs = tokenizer( + full_prompt, + return_tensors="pt", + truncation=True, + max_length=1024, + padding=True, + ) input_ids = inputs.input_ids.to(model.device) attention_mask = inputs.attention_mask.to(model.device) - streamer = transformers.TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=10.0) + streamer = transformers.TextIteratorStreamer( + tokenizer, skip_prompt=True, timeout=10.0 + ) generation_kwargs = dict( input_ids=input_ids, attention_mask=attention_mask, @@ -66,6 +78,7 @@ def generate(): # Start the generation in a separate thread from threading import Thread + thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() @@ -80,5 +93,6 @@ def generate(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8003) diff --git a/backend/instruction_tune_dataset.ipynb b/backend/instruction_tune_dataset.ipynb deleted file mode 100644 index 230ba51..0000000 --- a/backend/instruction_tune_dataset.ipynb +++ /dev/null @@ -1,488 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "0d8ee493-bade-44ce-8fa6-5d81783ca9b1", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-08-27 21:47:45,046 \u001b[1;37mINFO\u001b[0m cycquery.orm - Database setup, ready to run queries!\n" - ] - }, - { - "data": { - "text/plain": [ - "['fhir_etl',\n", - " 'fhir_trm',\n", - " 'information_schema',\n", - " 'mimic_fhir',\n", - " 'mimiciv_derived',\n", - " 'mimiciv_ed',\n", - " 'mimiciv_hosp',\n", - " 'mimiciv_icu',\n", - " 'mimiciv_note',\n", - " 'public']" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import requests\n", - "import cycquery.ops as qo\n", - "from cycquery import MIMICIVQuerier\n", - "import httpx\n", - "import json\n", - "import asyncio\n", - "import nest_asyncio\n", - "import re\n", - "\n", - "\n", - "def preprocess_medical_note(note):\n", - " # Remove only excessive whitespace within lines, preserving newlines\n", - " lines = note.split('\\n')\n", - " cleaned_lines = [' '.join(line.split()) for line in lines]\n", - " return '\\n'.join(cleaned_lines)\n", - " \n", - "\n", - "querier = MIMICIVQuerier(\n", - " dbms=\"postgresql\",\n", - " port=5432,\n", - " host=\"localhost\",\n", - " database=\"mimiciv-2.0\",\n", - " user=\"postgres\",\n", - " password=\"pwd\",\n", - ")\n", - "# List all schemas.\n", - "querier.list_schemas()" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "e4c68607-e743-4b67-a774-50e13f5a6961", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-08-27 21:50:03,648 \u001b[1;37mINFO\u001b[0m cycquery.orm - Query returned successfully!\n", - "2024-08-27 21:50:03,649 \u001b[1;37mINFO\u001b[0m cycquery.utils.profile - Finished executing function run_query in 0.355598 s\n" - ] - } - ], - "source": [ - "# List all tables.\n", - "querier.list_tables(\"mimiciv_note\")\n", - "ops = qo.Sequential(qo.DropEmpty(\"text\"), qo.DropNulls(\"text\"))\n", - "notes = querier.mimiciv_note.discharge().ops(ops).run(limit=100)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "fdc2a79a-5de1-4d42-b7a0-f33cd9558b1f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
note_idsubject_idhadm_idnote_typenote_seqcharttimestoretimetext
010000032-DS-211000003222595853DS212180-05-07 00:00:002180-05-09 15:26:00Name: ___ Unit No: __...
110000032-DS-221000003222841357DS222180-06-27 00:00:002180-07-01 10:15:00Name: ___ Unit No: __...
210000032-DS-231000003229079034DS232180-07-25 00:00:002180-07-25 21:42:00Name: ___ Unit No: __...
310000032-DS-241000003225742920DS242180-08-07 00:00:002180-08-10 05:43:00Name: ___ Unit No: __...
410000084-DS-171000008423052089DS172160-11-25 00:00:002160-11-25 15:09:00Name: ___ Unit No: ___...
...........................
9510002430-DS-71000243026295318DS72129-06-24 00:00:002129-06-24 16:51:00Name: ___ Unit No: ___\n", - " \n", - "Adm...
9610002443-DS-151000244321329021DS152183-10-20 00:00:002183-10-20 19:18:00Name: ___ Unit No: ...
9710002495-DS-131000249524982426DS132141-05-29 00:00:002141-05-30 02:29:00Name: ___ Unit No: ___\n", - " ...
9810002528-DS-171000252823193578DS172168-12-20 00:00:002168-12-20 17:28:00Name: ___ Unit No: ___\n", - " \n", - "...
9910002528-DS-181000252828605730DS182170-03-18 00:00:002170-03-18 16:39:00Name: ___ Unit No: ___\n", - " \n", - "...
\n", - "

100 rows × 8 columns

\n", - "
" - ], - "text/plain": [ - " note_id subject_id hadm_id note_type note_seq \\\n", - "0 10000032-DS-21 10000032 22595853 DS 21 \n", - "1 10000032-DS-22 10000032 22841357 DS 22 \n", - "2 10000032-DS-23 10000032 29079034 DS 23 \n", - "3 10000032-DS-24 10000032 25742920 DS 24 \n", - "4 10000084-DS-17 10000084 23052089 DS 17 \n", - ".. ... ... ... ... ... \n", - "95 10002430-DS-7 10002430 26295318 DS 7 \n", - "96 10002443-DS-15 10002443 21329021 DS 15 \n", - "97 10002495-DS-13 10002495 24982426 DS 13 \n", - "98 10002528-DS-17 10002528 23193578 DS 17 \n", - "99 10002528-DS-18 10002528 28605730 DS 18 \n", - "\n", - " charttime storetime \\\n", - "0 2180-05-07 00:00:00 2180-05-09 15:26:00 \n", - "1 2180-06-27 00:00:00 2180-07-01 10:15:00 \n", - "2 2180-07-25 00:00:00 2180-07-25 21:42:00 \n", - "3 2180-08-07 00:00:00 2180-08-10 05:43:00 \n", - "4 2160-11-25 00:00:00 2160-11-25 15:09:00 \n", - ".. ... ... \n", - "95 2129-06-24 00:00:00 2129-06-24 16:51:00 \n", - "96 2183-10-20 00:00:00 2183-10-20 19:18:00 \n", - "97 2141-05-29 00:00:00 2141-05-30 02:29:00 \n", - "98 2168-12-20 00:00:00 2168-12-20 17:28:00 \n", - "99 2170-03-18 00:00:00 2170-03-18 16:39:00 \n", - "\n", - " text \n", - "0 \n", - "Name: ___ Unit No: __... \n", - "1 \n", - "Name: ___ Unit No: __... \n", - "2 \n", - "Name: ___ Unit No: __... \n", - "3 \n", - "Name: ___ Unit No: __... \n", - "4 \n", - "Name: ___ Unit No: ___... \n", - ".. ... \n", - "95 \n", - "Name: ___ Unit No: ___\n", - " \n", - "Adm... \n", - "96 \n", - "Name: ___ Unit No: ... \n", - "97 \n", - "Name: ___ Unit No: ___\n", - " ... \n", - "98 \n", - "Name: ___ Unit No: ___\n", - " \n", - "... \n", - "99 \n", - "Name: ___ Unit No: ___\n", - " \n", - "... \n", - "\n", - "[100 rows x 8 columns]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "notes" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8555774c-ac51-4e91-9340-e9944923fc83", - "metadata": {}, - "outputs": [], - "source": [ - "example_note = \"\"\"\n", - "Patient Name: Jane Smith\n", - "Age: 58\n", - "Gender: Female\n", - "Date of Visit: August 25, 2024\n", - "\n", - "Chief Complaint: Persistent headaches and dizziness for the past 3 weeks\n", - "\n", - "History of Present Illness:\n", - "Jane Smith, a 58-year-old female, presents with a 3-week history of persistent headaches and dizziness. The headaches are described as throbbing, primarily located in the frontal and temporal regions, and rate 7/10 on a pain scale. They occur daily, lasting for several hours, and are partially relieved by over-the-counter pain medications. The patient reports associated symptoms of nausea, photophobia, and phonophobia. The dizziness is described as a sensation of the room spinning, exacerbated by sudden movements, and lasting for a few minutes at a time.\n", - "\n", - "Past Medical History:\n", - "1. Hypertension (diagnosed 10 years ago)\n", - "2. Type 2 Diabetes Mellitus (diagnosed 5 years ago)\n", - "3. Hyperlipidemia\n", - "\n", - "Medications:\n", - "1. Lisinopril 20mg daily\n", - "2. Metformin 1000mg twice daily\n", - "3. Atorvastatin 40mg daily\n", - "4. Aspirin 81mg daily\n", - "\n", - "Allergies: Penicillin (rash)\n", - "\n", - "Social History:\n", - "- Non-smoker\n", - "- Occasional alcohol use (1-2 glasses of wine per week)\n", - "- Works as a high school teacher\n", - "- Lives with husband, two children have moved out\n", - "\n", - "Family History:\n", - "- Father: Myocardial infarction at age 65\n", - "- Mother: Stroke at age 70\n", - "- Sister: Migraine headaches\n", - "\n", - "Review of Systems:\n", - "- General: Reports fatigue and decreased appetite\n", - "- HEENT: Denies vision changes, ear pain, or sinus congestion\n", - "- Cardiovascular: Denies chest pain or palpitations\n", - "- Respiratory: Denies shortness of breath or cough\n", - "- Gastrointestinal: Reports occasional nausea, denies vomiting or abdominal pain\n", - "- Musculoskeletal: Denies joint pain or muscle weakness\n", - "- Neurological: Reports headaches and dizziness as described above\n", - "\n", - "Physical Examination:\n", - "- Vital Signs: BP 145/90, HR 78, RR 16, Temp 37.0°C, SpO2 98% on room air\n", - "- General: Alert and oriented, appears uncomfortable\n", - "- HEENT: PERRLA, EOM intact, no nystagmus, tympanic membranes clear bilaterally\n", - "- Cardiovascular: Regular rate and rhythm, no murmurs\n", - "- Respiratory: Clear to auscultation bilaterally\n", - "- Neurological: Cranial nerves II-XII intact, normal gait, negative Romberg test\n", - "\n", - "Assessment and Plan:\n", - "1. Persistent headaches with associated symptoms suggestive of migraine\n", - " - Order MRI brain to rule out intracranial pathology\n", - " - Start sumatriptan 50mg as needed for acute attacks\n", - " - Discuss lifestyle modifications and headache diary\n", - "2. Vertigo, possibly peripheral in origin\n", - " - Refer to ENT for further evaluation\n", - " - Consider vestibular rehabilitation therapy\n", - "3. Hypertension, currently uncontrolled\n", - " - Increase lisinopril to 40mg daily\n", - " - Encourage home blood pressure monitoring\n", - " - Follow up in 2 weeks to reassess\n", - "4. Type 2 Diabetes Mellitus and Hyperlipidemia\n", - " - Continue current management\n", - " - Order HbA1c and lipid panel for next visit\n", - "\n", - "Follow-up: Schedule follow-up appointment in 2 weeks to reassess headaches, dizziness, and blood pressure control. Advise patient to return sooner if symptoms worsen or new symptoms develop.\n", - "\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32e23fe8-ad85-4ff8-8052-f6e655f3c70d", - "metadata": {}, - "outputs": [], - "source": [ - "import httpx\n", - "import asyncio\n", - "from IPython import get_ipython\n", - "\n", - "\n", - "async def generate_text():\n", - " url = 'http://localhost:8003/generate'\n", - " input_prompt = 'I am trying to curate a careful list of questions and answers that can be used to train a LLM that can understand EHR data and answer questions based on the data. Use the medical note to create question answer pairs. For example: What medications do the patient take? Provide the output as Question: \\n Answer: format.'\n", - "\n", - " payload = {\n", - " 'prompt': input_prompt,\n", - " 'context': preprocess_medical_note(example_note)\n", - " }\n", - "\n", - " timeout = httpx.Timeout(30.0, connect=60.0)\n", - " async with httpx.AsyncClient(timeout=timeout) as client:\n", - " try:\n", - " async with client.stream('POST', url, json=payload) as response:\n", - " if response.status_code == 200:\n", - " print(\"Response:\")\n", - " full_response = \"\"\n", - " async for chunk in response.aiter_text():\n", - " print(chunk, end='', flush=True)\n", - " full_response += chunk\n", - " print(\"\\n\\nFull Response:\", full_response)\n", - " else:\n", - " print('Error:', response.status_code, await response.text())\n", - " except httpx.ReadTimeout:\n", - " print(\"The request timed out. The server is taking too long to respond.\")\n", - " except httpx.ConnectTimeout:\n", - " print(\"Failed to connect to the server. Make sure the server is running and accessible.\")\n", - " except httpx.RequestError as exc:\n", - " print(f\"An error occurred while requesting {exc.request.url!r}.\")\n", - " except Exception as e:\n", - " print(f\"An unexpected error occurred: {str(e)}\")\n", - "\n", - "def run_async_code():\n", - " ipython = get_ipython()\n", - " if ipython is None:\n", - " # We're not in IPython, use the default event loop\n", - " asyncio.run(generate_text())\n", - " else:\n", - " # We're in IPython, use the IPython event loop\n", - " import nest_asyncio\n", - " nest_asyncio.apply()\n", - " asyncio.get_event_loop().run_until_complete(generate_text())\n", - "\n", - "# This allows the script to be imported without running the code\n", - "if __name__ == \"__main__\":\n", - " run_async_code()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/backend/poetry.lock b/backend/poetry.lock index 88b23fb..c5f6cf4 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -1681,6 +1681,30 @@ files = [ {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"}, ] +[[package]] +name = "motor" +version = "3.5.1" +description = "Non-blocking MongoDB driver for Tornado or asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "motor-3.5.1-py3-none-any.whl", hash = "sha256:f95a9ea0f011464235e0bd72910baa291db3a6009e617ac27b82f57885abafb8"}, + {file = "motor-3.5.1.tar.gz", hash = "sha256:1622bd7b39c3e6375607c14736f6e1d498128eadf6f5f93f8786cf17d37062ac"}, +] + +[package.dependencies] +pymongo = ">=4.5,<5" + +[package.extras] +aws = ["pymongo[aws] (>=4.5,<5)"] +docs = ["aiohttp", "readthedocs-sphinx-search (>=0.3,<1.0)", "sphinx (>=5.3,<8)", "sphinx-rtd-theme (>=2,<3)", "tornado"] +encryption = ["pymongo[encryption] (>=4.5,<5)"] +gssapi = ["pymongo[gssapi] (>=4.5,<5)"] +ocsp = ["pymongo[ocsp] (>=4.5,<5)"] +snappy = ["pymongo[snappy] (>=4.5,<5)"] +test = ["aiohttp (!=3.8.6)", "mockupdb", "pymongo[encryption] (>=4.5,<5)", "pytest (>=7)", "tornado (>=5)"] +zstd = ["pymongo[zstd] (>=4.5,<5)"] + [[package]] name = "msgpack" version = "1.0.8" @@ -2347,6 +2371,78 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] +[[package]] +name = "pymongo" +version = "4.8.0" +description = "Python driver for MongoDB " +optional = false +python-versions = ">=3.8" +files = [ + {file = "pymongo-4.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f2b7bec27e047e84947fbd41c782f07c54c30c76d14f3b8bf0c89f7413fac67a"}, + {file = "pymongo-4.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3c68fe128a171493018ca5c8020fc08675be130d012b7ab3efe9e22698c612a1"}, + {file = "pymongo-4.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:920d4f8f157a71b3cb3f39bc09ce070693d6e9648fb0e30d00e2657d1dca4e49"}, + {file = "pymongo-4.8.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52b4108ac9469febba18cea50db972605cc43978bedaa9fea413378877560ef8"}, + {file = "pymongo-4.8.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:180d5eb1dc28b62853e2f88017775c4500b07548ed28c0bd9c005c3d7bc52526"}, + {file = "pymongo-4.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aec2b9088cdbceb87e6ca9c639d0ff9b9d083594dda5ca5d3c4f6774f4c81b33"}, + {file = "pymongo-4.8.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0cf61450feadca81deb1a1489cb1a3ae1e4266efd51adafecec0e503a8dcd84"}, + {file = "pymongo-4.8.0-cp310-cp310-win32.whl", hash = "sha256:8b18c8324809539c79bd6544d00e0607e98ff833ca21953df001510ca25915d1"}, + {file = "pymongo-4.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:e5df28f74002e37bcbdfdc5109799f670e4dfef0fb527c391ff84f078050e7b5"}, + {file = "pymongo-4.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6b50040d9767197b77ed420ada29b3bf18a638f9552d80f2da817b7c4a4c9c68"}, + {file = "pymongo-4.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:417369ce39af2b7c2a9c7152c1ed2393edfd1cbaf2a356ba31eb8bcbd5c98dd7"}, + {file = "pymongo-4.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf821bd3befb993a6db17229a2c60c1550e957de02a6ff4dd0af9476637b2e4d"}, + {file = "pymongo-4.8.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9365166aa801c63dff1a3cb96e650be270da06e3464ab106727223123405510f"}, + {file = "pymongo-4.8.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc8b8582f4209c2459b04b049ac03c72c618e011d3caa5391ff86d1bda0cc486"}, + {file = "pymongo-4.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16e5019f75f6827bb5354b6fef8dfc9d6c7446894a27346e03134d290eb9e758"}, + {file = "pymongo-4.8.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b5802151fc2b51cd45492c80ed22b441d20090fb76d1fd53cd7760b340ff554"}, + {file = "pymongo-4.8.0-cp311-cp311-win32.whl", hash = "sha256:4bf58e6825b93da63e499d1a58de7de563c31e575908d4e24876234ccb910eba"}, + {file = "pymongo-4.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:b747c0e257b9d3e6495a018309b9e0c93b7f0d65271d1d62e572747f4ffafc88"}, + {file = "pymongo-4.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e6a720a3d22b54183352dc65f08cd1547204d263e0651b213a0a2e577e838526"}, + {file = "pymongo-4.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:31e4d21201bdf15064cf47ce7b74722d3e1aea2597c6785882244a3bb58c7eab"}, + {file = "pymongo-4.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6b804bb4f2d9dc389cc9e827d579fa327272cdb0629a99bfe5b83cb3e269ebf"}, + {file = "pymongo-4.8.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f2fbdb87fe5075c8beb17a5c16348a1ea3c8b282a5cb72d173330be2fecf22f5"}, + {file = "pymongo-4.8.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd39455b7ee70aabee46f7399b32ab38b86b236c069ae559e22be6b46b2bbfc4"}, + {file = "pymongo-4.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:940d456774b17814bac5ea7fc28188c7a1338d4a233efbb6ba01de957bded2e8"}, + {file = "pymongo-4.8.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:236bbd7d0aef62e64caf4b24ca200f8c8670d1a6f5ea828c39eccdae423bc2b2"}, + {file = "pymongo-4.8.0-cp312-cp312-win32.whl", hash = "sha256:47ec8c3f0a7b2212dbc9be08d3bf17bc89abd211901093e3ef3f2adea7de7a69"}, + {file = "pymongo-4.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:e84bc7707492f06fbc37a9f215374d2977d21b72e10a67f1b31893ec5a140ad8"}, + {file = "pymongo-4.8.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:519d1bab2b5e5218c64340b57d555d89c3f6c9d717cecbf826fb9d42415e7750"}, + {file = "pymongo-4.8.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:87075a1feb1e602e539bdb1ef8f4324a3427eb0d64208c3182e677d2c0718b6f"}, + {file = "pymongo-4.8.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77f53429515d2b3e86dcc83dadecf7ff881e538c168d575f3688698a8707b80a"}, + {file = "pymongo-4.8.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fdc20cd1e1141b04696ffcdb7c71e8a4a665db31fe72e51ec706b3bdd2d09f36"}, + {file = "pymongo-4.8.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:284d0717d1a7707744018b0b6ee7801b1b1ff044c42f7be7a01bb013de639470"}, + {file = "pymongo-4.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5bf0eb8b6ef40fa22479f09375468c33bebb7fe49d14d9c96c8fd50355188b0"}, + {file = "pymongo-4.8.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2ecd71b9226bd1d49416dc9f999772038e56f415a713be51bf18d8676a0841c8"}, + {file = "pymongo-4.8.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e0061af6e8c5e68b13f1ec9ad5251247726653c5af3c0bbdfbca6cf931e99216"}, + {file = "pymongo-4.8.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:658d0170f27984e0d89c09fe5c42296613b711a3ffd847eb373b0dbb5b648d5f"}, + {file = "pymongo-4.8.0-cp38-cp38-win32.whl", hash = "sha256:3ed1c316718a2836f7efc3d75b4b0ffdd47894090bc697de8385acd13c513a70"}, + {file = "pymongo-4.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:7148419eedfea9ecb940961cfe465efaba90595568a1fb97585fb535ea63fe2b"}, + {file = "pymongo-4.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e8400587d594761e5136a3423111f499574be5fd53cf0aefa0d0f05b180710b0"}, + {file = "pymongo-4.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:af3e98dd9702b73e4e6fd780f6925352237f5dce8d99405ff1543f3771201704"}, + {file = "pymongo-4.8.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de3a860f037bb51f968de320baef85090ff0bbb42ec4f28ec6a5ddf88be61871"}, + {file = "pymongo-4.8.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0fc18b3a093f3db008c5fea0e980dbd3b743449eee29b5718bc2dc15ab5088bb"}, + {file = "pymongo-4.8.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18c9d8f975dd7194c37193583fd7d1eb9aea0c21ee58955ecf35362239ff31ac"}, + {file = "pymongo-4.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:408b2f8fdbeca3c19e4156f28fff1ab11c3efb0407b60687162d49f68075e63c"}, + {file = "pymongo-4.8.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6564780cafd6abeea49759fe661792bd5a67e4f51bca62b88faab497ab5fe89"}, + {file = "pymongo-4.8.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d18d86bc9e103f4d3d4f18b85a0471c0e13ce5b79194e4a0389a224bb70edd53"}, + {file = "pymongo-4.8.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:9097c331577cecf8034422956daaba7ec74c26f7b255d718c584faddd7fa2e3c"}, + {file = "pymongo-4.8.0-cp39-cp39-win32.whl", hash = "sha256:d5428dbcd43d02f6306e1c3c95f692f68b284e6ee5390292242f509004c9e3a8"}, + {file = "pymongo-4.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:ef7225755ed27bfdb18730c68f6cb023d06c28f2b734597480fb4c0e500feb6f"}, + {file = "pymongo-4.8.0.tar.gz", hash = "sha256:454f2295875744dc70f1881e4b2eb99cdad008a33574bc8aaf120530f66c0cde"}, +] + +[package.dependencies] +dnspython = ">=1.16.0,<3.0.0" + +[package.extras] +aws = ["pymongo-auth-aws (>=1.1.0,<2.0.0)"] +docs = ["furo (==2023.9.10)", "readthedocs-sphinx-search (>=0.3,<1.0)", "sphinx (>=5.3,<8)", "sphinx-rtd-theme (>=2,<3)", "sphinxcontrib-shellcheck (>=1,<2)"] +encryption = ["certifi", "pymongo-auth-aws (>=1.1.0,<2.0.0)", "pymongocrypt (>=1.6.0,<2.0.0)"] +gssapi = ["pykerberos", "winkerberos (>=0.5.0)"] +ocsp = ["certifi", "cryptography (>=2.5)", "pyopenssl (>=17.2.0)", "requests (<3.0.0)", "service-identity (>=18.1.0)"] +snappy = ["python-snappy"] +test = ["pytest (>=7)"] +zstd = ["zstandard"] + [[package]] name = "pyparsing" version = "3.1.4" @@ -3631,4 +3727,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.11, <3.12" -content-hash = "15debef560f218caa646fc90838c8e9aa2e91588762d115626759af19acf178a" +content-hash = "2d42bd2ddcc99462e465d6789a775d9b8a479960ecb9a607979f809717963773" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 7b13e21..7761a80 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -21,6 +21,7 @@ passlib = "^1.7.4" pyjwt = "^2.9.0" sqlalchemy = { version = "^2.0.32", extras = ["asyncio"] } aiosqlite = "^0.20.0" +motor = "^3.5.1" [tool.poetry.group.jupyterlab] optional = true diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index f174f0e..21baa61 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -38,7 +38,10 @@ services: - BACKEND_PORT=${BACKEND_PORT} - FRONTEND_PORT=${FRONTEND_PORT} - JWT_SECRET_KEY=${JWT_SECRET_KEY} - - MONGO_URI=mongodb://${MONGO_USERNAME}:${MONGO_PASSWORD}@mongodb:27017 + - MONGO_USERNAME=${MONGO_USERNAME} + - MONGO_PASSWORD=${MONGO_PASSWORD} + - MONGO_HOST=mongodb + - MONGO_PORT=27017 volumes: - ./backend:/app networks: diff --git a/frontend/src/app/context/endpoint.tsx b/frontend/src/app/context/endpoint.tsx deleted file mode 100644 index 11da025..0000000 --- a/frontend/src/app/context/endpoint.tsx +++ /dev/null @@ -1,178 +0,0 @@ -'use client' - -import React, { createContext, useState, useContext, ReactNode, useCallback, useMemo, useEffect } from 'react'; -import { EndpointConfig } from '../configure/types/configure'; -import { ModelFacts } from '../configure/types/facts'; -import { useAuth } from './auth'; - -interface Endpoint { - name: string; - metrics: string[]; - models: string[]; -} - -interface EndpointContextType { - endpoints: Endpoint[]; - addEndpoint: (config: EndpointConfig) => Promise; - removeEndpoint: (name: string) => Promise; - addModelToEndpoint: (endpointName: string, modelName: string, modelVersion: string, isExistingModel: boolean) => Promise; - removeModelFromEndpoint: (endpointName: string, modelId: string) => Promise; - updateModelFacts: (modelId: string, modelFacts: ModelFacts) => Promise; - isLoading: boolean; -} - -const EndpointContext = createContext(undefined); - -export const useEndpointContext = () => { - const context = useContext(EndpointContext); - if (!context) { - throw new Error('useEndpointContext must be used within an EndpointProvider'); - } - return context; -}; - -export const EndpointProvider: React.FC<{ children: ReactNode }> = ({ children }) => { - const [endpoints, setEndpoints] = useState([]); - const [isLoading, setIsLoading] = useState(false); - const { getToken, isAuthenticated } = useAuth(); - - const apiRequest = useCallback(async (url: string, options: RequestInit = {}): Promise => { - const token = getToken(); - if (!token) { - throw new Error('No authentication token available'); - } - const response = await fetch(url, { - ...options, - headers: { - ...options.headers, - 'Authorization': `Bearer ${token}`, - }, - }); - if (!response.ok) { - throw new Error(`API request failed: ${response.statusText}`); - } - return response.json(); - }, [getToken]); - - const fetchEndpoints = useCallback(async () => { - if (!isAuthenticated()) return; - setIsLoading(true); - try { - const data = await apiRequest<{ endpoints: Endpoint[] }>('/api/endpoints'); - setEndpoints(data.endpoints); - } catch (error) { - console.error('Error fetching endpoints:', error); - } finally { - setIsLoading(false); - } - }, [apiRequest, isAuthenticated]); - - useEffect(() => { - fetchEndpoints(); - }, [fetchEndpoints]); - - const addEndpoint = useCallback(async (config: EndpointConfig) => { - setIsLoading(true); - try { - await apiRequest('/api/endpoints', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(config), - }); - await fetchEndpoints(); - } catch (error) { - console.error('Error adding endpoint:', error); - throw error; - } finally { - setIsLoading(false); - } - }, [apiRequest, fetchEndpoints]); - - const removeEndpoint = useCallback(async (name: string) => { - setIsLoading(true); - try { - await apiRequest(`/api/endpoints/${name}`, { method: 'DELETE' }); - await fetchEndpoints(); - } catch (error) { - console.error('Error removing endpoint:', error); - throw error; - } finally { - setIsLoading(false); - } - }, [apiRequest, fetchEndpoints]); - - const addModelToEndpoint = useCallback(async (endpointName: string, modelName: string, modelVersion: string, isExistingModel: boolean) => { - setIsLoading(true); - try { - const endpoint = endpoints.find(e => e.name === endpointName); - if (endpoint) { - const isDuplicate = endpoint.models.some(modelId => { - const [name, version] = modelId.split('|'); - return name === modelName && version === modelVersion; - }); - - if (isDuplicate) { - throw new Error('A model with the same name and version already exists in this endpoint.'); - } - } - - await apiRequest(`/api/endpoints/${endpointName}/models`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ name: modelName, version: modelVersion, isExistingModel }), - }); - await fetchEndpoints(); - } catch (error) { - console.error('Error adding model to endpoint:', error); - throw error; - } finally { - setIsLoading(false); - } - }, [endpoints, apiRequest, fetchEndpoints]); - - const updateModelFacts = useCallback(async (modelId: string, modelFacts: ModelFacts) => { - setIsLoading(true); - try { - await apiRequest(`/api/models/${modelId}/facts`, { - method: 'PUT', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(modelFacts), - }); - await fetchEndpoints(); - } catch (error) { - console.error('Error updating model facts:', error); - throw error; - } finally { - setIsLoading(false); - } - }, [apiRequest, fetchEndpoints]); - - const removeModelFromEndpoint = useCallback(async (endpointName: string, modelId: string) => { - setIsLoading(true); - try { - await apiRequest(`/api/endpoints/${endpointName}/models/${modelId}`, { method: 'DELETE' }); - await fetchEndpoints(); - } catch (error) { - console.error('Error removing model from endpoint:', error); - throw error; - } finally { - setIsLoading(false); - } - }, [apiRequest, fetchEndpoints]); - - const contextValue = useMemo(() => ({ - endpoints, - addEndpoint, - removeEndpoint, - addModelToEndpoint, - removeModelFromEndpoint, - updateModelFacts, - isLoading - }), [endpoints, addEndpoint, removeEndpoint, addModelToEndpoint, removeModelFromEndpoint, updateModelFacts, isLoading]); - - return ( - - {children} - - ); -}; diff --git a/frontend/src/app/context/model.tsx b/frontend/src/app/context/model.tsx deleted file mode 100644 index 418a800..0000000 --- a/frontend/src/app/context/model.tsx +++ /dev/null @@ -1,161 +0,0 @@ -'use client' - -import React, { createContext, useState, useContext, ReactNode, useCallback, useMemo, useEffect } from 'react'; -import { ModelFacts } from '../configure/types/facts'; -import { useAuth } from './auth'; - -interface ModelBasicInfo { - name: string; - version: string; -} - -interface ModelData { - id: string; - endpoints: string[]; - basic_info: ModelBasicInfo; - facts: ModelFacts | null; - overall_status: string; -} - -interface ModelContextType { - models: ModelData[]; - fetchModels: () => Promise; - getModelById: (id: string) => Promise; - updateModelFacts: (id: string, facts: ModelFacts) => Promise; - isLoading: boolean; -} - -const ModelContext = createContext(undefined); - -export const useModelContext = () => { - const context = useContext(ModelContext); - if (!context) { - throw new Error('useModelContext must be used within a ModelProvider'); - } - return context; -}; - -export const ModelProvider: React.FC<{ children: ReactNode }> = ({ children }) => { - const [models, setModels] = useState([]); - const [isLoading, setIsLoading] = useState(false); - const { getToken, isAuthenticated } = useAuth(); - - const apiRequest = useCallback(async (url: string, options: RequestInit = {}): Promise => { - const token = getToken(); - if (!token) { - throw new Error('No authentication token available'); - } - const response = await fetch(url, { - ...options, - headers: { - ...options.headers, - 'Authorization': `Bearer ${token}`, - }, - }); - if (!response.ok) { - throw new Error(`API request failed: ${response.statusText}`); - } - return response.json(); - }, [getToken]); - - const fetchModels = useCallback(async () => { - if (!isAuthenticated()) return; - setIsLoading(true); - try { - const data = await apiRequest>('/api/models'); - const modelArray = await Promise.all(Object.entries(data).map(async ([id, modelInfo]: [string, any]) => { - const safetyData = await apiRequest<{ overall_status: string }>(`/api/model/${id}/safety`); - return { - id, - ...modelInfo, - overall_status: safetyData.overall_status - }; - })); - setModels(modelArray); - } catch (error) { - console.error('Error fetching models:', error); - setModels([]); - } finally { - setIsLoading(false); - } - }, [apiRequest, isAuthenticated]); - - useEffect(() => { - fetchModels(); - }, [fetchModels]); - - const getModelById = useCallback(async (id: string): Promise => { - setIsLoading(true); - try { - const cachedModel = models.find(m => m.id === id); - if (cachedModel && cachedModel.facts) { - setIsLoading(false); - return cachedModel; - } - - const data = await apiRequest(`/api/models/${id}`); - const safetyData = await apiRequest<{ overall_status: string }>(`/api/model/${id}/safety`); - const factsData = await apiRequest(`/api/models/${id}/facts`); - - const newModel: ModelData = { - id, - ...data, - overall_status: safetyData.overall_status, - facts: factsData - }; - - setModels(prevModels => { - const index = prevModels.findIndex(m => m.id === id); - if (index !== -1) { - const updatedModels = [...prevModels]; - updatedModels[index] = newModel; - return updatedModels; - } - return [...prevModels, newModel]; - }); - - return newModel; - } catch (error) { - console.error('Error fetching model:', error); - throw error; - } finally { - setIsLoading(false); - } - }, [models, apiRequest]); - - const updateModelFacts = useCallback(async (id: string, facts: ModelFacts) => { - setIsLoading(true); - try { - await apiRequest(`/api/models/${id}/facts`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify(facts), - }); - - setModels(prevModels => prevModels.map(model => - model.id === id ? { ...model, facts } : model - )); - } catch (error) { - console.error('Error updating model facts:', error); - throw error; - } finally { - setIsLoading(false); - } - }, [apiRequest]); - - const contextValue = useMemo(() => ({ - models, - fetchModels, - getModelById, - updateModelFacts, - isLoading - }), [models, fetchModels, getModelById, updateModelFacts, isLoading]); - - return ( - - {children} - - ); -}; diff --git a/frontend/src/app/home/page.tsx b/frontend/src/app/home/page.tsx index 1253869..f15828b 100644 --- a/frontend/src/app/home/page.tsx +++ b/frontend/src/app/home/page.tsx @@ -1,5 +1,5 @@ 'use client' -import React, { useState, useEffect } from 'react' +import React, { useState } from 'react' import { Box, Text, @@ -8,54 +8,102 @@ import { VStack, useColorModeValue, Button, - Textarea, Input, Container, Divider, + Table, + Thead, + Tbody, + Tr, + Th, + Td, + useToast, } from '@chakra-ui/react' import Sidebar from '../components/sidebar' import { withAuth } from '../components/with-auth' +interface MedicalNote { + note_id: string; + subject_id: number; + hadm_id: string; + text: string; +} + function HomePage() { - const [medicalNote, setMedicalNote] = useState('') - const [userPrompt, setUserPrompt] = useState('') - const [response, setResponse] = useState('') + const [patientId, setPatientId] = useState('') + const [medicalNotes, setMedicalNotes] = useState([]) + const [isLoading, setIsLoading] = useState(false) const bgColor = useColorModeValue('gray.50', 'gray.900') const cardBgColor = useColorModeValue('white', 'gray.800') const textColor = useColorModeValue('gray.800', 'gray.100') - const accentColor = useColorModeValue('blue.500', 'blue.300') - useEffect(() => { - // Load a sample medical note or fetch from an API - const sampleNote = "Patient presents with..." - setMedicalNote(sampleNote) - }, []) + const toast = useToast() - const handleSubmit = async () => { - // Combine the medical note and user prompt - const combinedPrompt = `Medical Note: ${medicalNote}\n\nUser Query: ${userPrompt}` + const loadMedicalNotes = async () => { + if (!patientId.trim()) { + toast({ + title: "Error", + description: "Please enter a patient ID", + status: "error", + duration: 3000, + isClosable: true, + }) + return + } + + // Validate that patientId is a number + if (isNaN(Number(patientId))) { + toast({ + title: "Error", + description: "Patient ID must be a number", + status: "error", + duration: 3000, + isClosable: true, + }) + return + } - // Send to LLM backend endpoint + setIsLoading(true) try { - const result = await sendToLLM(combinedPrompt) - setResponse(result) + const response = await fetch(`/api/medical_notes/${patientId}`, { + headers: { + 'Authorization': `Bearer ${localStorage.getItem('token')}`, + }, + }) + + if (!response.ok) { + if (response.status === 404) { + throw new Error('No medical notes found for this patient') + } else { + throw new Error('Failed to fetch medical notes') + } + } + + const data = await response.json() + setMedicalNotes(data) + toast({ + title: "Success", + description: "Medical notes loaded successfully", + status: "success", + duration: 3000, + isClosable: true, + }) } catch (error) { - console.error('Error querying LLM:', error) - setResponse('An error occurred while processing your request.') + console.error('Error loading medical notes:', error) + toast({ + title: "Error", + description: error.message || "An error occurred while loading medical notes", + status: "error", + duration: 3000, + isClosable: true, + }) + setMedicalNotes([]) + } finally { + setIsLoading(false) } } - const sendToLLM = async (prompt) => { - // Implement your LLM API call here - // This is a placeholder function - return new Promise((resolve) => { - setTimeout(() => { - resolve('This is a sample response from the LLM.') - }, 1000) - }) - } - return ( @@ -63,34 +111,48 @@ function HomePage() { - Clinical LLM Dataset Curation - Create and curate question-answer pairs for clinical LLMs. + Medical Notes Dashboard + Load and view medical notes for a specific patient. - Medical Note -