Skip to content

Commit

Permalink
Improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong1120 committed Dec 4, 2024
1 parent ec6ca22 commit 5f1b0d1
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 95 deletions.
272 changes: 191 additions & 81 deletions tests/examples/streaming_web/backend/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import base64
import json
import unittest
from unittest.mock import AsyncMock
from unittest.mock import MagicMock

import redis
from fastapi import HTTPException
from fastapi import WebSocket

from examples.streaming_web.backend.utils import RedisManager
Expand All @@ -24,6 +26,9 @@ def setUp(self) -> None:
self.redis_mock = MagicMock(spec=redis.Redis)
self.redis_mock.xrevrange = AsyncMock()
self.redis_mock.scan = AsyncMock()
self.redis_mock.get = AsyncMock()
self.redis_mock.set = AsyncMock()
self.redis_mock.delete = AsyncMock()

async def test_fetch_latest_frame_for_key_with_data(self) -> None:
"""
Expand Down Expand Up @@ -113,27 +118,6 @@ async def test_fetch_latest_frames(self) -> None:
]
self.assertEqual(result, expected_result)


class TestUtils(unittest.IsolatedAsyncioTestCase):
"""
Test suite for utility functions in the streaming_web module.
"""

def setUp(self) -> None:
"""
Set up the test environment before each test.
"""
self.redis_mock = MagicMock(spec=redis.Redis)
self.redis_mock.scan = AsyncMock()
self.redis_mock.mget = AsyncMock()
self.redis_mock.get = AsyncMock()

def tearDown(self) -> None:
"""
Clean up after each test.
"""
self.redis_mock.reset_mock()

async def test_get_labels(self) -> None:
"""
Test the get_labels function to ensure it returns expected labels.
Expand All @@ -160,96 +144,139 @@ async def test_get_labels(self) -> None:
expected_result = ['label1', 'label2', 'label3']
self.assertEqual(result, expected_result)

async def test_get_image_data(self) -> None:
async def test_get_keys_for_label(self) -> None:
"""
Test the get_image_data function
to ensure it returns correct image data.
Test the get_keys_for_label function to ensure it returns correct keys.
"""
# Mock the Redis scan method to return keys matching the label
label = 'label1'
encoded_label = Utils.encode(label)

# Mock the Redis scan method to return keys matching the label
self.redis_mock.scan.return_value = (
0, [
b'stream_frame:label1_image1',
b'stream_frame:label1_image2',
f'stream_frame:{encoded_label}|image1'.encode(),
f'stream_frame:{encoded_label}|image2'.encode(),
],
)

# Mock the Redis mget method to return image data
self.redis_mock.mget.return_value = [
b'image_data_1',
b'image_data_2',
]

# Call the function
redis_manager = RedisManager('localhost', 6379, 'password')
redis_manager.client = self.redis_mock
result = await redis_manager.get_keys_for_label(label)

# Check the expected result
expected_result = [
'stream_frame:label1_image1',
'stream_frame:label1_image2',
f'stream_frame:{encoded_label}|image1',
f'stream_frame:{encoded_label}|image2',
]
self.assertEqual(result, expected_result)

async def test_get_image_data_no_image(
self,
) -> None:
async def test_update_partial_config(self) -> None:
"""
Test get_image_data function when some images are missing.
Test the update_partial_config function to ensure it updates correctly.
"""
# Mock the Redis scan method to return keys matching the label
label = 'label1'
self.redis_mock.scan.return_value = (
0, [
b'stream_frame:label1_image1',
b'stream_frame:label1_image2',
],
key = 'new_key'
value = 'new_value'
cached_config = {'existing_key': 'existing_value'}

# Mock the Redis get and set methods
self.redis_mock.get.return_value = json.dumps(
cached_config,
).encode('utf-8')

redis_manager = RedisManager('localhost', 6379, 'password')
redis_manager.client = self.redis_mock
await redis_manager.update_partial_config(key, value)

# Check if the set method was called with the updated config
cached_config[key] = value
self.redis_mock.set.assert_called_once_with(
'config_cache', json.dumps(cached_config), ex=3600,
)

# Mock the Redis mget method to return None for an image
self.redis_mock.mget.return_value = [
None, # Simulate missing image
b'image_data_2',
]
async def test_get_partial_config(self) -> None:
"""
Test the get_partial_config function to ensure it retrieves correctly.
"""
key = 'existing_key'
cached_config = {'existing_key': 'existing_value'}

# Mock the Redis get method
self.redis_mock.get.return_value = json.dumps(
cached_config,
).encode('utf-8')

# Call the function
redis_manager = RedisManager('localhost', 6379, 'password')
redis_manager.client = self.redis_mock
keys = await redis_manager.get_keys_for_label(label)

# Simulate processing image data
result = []
for key, image in zip(keys, self.redis_mock.mget.return_value):
if image:
# 提取圖像名稱
image_name = key.split('_')[-1]
encoded_image = base64.b64encode(image).decode('utf-8')
result.append((encoded_image, image_name))

# 預先計算預期結果
expected_encoded_image = base64.b64encode(
b'image_data_2',
).decode('utf-8')
expected_result = [
(expected_encoded_image, 'image2'),
]
result = await redis_manager.get_partial_config(key)

self.assertEqual(result, expected_result)
# Check the expected result
self.assertEqual(result, cached_config[key])

async def test_process_image_data(self) -> None:
async def test_delete_config_cache(self) -> None:
"""
Test the process_image_data function to ensure image data
processesed correctly.
Test the delete_config_cache function to ensure it deletes correctly.
"""
image = b'image_data_1'
redis_manager = RedisManager('localhost', 6379, 'password')
redis_manager.client = self.redis_mock
await redis_manager.delete_config_cache()

# Call the function
result = base64.b64encode(image).decode('utf-8')
# Check if the delete method was called with the correct key
self.redis_mock.delete.assert_called_once_with('config_cache')

async def test_get_config_cache(self) -> None:
"""
Test the get_config_cache function to ensure it retrieves correctly.
"""
cached_config = {'existing_key': 'existing_value'}

# Mock the Redis get method
self.redis_mock.get.return_value = json.dumps(
cached_config,
).encode('utf-8')

redis_manager = RedisManager('localhost', 6379, 'password')
redis_manager.client = self.redis_mock
result = await redis_manager.get_config_cache()

# Check the expected result
expected_result = base64.b64encode(image).decode('utf-8')
self.assertEqual(result, expected_result)
self.assertEqual(result, cached_config)

async def test_set_config_cache(self) -> None:
"""
Test the set_config_cache function to ensure it sets correctly.
"""
config = {'new_key': 'new_value'}

redis_manager = RedisManager('localhost', 6379, 'password')
redis_manager.client = self.redis_mock
await redis_manager.set_config_cache(config)

# Check if the set method was called with the correct config
self.redis_mock.set.assert_called_once_with(
'config_cache', json.dumps(config), ex=3600,
)


class TestUtils(unittest.IsolatedAsyncioTestCase):
"""
Test suite for utility functions in the streaming_web module.
"""

def setUp(self) -> None:
"""
Set up the test environment before each test.
"""
self.redis_mock = MagicMock(spec=redis.Redis)
self.redis_mock.scan = AsyncMock()
self.redis_mock.mget = AsyncMock()
self.redis_mock.get = AsyncMock()

def tearDown(self) -> None:
"""
Clean up after each test.
"""
self.redis_mock.reset_mock()

async def test_send_frames(self) -> None:
"""
Expand Down Expand Up @@ -299,7 +326,7 @@ async def test_encode(self) -> None:
# Check if encoding and underscore replacement work as expected
expected_encoded = base64.urlsafe_b64encode(
input_string.encode('utf-8'),
).decode('utf-8').replace('_', '-')
).decode('utf-8')
self.assertEqual(encoded_string, expected_encoded)

async def test_decode_valid_base64(self) -> None:
Expand All @@ -308,7 +335,7 @@ async def test_decode_valid_base64(self) -> None:
"""
input_string = base64.urlsafe_b64encode(
b'test_label',
).decode('utf-8').replace('_', '-')
).decode('utf-8')
decoded_string = Utils.decode(input_string)
self.assertEqual(decoded_string, 'test_label')

Expand All @@ -320,6 +347,89 @@ async def test_decode_invalid_base64(self) -> None:
decoded_string = Utils.decode(input_string)
self.assertEqual(decoded_string, input_string)

async def test_load_configuration(self) -> None:
"""
Test the load_configuration function to ensure it loads correctly.
"""
config_path = 'test_config.json'
config_data = [{'key': 'value'}]

# Mock the open function to return the config data
with unittest.mock.patch(
'builtins.open',
unittest.mock.mock_open(read_data=json.dumps(config_data)),
):
result = Utils.load_configuration(config_path)

self.assertEqual(result, config_data)

async def test_save_configuration(self) -> None:
"""
Test the save_configuration function to ensure it saves correctly.
"""
config_path = 'test_config.json'
config_data = [{'key': 'value'}]

# Mock the open function
with unittest.mock.patch(
'builtins.open',
unittest.mock.mock_open(),
) as mock_file:
Utils.save_configuration(config_path, config_data)

# Check if the file was written with the correct data
mock_file().write.assert_called_once_with(
json.dumps(config_data, indent=4, ensure_ascii=False),
)

async def test_verify_localhost(self) -> None:
"""
Test the verify_localhost function to ensure it verifies correctly.
"""
request_mock = MagicMock()
request_mock.client.host = '127.0.0.1'

# Should not raise an exception
Utils.verify_localhost(request_mock)

request_mock.client.host = '192.168.1.1'
with self.assertRaises(HTTPException):
Utils.verify_localhost(request_mock)

async def test_update_configuration(self) -> None:
"""
Test the update_configuration function to ensure it updates correctly.
"""
config_path = 'test_config.json'
current_config = [{'video_url': 'url1', 'key': 'value1'}]
new_config = [
{'video_url': 'url1', 'key': 'new_value1'}, {
'video_url': 'url2', 'key': 'value2',
},
]

# Mock the load_configuration and save_configuration functions
with unittest.mock.patch(
'examples.streaming_web.backend.utils.Utils.load_configuration',
return_value=current_config,
):
with unittest.mock.patch(
'examples.streaming_web.backend.utils.Utils.'
'save_configuration',
) as mock_save:
result = Utils.update_configuration(config_path, new_config)

# Check the expected result
expected_result = [
{'video_url': 'url1', 'key': 'new_value1'},
{'video_url': 'url2', 'key': 'value2'},
]
self.assertEqual(result, expected_result)

# Check if the save_configuration function
# was called with the correct data
mock_save.assert_called_once_with(config_path, expected_result)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 5f1b0d1

Please sign in to comment.