diff --git a/finetune.py b/finetune.py index 5077a94..41c1d39 100644 --- a/finetune.py +++ b/finetune.py @@ -1164,11 +1164,11 @@ def format_audio_list( debug_print("Processing with VAD", "AUDIO") # Get VAD segments with resampling vad_segments = process_audio_with_vad( - wav, sr, vad_model, get_speech_timestamps) + wav, sr, vad_model, get_speech_timestamps, max_duration=max_duration) # Group short segments that are close together merged_segments = merge_short_segments( - vad_segments, min_duration, max_gap=0.3) + vad_segments, min_duration, sr=sr, max_gap=0.3) debug_print( f"Merged {len(vad_segments)-len(merged_segments)} short segments", "SEGMENTS") @@ -1515,7 +1515,7 @@ def _create_bpe_tokenizer(bpe_whisper_words, bpe_out_path, bpe_base_path): raise -def merge_short_segments(segments, min_duration, max_gap=0.5): +def merge_short_segments(segments, min_duration, sr, max_gap=0.5): """ More aggressive merge strategy for short segments - Increases max_gap to 0.5s (from 0.3s) @@ -1527,7 +1527,7 @@ def merge_short_segments(segments, min_duration, max_gap=0.5): merged = [] current_group = [] - target_duration = (min_duration + 10.0) / 2 # Target middle of range + target_duration = sr * (min_duration + 10.0) / 2 # Target middle of range for i, segment in enumerate(segments): current_duration = sum(s["end"] - s["start"] @@ -1566,16 +1566,15 @@ def merge_short_segments(segments, min_duration, max_gap=0.5): merged.append(merged_segment) debug_print( - f"Merged {len(segments) - len(merged)} segments into {len(merged)} segments with mid-range preference", + f"Merged {len(segments) - len(merged)} segments with mid-range preference, for a new total of {len(merged)}", "SEGMENTS" ) return merged -def extend_segment(wav, start, end, sr, min_duration, context_window=1.0): +def extend_segment(wav, start, end, sr, min_duration): """ Improved segment extension with better context handling - - Adds context_window parameter for smoother extensions - More balanced extension on both sides - Checks audio content when extending """ @@ -1585,22 +1584,24 @@ def extend_segment(wav, start, end, sr, min_duration, context_window=1.0): samples_needed = int((min_duration - current_duration) * sr) + # Check if we have enough samples in the file + if (samples_needed + current_duration * sr) > wav.size(-1): + # Pad the file symmetrically with zeroes + padding_amount = (samples_needed + current_duration * sr - wav.size(-1)) + padding = torch.zeros(int(padding_amount // 2 + 1), dtype=wav.dtype) + return torch.cat([padding, wav, padding], dim=-1) + # Try to extend equally on both sides extend_left = samples_needed // 2 - extend_right = samples_needed - extend_left + new_start = max(0, start - extend_left) - # Add some context window - context_samples = int(context_window * sr) - new_start = max(0, start - extend_left - context_samples) - new_end = min(wav.size(-1), end + extend_right + context_samples) + # If there weren't enough samples on the left, extend more on the right side + extend_right = samples_needed - (start - new_start) + new_end = min(wav.size(-1), end + extend_right) - # Check if we got enough duration - if (new_end - new_start) / sr < min_duration: - # If still too short, try to compensate from the other side - if new_start == 0: - new_end = min(wav.size(-1), end + samples_needed + context_samples) - elif new_end == wav.size(-1): - new_start = max(0, start - samples_needed - context_samples) + # If there weren't enough samples on the right, extend more on the left side + extend_left = samples_needed - (new_end - end) + new_start = max(0, start - extend_left) debug_print( f"Extended segment from {current_duration:.2f}s to {(new_end - new_start) / sr:.2f}s", @@ -1699,151 +1700,110 @@ def create_dataset_splits(df, eval_percentage, random_seed=42): def save_audio_segment( - sas_audio, - sas_sr, - sas_start_time, - sas_end_time, - sas_sentence, - sas_audio_file_name_without_ext, - sas_segment_idx, - sas_speaker_name, - sas_audio_folder, - sas_metadata, - sas_max_duration, - _sas_buffer, - sas_too_long_files, - sas_target_language, + audio, + sr, + start_time, + end_time, + transcription, + audio_file_name_without_ext, + segment_idx, + speaker_name, + audio_folder, + metadata, + target_language, ): """Helper function to save audio segments and update metadata""" - sas_sentence = sas_sentence.strip() - sas_sentence = multilingual_cleaners(sas_sentence, sas_target_language) - sas_audio_file_name = f"{sas_audio_file_name_without_ext}_{str(sas_segment_idx).zfill(8)}.wav" + transcription = transcription.strip() + sentence = multilingual_cleaners(sentence, target_language) + audio_file_name = f"{audio_file_name_without_ext}_{str(segment_idx).zfill(8)}.wav" - sas_absolute_path = os.path.join(sas_audio_folder, sas_audio_file_name) - os.makedirs(os.path.dirname(sas_absolute_path), exist_ok=True) + absolute_path = os.path.join(audio_folder, audio_file_name) + os.makedirs(os.path.dirname(absolute_path), exist_ok=True) # Extract audio segment - sas_audio_start = int(sas_sr * sas_start_time) - sas_audio_end = int(sas_sr * sas_end_time) - sas_audio_segment = sas_audio[sas_audio_start:sas_audio_end].unsqueeze(0) - - # Handle long audio segments - if sas_audio_segment.size(-1) > sas_max_duration * sas_sr: - sas_too_long_files.append( - (sas_audio_file_name, sas_audio_segment.size(-1) / sas_sr)) - - while sas_audio_segment.size(-1) > sas_max_duration * sas_sr: - sas_split_audio = sas_audio_segment[:, : int( - sas_max_duration * sas_sr)] - sas_audio_segment = sas_audio_segment[:, int( - sas_max_duration * sas_sr):] - sas_split_file_name = f"{sas_audio_file_name_without_ext}_{str(sas_segment_idx).zfill(8)}.wav" - sas_split_relative_path = os.path.join(sas_split_file_name) - sas_split_absolute_path = os.path.normpath( - os.path.join(sas_audio_folder, sas_split_relative_path)) - - os.makedirs( - os.path.dirname(sas_split_absolute_path), - exist_ok=True) - torchaudio.save(str(sas_split_absolute_path), sas_split_audio, sas_sr) - - sas_metadata["audio_file"].append( - f"wavs/{sas_split_relative_path}") - sas_metadata["text"].append(sas_sentence) - sas_metadata["speaker_name"].append(sas_speaker_name) - sas_segment_idx += 1 + audio_start = int(sr * start_time) + audio_end = int(sr * end_time) + audio_segment = audio[audio_start:audio_end].unsqueeze(0) # Only save if segment is at least 1 second - if sas_audio_segment.size(-1) >= sas_sr: - torchaudio.save(str(sas_absolute_path), sas_audio_segment, sas_sr) - sas_metadata["audio_file"].append(f"wavs/{sas_audio_file_name}") - sas_metadata["text"].append(sas_sentence) - sas_metadata["speaker_name"].append(sas_speaker_name) + if audio_segment.size(-1) >= sr: + torchaudio.save(str(absolute_path), audio_segment, sr) + metadata["audio_file"].append(f"wavs/{audio_file_name}") + metadata["text"].append(sentence) + metadata["speaker_name"].append(speaker_name) def process_transcription_result( - ptr_result, - ptr_audio, - ptr_sr, - ptr_segment_idx, - ptr_audio_file_name_without_ext, - ptr_metadata, - ptr_whisper_words, - ptr_max_duration, - ptr_buffer, - ptr_speaker_name, - ptr_audio_folder, - ptr_too_long_files, - ptr_create_bpe_tokenizer, - ptr_target_language, + result, + audio, + sr, + segment_idx, + audio_file_name_without_ext, + metadata, + whisper_words, + buffer_time, + speaker_name, + audio_folder, + create_bpe_tokenizer, + target_language, ): """Helper function to process transcription results and save audio segments""" - ptr_i = ptr_segment_idx + 1 - ptr_sentence = "" - ptr_sentence_start = None - ptr_first_word = True - ptr_current_words = [] - - for ptr_segment in ptr_result["segments"]: - if "words" not in ptr_segment: + i = segment_idx + 1 + sentence = "" + first_word = True + segment_content = "" + + for segment in result["segments"]: + + if "words" not in segment: continue - for ptr_word_info in ptr_segment["words"]: - ptr_word = ptr_word_info.get("word", "").strip() - if not ptr_word: + segment_content = "" + segment_start = segment["words"][0].get("start", 0) - buffer_time + segment_start = max(0, segment_start) + segment_end = segment_start + + for word_info in segment["words"]: + word = word_info.get("word", "").strip() + if not word: continue - ptr_start_time = ptr_word_info.get("start", 0) - ptr_end_time = ptr_word_info.get("end", 0) + end_time = word_info.get("end", 0) - if ptr_create_bpe_tokenizer: - ptr_whisper_words.append(ptr_word) + if create_bpe_tokenizer: + whisper_words.append(word) - if ptr_first_word: - ptr_sentence_start = ptr_start_time - if len(ptr_current_words) == 0: - ptr_sentence_start = max( - ptr_sentence_start - ptr_buffer, 0) - else: - ptr_previous_end = ptr_current_words[-1].get( - "end", 0) if ptr_current_words else 0 - ptr_sentence_start = max( - ptr_sentence_start - ptr_buffer, - (ptr_previous_end + ptr_start_time) / 2) - ptr_sentence = ptr_word - ptr_first_word = False + if first_word: + sentence = word + first_word = False else: - ptr_sentence += " " + ptr_word - - ptr_current_words.append( - {"word": ptr_word, "start": ptr_start_time, "end": ptr_end_time}) - - # Handle sentence splitting and audio saving - if ptr_word[-1] in ["!", ".", - "?"] or (ptr_end_time - ptr_sentence_start) > ptr_max_duration: - save_audio_segment( - ptr_audio, - ptr_sr, - ptr_sentence_start, - ptr_end_time, - ptr_sentence, - ptr_audio_file_name_without_ext, - ptr_i, - ptr_speaker_name, - ptr_audio_folder, - ptr_metadata, - ptr_max_duration, - ptr_buffer, - ptr_too_long_files, - ptr_target_language, - ) - ptr_i += 1 - ptr_first_word = True - ptr_current_words = [] - ptr_sentence = "" + sentence += " " + word + + if word[-1] in ["!", ".", "?"]: + segment_content += sentence + " " + segment_end = end_time + buffer_time + first_word = True + sentence = "" + + segment_content = segment_content.strip() + if segment_content: + save_audio_segment( + audio, + sr, + segment_start, + segment_end, + sentence, + audio_file_name_without_ext, + i, + speaker_name, + audio_folder, + metadata, + target_language, + ) + -def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps): +def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps, max_duration=float("inf")): """ Enhanced VAD processing with better end-of-speech detection """ @@ -1859,11 +1819,12 @@ def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps): vad_model, sampling_rate=16000, threshold=0.2, # Lower threshold to be more sensitive to speech + neg_threshold=0.001, # Negative threshold needs to be set explicitly due to a bug in Silero VAD min_speech_duration_ms=200, # Shorter to catch brief utterances - max_speech_duration_s=float("inf"), + max_speech_duration_s= max_duration, min_silence_duration_ms=300, # Shorter silence duration - window_size_samples=1024, # Smaller window for more precise detection - speech_pad_ms=300, # Add padding to end of speech segments + #window_size_samples=1024, # Smaller window for more precise detection # DEPRECATED: does nothing + speech_pad_ms=100, # Add padding to end of speech segments ) # Scale timestamps back to original sample rate @@ -1871,11 +1832,11 @@ def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps): for segment in vad_segments: segment["start"] = int(segment["start"] * scale_factor) # Add extra padding at the end - segment["end"] = int(segment["end"] * scale_factor) + \ - int(0.2 * sr) # Add 200ms padding + segment["end"] = int(segment["end"] * scale_factor) #+ \ + # int(0.2 * sr) # Add 200ms padding merged_segments = merge_short_segments( - vad_segments, min_duration=6.0, max_gap=0.5) + vad_segments, min_duration=6.0, sr=sr, max_gap=0.5) debug_print( f"VAD processing: {len(vad_segments)} original segments, {len(merged_segments)} after merging",