diff --git a/main.py b/main.py index 29394f5..e547f1f 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,7 @@ from langchain.schema import Document from langchain_core.runnables.config import run_in_executor -from models import DocumentModel, DocumentResponse +from models import DocumentModel, DocumentResponse, StoreDocument from store import AsyncPgVector load_dotenv(find_dotenv()) @@ -200,8 +200,6 @@ async def store_data_in_vector_db(data, file_id, overwrite: bool = False) -> boo return {"message": "Documents exist. Overwrite not implemented.", "error": str(e)} return {"message": "An error occurred while adding documents.", "error": str(e)} - - def get_loader(filename: str, file_content_type: str, filepath: str): file_ext = filename.split(".")[-1].lower() @@ -240,32 +238,26 @@ def get_loader(filename: str, file_content_type: str, filepath: str): return loader, known_type - @app.post("/doc") -async def store_doc( - filepath: str, - filename: str, - file_content_type: str, - file_id: str, -): +async def store_doc(document: StoreDocument): # Check if the file exists - if not os.path.exists(filepath): + if not os.path.exists(document.filepath): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.FILE_NOT_FOUND(), ) try: - loader, known_type = get_loader(filename, file_content_type, filepath) + loader, known_type = get_loader(document.filename, document.file_content_type, document.filepath) data = loader.load() - result = await store_data_in_vector_db(data, file_id) + result = await store_data_in_vector_db(data, document.file_id) if result: return { "status": True, - "file_id": file_id, - "filename": filename, + "file_id": document.file_id, + "filename": document.filename, "known_type": known_type, } else: @@ -285,4 +277,3 @@ async def store_doc( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) - diff --git a/models.py b/models.py index 464caeb..30441d8 100644 --- a/models.py +++ b/models.py @@ -1,9 +1,8 @@ import hashlib +from enum import Enum from typing import Optional - from pydantic import BaseModel - class DocumentResponse(BaseModel): page_content: str metadata: dict @@ -16,3 +15,13 @@ class DocumentModel(BaseModel): def generate_digest(self): hash_obj = hashlib.md5(self.page_content.encode()) return hash_obj.hexdigest() + +class StoreDocument(BaseModel): + filepath: str + filename: str + file_content_type: str + file_id: str + +class CleanupMethod(str, Enum): + incremental = "incremental" + full = "full" \ No newline at end of file