diff --git a/src/gogadget/main.py b/src/gogadget/main.py index ecebf41..c56fa98 100644 --- a/src/gogadget/main.py +++ b/src/gogadget/main.py @@ -441,16 +441,6 @@ def frequency_analysis( ) -# @app.command( -# no_args_is_help=True, -# help=HelpText.interactive_transcript, -# rich_help_panel="Primary Functions", -# epilog=ffmpeg_warning(), -# ) -# def interactive_transcript(): -# return - - @app.command( no_args_is_help=True, rich_help_panel="Primary Functions", @@ -605,15 +595,16 @@ def install( "Some transcriber functions may appear to freeze for a few minutes if you haven't run them before!" ) CliUtils.print_status("Transcriber: Checking CUDA status") - transcriber = import_module(".transcriber", APP_NAME) + utils = import_module(".utils", APP_NAME) - cuda = transcriber.cuda_available() + cuda = utils.is_cuda_available() if cuda: CliUtils.print_rich("CUDA enabled, can use GPU processing") else: CliUtils.print_rich("CUDA disabled, using CPU processing") CliUtils.print_status("Transcriber: Initialising models") + transcriber = import_module(".transcriber", APP_NAME) dummy_transcribe_file = get_resources_directory() / "a.mp3" transcriber.transcriber( input_path=dummy_transcribe_file, diff --git a/src/gogadget/transcriber.py b/src/gogadget/transcriber.py index 3f30b4e..24b2584 100644 --- a/src/gogadget/transcriber.py +++ b/src/gogadget/transcriber.py @@ -12,7 +12,7 @@ from .cli_utils import CliUtils from .config import SUPPORTED_AUDIO_EXTS, SUPPORTED_VIDEO_EXTS -from .utils import get_cpu_cores, list_files_with_extension +from .utils import get_cpu_cores, is_cuda_available, list_files_with_extension def transcriber( @@ -53,7 +53,7 @@ def transcriber( compute_type = "int8" device = "cpu" if use_gpu: - if cuda_available(): + if is_cuda_available(): device = "cuda" compute_type = "float16" else: @@ -296,17 +296,10 @@ def reclaim_memory_gpu(): from torch.cuda import empty_cache as empty_cuda_cache gc.collect() - if cuda_available(): + if is_cuda_available(): empty_cuda_cache() def reclaim_memory_cpu(): """Force clear model from memory""" gc.collect() - - -def cuda_available() -> bool: - """Check if the current system supports CUDA""" - from torch.cuda import is_available as is_cuda_available - - return is_cuda_available() diff --git a/src/gogadget/utils.py b/src/gogadget/utils.py index c0fd67d..cdf6c3e 100644 --- a/src/gogadget/utils.py +++ b/src/gogadget/utils.py @@ -12,6 +12,7 @@ import pandas as pd from .cli_utils import CliUtils +from .command_runner import get_platform, program_exists def get_cpu_cores(minus_one: bool = False): @@ -173,3 +174,46 @@ def sanitise_string_html(input_string: str) -> str: return output_string else: return "" + + +def is_cuda_available() -> bool: + """Check if the current system supports CUDA + Do a command line check first to avoid the loading times of torch""" + + cli_available = check_cuda_cli_tools() + + if not cli_available: + return False + + cuda_available = check_cuda_torch() + + if not cuda_available: + CliUtils.print_rich( + "CUDA not available, could not initialise using torch. If you are trying to use CUDA, do you have the CUDA specific version of torch installed?" + ) + + return cuda_available + + +def check_cuda_cli_tools() -> bool: + """Check if the current system supports CUDA by testing for cli tools installed""" + + # Macos should always return false + current_platform = get_platform() + if current_platform == "Darwin": + CliUtils.print_rich("CUDA not available, running on macOS") + return False + + nvidia_smi = program_exists("nvidia-smi") + + if not nvidia_smi: + CliUtils.print_rich("CUDA not available, nvidia-smi not found") + + return nvidia_smi + + +def check_cuda_torch() -> bool: + """Check if the current system supports CUDA using torch""" + from torch.cuda import is_available as is_cuda_available + + return is_cuda_available()