diff --git a/software/source/server/services/stt/local-whisper/stt.py b/software/source/server/services/stt/local-whisper/stt.py index 1c2743b2..b1ea522b 100644 --- a/software/source/server/services/stt/local-whisper/stt.py +++ b/software/source/server/services/stt/local-whisper/stt.py @@ -5,11 +5,11 @@ from datetime import datetime import os import contextlib +import platform import tempfile import shutil import ffmpeg import subprocess - import urllib.request @@ -56,21 +56,92 @@ def install(service_dir): print("Whisper Rust executable already exists. Skipping build.") WHISPER_MODEL_PATH = os.path.join(service_dir, "model") - WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL_NAME", "ggml-tiny.en.bin") - WHISPER_MODEL_URL = os.getenv( - "WHISPER_MODEL_URL", - "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/", - ) - - if not os.path.isfile(os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME)): + while not valid_model(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME): + print(f"Downloading Whisper model '{WHISPER_MODEL_NAME}'.") + WHISPER_MODEL_URL = os.getenv( + "WHISPER_MODEL_URL", + "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/", + ) os.makedirs(WHISPER_MODEL_PATH, exist_ok=True) urllib.request.urlretrieve( f"{WHISPER_MODEL_URL}{WHISPER_MODEL_NAME}", os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME), ) else: - print("Whisper model already exists. Skipping download.") + print(f"Whisper model '{WHISPER_MODEL_NAME}' installed.") + + +def valid_model(model_path: str, model_file: str) -> bool: + # Try to validate model through cryptographic hash comparison + + model_file_path = os.path.join(model_path, model_file) + if not os.path.isfile(model_file_path): + return False + + # Download details file and get hash + details_file = f"https://huggingface.co/ggerganov/whisper.cpp/raw/main/{model_file}" + try: + with urllib.request.urlopen(details_file) as response: + body_bytes = response.read() + except: + print("Internet connection not detected. Skipping validation.") + return True + + lines = body_bytes.splitlines() + colon_index = lines[1].find(b':') + details_hash = lines[1][colon_index + 1:].decode() + + # Generate model hash using native commands + model_hash = None + system = platform.system() + if system == 'Darwin': + shasum_path = shutil.which('shasum') + model_hash = subprocess.check_output( + f"{shasum_path} -a 256 {model_file_path} | cut -d' ' -f1", + text=True, + shell=True + ) + elif system == 'Linux': + sha256sum_path = shutil.which('sha256sum') + model_hash = subprocess.check_output( + f"{sha256sum_path} {model_file_path} | cut -d' ' -f1", + text=True, + shell=True + ) + elif system == 'Windows': + comspec = os.getenv("COMSPEC") + if comspec.endswith('cmd.exe'): # Most likely + certutil_path = shutil.which('certutil') + first_op = f"{certutil_path} -hashfile {model_file_path} sha256" + second_op = 'findstr /v "SHA256 CertUtil"' # Prints only lines that do not contain a match. + model_hash = subprocess.check_output(f"{first_op} | {second_op}", text=True, shell=True) + else: + first_op = f"Get-FileHash -LiteralPath {model_file_path} -Algorithm SHA256" + subsequent_ops = "Select-Object Hash | Format-Table -HideTableHeaders | Out-String" + model_hash = subprocess.check_output([ + 'pwsh', + '-Command', + f"({first_op} | {subsequent_ops}).trim().toLower()" + ], + text=True + ) + else: + print(f"System '{system}' not supported. Skipping validation.") + return True + + if details_hash == model_hash.strip(): + print(f"Whisper model '{model_file}' file is valid.") + else: + msg = f''' + The model '{model_file}' did not validate. STT may not function correctly. + The model path is '{model_path}'. + Manually download and verify the model's hash to get better functionality. + Continuing. + ''' + print(msg) + + return True def convert_mime_type_to_format(mime_type: str) -> str: