diff --git a/finetune.py b/finetune.py index 96c0471..ce34e8f 100644 --- a/finetune.py +++ b/finetune.py @@ -1588,7 +1588,7 @@ def extend_segment(wav, start, end, sr, min_duration): if (samples_needed + current_duration * sr) > wav.size(-1): # Pad the file symmetrically with zeroes padding_amount = (samples_needed + current_duration * sr - wav.size(-1)) - padding = torch.zeros(padding_amount // 2 + 1, dtype=wav.dtype) + padding = torch.zeros(int(padding_amount // 2 + 1), dtype=wav.dtype) return torch.cat([padding, wav, padding], dim=-1) # Try to extend equally on both sides @@ -1700,135 +1700,106 @@ def create_dataset_splits(df, eval_percentage, random_seed=42): def save_audio_segment( - sas_audio, - sas_sr, - sas_start_time, - sas_end_time, - sas_transcription, - sas_audio_file_name_without_ext, - sas_segment_idx, - sas_speaker_name, - sas_audio_folder, - sas_metadata, - sas_max_duration, - _sas_buffer, - sas_too_long_files, - sas_target_language, + audio, + sr, + start_time, + end_time, + transcription, + audio_file_name_without_ext, + segment_idx, + speaker_name, + audio_folder, + metadata, + target_language, ): """Helper function to save audio segments and update metadata""" - sas_transcription = sas_transcription.strip() - sas_sentence = multilingual_cleaners(sas_sentence, sas_target_language) - sas_audio_file_name = f"{sas_audio_file_name_without_ext}_{str(sas_segment_idx).zfill(8)}.wav" + transcription = transcription.strip() + sentence = multilingual_cleaners(sentence, target_language) + audio_file_name = f"{audio_file_name_without_ext}_{str(segment_idx).zfill(8)}.wav" - sas_absolute_path = os.path.join(sas_audio_folder, sas_audio_file_name) - os.makedirs(os.path.dirname(sas_absolute_path), exist_ok=True) + absolute_path = os.path.join(audio_folder, audio_file_name) + os.makedirs(os.path.dirname(absolute_path), exist_ok=True) # Extract audio segment - sas_audio_start = int(sas_sr * sas_start_time) - sas_audio_end = int(sas_sr * (sas_end_time + _sas_buffer)) - sas_audio_segment = sas_audio[sas_audio_start:sas_audio_end].unsqueeze(0) - - # Handle long audio segments - if sas_audio_segment.size(-1) > sas_max_duration * sas_sr: - sas_too_long_files.append( - (sas_audio_file_name, sas_audio_segment.size(-1) / sas_sr)) - - # Skip if audio is too short - if sas_audio_segment.size(-1) < sas_sr: - debug_print( - f"Skipping short audio segment: {sas_audio_file_name}", - "SEGMENTS") - return - - os.makedirs( - os.path.dirname(sas_split_absolute_path), - exist_ok=True) - torchaudio.save(str(sas_split_absolute_path), sas_split_audio, sas_sr) - - sas_metadata["audio_file"].append( - f"wavs/{sas_split_relative_path}") - sas_metadata["text"].append(sas_sentence) - sas_metadata["speaker_name"].append(sas_speaker_name) - sas_segment_idx += 1 + audio_start = int(sr * start_time) + audio_end = int(sr * end_time) + audio_segment = audio[audio_start:audio_end].unsqueeze(0) # Only save if segment is at least 1 second - if sas_audio_segment.size(-1) >= sas_sr: - torchaudio.save(str(sas_absolute_path), sas_audio_segment, sas_sr) - sas_metadata["audio_file"].append(f"wavs/{sas_audio_file_name}") - sas_metadata["text"].append(sas_sentence) - sas_metadata["speaker_name"].append(sas_speaker_name) + if audio_segment.size(-1) >= sr: + torchaudio.save(str(absolute_path), audio_segment, sr) + metadata["audio_file"].append(f"wavs/{audio_file_name}") + metadata["text"].append(sentence) + metadata["speaker_name"].append(speaker_name) def process_transcription_result( - ptr_result, - ptr_audio, - ptr_sr, - ptr_segment_idx, - ptr_audio_file_name_without_ext, - ptr_metadata, - ptr_whisper_words, - ptr_max_duration, - ptr_buffer, - ptr_speaker_name, - ptr_audio_folder, - ptr_too_long_files, - ptr_create_bpe_tokenizer, - ptr_target_language, + result, + audio, + sr, + segment_idx, + audio_file_name_without_ext, + metadata, + whisper_words, + buffer_time, + speaker_name, + audio_folder, + create_bpe_tokenizer, + target_language, ): """Helper function to process transcription results and save audio segments""" - ptr_i = ptr_segment_idx + 1 - ptr_sentence = "" - ptr_sentence_start = None - ptr_first_word = True - ptr_current_words = [] - - for ptr_segment in ptr_result["segments"]: - if "words" not in ptr_segment: + i = segment_idx + 1 + sentence = "" + first_word = True + segment_content = "" + + for segment in result["segments"]: + + if "words" not in segment: continue - for ptr_word_info in ptr_segment["words"]: - ptr_word = ptr_word_info.get("word", "").strip() - if not ptr_word: + segment_content = "" + segment_start = segment["words"][0].get("start", 0) - buffer_time + segment_end = segment_start + + for word_info in segment["words"]: + word = word_info.get("word", "").strip() + if not word: continue - ptr_start_time = ptr_word_info.get("start", 0) - ptr_end_time = ptr_word_info.get("end", 0) + end_time = word_info.get("end", 0) - if ptr_create_bpe_tokenizer: - ptr_whisper_words.append(ptr_word) + if create_bpe_tokenizer: + whisper_words.append(word) - if ptr_first_word: - ptr_sentence_start = ptr_start_time - ptr_sentence = ptr_word - ptr_first_word = False + if first_word: + sentence = word + first_word = False else: - ptr_sentence += " " + ptr_word - - ptr_current_words.append( - {"word": ptr_word, "start": ptr_start_time, "end": ptr_end_time}) - - # Handle sentence splitting and audio saving - if ptr_word[-1] in ["!", ".", - "?"] or (ptr_end_time - ptr_sentence_start) > ptr_max_duration: - save_audio_segment( - ptr_audio, - ptr_sr, - ptr_sentence_start, - ptr_end_time, - ptr_sentence, - ptr_audio_file_name_without_ext, - ptr_i, - ptr_speaker_name, - ptr_audio_folder, - ptr_metadata, - ptr_max_duration, - ptr_buffer, - ptr_too_long_files, - ptr_target_language, - ) - ptr_first_word = True - ptr_current_words = [] - ptr_sentence = "" + sentence += " " + word + + if word[-1] in ["!", ".", "?"]: + segment_content += sentence + " " + segment_end = end_time + buffer_time + first_word = True + sentence = "" + + segment_content = segment_content.strip() + if segment_content: + save_audio_segment( + audio, + sr, + segment_start, + segment_end, + sentence, + audio_file_name_without_ext, + i, + speaker_name, + audio_folder, + metadata, + target_language, + ) + def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps, max_duration=float("inf")):