diff --git a/backend/app/api/v1/chat.py b/backend/app/api/v1/chat.py index dd3c31e..f6854e0 100644 --- a/backend/app/api/v1/chat.py +++ b/backend/app/api/v1/chat.py @@ -11,6 +11,7 @@ user_repository, ) from app.requests import chat_query +from app.utils import clean_text from app.vectorstore.chroma import ChromaDB from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel @@ -91,7 +92,10 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d conversation_id = str(conversation.id) content = response["response"] + content_length = len(content) + clean_content = clean_text(content) text_references = [] + not_exact_matched_refs = [] for reference in response["references"]: sentence = reference["sentence"] @@ -102,28 +106,51 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d doc_sent, doc_ids, doc_metadata = vectorstore.get_relevant_segments( original_sentence, - k=1, + k=5, num_surrounding_sentences=0, - # TODO: uncomment this when we fix the filename in the metadata - # metadata_filter={"filename": original_filename} + metadata_filter={"filename": original_filename}, ) - for sent, id, metadata in zip(doc_sent, doc_ids, doc_metadata): - index = content.find(sentence) - if index != -1: - text_reference = { - "asset_id": metadata["asset_id"], - "project_id": metadata["project_id"], - "page_number": metadata["page_number"], - "filename": original_filename, - "source": [sent], - "start": index, - "end": index + len(sentence), - } - text_references.append(text_reference) + # Search for exact match + best_match_index = 0 + + for index, sent in enumerate(doc_sent): + if clean_text(original_sentence) in clean_text(sent): + best_match_index = index + + metadata = doc_metadata[best_match_index] + sent = doc_sent[best_match_index] + + index = clean_content.find(clean_text(sentence)) + + if index != -1: + text_reference = { + "asset_id": metadata["asset_id"], + "project_id": metadata["project_id"], + "page_number": metadata["page_number"], + "filename": original_filename, + "source": [sent], + "start": index, + "end": index + len(sentence), + } + text_references.append(text_reference) + else: + no_exact_reference = { + "asset_id": metadata["asset_id"], + "project_id": metadata["project_id"], + "page_number": metadata["page_number"], + "filename": original_filename, + "source": [sent], + "start": 0, + "end": content_length, + } + not_exact_matched_refs.append(no_exact_reference) # group text references based on start and end - refs = group_by_start_end(text_references) + if len(text_references) > 0: + refs = group_by_start_end(text_references) + else: + refs = group_by_start_end(not_exact_matched_refs) conversation_repository.create_conversation_message( db, diff --git a/backend/app/api/v1/projects.py b/backend/app/api/v1/projects.py index 830ee49..2a95586 100644 --- a/backend/app/api/v1/projects.py +++ b/backend/app/api/v1/projects.py @@ -26,7 +26,9 @@ @project_router.post("/", status_code=201) def create_project(project: ProjectCreate, db: Session = Depends(get_db)): if not project.name.strip(): - raise HTTPException(status_code=400, detail="Project name is required and cannot be empty.") + raise HTTPException( + status_code=400, detail="Project name is required and cannot be empty." + ) db_project = project_repository.create_project(db=db, project=project) return { @@ -39,7 +41,7 @@ def create_project(project: ProjectCreate, db: Session = Depends(get_db)): @project_router.get("/") def get_projects( page: int = Query(1, ge=1), - page_size: int = Query(20, ge=1, le=100), + page_size: int = Query(100, ge=1, le=100), db: Session = Depends(get_db), ): try: @@ -67,7 +69,10 @@ def get_projects( } except Exception: logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail="An internal server error occurred while processing your request. Please try again later.") + raise HTTPException( + status_code=500, + detail="An internal server error occurred while processing your request. Please try again later.", + ) @project_router.get("/{id}") @@ -75,7 +80,9 @@ def get_project(id: int, db: Session = Depends(get_db)): try: project = project_repository.get_project(db=db, project_id=id) if project is None: - raise HTTPException(status_code=404, detail="The requested project could not be found.") + raise HTTPException( + status_code=404, detail="The requested project could not be found." + ) return { "status": "success", @@ -92,7 +99,10 @@ def get_project(id: int, db: Session = Depends(get_db)): raise except Exception: logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail="An internal server error occurred while retrieving the project. Please try again later.") + raise HTTPException( + status_code=500, + detail="An internal server error occurred while retrieving the project. Please try again later.", + ) @project_router.get("/{id}/assets") @@ -127,7 +137,10 @@ def get_assets( } except Exception: logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail="An internal server error occurred while retrieving assets. Please try again later.") + raise HTTPException( + status_code=500, + detail="An internal server error occurred while retrieving assets. Please try again later.", + ) @project_router.post("/{id}/assets") @@ -137,7 +150,9 @@ async def upload_files( try: project = project_repository.get_project(db=db, project_id=id) if project is None: - raise HTTPException(status_code=404, detail="The specified project could not be found.") + raise HTTPException( + status_code=404, detail="The specified project could not be found." + ) # Ensure the upload directory exists os.makedirs(os.path.join(settings.upload_dir, str(id)), exist_ok=True) @@ -147,14 +162,15 @@ async def upload_files( # Check if the uploaded file is a PDF if file.content_type != "application/pdf": raise HTTPException( - status_code=400, detail=f"The file '{file.filename}' is not a valid PDF. Please upload only PDF files." + status_code=400, + detail=f"The file '{file.filename}' is not a valid PDF. Please upload only PDF files.", ) # Check if the file size is greater than 20MB if file.size > settings.MAX_FILE_SIZE: raise HTTPException( status_code=400, - detail=f"The file '{file.filename}' exceeds the maximum allowed size of 20MB. Please upload a smaller file." + detail=f"The file '{file.filename}' exceeds the maximum allowed size of 20MB. Please upload a smaller file.", ) # Generate a secure filename @@ -190,7 +206,10 @@ async def upload_files( raise except Exception: logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail="An error occurred while uploading files. Please try again later.") + raise HTTPException( + status_code=500, + detail="An error occurred while uploading files. Please try again later.", + ) @project_router.post("/{id}/assets/url") @@ -199,14 +218,22 @@ async def add_url_asset(id: int, data: UrlAssetCreate, db: Session = Depends(get urls = data.url project = project_repository.get_project(db=db, project_id=id) if project is None: - raise HTTPException(status_code=404, detail="The specified project could not be found.") + raise HTTPException( + status_code=404, detail="The specified project could not be found." + ) if not urls: - raise HTTPException(status_code=400, detail="No URLs provided. Please provide at least one valid URL.") + raise HTTPException( + status_code=400, + detail="No URLs provided. Please provide at least one valid URL.", + ) for url in urls: if not is_valid_url(url): - raise HTTPException(status_code=400, detail=f"Invalid URL format: {url}. Please provide a valid URL.") + raise HTTPException( + status_code=400, + detail=f"Invalid URL format: {url}. Please provide a valid URL.", + ) url_assets = [] for url in urls: @@ -244,7 +271,10 @@ async def add_url_asset(id: int, data: UrlAssetCreate, db: Session = Depends(get raise except Exception: logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail="An error occurred while processing the URL asset. Please try again later.") + raise HTTPException( + status_code=500, + detail="An error occurred while processing the URL asset. Please try again later.", + ) @project_router.get("/{id}/assets/{asset_id}") @@ -254,14 +284,18 @@ async def get_file(asset_id: int, db: Session = Depends(get_db)): if asset is None: raise HTTPException( - status_code=404, detail="The requested file could not be found in the database." + status_code=404, + detail="The requested file could not be found in the database.", ) filepath = asset.path # Check if the file exists if not os.path.isfile(filepath): - raise HTTPException(status_code=404, detail="The requested file could not be found on the server.") + raise HTTPException( + status_code=404, + detail="The requested file could not be found on the server.", + ) # Return the file return FileResponse( @@ -272,7 +306,10 @@ async def get_file(asset_id: int, db: Session = Depends(get_db)): raise except Exception as e: logger.error(f"Error retrieving file: {str(e)}\n{traceback.format_exc()}") - raise HTTPException(status_code=500, detail="An error occurred while retrieving the file. Please try again later.") + raise HTTPException( + status_code=500, + detail="An error occurred while retrieving the file. Please try again later.", + ) @project_router.get("/{id}/processes") @@ -317,7 +354,9 @@ def update_project(id: int, project: ProjectUpdate, db: Session = Depends(get_db try: db_project = project_repository.get_project(db=db, project_id=id) if db_project is None: - raise HTTPException(status_code=404, detail="The specified project could not be found.") + raise HTTPException( + status_code=404, detail="The specified project could not be found." + ) updated_project = project_repository.update_project( db=db, project_id=id, project=project ) @@ -330,7 +369,10 @@ def update_project(id: int, project: ProjectUpdate, db: Session = Depends(get_db raise except Exception: logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail="An error occurred while updating the project. Please try again later.") + raise HTTPException( + status_code=500, + detail="An error occurred while updating the project. Please try again later.", + ) @project_router.delete("/{project_id}") @@ -338,7 +380,9 @@ async def delete_project(project_id: int, db: Session = Depends(get_db)): try: project = project_repository.get_project(db, project_id) if not project: - raise HTTPException(status_code=404, detail="The specified project could not be found.") + raise HTTPException( + status_code=404, detail="The specified project could not be found." + ) project.deleted_at = datetime.now(tz=timezone.utc) @@ -359,7 +403,10 @@ async def delete_project(project_id: int, db: Session = Depends(get_db)): except Exception: logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail="An error occurred while deleting the project. Please try again later.") + raise HTTPException( + status_code=500, + detail="An error occurred while deleting the project. Please try again later.", + ) @project_router.delete("/{project_id}/assets/{asset_id}") @@ -367,20 +414,28 @@ async def delete_asset(project_id: int, asset_id: int, db: Session = Depends(get try: project = project_repository.get_project(db, project_id) if not project: - raise HTTPException(status_code=404, detail="The specified project could not be found.") + raise HTTPException( + status_code=404, detail="The specified project could not be found." + ) asset = project_repository.get_asset(db, asset_id) if asset is None: - raise HTTPException(status_code=404, detail="The specified asset could not be found in the database.") + raise HTTPException( + status_code=404, + detail="The specified asset could not be found in the database.", + ) if asset.project_id != project_id: - raise HTTPException(status_code=400, detail="The specified asset does not belong to the given project.") + raise HTTPException( + status_code=400, + detail="The specified asset does not belong to the given project.", + ) # Store asset information before deletion asset_info = { "id": asset.id, "filename": asset.filename, - "project_id": asset.project_id + "project_id": asset.project_id, } try: @@ -393,8 +448,11 @@ async def delete_asset(project_id: int, asset_id: int, db: Session = Depends(get # Soft delete the asset asset.deleted_at = datetime.now(tz=timezone.utc) db.commit() - - logger.log("info", f"Asset {asset_info['id']} successfully marked as deleted in the database") + + logger.log( + "info", + f"Asset {asset_info['id']} successfully marked as deleted in the database", + ) return {"message": "Asset deleted successfully"} except HTTPException as http_error: @@ -404,4 +462,7 @@ async def delete_asset(project_id: int, asset_id: int, db: Session = Depends(get db.rollback() error_msg = f"An error occurred while deleting the asset: {str(e)}" logger.error(f"{error_msg}\n{traceback.format_exc()}") - raise HTTPException(status_code=500, detail="An error occurred while deleting the asset. Please try again later.") \ No newline at end of file + raise HTTPException( + status_code=500, + detail="An error occurred while deleting the asset. Please try again later.", + ) diff --git a/backend/app/processing/file_preprocessing.py b/backend/app/processing/file_preprocessing.py index 0d66ba6..514877b 100644 --- a/backend/app/processing/file_preprocessing.py +++ b/backend/app/processing/file_preprocessing.py @@ -22,10 +22,10 @@ def process_file(asset_id: int): file_preprocessor.submit(preprocess_file, asset_id) -def process_segmentation(project_id: int, asset_content_id: int, asset_file_name: str): +def process_segmentation(project_id: int, asset_id: int, asset_file_name: str): try: with SessionLocal() as db: - asset_content = project_repository.get_asset_content(db, asset_content_id) + asset_content = project_repository.get_asset_content(db, asset_id) # segmentation = extract_file_segmentation( # api_token=api_key, pdf_content=asset_content.content @@ -36,7 +36,7 @@ def process_segmentation(project_id: int, asset_content_id: int, asset_file_name docs=asset_content.content["content"], metadatas=[ { - "asset_id": asset_content.asset_id, + "asset_id": asset_id, "filename": asset_file_name, "project_id": project_id, "page_number": asset_content.content["page_number_data"][index], @@ -47,16 +47,16 @@ def process_segmentation(project_id: int, asset_content_id: int, asset_file_name project_repository.update_asset_content_status( db, - asset_content_id=asset_content_id, + asset_id=asset_id, status=AssetProcessingStatus.COMPLETED, ) except Exception as e: - logger.error(f"Error during segmentation for asset {asset_content_id}: {e}") + logger.error(f"Error during segmentation for asset {asset_id}: {e}") with SessionLocal() as db: project_repository.update_asset_content_status( db, - asset_content_id=asset_content_id, + asset_id=asset_id, status=AssetProcessingStatus.FAILED, ) @@ -117,7 +117,7 @@ def preprocess_file(asset_id: int): file_segmentation_executor.submit( process_segmentation, asset.project_id, - asset_content.id, + asset_content.asset_id, asset.filename, ) diff --git a/backend/app/processing/process_queue.py b/backend/app/processing/process_queue.py index 4bd3cca..36d3671 100644 --- a/backend/app/processing/process_queue.py +++ b/backend/app/processing/process_queue.py @@ -21,6 +21,7 @@ from app.logger import Logger import traceback +from app.utils import clean_text from app.vectorstore.chroma import ChromaDB @@ -68,7 +69,9 @@ def process_step_task( while retries < settings.max_retries and not success: try: if process.type == "extractive_summary": - data = extractive_summary_process(api_key, process, process_step, asset_content) + data = extractive_summary_process( + api_key, process, process_step, asset_content + ) if data["summary"]: summaries.append(data["summary"]) @@ -81,7 +84,9 @@ def process_step_task( elif process.type == "extract": # Handle non-extractive summary process - data = extract_process(api_key, process, process_step, asset_content) + data = extract_process( + api_key, process, process_step, asset_content + ) # Update process step output outside the expensive operations with SessionLocal() as db: @@ -197,6 +202,7 @@ def process_task(process_id: int): process.message = str(e) db.commit() + def handle_exceptions(func): @wraps(func) def wrapper(*args, **kwargs): @@ -209,8 +215,10 @@ def wrapper(*args, **kwargs): logger.error(f"Error in {func.__name__}: {str(e)}") logger.error(traceback.format_exc()) raise + return wrapper + @handle_exceptions def extractive_summary_process(api_key, process, process_step, asset_content): try: @@ -265,11 +273,12 @@ def extractive_summary_process(api_key, process, process_step, asset_content): @handle_exceptions def extract_process(api_key, process, process_step, asset_content): pdf_content = "" - vectorstore = ChromaDB( - f"panda-etl-{process.project_id}", similary_threshold=3 - ) + vectorstore = ChromaDB(f"panda-etl-{process.project_id}", similary_threshold=3) if ( - ("multiple_fields" not in process.details or not process.details["multiple_fields"]) + ( + "multiple_fields" not in process.details + or not process.details["multiple_fields"] + ) and asset_content.content and asset_content.content.get("word_count", 0) > 500 ): @@ -284,6 +293,7 @@ def extract_process(api_key, process, process_step, asset_content): }, k=5, ) + for index, metadata in enumerate(relevant_docs["metadatas"][0]): segment_data = [relevant_docs["documents"][0][index]] if metadata["previous_sentence_id"] != -1: @@ -317,7 +327,8 @@ def extract_process(api_key, process, process_step, asset_content): for context in data["context"]: for sources in context: page_numbers = [] - for source in sources["sources"]: + for source_index, source in enumerate(sources["sources"]): + relevant_docs = vectorstore.get_relevant_docs( source, where={ @@ -326,12 +337,26 @@ def extract_process(api_key, process, process_step, asset_content): {"project_id": process.project_id}, ] }, - k=1, + k=5, ) + most_relevant_index = 0 + match = False + clean_source = clean_text(source) + # search for exact match Index + for index, relevant_doc in enumerate(relevant_docs["documents"][0]): + if clean_source in clean_text(relevant_doc): + most_relevant_index = index + match = True + + if not match and len(relevant_docs["documents"][0]) > 0: + sources["sources"][source_index] = relevant_docs["documents"][0][0] + if len(relevant_docs["metadatas"][0]) > 0: page_numbers.append( - relevant_docs["metadatas"][0][0]["page_number"] + relevant_docs["metadatas"][0][most_relevant_index][ + "page_number" + ] ) sources["page_numbers"] = page_numbers @@ -341,10 +366,13 @@ def extract_process(api_key, process, process_step, asset_content): "context": data["context"], } -def update_process_step_status(db, process_step, status, output=None, output_references=None): + +def update_process_step_status( + db, process_step, status, output=None, output_references=None +): """ Update the status of a process step. - + Args: db: Database session process_step: The process step to update @@ -353,9 +381,5 @@ def update_process_step_status(db, process_step, status, output=None, output_ref output_references: Optional output references """ process_repository.update_process_step_status( - db, - process_step, - status, - output=output, - output_references=output_references + db, process_step, status, output=output, output_references=output_references ) diff --git a/backend/app/utils.py b/backend/app/utils.py index c67ac63..cc78983 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -3,6 +3,7 @@ import uuid import requests import re +import string def generate_unique_filename(url, extension=".html"): @@ -28,6 +29,16 @@ def is_valid_url(url): return re.match(regex, url) is not None +def clean_text(text): + # Remove newline characters + text = text.replace("\n", " ") + + # Remove punctuation + text = text.translate(str.maketrans("", "", string.punctuation)) + + return text + + def fetch_html_and_save(url, file_path): parsed_url = urlparse(url) if not parsed_url.scheme: