diff --git a/code/backend/batch/utilities/helpers/env_helper.py b/code/backend/batch/utilities/helpers/env_helper.py index e1d40a162..db2f74ebb 100644 --- a/code/backend/batch/utilities/helpers/env_helper.py +++ b/code/backend/batch/utilities/helpers/env_helper.py @@ -91,7 +91,9 @@ def __load_config(self, **kwargs) -> None: # Chat History DB Integration Settings # Set default values based on DATABASE_TYPE - self.DATABASE_TYPE = os.getenv("DATABASE_TYPE", "").strip() or "CosmosDB" + self.DATABASE_TYPE = ( + os.getenv("DATABASE_TYPE", "").strip() or DatabaseType.POSTGRESQL.value + ) # Cosmos DB configuration if self.DATABASE_TYPE == DatabaseType.COSMOSDB.value: azure_cosmosdb_info = self.get_info_from_env("AZURE_COSMOSDB_INFO", "") diff --git a/code/tests/search_utilities/test_postgres_search_handler.py b/code/tests/search_utilities/test_postgres_search_handler.py index eead10dd3..1c8117791 100644 --- a/code/tests/search_utilities/test_postgres_search_handler.py +++ b/code/tests/search_utilities/test_postgres_search_handler.py @@ -1,3 +1,4 @@ +import json import pytest from unittest.mock import MagicMock, patch from backend.batch.utilities.common.source_document import SourceDocument @@ -124,6 +125,54 @@ def test_get_files(handler): assert result[1] == "test2.txt" +def test_output_results(handler): + results = [ + {"id": "1", "title": "file1.txt"}, + {"id": "2", "title": "file2.txt"}, + {"id": "3", "title": "file1.txt"}, + {"id": "4", "title": "file3.txt"}, + {"id": "5", "title": "file2.txt"}, + ] + + expected_output = { + "file1.txt": ["1", "3"], + "file2.txt": ["2", "5"], + "file3.txt": ["4"], + } + + result = handler.output_results(results) + + assert result == expected_output + assert len(result) == 3 + assert "file1.txt" in result + assert result["file2.txt"] == ["2", "5"] + + +def test_process_results(handler): + results = [ + {"metadata": json.dumps({"chunk": "Chunk1"}), "content": "Content1"}, + {"metadata": json.dumps({"chunk": "Chunk2"}), "content": "Content2"}, + ] + expected_output = [["Chunk1", "Content1"], ["Chunk2", "Content2"]] + result = handler.process_results(results) + assert result == expected_output + + +def test_process_results_none(handler): + result = handler.process_results(None) + assert result == [] + + +def test_process_results_missing_chunk(handler): + results = [ + {"metadata": json.dumps({}), "content": "Content1"}, + {"metadata": json.dumps({"chunk": "Chunk2"}), "content": "Content2"}, + ] + expected_output = [[0, "Content1"], ["Chunk2", "Content2"]] + result = handler.process_results(results) + assert result == expected_output + + def test_delete_files(handler): files_to_delete = {"test1.txt": [1, 2], "test2.txt": [3]} mock_delete_documents = MagicMock() diff --git a/code/tests/test_chat_history.py b/code/tests/test_chat_history.py index f1b8bdcb1..6ef805d50 100644 --- a/code/tests/test_chat_history.py +++ b/code/tests/test_chat_history.py @@ -2,7 +2,7 @@ This module tests the entry point for the application. """ -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from create_app import create_app @@ -555,6 +555,54 @@ def test_update_conversation_success( "success": True, } + @patch("backend.api.chat_history.AsyncAzureOpenAI") + @patch( + "backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default" + ) + def test_update_conversation_new_success( + self, + get_active_config_or_default_mock, + azure_openai_mock: MagicMock, + mock_conversation_client, + client, + ): + get_active_config_or_default_mock.return_value.enable_chat_history = True + mock_conversation_client.get_conversation.return_value = [] + mock_conversation_client.create_message.return_value = "success" + mock_conversation_client.create_conversation.return_value = { + "title": "Test Title", + "updatedAt": "2024-12-01", + "id": "conv1", + } + request_json = { + "conversation_id": "conv1", + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ], + } + + openai_client_mock = azure_openai_mock.return_value + + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock(content="Test Title"))] + + openai_client_mock.chat.completions.create = AsyncMock( + return_value=mock_response + ) + + response = client.post("/api/history/update", json=request_json) + + assert response.status_code == 200 + assert response.json == { + "data": { + "conversation_id": "conv1", + "date": "2024-12-01", + "title": "Test Title", + }, + "success": True, + } + @patch( "backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default" ) @@ -568,6 +616,75 @@ def test_update_conversation_no_chat_history( assert response.status_code == 400 assert response.json == {"error": "Chat history is not available"} + @patch( + "backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default" + ) + def test_update_conversation_connect_error( + self, get_active_config_or_default_mock, mock_conversation_client, client + ): + get_active_config_or_default_mock.return_value.enable_chat_history = True + mock_conversation_client.get_conversation.return_value = { + "title": "Test Title", + "updatedAt": "2024-12-01", + "id": "conv1", + } + request_json = { + "conversation_id": "conv1", + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ], + } + mock_conversation_client.connect.side_effect = Exception("Unexpected error") + + # Make the API call + response = client.post( + "/api/history/update", + json=request_json, + headers={"Content-Type": "application/json"}, + ) + + # Assert response + assert response.status_code == 500 + assert response.json == { + "error": "Error while updating the conversation history" + } + + @patch( + "backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default" + ) + def test_update_conversation_error( + self, get_active_config_or_default_mock, mock_conversation_client, client + ): + get_active_config_or_default_mock.return_value.enable_chat_history = True + mock_conversation_client.create_message.side_effect = Exception( + "Unexpected error" + ) + mock_conversation_client.get_conversation.return_value = { + "title": "Test Title", + "updatedAt": "2024-12-01", + "id": "conv1", + } + request_json = { + "conversation_id": "conv1", + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ], + } + + response = client.post( + "/api/history/update", + json=request_json, + headers={"Content-Type": "application/json"}, + ) + + # Assert response + assert response.status_code == 500 + assert response.json == { + "error": "Error while updating the conversation history" + } + @patch( "backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default" ) diff --git a/code/tests/utilities/helpers/test_database_factory.py b/code/tests/utilities/helpers/test_database_factory.py new file mode 100644 index 000000000..0a1734171 --- /dev/null +++ b/code/tests/utilities/helpers/test_database_factory.py @@ -0,0 +1,89 @@ +import pytest +from unittest.mock import patch, MagicMock +from backend.batch.utilities.helpers.config.database_type import DatabaseType +from backend.batch.utilities.chat_history.cosmosdb import CosmosConversationClient +from backend.batch.utilities.chat_history.database_factory import DatabaseFactory +from backend.batch.utilities.chat_history.postgresdbservice import ( + PostgresConversationClient, +) + + +@patch("backend.batch.utilities.chat_history.database_factory.DefaultAzureCredential") +@patch("backend.batch.utilities.chat_history.database_factory.EnvHelper") +@patch( + "backend.batch.utilities.chat_history.database_factory.CosmosConversationClient", + autospec=True, +) +def test_get_conversation_client_cosmos( + mock_cosmos_client, mock_env_helper, mock_credential +): + # Configure the EnvHelper mock + mock_env_instance = mock_env_helper.return_value + mock_env_instance.DATABASE_TYPE = DatabaseType.COSMOSDB.value + mock_env_instance.AZURE_COSMOSDB_ACCOUNT = "cosmos_account" + mock_env_instance.AZURE_COSMOSDB_DATABASE = "cosmos_database" + mock_env_instance.AZURE_COSMOSDB_CONVERSATIONS_CONTAINER = "conversations_container" + mock_env_instance.AZURE_COSMOSDB_ENABLE_FEEDBACK = False + mock_env_instance.AZURE_COSMOSDB_ACCOUNT_KEY = None + + mock_access_token = MagicMock() + mock_access_token.token = "mock-access-token" + mock_credential.return_value.get_token.return_value = mock_access_token + mock_credential_instance = mock_credential.return_value + + # Mock the CosmosConversationClient instance + mock_cosmos_instance = MagicMock(spec=CosmosConversationClient) + mock_cosmos_client.return_value = mock_cosmos_instance + + # Call the method + client = DatabaseFactory.get_conversation_client() + + # Assert the CosmosConversationClient was called with correct arguments + mock_cosmos_client.assert_called_once_with( + cosmosdb_endpoint="https://cosmos_account.documents.azure.com:443/", + credential=mock_credential_instance, + database_name="cosmos_database", + container_name="conversations_container", + enable_message_feedback=False, + ) + assert isinstance(client, CosmosConversationClient) + assert client == mock_cosmos_instance + + +@patch("backend.batch.utilities.chat_history.database_factory.DefaultAzureCredential") +@patch("backend.batch.utilities.chat_history.database_factory.EnvHelper") +@patch( + "backend.batch.utilities.chat_history.database_factory.PostgresConversationClient", + autospec=True, +) +def test_get_conversation_client_postgres( + mock_postgres_client, mock_env_helper, mock_credential +): + mock_env_instance = mock_env_helper.return_value + mock_env_instance.DATABASE_TYPE = DatabaseType.POSTGRESQL.value + mock_env_instance.POSTGRESQL_USER = "postgres_user" + mock_env_instance.POSTGRESQL_HOST = "postgres_host" + mock_env_instance.POSTGRESQL_DATABASE = "postgres_database" + + mock_access_token = MagicMock() + mock_access_token.token = "mock-access-token" + mock_credential.return_value.get_token.return_value = mock_access_token + + mock_postgres_instance = MagicMock(spec=PostgresConversationClient) + mock_postgres_client.return_value = mock_postgres_instance + + client = DatabaseFactory.get_conversation_client() + + mock_postgres_client.assert_called_once_with( + user="postgres_user", host="postgres_host", database="postgres_database" + ) + assert isinstance(client, PostgresConversationClient) + + +@patch("backend.batch.utilities.chat_history.database_factory.EnvHelper") +def test_get_conversation_client_invalid_database_type(mock_env_helper): + mock_env_instance = mock_env_helper.return_value + mock_env_instance.DATABASE_TYPE = "INVALID_DB" + + with pytest.raises(ValueError, match="Unsupported DATABASE_TYPE"): + DatabaseFactory.get_conversation_client() diff --git a/infra/main.bicep b/infra/main.bicep index 2d061e9d9..3a4f426a7 100644 --- a/infra/main.bicep +++ b/infra/main.bicep @@ -306,7 +306,7 @@ param azureMachineLearningName string = 'aml-${resourceToken}' 'CosmosDB' 'PostgreSQL' ]) -param databaseType string = 'CosmosDB' +param databaseType string = 'PostgreSQL' @description('Azure Cosmos DB Account Name') param azureCosmosDBAccountName string = 'cosmos-${resourceToken}' @@ -1258,13 +1258,17 @@ module createIndex './core/database/deploy_create_table_script.bicep' = if (data keyVaultName: keyvault.outputs.name postgresSqlServerName: postgresDBModule.outputs.postgresDbOutput.postgresSQLName webAppPrincipalName: hostingModel == 'code' ? web.outputs.FRONTEND_API_NAME : web_docker.outputs.FRONTEND_API_NAME - adminAppPrincipalName: hostingModel == 'code' ? adminweb.outputs.WEBSITE_ADMIN_NAME : adminweb_docker.outputs.WEBSITE_ADMIN_NAME + adminAppPrincipalName: hostingModel == 'code' + ? adminweb.outputs.WEBSITE_ADMIN_NAME + : adminweb_docker.outputs.WEBSITE_ADMIN_NAME managedIdentityName: managedIdentityModule.outputs.managedIdentityOutput.name } scope: rg - dependsOn: hostingModel == 'code' ? [keyvault, postgresDBModule, storekeys, web, adminweb] : [ - [keyvault, postgresDBModule, storekeys, web_docker, adminweb_docker] - ] + dependsOn: hostingModel == 'code' + ? [keyvault, postgresDBModule, storekeys, web, adminweb] + : [ + [keyvault, postgresDBModule, storekeys, web_docker, adminweb_docker] + ] } output APPLICATIONINSIGHTS_CONNECTION_STRING string = monitoring.outputs.applicationInsightsConnectionString diff --git a/infra/main.json b/infra/main.json index 328f4b9a1..c7bd6a8e7 100644 --- a/infra/main.json +++ b/infra/main.json @@ -5,7 +5,7 @@ "_generator": { "name": "bicep", "version": "0.30.23.60470", - "templateHash": "15176273744623029817" + "templateHash": "10654587243799689217" } }, "parameters": { @@ -614,7 +614,7 @@ }, "databaseType": { "type": "string", - "defaultValue": "CosmosDB", + "defaultValue": "PostgreSQL", "allowedValues": [ "CosmosDB", "PostgreSQL"