Skip to content

Commit

Permalink
Fix audio trimming and remove redundant length checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Yohrog committed Dec 2, 2024
1 parent 2cbbb81 commit 56f08ef
Showing 1 changed file with 82 additions and 111 deletions.
193 changes: 82 additions & 111 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,7 +1588,7 @@ def extend_segment(wav, start, end, sr, min_duration):
if (samples_needed + current_duration * sr) > wav.size(-1):
# Pad the file symmetrically with zeroes
padding_amount = (samples_needed + current_duration * sr - wav.size(-1))
padding = torch.zeros(padding_amount // 2 + 1, dtype=wav.dtype)
padding = torch.zeros(int(padding_amount // 2 + 1), dtype=wav.dtype)
return torch.cat([padding, wav, padding], dim=-1)

# Try to extend equally on both sides
Expand Down Expand Up @@ -1700,135 +1700,106 @@ def create_dataset_splits(df, eval_percentage, random_seed=42):


def save_audio_segment(
sas_audio,
sas_sr,
sas_start_time,
sas_end_time,
sas_transcription,
sas_audio_file_name_without_ext,
sas_segment_idx,
sas_speaker_name,
sas_audio_folder,
sas_metadata,
sas_max_duration,
_sas_buffer,
sas_too_long_files,
sas_target_language,
audio,
sr,
start_time,
end_time,
transcription,
audio_file_name_without_ext,
segment_idx,
speaker_name,
audio_folder,
metadata,
target_language,
):
"""Helper function to save audio segments and update metadata"""
sas_transcription = sas_transcription.strip()
sas_sentence = multilingual_cleaners(sas_sentence, sas_target_language)
sas_audio_file_name = f"{sas_audio_file_name_without_ext}_{str(sas_segment_idx).zfill(8)}.wav"
transcription = transcription.strip()
sentence = multilingual_cleaners(sentence, target_language)
audio_file_name = f"{audio_file_name_without_ext}_{str(segment_idx).zfill(8)}.wav"

sas_absolute_path = os.path.join(sas_audio_folder, sas_audio_file_name)
os.makedirs(os.path.dirname(sas_absolute_path), exist_ok=True)
absolute_path = os.path.join(audio_folder, audio_file_name)
os.makedirs(os.path.dirname(absolute_path), exist_ok=True)

# Extract audio segment
sas_audio_start = int(sas_sr * sas_start_time)
sas_audio_end = int(sas_sr * (sas_end_time + _sas_buffer))
sas_audio_segment = sas_audio[sas_audio_start:sas_audio_end].unsqueeze(0)

# Handle long audio segments
if sas_audio_segment.size(-1) > sas_max_duration * sas_sr:
sas_too_long_files.append(
(sas_audio_file_name, sas_audio_segment.size(-1) / sas_sr))

# Skip if audio is too short
if sas_audio_segment.size(-1) < sas_sr:
debug_print(
f"Skipping short audio segment: {sas_audio_file_name}",
"SEGMENTS")
return

os.makedirs(
os.path.dirname(sas_split_absolute_path),
exist_ok=True)
torchaudio.save(str(sas_split_absolute_path), sas_split_audio, sas_sr)

sas_metadata["audio_file"].append(
f"wavs/{sas_split_relative_path}")
sas_metadata["text"].append(sas_sentence)
sas_metadata["speaker_name"].append(sas_speaker_name)
sas_segment_idx += 1
audio_start = int(sr * start_time)
audio_end = int(sr * end_time)
audio_segment = audio[audio_start:audio_end].unsqueeze(0)

# Only save if segment is at least 1 second
if sas_audio_segment.size(-1) >= sas_sr:
torchaudio.save(str(sas_absolute_path), sas_audio_segment, sas_sr)
sas_metadata["audio_file"].append(f"wavs/{sas_audio_file_name}")
sas_metadata["text"].append(sas_sentence)
sas_metadata["speaker_name"].append(sas_speaker_name)
if audio_segment.size(-1) >= sr:
torchaudio.save(str(absolute_path), audio_segment, sr)
metadata["audio_file"].append(f"wavs/{audio_file_name}")
metadata["text"].append(sentence)
metadata["speaker_name"].append(speaker_name)


def process_transcription_result(
ptr_result,
ptr_audio,
ptr_sr,
ptr_segment_idx,
ptr_audio_file_name_without_ext,
ptr_metadata,
ptr_whisper_words,
ptr_max_duration,
ptr_buffer,
ptr_speaker_name,
ptr_audio_folder,
ptr_too_long_files,
ptr_create_bpe_tokenizer,
ptr_target_language,
result,
audio,
sr,
segment_idx,
audio_file_name_without_ext,
metadata,
whisper_words,
buffer_time,
speaker_name,
audio_folder,
create_bpe_tokenizer,
target_language,
):
"""Helper function to process transcription results and save audio segments"""
ptr_i = ptr_segment_idx + 1
ptr_sentence = ""
ptr_sentence_start = None
ptr_first_word = True
ptr_current_words = []

for ptr_segment in ptr_result["segments"]:
if "words" not in ptr_segment:
i = segment_idx + 1
sentence = ""
first_word = True
segment_content = ""

for segment in result["segments"]:

if "words" not in segment:
continue

for ptr_word_info in ptr_segment["words"]:
ptr_word = ptr_word_info.get("word", "").strip()
if not ptr_word:
segment_content = ""
segment_start = segment["words"][0].get("start", 0) - buffer_time
segment_end = segment_start

for word_info in segment["words"]:
word = word_info.get("word", "").strip()
if not word:
continue

ptr_start_time = ptr_word_info.get("start", 0)
ptr_end_time = ptr_word_info.get("end", 0)
end_time = word_info.get("end", 0)

if ptr_create_bpe_tokenizer:
ptr_whisper_words.append(ptr_word)
if create_bpe_tokenizer:
whisper_words.append(word)

if ptr_first_word:
ptr_sentence_start = ptr_start_time
ptr_sentence = ptr_word
ptr_first_word = False
if first_word:
sentence = word
first_word = False
else:
ptr_sentence += " " + ptr_word

ptr_current_words.append(
{"word": ptr_word, "start": ptr_start_time, "end": ptr_end_time})

# Handle sentence splitting and audio saving
if ptr_word[-1] in ["!", ".",
"?"] or (ptr_end_time - ptr_sentence_start) > ptr_max_duration:
save_audio_segment(
ptr_audio,
ptr_sr,
ptr_sentence_start,
ptr_end_time,
ptr_sentence,
ptr_audio_file_name_without_ext,
ptr_i,
ptr_speaker_name,
ptr_audio_folder,
ptr_metadata,
ptr_max_duration,
ptr_buffer,
ptr_too_long_files,
ptr_target_language,
)
ptr_first_word = True
ptr_current_words = []
ptr_sentence = ""
sentence += " " + word

if word[-1] in ["!", ".", "?"]:
segment_content += sentence + " "
segment_end = end_time + buffer_time
first_word = True
sentence = ""

segment_content = segment_content.strip()
if segment_content:
save_audio_segment(
audio,
sr,
segment_start,
segment_end,
sentence,
audio_file_name_without_ext,
i,
speaker_name,
audio_folder,
metadata,
target_language,
)



def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps, max_duration=float("inf")):
Expand Down

0 comments on commit 56f08ef

Please sign in to comment.