Skip to content

Commit

Permalink
Merge pull request #427 from hykilpikonna/alltalkbeta
Browse files Browse the repository at this point in the history
[F] Fix #251: Expected a value of type 'str'
  • Loading branch information
erew123 authored Nov 29, 2024
2 parents 94f90e1 + 0c6e113 commit 239b656
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,7 @@ def format_audio_list(
continue

# Load and process audio
wav, sr = torchaudio.load(audio_path)
wav, sr = torchaudio.load(str(audio_path))
if wav.size(0) != 1:
wav = torch.mean(wav, dim=0, keepdim=True)
wav = wav.squeeze()
Expand Down Expand Up @@ -1209,7 +1209,7 @@ def format_audio_list(

chunk_path = os.path.join(
temp_folder, f"{audio_file_name_without_ext}_chunk_{chunk_idx}.wav")
torchaudio.save(chunk_path, chunk.unsqueeze(0), sr)
torchaudio.save(str(chunk_path), chunk.unsqueeze(0), sr)

# Transcribe with appropriate precision
if fal_precision == "mixed" and device == "cuda":
Expand Down Expand Up @@ -1745,7 +1745,7 @@ def save_audio_segment(
os.makedirs(
os.path.dirname(sas_split_absolute_path),
exist_ok=True)
torchaudio.save(sas_split_absolute_path, sas_split_audio, sas_sr)
torchaudio.save(str(sas_split_absolute_path), sas_split_audio, sas_sr)

sas_metadata["audio_file"].append(
f"wavs/{sas_split_relative_path}")
Expand All @@ -1755,7 +1755,7 @@ def save_audio_segment(

# Only save if segment is at least 1 second
if sas_audio_segment.size(-1) >= sas_sr:
torchaudio.save(sas_absolute_path, sas_audio_segment, sas_sr)
torchaudio.save(str(sas_absolute_path), sas_audio_segment, sas_sr)
sas_metadata["audio_file"].append(f"wavs/{sas_audio_file_name}")
sas_metadata["text"].append(sas_sentence)
sas_metadata["speaker_name"].append(sas_speaker_name)
Expand Down Expand Up @@ -2221,7 +2221,7 @@ def save_audio_and_correction(
f"Saving edited audio: {sr}Hz, length: {len(audio)}",
"DATA_PROCESS")
audio_tensor = torch.tensor(audio).unsqueeze(0)
torchaudio.save(audio_path, audio_tensor, sr)
torchaudio.save(str(audio_path), audio_tensor, sr)
save_status_msg.append("Audio saved successfully")
debug_print(
f"Saved edited audio to {audio_path}",
Expand Down Expand Up @@ -2959,7 +2959,7 @@ def run_tts(lang, tts_text, speaker_audio_file):
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
out_path = fp.name
torchaudio.save(out_path, out["wav"], 24000)
torchaudio.save(str(out_path), out["wav"], 24000)

return "Speech generated !", out_path, speaker_audio_file

Expand Down Expand Up @@ -3171,7 +3171,7 @@ def compact_custom_model(
if file_path.is_file() and file_path.suffix.lower() == ".wav":
try:
# Load audio file and get duration
waveform, sample_rate = torchaudio.load(file_path)
waveform, sample_rate = torchaudio.load(str(file_path))
duration = waveform.size(
1) / sample_rate # Duration in seconds

Expand Down
2 changes: 1 addition & 1 deletion system/tts_engines/f5tts/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ async def infer_process(
):
"""Process text and prepare for batch inference"""
# Split the input text into batches
audio, sr = torchaudio.load(ref_audio)
audio, sr = torchaudio.load(str(ref_audio))
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
gen_text_batches = self.chunk_text(gen_text, max_chars=max_chars)

Expand Down
2 changes: 1 addition & 1 deletion system/tts_engines/xtts/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ async def generate_tts(self, text, voice, language, temperature, repetition_pena
else:
self.print_message("Starting non-streaming generation", message_type="debug_tts")
output = self.model.inference(**common_args)
torchaudio.save(output_file, torch.tensor(output["wav"]).unsqueeze(0), 24000)
torchaudio.save(str(output_file), torch.tensor(output["wav"]).unsqueeze(0), 24000)
self.print_message(f"Saved audio to: {output_file}", message_type="debug_tts")

elif self.current_model_loaded.startswith("apitts"):
Expand Down

0 comments on commit 239b656

Please sign in to comment.