Skip to content

Commit

Permalink
Set PostgreSQL as the Default Database for CWYD Deployment & unittest…
Browse files Browse the repository at this point in the history
…cases
  • Loading branch information
Pavan Kumar committed Dec 3, 2024
1 parent 6198de7 commit 11d78d0
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 9 deletions.
4 changes: 3 additions & 1 deletion code/backend/batch/utilities/helpers/env_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def __load_config(self, **kwargs) -> None:

# Chat History DB Integration Settings
# Set default values based on DATABASE_TYPE
self.DATABASE_TYPE = os.getenv("DATABASE_TYPE", "").strip() or "CosmosDB"
self.DATABASE_TYPE = (
os.getenv("DATABASE_TYPE", "").strip() or DatabaseType.POSTGRESQL.value
)
# Cosmos DB configuration
if self.DATABASE_TYPE == DatabaseType.COSMOSDB.value:
azure_cosmosdb_info = self.get_info_from_env("AZURE_COSMOSDB_INFO", "")
Expand Down
49 changes: 49 additions & 0 deletions code/tests/search_utilities/test_postgres_search_handler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import pytest
from unittest.mock import MagicMock, patch
from backend.batch.utilities.common.source_document import SourceDocument
Expand Down Expand Up @@ -124,6 +125,54 @@ def test_get_files(handler):
assert result[1] == "test2.txt"


def test_output_results(handler):
results = [
{"id": "1", "title": "file1.txt"},
{"id": "2", "title": "file2.txt"},
{"id": "3", "title": "file1.txt"},
{"id": "4", "title": "file3.txt"},
{"id": "5", "title": "file2.txt"},
]

expected_output = {
"file1.txt": ["1", "3"],
"file2.txt": ["2", "5"],
"file3.txt": ["4"],
}

result = handler.output_results(results)

assert result == expected_output
assert len(result) == 3
assert "file1.txt" in result
assert result["file2.txt"] == ["2", "5"]


def test_process_results(handler):
results = [
{"metadata": json.dumps({"chunk": "Chunk1"}), "content": "Content1"},
{"metadata": json.dumps({"chunk": "Chunk2"}), "content": "Content2"},
]
expected_output = [["Chunk1", "Content1"], ["Chunk2", "Content2"]]
result = handler.process_results(results)
assert result == expected_output


def test_process_results_none(handler):
result = handler.process_results(None)
assert result == []


def test_process_results_missing_chunk(handler):
results = [
{"metadata": json.dumps({}), "content": "Content1"},
{"metadata": json.dumps({"chunk": "Chunk2"}), "content": "Content2"},
]
expected_output = [[0, "Content1"], ["Chunk2", "Content2"]]
result = handler.process_results(results)
assert result == expected_output


def test_delete_files(handler):
files_to_delete = {"test1.txt": [1, 2], "test2.txt": [3]}
mock_delete_documents = MagicMock()
Expand Down
119 changes: 118 additions & 1 deletion code/tests/test_chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This module tests the entry point for the application.
"""

from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from create_app import create_app
Expand Down Expand Up @@ -555,6 +555,54 @@ def test_update_conversation_success(
"success": True,
}

@patch("backend.api.chat_history.AsyncAzureOpenAI")
@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
)
def test_update_conversation_new_success(
self,
get_active_config_or_default_mock,
azure_openai_mock: MagicMock,
mock_conversation_client,
client,
):
get_active_config_or_default_mock.return_value.enable_chat_history = True
mock_conversation_client.get_conversation.return_value = []
mock_conversation_client.create_message.return_value = "success"
mock_conversation_client.create_conversation.return_value = {
"title": "Test Title",
"updatedAt": "2024-12-01",
"id": "conv1",
}
request_json = {
"conversation_id": "conv1",
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
],
}

openai_client_mock = azure_openai_mock.return_value

mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(content="Test Title"))]

openai_client_mock.chat.completions.create = AsyncMock(
return_value=mock_response
)

response = client.post("/api/history/update", json=request_json)

assert response.status_code == 200
assert response.json == {
"data": {
"conversation_id": "conv1",
"date": "2024-12-01",
"title": "Test Title",
},
"success": True,
}

@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
)
Expand All @@ -568,6 +616,75 @@ def test_update_conversation_no_chat_history(
assert response.status_code == 400
assert response.json == {"error": "Chat history is not available"}

@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
)
def test_update_conversation_connect_error(
self, get_active_config_or_default_mock, mock_conversation_client, client
):
get_active_config_or_default_mock.return_value.enable_chat_history = True
mock_conversation_client.get_conversation.return_value = {
"title": "Test Title",
"updatedAt": "2024-12-01",
"id": "conv1",
}
request_json = {
"conversation_id": "conv1",
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
],
}
mock_conversation_client.connect.side_effect = Exception("Unexpected error")

# Make the API call
response = client.post(
"/api/history/update",
json=request_json,
headers={"Content-Type": "application/json"},
)

# Assert response
assert response.status_code == 500
assert response.json == {
"error": "Error while updating the conversation history"
}

@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
)
def test_update_conversation_error(
self, get_active_config_or_default_mock, mock_conversation_client, client
):
get_active_config_or_default_mock.return_value.enable_chat_history = True
mock_conversation_client.create_message.side_effect = Exception(
"Unexpected error"
)
mock_conversation_client.get_conversation.return_value = {
"title": "Test Title",
"updatedAt": "2024-12-01",
"id": "conv1",
}
request_json = {
"conversation_id": "conv1",
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
],
}

response = client.post(
"/api/history/update",
json=request_json,
headers={"Content-Type": "application/json"},
)

# Assert response
assert response.status_code == 500
assert response.json == {
"error": "Error while updating the conversation history"
}

@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
)
Expand Down
89 changes: 89 additions & 0 deletions code/tests/utilities/helpers/test_database_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
from unittest.mock import patch, MagicMock
from backend.batch.utilities.helpers.config.database_type import DatabaseType
from backend.batch.utilities.chat_history.cosmosdb import CosmosConversationClient
from backend.batch.utilities.chat_history.database_factory import DatabaseFactory
from backend.batch.utilities.chat_history.postgresdbservice import (
PostgresConversationClient,
)


@patch("backend.batch.utilities.chat_history.database_factory.DefaultAzureCredential")
@patch("backend.batch.utilities.chat_history.database_factory.EnvHelper")
@patch(
"backend.batch.utilities.chat_history.database_factory.CosmosConversationClient",
autospec=True,
)
def test_get_conversation_client_cosmos(
mock_cosmos_client, mock_env_helper, mock_credential
):
# Configure the EnvHelper mock
mock_env_instance = mock_env_helper.return_value
mock_env_instance.DATABASE_TYPE = DatabaseType.COSMOSDB.value
mock_env_instance.AZURE_COSMOSDB_ACCOUNT = "cosmos_account"
mock_env_instance.AZURE_COSMOSDB_DATABASE = "cosmos_database"
mock_env_instance.AZURE_COSMOSDB_CONVERSATIONS_CONTAINER = "conversations_container"
mock_env_instance.AZURE_COSMOSDB_ENABLE_FEEDBACK = False
mock_env_instance.AZURE_COSMOSDB_ACCOUNT_KEY = None

mock_access_token = MagicMock()
mock_access_token.token = "mock-access-token"
mock_credential.return_value.get_token.return_value = mock_access_token
mock_credential_instance = mock_credential.return_value

# Mock the CosmosConversationClient instance
mock_cosmos_instance = MagicMock(spec=CosmosConversationClient)
mock_cosmos_client.return_value = mock_cosmos_instance

# Call the method
client = DatabaseFactory.get_conversation_client()

# Assert the CosmosConversationClient was called with correct arguments
mock_cosmos_client.assert_called_once_with(
cosmosdb_endpoint="https://cosmos_account.documents.azure.com:443/",
credential=mock_credential_instance,
database_name="cosmos_database",
container_name="conversations_container",
enable_message_feedback=False,
)
assert isinstance(client, CosmosConversationClient)
assert client == mock_cosmos_instance


@patch("backend.batch.utilities.chat_history.database_factory.DefaultAzureCredential")
@patch("backend.batch.utilities.chat_history.database_factory.EnvHelper")
@patch(
"backend.batch.utilities.chat_history.database_factory.PostgresConversationClient",
autospec=True,
)
def test_get_conversation_client_postgres(
mock_postgres_client, mock_env_helper, mock_credential
):
mock_env_instance = mock_env_helper.return_value
mock_env_instance.DATABASE_TYPE = DatabaseType.POSTGRESQL.value
mock_env_instance.POSTGRESQL_USER = "postgres_user"
mock_env_instance.POSTGRESQL_HOST = "postgres_host"
mock_env_instance.POSTGRESQL_DATABASE = "postgres_database"

mock_access_token = MagicMock()
mock_access_token.token = "mock-access-token"
mock_credential.return_value.get_token.return_value = mock_access_token

mock_postgres_instance = MagicMock(spec=PostgresConversationClient)
mock_postgres_client.return_value = mock_postgres_instance

client = DatabaseFactory.get_conversation_client()

mock_postgres_client.assert_called_once_with(
user="postgres_user", host="postgres_host", database="postgres_database"
)
assert isinstance(client, PostgresConversationClient)


@patch("backend.batch.utilities.chat_history.database_factory.EnvHelper")
def test_get_conversation_client_invalid_database_type(mock_env_helper):
mock_env_instance = mock_env_helper.return_value
mock_env_instance.DATABASE_TYPE = "INVALID_DB"

with pytest.raises(ValueError, match="Unsupported DATABASE_TYPE"):
DatabaseFactory.get_conversation_client()
14 changes: 9 additions & 5 deletions infra/main.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ param azureMachineLearningName string = 'aml-${resourceToken}'
'CosmosDB'
'PostgreSQL'
])
param databaseType string = 'CosmosDB'
param databaseType string = 'PostgreSQL'

@description('Azure Cosmos DB Account Name')
param azureCosmosDBAccountName string = 'cosmos-${resourceToken}'
Expand Down Expand Up @@ -1258,13 +1258,17 @@ module createIndex './core/database/deploy_create_table_script.bicep' = if (data
keyVaultName: keyvault.outputs.name
postgresSqlServerName: postgresDBModule.outputs.postgresDbOutput.postgresSQLName
webAppPrincipalName: hostingModel == 'code' ? web.outputs.FRONTEND_API_NAME : web_docker.outputs.FRONTEND_API_NAME
adminAppPrincipalName: hostingModel == 'code' ? adminweb.outputs.WEBSITE_ADMIN_NAME : adminweb_docker.outputs.WEBSITE_ADMIN_NAME
adminAppPrincipalName: hostingModel == 'code'
? adminweb.outputs.WEBSITE_ADMIN_NAME
: adminweb_docker.outputs.WEBSITE_ADMIN_NAME
managedIdentityName: managedIdentityModule.outputs.managedIdentityOutput.name
}
scope: rg
dependsOn: hostingModel == 'code' ? [keyvault, postgresDBModule, storekeys, web, adminweb] : [
[keyvault, postgresDBModule, storekeys, web_docker, adminweb_docker]
]
dependsOn: hostingModel == 'code'
? [keyvault, postgresDBModule, storekeys, web, adminweb]
: [
[keyvault, postgresDBModule, storekeys, web_docker, adminweb_docker]
]
}

output APPLICATIONINSIGHTS_CONNECTION_STRING string = monitoring.outputs.applicationInsightsConnectionString
Expand Down
4 changes: 2 additions & 2 deletions infra/main.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"_generator": {
"name": "bicep",
"version": "0.30.23.60470",
"templateHash": "15176273744623029817"
"templateHash": "10654587243799689217"
}
},
"parameters": {
Expand Down Expand Up @@ -614,7 +614,7 @@
},
"databaseType": {
"type": "string",
"defaultValue": "CosmosDB",
"defaultValue": "PostgreSQL",
"allowedValues": [
"CosmosDB",
"PostgreSQL"
Expand Down

0 comments on commit 11d78d0

Please sign in to comment.