diff --git a/api/routes.py b/api/routes.py index 64d52f3..d373fbb 100644 --- a/api/routes.py +++ b/api/routes.py @@ -5,15 +5,15 @@ import uvicorn from datastew import DataDictionarySource -from datastew.embedding import MPNetAdapter +from datastew.embedding import GPT4Adapter, MPNetAdapter from datastew.process.ols import OLSTerminologyImportTask from datastew.repository import WeaviateRepository -from datastew.repository.model import Terminology, Concept, Mapping +from datastew.repository.model import Concept, Mapping, Terminology from datastew.visualisation import get_plot_for_current_database_state -from fastapi import FastAPI, HTTPException, File, UploadFile +from fastapi import FastAPI, File, HTTPException, UploadFile from starlette.background import BackgroundTasks from starlette.middleware.cors import CORSMiddleware -from starlette.responses import RedirectResponse, HTMLResponse +from starlette.responses import HTMLResponse, RedirectResponse app = FastAPI( title="INDEX", @@ -200,11 +200,25 @@ async def get_closest_mappings_for_text(text: str, terminology_name: str = "SNOM # Endpoint to get mappings for a data dictionary source @app.post("/mappings/dict", tags=["mappings"], description="Get mappings for a data dictionary source.") -async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), variable_field: str = 'variable', - description_field: str = 'description'): +async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), + selected_model: str = "sentence-transformers/all-mpnet-base-v2", + selected_terminology: str = "SNOMED CT", + variable_field: str = "variable", + description_field: str = "description"): try: + if selected_model == "text-embedding-ada-002": + embedding_model = GPT4Adapter(selected_model) + elif selected_model == "sentence-transformers/all-mpnet-base-v2": + embedding_model = MPNetAdapter(selected_model) + else: + raise HTTPException(status_code=400, detail="Unsupported embedding model.") + # Determine file extension and create a temporary file with the correct extension - _, file_extension = os.path.splitext(file.filename) + if file.filename is not None: + file_extension = os.path.splitext(file.filename)[1].lower() + else: + raise HTTPException(status_code=400, detail="Invalid file type. The file must have a suffix.") + with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: tmp_file.write(await file.read()) tmp_file_path = tmp_file.name @@ -219,15 +233,17 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), vari variable = row['variable'] description = row['description'] embedding = embedding_model.get_embedding(description) - closest_mappings, similarities = repository.get_closest_mappings(embedding, limit=5) + closest_mappings = repository.get_terminology_and_model_specific_closest_mappings( + embedding, selected_terminology, selected_model, limit=5 + ) mappings_list = [] - for mapping, similarity in zip(closest_mappings, similarities): + for mapping, similarity in closest_mappings: concept = mapping.concept terminology = concept.terminology mappings_list.append({ "concept": { - "id": concept.concept_id, - "name": concept.name, + "id": concept.concept_identifier, + "name": concept.pref_label, "terminology": { "id": terminology.id, "name": terminology.name