Skip to content

Commit

Permalink
refactor(tests): improve unit test setup for prompt and document extr…
Browse files Browse the repository at this point in the history
…action

- Utilize `patch` decorator for better isolation in prompt transform tests.
- Simplify `mock_download` call by removing unnecessary `tenant_id` parameter in document extractor tests.
  • Loading branch information
laipz8200 committed Oct 20, 2024
1 parent 8cbc26e commit 4578695
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
24 changes: 13 additions & 11 deletions api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest

Expand Down Expand Up @@ -140,16 +140,18 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg

prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,
query=None,
files=files,
context=context,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url))
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,
query=None,
files=files,
context=context,
memory_config=None,
memory=None,
model_config=model_config_mock,
)

assert len(prompt_messages) == 4
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def test_run_extract_text(
mock_file.transfer_method = transfer_method
mock_file.related_id = "test_file_id" if transfer_method == FileTransferMethod.LOCAL_FILE else None
mock_file.remote_url = "https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None
mock_file.tenant_id = "test_tenant_id"

mock_array_file_segment = Mock(spec=ArrayFileSegment)
mock_array_file_segment.value = [mock_file]
Expand Down Expand Up @@ -128,7 +127,7 @@ def test_run_extract_text(
if transfer_method == FileTransferMethod.REMOTE_URL:
mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt")
elif transfer_method == FileTransferMethod.LOCAL_FILE:
mock_download.assert_called_once_with(upload_file_id="test_file_id", tenant_id="test_tenant_id")
mock_download.assert_called_once_with(mock_file)


def test_extract_text_from_plain_text():
Expand Down

0 comments on commit 4578695

Please sign in to comment.