Skip to content

Commit

Permalink
Refactor the code to boost efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong1120 committed Dec 3, 2024
1 parent fa59881 commit ec6ca22
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 73 deletions.
90 changes: 48 additions & 42 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import argparse
import asyncio
import base64
import gc
import json
import logging
Expand All @@ -12,7 +11,6 @@
from multiprocessing import Process
from typing import TypedDict

import anyio
import cv2
from dotenv import load_dotenv
from watchdog.observers import Observer
Expand All @@ -33,10 +31,6 @@

is_windows = os.name == 'nt'

# Initialise Redis manager
if not is_windows:
redis_manager = RedisManager()


class AppConfig(TypedDict, total=False):
"""
Expand Down Expand Up @@ -94,6 +88,9 @@ def compute_config_hash(self, config: dict) -> str:

async def reload_configurations(self):
async with self.lock:
if not is_windows:
redis_manager = RedisManager()

self.logger.info('Reloading configurations...')
with open(self.config_file, encoding='utf-8') as file:
configurations = json.load(file)
Expand Down Expand Up @@ -123,12 +120,8 @@ async def reload_configurations(self):
'stream_name', 'prediction_visual',
)

site = base64.urlsafe_b64encode(
site.encode('utf-8'),
).decode('utf-8')
stream_name = base64.urlsafe_b64encode(
stream_name.encode('utf-8'),
).decode('utf-8')
site = Utils.encode(site)
stream_name = Utils.encode(stream_name)

key_to_delete = f"stream_frame:{site}|{stream_name}"

Expand Down Expand Up @@ -194,6 +187,9 @@ async def reload_configurations(self):
)
)

# Close Redis connection
await redis_manager.close_connection()

async def run_multiple_streams(self) -> None:
"""
Manage multiple video streams based on a config file.
Expand All @@ -218,7 +214,7 @@ async def run_multiple_streams(self) -> None:

try:
while True:
await anyio.sleep(1)
await asyncio.sleep(1)
except KeyboardInterrupt:
self.logger.info(
'\n[INFO] Received KeyboardInterrupt. Stopping observer...',
Expand Down Expand Up @@ -257,15 +253,15 @@ async def process_single_stream(
detect_with_server (bool): If run detection with server api or not.
detection_items (dict): The detection items to check for.
"""
if not is_windows:
redis_manager = RedisManager()

# Initialise the stream capture object
streaming_capture = StreamCapture(stream_url=video_url)

# Get the API URL from environment variables
api_url = os.getenv('API_URL', 'http://localhost:5000')

# Initialise the live stream detector
live_stream_detector = LiveStreamDetector(
api_url=api_url,
api_url=os.getenv('API_URL', 'http://localhost:5000'),
model_key=model_key,
output_folder=site,
detect_with_server=detect_with_server,
Expand Down Expand Up @@ -345,7 +341,7 @@ async def process_single_stream(

# Translate the warnings
translated_warnings = Translator.translate_warning(
warnings, language,
tuple(warnings), language,
)

# Draw the detections on the frame
Expand All @@ -368,7 +364,7 @@ async def process_single_stream(
):
translated_controlled_zone_warning: list[str] = (
Translator.translate_warning(
controlled_zone_warning, language,
tuple(controlled_zone_warning), language,
)
)
message = (
Expand Down Expand Up @@ -452,13 +448,10 @@ async def process_single_stream(
try:
# Encode site and stream_name to avoid issues
# with special characters
encoded_site = base64.urlsafe_b64encode(
(site or 'default_site').encode('utf-8'),
).decode('utf-8')

encoded_stream_name = base64.urlsafe_b64encode(
stream_name.encode('utf-8'),
).decode('utf-8')
encoded_site = Utils.encode(site or 'default site')
encoded_stream_name = Utils.encode(
stream_name,
) or 'default stream name'

# Use a unique key for each thread or process
key = f"stream_frame:{encoded_site}|{encoded_stream_name}"
Expand All @@ -468,7 +461,7 @@ async def process_single_stream(

# Translate the warnings
translated_warnings = Translator.translate_warning(
warnings=warnings, language='zh-TW',
warnings=tuple(warnings), language='zh-TW',
)

# Combine warnings into a single string for storage
Expand Down Expand Up @@ -500,6 +493,11 @@ async def process_single_stream(

# Release resources after processing
await streaming_capture.release_resources()

# Close the Redis connection
if not is_windows:
await redis_manager.close_connection()

gc.collect()

async def process_streams(self, config: AppConfig) -> None:
Expand Down Expand Up @@ -551,20 +549,24 @@ async def process_streams(self, config: AppConfig) -> None:
)
finally:
if not is_windows:
site = config.get('site')
stream_name = config.get('stream_name', 'prediction_visual')

site = base64.urlsafe_b64encode(
(site or 'default_site').encode('utf-8'),
).decode('utf-8')
stream_name = base64.urlsafe_b64encode(
stream_name.encode('utf-8'),
).decode('utf-8')
redis_manager = RedisManager()
site = config.get('site') or 'default site'
stream_name = config.get(
'stream_name',
) or 'default stream name'

site = Utils.encode(site)
stream_name = Utils.encode(
stream_name,
)

key = f"stream_frame:{site}|{stream_name}"
await redis_manager.delete(key)
self.logger.info(f"Deleted Redis key: {key}")

# Close the Redis connection
await redis_manager.close_connection()

def start_process(self, config: AppConfig) -> Process:
"""
Start a new process for processing a video stream.
Expand Down Expand Up @@ -628,9 +630,10 @@ async def process_single_image(

# Initialise the live stream detector,
# but here used for a single image
api_url = os.getenv('API_URL', 'http://localhost:5000')
live_stream_detector = LiveStreamDetector(
api_url=api_url, model_key=model_key, output_folder=output_folder,
api_url=os.getenv('API_URL', 'http://localhost:5000'),
model_key=model_key,
output_folder=output_folder,
)

# Initialise the drawing manager
Expand Down Expand Up @@ -717,10 +720,13 @@ async def main():
print('\n[INFO] Received KeyboardInterrupt. Shutting down...')
finally:
# Perform necessary cleanup if needed
if not is_windows:
await redis_manager.close_connection()
print('[INFO] Redis connection closed.')
# if not is_windows:
# await redis_manager.close_connection()
# print('[INFO] Redis connection closed.')
print('[INFO] Application stopped.')
# Clear the asyncio event loop
await asyncio.sleep(0)


if __name__ == '__main__':
anyio.run(main)
asyncio.run(main())
29 changes: 15 additions & 14 deletions src/lang_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from functools import lru_cache

LANGUAGES = {
'zh-TW': {
'warning_people_in_controlled_area': '警告: 有{count}個人進入受控區域!',
Expand Down Expand Up @@ -129,12 +131,13 @@ class Translator:
"""

@staticmethod
def translate_warning(warnings: list[str], language: str) -> list[str]:
@lru_cache(maxsize=128)
def translate_warning(warnings: tuple[str, ...], language: str) -> list[str]:
"""
Translate warnings from English to the specified language.
Args:
warnings (list[str]): A list of warnings in English.
warnings (tuple[str, ...]): A tuple of warnings in English.
language (str): The target language code (e.g., 'zh-TW', 'en').
Returns:
Expand Down Expand Up @@ -162,21 +165,21 @@ def translate_warning(warnings: list[str], language: str) -> list[str]:
'warning_no_safety_vest', warning,
)
elif 'Someone is too close to' in warning:
label_key = (
'machinery'
if 'machinery' in warning
else 'vehicle'
)
label_key = 'machinery' if 'machinery' in warning else 'vehicle'
translated_warning = LANGUAGES[language].get(
'warning_close_to_machinery', warning,
).replace(
)
translated_warning = translated_warning.replace(
'{label}', LANGUAGES[language].get(label_key, label_key),
)
elif 'people have entered the controlled area!' in warning:
count = warning.split(' ')[1] # Extract count of people
translated_warning = LANGUAGES[language].get(
'warning_people_in_controlled_area', warning,
).replace('{count}', count)
)
translated_warning = translated_warning.replace(
'{count}', count,
)
else:
# Keep the original warning if no match
translated_warning = warning
Expand All @@ -196,11 +199,9 @@ def main():
# Specify the language to translate to
# (e.g., 'zh-TW' for Traditional Chinese)
language = 'zh-TW'

# Translate the warnings
translated_warnings = Translator.translate_warning(warnings, language)

# Output the translated warnings
translated_warnings = Translator.translate_warning(
tuple(warnings), language,
)
print('Original Warnings:', warnings)
print('Translated Warnings:', translated_warnings)

Expand Down
20 changes: 19 additions & 1 deletion src/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import base64
import logging
import os
from datetime import datetime
Expand Down Expand Up @@ -35,6 +36,21 @@ def is_expired(expire_date_str: str | None) -> bool:
return False
return False

@staticmethod
def encode(value: str) -> str:
"""
Encode a value into a URL-safe Base64 string.
Args:
value (str): The value to encode.
Returns:
str: The encoded string.
"""
return base64.urlsafe_b64encode(
value.encode('utf-8'),
).decode('utf-8')


class FileEventHandler(FileSystemEventHandler):
"""
Expand Down Expand Up @@ -64,7 +80,9 @@ def on_modified(self, event):
event_path = os.path.abspath(event.src_path)
if event_path == self.file_path:
print(f"[DEBUG] Configuration file modified: {event_path}")
asyncio.run_coroutine_threadsafe(self.callback(), self.loop)
asyncio.run_coroutine_threadsafe(
self.callback(), self.loop, # Ensure the callback is run in the loop
)


class RedisManager:
Expand Down
Loading

0 comments on commit ec6ca22

Please sign in to comment.