From d63513a0a220d0a79b984443ab76348a79197757 Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Tue, 29 Oct 2024 11:17:29 +0100 Subject: [PATCH] feat[chat]: vectorize extraction result for improved chat content (#45) * feat[chat]: vectorize extraction result for improved chat content * feat[Chat]: use references from extraction results as well --- backend/app/api/v1/chat.py | 18 +++ backend/app/config.py | 4 + backend/app/processing/process_queue.py | 48 +++++++- .../tests/processing/test_process_queue.py | 107 +++++++++++++++++- 4 files changed, 175 insertions(+), 2 deletions(-) diff --git a/backend/app/api/v1/chat.py b/backend/app/api/v1/chat.py index 357ee2a..ced0456 100644 --- a/backend/app/api/v1/chat.py +++ b/backend/app/api/v1/chat.py @@ -65,6 +65,24 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d ordered_file_names = [doc_id_to_filename[doc_id] for doc_id in doc_ids] + extract_vectorstore = ChromaDB(f"panda-etl-extraction-{project_id}", + similarity_threshold=settings.chat_extraction_doc_threshold) + + # Extract reference documents from the extraction results from db + extraction_docs = extract_vectorstore.get_relevant_docs( + chat_request.query, + k=settings.chat_extraction_max_docs + ) + + # Append text from single documents together + for extraction_doc in extraction_docs["metadatas"][0]: + index = next((i for i, item in enumerate(ordered_file_names) if item == extraction_doc["filename"]), None) + if index is None: + ordered_file_names.append(extraction_doc["filename"]) + docs.append(extraction_doc["reference"]) + else: + docs[index] = f'{extraction_doc["reference"]}\n\n{docs[index]}' + docs_formatted = [ {"filename": filename, "quote": quote} for filename, quote in zip(ordered_file_names, docs) diff --git a/backend/app/config.py b/backend/app/config.py index ed49d66..df11117 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -26,6 +26,10 @@ class Settings(BaseSettings): openai_api_key: str = "" openai_embedding_model: str = "text-embedding-ada-002" + # Extraction References for chat + chat_extraction_doc_threshold: float = 0.5 + chat_extraction_max_docs: int = 50 + class Config: env_file = ".env" diff --git a/backend/app/processing/process_queue.py b/backend/app/processing/process_queue.py index 4d96471..98b4be5 100644 --- a/backend/app/processing/process_queue.py +++ b/backend/app/processing/process_queue.py @@ -46,8 +46,9 @@ def process_step_task( # Initial DB operations (open and fetch relevant data) with SessionLocal() as db: process = process_repository.get_process(db, process_id) + project_id = process.project_id process_step = process_repository.get_process_step(db, process_step_id) - + filename = process_step.asset.filename if process.status == ProcessStatus.STOPPED: return False # Stop processing if the process is stopped @@ -84,6 +85,15 @@ def process_step_task( output_references=data["context"], ) + # vectorize extraction result + try: + vectorize_extraction_process_step(project_id=project_id, + process_step_id=process_step_id, + filename=filename, + references=data["context"]) + except Exception : + logger.error(f"Failed to vectorize extraction results for chat {traceback.format_exc()}") + success = True except CreditLimitExceededException: @@ -361,3 +371,39 @@ def update_process_step_status( process_repository.update_process_step_status( db, process_step, status, output=output, output_references=output_references ) + +def vectorize_extraction_process_step(project_id: int, process_step_id: int, filename: str, references: dict) -> None: + # Vectorize extraction result and dump in database + field_references = {} + + # Loop to concatenate sources for each reference + for extraction_references in references: + for extraction_reference in extraction_references: + sources = extraction_reference.get("sources", []) + if sources: + sources_catenated = "\n".join(sources) + field_references.setdefault(extraction_reference["name"], "") + field_references[extraction_reference["name"]] += ( + "\n" + sources_catenated if field_references[extraction_reference["name"]] else sources_catenated + ) + + # Only proceed if there are references to add + if not field_references: + return + + # Initialize Vectorstore + vectorstore = ChromaDB(f"panda-etl-extraction-{project_id}") + + docs = [f"{filename} {key}" for key in field_references] + metadatas = [ + { + "project_id": project_id, + "process_step_id": process_step_id, + "filename": filename, + "reference": reference + } + for reference in field_references.values() + ] + + # Add documents to vectorstore + vectorstore.add_docs(docs=docs, metadatas=metadatas) diff --git a/backend/tests/processing/test_process_queue.py b/backend/tests/processing/test_process_queue.py index 07f269b..4aa6e5e 100644 --- a/backend/tests/processing/test_process_queue.py +++ b/backend/tests/processing/test_process_queue.py @@ -1,11 +1,12 @@ from app.requests.schemas import ExtractFieldsResponse import pytest -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch from app.processing.process_queue import ( handle_exceptions, extract_process, update_process_step_status, find_best_match_for_short_reference, + vectorize_extraction_process_step, ) from app.exceptions import CreditLimitExceededException from app.models import ProcessStepStatus @@ -180,3 +181,107 @@ def test_chroma_db_initialization(mock_extract_data, mock_chroma): mock_chroma.assert_called_with(f"panda-etl-{process.project_id}", similarity_threshold=3) assert mock_chroma.call_count >= 1 + +@patch('app.processing.process_queue.ChromaDB') +def test_vectorize_extraction_process_step_single_reference(mock_chroma_db): + # Mock ChromaDB instance + mock_vectorstore = MagicMock() + mock_chroma_db.return_value = mock_vectorstore + + # Inputs + project_id = 123 + process_step_id = 1 + filename = "sample_file" + references = [ + [ + {"name": "field1", "sources": ["source1", "source2"]} + ] + ] + + # Call function + vectorize_extraction_process_step(project_id, process_step_id, filename, references) + + # Expected docs and metadata to add to ChromaDB + expected_docs = ["sample_file field1"] + expected_metadatas = [ + { + "project_id": project_id, + "process_step_id": process_step_id, + "filename": filename, + "reference": "source1\nsource2" + } + ] + + # Assertions + mock_vectorstore.add_docs.assert_called_once_with( + docs=expected_docs, + metadatas=expected_metadatas + ) + +@patch('app.processing.process_queue.ChromaDB') +def test_vectorize_extraction_process_step_multiple_references_concatenation(mock_chroma_db): + # Mock ChromaDB instance + mock_vectorstore = MagicMock() + mock_chroma_db.return_value = mock_vectorstore + + # Inputs + project_id = 456 + process_step_id = 2 + filename = "test_file" + references = [ + [ + {"name": "field1", "sources": ["source1", "source2"]}, + {"name": "field1", "sources": ["source3"]} + ], + [ + {"name": "field2", "sources": ["source4"]} + ] + ] + + # Call function + vectorize_extraction_process_step(project_id, process_step_id, filename, references) + + # Expected docs and metadata to add to ChromaDB + expected_docs = ["test_file field1", "test_file field2"] + expected_metadatas = [ + { + "project_id": project_id, + "process_step_id": process_step_id, + "filename": filename, + "reference": "source1\nsource2\nsource3" + }, + { + "project_id": project_id, + "process_step_id": process_step_id, + "filename": filename, + "reference": "source4" + } + ] + + # Assertions + mock_vectorstore.add_docs.assert_called_once_with( + docs=expected_docs, + metadatas=expected_metadatas + ) + +@patch('app.processing.process_queue.ChromaDB') # Replace with the correct module path +def test_vectorize_extraction_process_step_empty_sources(mock_chroma_db): + # Mock ChromaDB instance + mock_vectorstore = MagicMock() + mock_chroma_db.return_value = mock_vectorstore + + # Inputs + project_id = 789 + process_step_id = 3 + filename = "empty_sources_file" + references = [ + [ + {"name": "field1", "sources": []} + ] + ] + + # Call function + vectorize_extraction_process_step(project_id, process_step_id, filename, references) + + # Expected no calls to add_docs due to empty sources + mock_vectorstore.add_docs.assert_not_called()