Skip to content

Commit

Permalink
Updated firstrun better docker/colab support
Browse files Browse the repository at this point in the history
In Docker or Google colab environments where people may not have access to a keyboard, just download piper (as its small) and dont wait the 60 seconds for a keypress.
  • Loading branch information
erew123 authored Nov 25, 2024
1 parent 50cafba commit 856d67f
Show file tree
Hide file tree
Showing 2 changed files with 402 additions and 194 deletions.
139 changes: 115 additions & 24 deletions system/config/firstrun.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import shutil
import time
import sys
import zipfile
import requests
Expand All @@ -15,18 +16,75 @@

this_dir = Path(__file__).parent.resolve().parent.resolve().parent.resolve()

def download_file(url, dest_path):
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
t = tqdm(total=total_size, unit='iB', unit_scale=True, desc=str(dest_path))
with open(dest_path, 'wb') as file:
for data in response.iter_content(block_size):
t.update(len(data))
file.write(data)
t.close()
def is_running_in_colab():
"""Test for google colab"""
try:
import google.colab
return True
except ImportError:
return False

def is_running_in_docker():
"""Test for google colab"""
path = '/proc/self/cgroup'
return os.path.exists('/.dockerenv') or (
os.path.exists(path) and any('docker' in line for line in open(path))
)

def download_file(url: str, dest_path: str, timeout: int = 30, retries: int = 3) -> bool:
"""
Download a file from a URL to a specified destination path with progress bar.
Args:
url (str): URL of the file to download
dest_path (str): Destination path where the file will be saved
timeout (int, optional): Connection and read timeout in seconds. Defaults to 30.
retries (int, optional): Number of retry attempts for failed downloads. Defaults to 3.
"""
progress_bar = None
for attempt in range(retries):
try:
response = requests.get(url, stream=True,
timeout=(timeout, timeout))
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
# Initialize progress bar
progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True,
desc=f"Downloading {Path(dest_path).name} (Attempt {attempt + 1}/{retries})")
Path(dest_path).parent.mkdir(parents=True, exist_ok=True)
with open(dest_path, 'wb') as dwn_file:
for data in response.iter_content(block_size):
if data:
dwn_file.write(data)
progress_bar.update(len(data))
if progress_bar:
progress_bar.close()
if progress_bar.n != total_size and total_size > 0:
raise requests.exceptions.RequestException(
f"Downloaded size ({progress_bar.n} bytes) does not match expected size ({total_size} bytes)"
)
return True
except requests.exceptions.RequestException as e:
if progress_bar:
progress_bar.close()
if Path(dest_path).exists():
Path(dest_path).unlink()
if attempt == retries - 1:
raise requests.exceptions.RequestException(
f"Failed to download {url} after {retries} attempts: {str(e)}"
)
print(f"Download attempt {attempt + 1}/{retries} failed: {str(e)}")
print("Retrying in 5 seconds...")
time.sleep(5)
except Exception as e:
if progress_bar:
progress_bar.close()
raise Exception(f"Unexpected error downloading {url}: {str(e)}")
return False

def setup_piper():
"""Download base piper files"""
os.makedirs(this_dir / 'models/piper', exist_ok=True)
os.makedirs(this_dir / 'models/xtts', exist_ok=True)
download_file("https://huggingface.co/rhasspy/piper-voices/resolve/v1.0.0/en/en_US/ljspeech/high/en_US-ljspeech-high.onnx?download=true",
Expand All @@ -35,6 +93,7 @@ def setup_piper():
this_dir / "models/piper/en_US-ljspeech-high.onnx.json")

def setup_vits():
"""Download base vits files"""
os.makedirs(this_dir / 'models/vits', exist_ok=True)
os.makedirs(this_dir / 'models/xtts', exist_ok=True)
zip_path = this_dir / 'models/vits/vits.zip'
Expand All @@ -44,6 +103,7 @@ def setup_vits():
os.remove(zip_path)

def setup_xtts():
"""Download base xtts files"""
os.makedirs(this_dir / 'models/xtts/xttsv2_2.0.3', exist_ok=True)
base_url = "https://huggingface.co/coqui/XTTS-v2/resolve/v2.0.3/"
files = [
Expand All @@ -56,17 +116,39 @@ def setup_xtts():
"speakers_xtts.pth",
"vocab.json"
]
for file in files:
download_file(base_url + file + "?download=true", this_dir / f"models/xtts/xttsv2_2.0.3/{file}")


for each_file in files:
download_file(base_url + file + "?download=true", this_dir / f"models/xtts/xttsv2_2.0.3/{each_file}")

def update_tts_engines(engine):
AlltalkTTSEnginesConfig.get_instance().change_engine(engine).save()
"""Set engine to users choice that matches the download"""
try:
config_instance = AlltalkTTSEnginesConfig.get_instance()
config_instance.change_engine(engine)
config_instance.save()
except Exception as e:
print(f"[{branding}TTS] Error updating TTS engine configuration: {str(e)}")

def set_firstrun_model_false():
config.firstrun_model = False
config.save()
try:
# Get a fresh instance of the config
config_fr = AlltalkConfig.get_instance()
config_fr.firstrun_model = False

# Force a flush to disk immediately
config_fr.save()

# Verify the save worked
config_path = AlltalkConfig.default_config_path()
with open(config_path, 'r', encoding="utf-8") as f:
saved_config = json.load(f)
if saved_config.get('firstrun_model', True):
print(f"[{branding}TTS] Warning: Failed to save firstrun_model=False")
# Try one more time with direct file write
saved_config['firstrun_model'] = False
with open(config_path, 'w', encoding="utf-8") as f:
json.dump(saved_config, f, indent=2)
except Exception as e:
print(f"[{branding}TTS] Error saving firstrun configuration: {str(e)}")

def warning_message():
print(f"[{branding}TTS]")
Expand All @@ -77,7 +159,7 @@ def warning_message():
print(f"[{branding}TTS]")

# Load confignew.json (might be version 1 or 2):
with open(AlltalkConfig.default_config_path(), 'r') as file:
with open(AlltalkConfig.default_config_path(), 'r', encoding="utf-8") as file:
config_json = json.load(file)

branding = config_json['branding']
Expand All @@ -102,19 +184,28 @@ def warning_message():

# Check if firstrun_model is true
if config.firstrun_model:
# Check for Colab or Docker environment
if is_running_in_colab() or is_running_in_docker():
print(f"[{branding}TTS] Detected Colab/Docker environment - automatically selecting Piper")
setup_piper()
update_tts_engines('piper')
set_firstrun_model_false()
warning_message()
exit()

print(f"[{branding}TTS]")
print(f"[{branding}TTS] \033[92mThis is the first time startup.. Please download a start TTS model. Other TTS engines\033[0m")
print(f"[{branding}TTS] \033[92mand TTS models can be downloaded/managed in the Gradio Interface `TTS Engines Settings`\033[0m")
print(f"[{branding}TTS] \033[92mtab after inital setup.\033[0m")
print(f"[{branding}TTS]")

# List of available models
models = [
{"name": "piper", "model": "piper"},
{"name": "vits", "model": "tts_models--en--vctk--vits"},
{"name": "xtts", "model": "xttsv2_2.0.3"},
]

# Display models to the user
print(f"[{branding}TTS] \033[94mAvailable First Time Start-up models:\033[0m")
print(f"[{branding}TTS]")
Expand All @@ -124,8 +215,8 @@ def warning_message():
print(f"[{branding}TTS]")
print(f"[{branding}TTS] \033[94mIn \033[91m60 seconds\033[0m \033[94ma Piper model will be \033[91mdownloaded automatically.\033[0m")
print(f"[{branding}TTS]")
# Auto-select model after 30 seconds

# Auto-select model after 60 seconds
selected_model = models[0] # Default to the first model (piper)
try:
user_choice = inputimeout(prompt=f"[{branding}TTS] \033[92mEnter your choice 1-4: \033[0m ", timeout=60)
Expand Down Expand Up @@ -167,6 +258,6 @@ def warning_message():
# Set firstrun_model to false
set_firstrun_model_false()
warning_message()
exit()
sys.exit()
else:
exit()
sys.exit()
Loading

0 comments on commit 856d67f

Please sign in to comment.