Skip to content

Commit

Permalink
Updated CUDA check to be faster on machines without CUDA capable hard…
Browse files Browse the repository at this point in the history
…ware
  • Loading branch information
jonathanfox5 committed Dec 2, 2024
1 parent 0eba62f commit 1152f5d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 22 deletions.
15 changes: 3 additions & 12 deletions src/gogadget/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,16 +441,6 @@ def frequency_analysis(
)


# @app.command(
# no_args_is_help=True,
# help=HelpText.interactive_transcript,
# rich_help_panel="Primary Functions",
# epilog=ffmpeg_warning(),
# )
# def interactive_transcript():
# return


@app.command(
no_args_is_help=True,
rich_help_panel="Primary Functions",
Expand Down Expand Up @@ -605,15 +595,16 @@ def install(
"Some transcriber functions may appear to freeze for a few minutes if you haven't run them before!"
)
CliUtils.print_status("Transcriber: Checking CUDA status")
transcriber = import_module(".transcriber", APP_NAME)
utils = import_module(".utils", APP_NAME)

cuda = transcriber.cuda_available()
cuda = utils.is_cuda_available()
if cuda:
CliUtils.print_rich("CUDA enabled, can use GPU processing")
else:
CliUtils.print_rich("CUDA disabled, using CPU processing")

CliUtils.print_status("Transcriber: Initialising models")
transcriber = import_module(".transcriber", APP_NAME)
dummy_transcribe_file = get_resources_directory() / "a.mp3"
transcriber.transcriber(
input_path=dummy_transcribe_file,
Expand Down
13 changes: 3 additions & 10 deletions src/gogadget/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .cli_utils import CliUtils
from .config import SUPPORTED_AUDIO_EXTS, SUPPORTED_VIDEO_EXTS
from .utils import get_cpu_cores, list_files_with_extension
from .utils import get_cpu_cores, is_cuda_available, list_files_with_extension


def transcriber(
Expand Down Expand Up @@ -53,7 +53,7 @@ def transcriber(
compute_type = "int8"
device = "cpu"
if use_gpu:
if cuda_available():
if is_cuda_available():
device = "cuda"
compute_type = "float16"
else:
Expand Down Expand Up @@ -296,17 +296,10 @@ def reclaim_memory_gpu():
from torch.cuda import empty_cache as empty_cuda_cache

gc.collect()
if cuda_available():
if is_cuda_available():
empty_cuda_cache()


def reclaim_memory_cpu():
"""Force clear model from memory"""
gc.collect()


def cuda_available() -> bool:
"""Check if the current system supports CUDA"""
from torch.cuda import is_available as is_cuda_available

return is_cuda_available()
44 changes: 44 additions & 0 deletions src/gogadget/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pandas as pd

from .cli_utils import CliUtils
from .command_runner import get_platform, program_exists


def get_cpu_cores(minus_one: bool = False):
Expand Down Expand Up @@ -173,3 +174,46 @@ def sanitise_string_html(input_string: str) -> str:
return output_string
else:
return ""


def is_cuda_available() -> bool:
"""Check if the current system supports CUDA
Do a command line check first to avoid the loading times of torch"""

cli_available = check_cuda_cli_tools()

if not cli_available:
return False

cuda_available = check_cuda_torch()

if not cuda_available:
CliUtils.print_rich(
"CUDA not available, could not initialise using torch. If you are trying to use CUDA, do you have the CUDA specific version of torch installed?"
)

return cuda_available


def check_cuda_cli_tools() -> bool:
"""Check if the current system supports CUDA by testing for cli tools installed"""

# Macos should always return false
current_platform = get_platform()
if current_platform == "Darwin":
CliUtils.print_rich("CUDA not available, running on macOS")
return False

nvidia_smi = program_exists("nvidia-smi")

if not nvidia_smi:
CliUtils.print_rich("CUDA not available, nvidia-smi not found")

return nvidia_smi


def check_cuda_torch() -> bool:
"""Check if the current system supports CUDA using torch"""
from torch.cuda import is_available as is_cuda_available

return is_cuda_available()

0 comments on commit 1152f5d

Please sign in to comment.