Skip to content

Commit

Permalink
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 1, 2024
1 parent a877363 commit 241b0fc
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 59 deletions.
1 change: 0 additions & 1 deletion adrenaline/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from api.patients.answer import initialize_llm
from api.patients.db import check_database_connection
from api.routes.answer import router as answer_router
from api.routes.auth import router as auth_router
Expand Down
30 changes: 17 additions & 13 deletions adrenaline/api/patients/rag.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""RAG for patients and cohort search."""

import os
import asyncio
import logging
import os
from typing import Any, Dict, List, Tuple

import chromadb
import httpx
from chromadb.config import Settings


COLLECTION_NAME = "patient_notes"
CHROMA_HOST = "localhost"
CHROMA_PORT = os.getenv("CHROMA_SERVICE_PORT", 8000)
Expand Down Expand Up @@ -71,10 +72,7 @@ def __init__(self, host: str, port: int):
self.port = port
self.collection_name = COLLECTION_NAME
self.client = chromadb.HttpClient(
Settings(
chroma_server_host=self.host,
chroma_server_http_port=self.port
)
Settings(chroma_server_host=self.host, chroma_server_http_port=self.port)
)
self.collection = None

Expand All @@ -83,7 +81,9 @@ def connect(self):
try:
self.collection = self.client.get_collection(self.collection_name)
except ValueError:
raise ValueError(f"Collection {self.collection_name} does not exist in ChromaDB")
raise ValueError(
f"Collection {self.collection_name} does not exist in ChromaDB"
)

def get_collection(self):
"""Get the collection."""
Expand All @@ -99,27 +99,29 @@ async def search(
) -> List[Dict[str, Any]]:
"""Retrieve the relevant notes from ChromaDB."""
collection = self.get_collection()

where_clause = {"patient_id": patient_id} if patient_id else None

results = await asyncio.to_thread(
collection.query,
query_embeddings=[query_vector],
n_results=top_k,
where=where_clause,
include=["metadatas", "distances"]
include=["metadatas", "distances"],
)

filtered_results = []
for idx, (metadata, distance) in enumerate(zip(results['metadatas'][0], results['distances'][0])):
for idx, (metadata, distance) in enumerate(
zip(results["metadatas"][0], results["distances"][0])
):
result = {
"patient_id": metadata["patient_id"],
"note_id": metadata["note_id"],
"note_text": metadata["note_text"],
"note_type": metadata["note_type"],
"timestamp": metadata["timestamp"],
"encounter_id": metadata["encounter_id"],
"distance": 1 - distance # Convert distance to similarity score
"distance": 1 - distance, # Convert distance to similarity score
}
filtered_results.append(result)

Expand All @@ -130,7 +132,9 @@ async def cohort_search(
self, query_vector: List[float], top_k: int = 2
) -> List[Tuple[int, Dict[str, Any]]]:
"""Retrieve the cohort search results from ChromaDB."""
search_results = await self.search(query_vector, top_k=top_k * 2) # Get more results initially
search_results = await self.search(
query_vector, top_k=top_k * 2
) # Get more results initially

# Group results by patient_id and keep only the top result for each patient
patient_results = {}
Expand Down Expand Up @@ -250,4 +254,4 @@ async def retrieve_relevant_notes(
logger.info(f"Retrieved {len(search_results)} relevant notes")
for i, result in enumerate(search_results):
logger.info(f"Result {i+1}: Distance = {result['distance']}")
return search_results
return search_results
43 changes: 22 additions & 21 deletions scripts/fetch_pmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,54 +3,55 @@
import csv
from datetime import datetime, timedelta


def fetch_pmc_articles(search_terms, days=7, max_results=100):
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/"

# Get the date range for the last week
end_date = datetime.now().strftime("%Y/%m/%d")
start_date = (datetime.now() - timedelta(days=days)).strftime("%Y/%m/%d")

# Construct the search query
search_query = f"({' OR '.join(search_terms)}) AND {start_date}:{end_date}[PDAT]"

# First, search for PMCIDs
search_url = f"{base_url}esearch.fcgi?db=pmc&term={search_query}&retmax={max_results}&usehistory=y"
search_response = requests.get(search_url)
search_root = ET.fromstring(search_response.content)

# Extract WebEnv and QueryKey
web_env = search_root.find("WebEnv").text
query_key = search_root.find("QueryKey").text

# Now, fetch the details for these PMCIDs
fetch_url = f"{base_url}efetch.fcgi?db=pmc&query_key={query_key}&WebEnv={web_env}&retmax={max_results}&retmode=xml"
fetch_response = requests.get(fetch_url)
fetch_root = ET.fromstring(fetch_response.content)

articles = []
for article in fetch_root.findall(".//article"):
pmcid = article.find(".//article-id[@pub-id-type='pmc']").text
title = article.find(".//article-title").text
abstract = article.find(".//abstract/p")
abstract_text = abstract.text if abstract is not None else "No abstract available"

articles.append({
"PMCID": pmcid,
"Title": title,
"Abstract": abstract_text
})

abstract_text = (
abstract.text if abstract is not None else "No abstract available"
)

articles.append({"PMCID": pmcid, "Title": title, "Abstract": abstract_text})

return articles


def save_to_csv(articles, filename):
with open(filename, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['PMCID', 'Title', 'Abstract']
with open(filename, "w", newline="", encoding="utf-8") as csvfile:
fieldnames = ["PMCID", "Title", "Abstract"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

writer.writeheader()
for article in articles:
writer.writerow(article)


if __name__ == "__main__":
search_terms = [
"medicine",
Expand All @@ -67,12 +68,12 @@ def save_to_csv(articles, filename):
"personalized medicine",
"clinical trials",
"medical imaging",
"genomics"
"genomics",
]

articles = fetch_pmc_articles(search_terms)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"pmc_medical_articles_{timestamp}.csv"
save_to_csv(articles, filename)
print(f"Saved {len(articles)} articles to {filename}")
print(f"Saved {len(articles)} articles to {filename}")
75 changes: 52 additions & 23 deletions scripts/load_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from rich.console import Console
from rich.logging import RichHandler
from rich.progress import Progress, TaskID
from pydantic import BaseModel, Field
from pydantic import BaseModel

# Configure logging with rich
logging.basicConfig(
level="INFO",
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler(rich_tracebacks=True)]
handlers=[RichHandler(rich_tracebacks=True)],
)
logger = logging.getLogger("rich")
console = Console()
Expand All @@ -30,6 +30,7 @@
NER_SERVICE_URL = "http://localhost:8003/extract_entities"
NER_SERVICE_TIMEOUT = 300 # 5 minutes


class Entity(BaseModel):
pretty_name: str
cui: str
Expand All @@ -47,11 +48,13 @@ class Entity(BaseModel):
id: int
meta_anns: Dict[str, Any]


class NERResponse(BaseModel):
note_id: str
text: str
entities: List[Entity]


class DatabaseManager:
def __init__(self, mongo_uri: str, db_name: str):
self.client: AsyncIOMotorClient = AsyncIOMotorClient(mongo_uri)
Expand All @@ -69,23 +72,30 @@ async def ensure_indexes(self) -> None:
await self.patients_collection.create_indexes(indexes)

async def get_all_notes(self) -> List[Dict[str, Any]]:
cursor = self.patients_collection.aggregate([
{"$unwind": "$notes"},
{"$project": {
"patient_id": 1,
"note_id": "$notes.note_id",
"text": "$notes.text",
"entities_exist": {"$ifNull": ["$notes.entities", False]}
}}
])
cursor = self.patients_collection.aggregate(
[
{"$unwind": "$notes"},
{
"$project": {
"patient_id": 1,
"note_id": "$notes.note_id",
"text": "$notes.text",
"entities_exist": {"$ifNull": ["$notes.entities", False]},
}
},
]
)
return await cursor.to_list(length=None)

async def update_note_with_entities(self, patient_id: int, note_id: str, entities: List[Entity]) -> None:
async def update_note_with_entities(
self, patient_id: int, note_id: str, entities: List[Entity]
) -> None:
await self.patients_collection.update_one(
{"patient_id": patient_id, "notes.note_id": note_id},
{"$set": {"notes.$.entities": [entity.dict() for entity in entities]}}
{"$set": {"notes.$.entities": [entity.dict() for entity in entities]}},
)


async def extract_entities(note_text: str, note_id: str) -> NERResponse:
async with httpx.AsyncClient(timeout=httpx.Timeout(NER_SERVICE_TIMEOUT)) as client:
try:
Expand All @@ -104,31 +114,45 @@ async def extract_entities(note_text: str, note_id: str) -> NERResponse:
logger.error("Request to clinical NER service timed out")
raise

async def process_notes(db_manager: DatabaseManager, progress: Progress, task: TaskID, recreate: bool) -> None:

async def process_notes(
db_manager: DatabaseManager, progress: Progress, task: TaskID, recreate: bool
) -> None:
notes = await db_manager.get_all_notes()
total_notes = len(notes)
progress.update(task, total=total_notes)

for i, note in enumerate(notes):
if not recreate and note["entities_exist"]:
logger.info(f"Skipping note {note['note_id']} as entities already exist")
progress.update(task, advance=1, description=f"Skipped note {i+1}/{total_notes}")
progress.update(
task, advance=1, description=f"Skipped note {i+1}/{total_notes}"
)
continue

try:
ner_response = await extract_entities(note["text"], note["note_id"])
await db_manager.update_note_with_entities(note["patient_id"], note["note_id"], ner_response.entities)
progress.update(task, advance=1, description=f"Processed note {i+1}/{total_notes}")
await db_manager.update_note_with_entities(
note["patient_id"], note["note_id"], ner_response.entities
)
progress.update(
task, advance=1, description=f"Processed note {i+1}/{total_notes}"
)
except Exception as e:
logger.error(f"Error processing note {note['note_id']}: {str(e)}")
progress.update(task, advance=1, description=f"Error on note {i+1}/{total_notes}")
progress.update(
task, advance=1, description=f"Error on note {i+1}/{total_notes}"
)


async def main(recreate: bool) -> None:
start_time = datetime.now()
console.print("[bold green]Starting NER processing and database update...[/bold green]")
console.print(
"[bold green]Starting NER processing and database update...[/bold green]"
)

db_manager = DatabaseManager(MONGO_URI, DB_NAME)

with Progress() as progress:
index_task = progress.add_task("[cyan]Ensuring database indexes...", total=1)
await db_manager.ensure_indexes()
Expand All @@ -139,11 +163,16 @@ async def main(recreate: bool) -> None:

end_time = datetime.now()
duration = end_time - start_time
console.print(f"[bold green]NER processing and database update completed in {duration}[/bold green]")
console.print(
f"[bold green]NER processing and database update completed in {duration}[/bold green]"
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process clinical notes with NER")
parser.add_argument("--recreate", action="store_true", help="Recreate entities for all notes")
parser.add_argument(
"--recreate", action="store_true", help="Recreate entities for all notes"
)
args = parser.parse_args()

asyncio.run(main(args.recreate))
asyncio.run(main(args.recreate))
2 changes: 1 addition & 1 deletion services/docker-compose.dev_gpu-services.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,4 @@ volumes:
networks:
services-network:
name: services-network
driver: bridge
driver: bridge

0 comments on commit 241b0fc

Please sign in to comment.