diff --git a/finetune.py b/finetune.py index 3eb3217..12e6306 100644 --- a/finetune.py +++ b/finetune.py @@ -1164,11 +1164,11 @@ def format_audio_list( debug_print("Processing with VAD", "AUDIO") # Get VAD segments with resampling vad_segments = process_audio_with_vad( - wav, sr, vad_model, get_speech_timestamps) + wav, sr, vad_model, get_speech_timestamps, max_duration=max_duration) # Group short segments that are close together merged_segments = merge_short_segments( - vad_segments, min_duration, max_gap=0.3) + vad_segments, min_duration, sr=sr, max_gap=0.3) debug_print( f"Merged {len(vad_segments)-len(merged_segments)} short segments", "SEGMENTS") @@ -1515,7 +1515,7 @@ def _create_bpe_tokenizer(bpe_whisper_words, bpe_out_path, bpe_base_path): raise -def merge_short_segments(segments, min_duration, max_gap=0.5): +def merge_short_segments(segments, min_duration, sr, max_gap=0.5): """ More aggressive merge strategy for short segments - Increases max_gap to 0.5s (from 0.3s) @@ -1527,7 +1527,7 @@ def merge_short_segments(segments, min_duration, max_gap=0.5): merged = [] current_group = [] - target_duration = (min_duration + 10.0) / 2 # Target middle of range + target_duration = sr * (min_duration + 10.0) / 2 # Target middle of range for i, segment in enumerate(segments): current_duration = sum(s["end"] - s["start"] @@ -1566,16 +1566,15 @@ def merge_short_segments(segments, min_duration, max_gap=0.5): merged.append(merged_segment) debug_print( - f"Merged {len(segments) - len(merged)} segments into {len(merged)} segments with mid-range preference", + f"Merged {len(segments) - len(merged)} segments with mid-range preference, for a new total of {len(merged)}", "SEGMENTS" ) return merged -def extend_segment(wav, start, end, sr, min_duration, context_window=1.0): +def extend_segment(wav, start, end, sr, min_duration): """ Improved segment extension with better context handling - - Adds context_window parameter for smoother extensions - More balanced extension on both sides - Checks audio content when extending """ @@ -1585,22 +1584,24 @@ def extend_segment(wav, start, end, sr, min_duration, context_window=1.0): samples_needed = int((min_duration - current_duration) * sr) + # Check if we have enough samples in the file + if (samples_needed + current_duration * sr) > wav.size(-1): + # Pad the file symmetrically with zeroes + padding_amount = (samples_needed + current_duration * sr - wav.size(-1)) + padding = torch.zeros(padding_amount // 2 + 1, dtype=wav.dtype) + return torch.cat([padding, wav, padding], dim=-1) + # Try to extend equally on both sides extend_left = samples_needed // 2 - extend_right = samples_needed - extend_left + new_start = max(0, start - extend_left) - # Add some context window - context_samples = int(context_window * sr) - new_start = max(0, start - extend_left - context_samples) - new_end = min(wav.size(-1), end + extend_right + context_samples) + # If there weren't enough samples on the left, extend more on the right side + extend_right = samples_needed - (start - new_start) + new_end = min(wav.size(-1), end + extend_right) - # Check if we got enough duration - if (new_end - new_start) / sr < min_duration: - # If still too short, try to compensate from the other side - if new_start == 0: - new_end = min(wav.size(-1), end + samples_needed + context_samples) - elif new_end == wav.size(-1): - new_start = max(0, start - samples_needed - context_samples) + # If there weren't enough samples on the right, extend more on the left side + extend_left = samples_needed - (new_end - end) + new_start = max(0, start - extend_left) debug_print( f"Extended segment from {current_duration:.2f}s to {(new_end - new_start) / sr:.2f}s", @@ -1706,6 +1707,7 @@ def save_audio_segment( sas_sentence, sas_audio_file_name_without_ext, sas_segment_idx, + sas_sentence_idx, sas_speaker_name, sas_audio_folder, sas_metadata, @@ -1717,14 +1719,14 @@ def save_audio_segment( """Helper function to save audio segments and update metadata""" sas_sentence = sas_sentence.strip() sas_sentence = multilingual_cleaners(sas_sentence, sas_target_language) - sas_audio_file_name = f"{sas_audio_file_name_without_ext}_{str(sas_segment_idx).zfill(8)}.wav" + sas_audio_file_name = f"{sas_audio_file_name_without_ext}_{str(sas_segment_idx).zfill(8)}_{str(sas_sentence_idx).zfill(8)}.wav" sas_absolute_path = os.path.join(sas_audio_folder, sas_audio_file_name) os.makedirs(os.path.dirname(sas_absolute_path), exist_ok=True) # Extract audio segment sas_audio_start = int(sas_sr * sas_start_time) - sas_audio_end = int(sas_sr * sas_end_time) + sas_audio_end = int(sas_sr * (sas_end_time + _sas_buffer)) sas_audio_segment = sas_audio[sas_audio_start:sas_audio_end].unsqueeze(0) # Handle long audio segments @@ -1732,33 +1734,17 @@ def save_audio_segment( sas_too_long_files.append( (sas_audio_file_name, sas_audio_segment.size(-1) / sas_sr)) - while sas_audio_segment.size(-1) > sas_max_duration * sas_sr: - sas_split_audio = sas_audio_segment[:, : int( - sas_max_duration * sas_sr)] - sas_audio_segment = sas_audio_segment[:, int( - sas_max_duration * sas_sr):] - sas_split_file_name = f"{sas_audio_file_name_without_ext}_{str(sas_segment_idx).zfill(8)}.wav" - sas_split_relative_path = os.path.join(sas_split_file_name) - sas_split_absolute_path = os.path.normpath( - os.path.join(sas_audio_folder, sas_split_relative_path)) - - os.makedirs( - os.path.dirname(sas_split_absolute_path), - exist_ok=True) - torchaudio.save(sas_split_absolute_path, sas_split_audio, sas_sr) - - sas_metadata["audio_file"].append( - f"wavs/{sas_split_relative_path}") - sas_metadata["text"].append(sas_sentence) - sas_metadata["speaker_name"].append(sas_speaker_name) - sas_segment_idx += 1 - - # Only save if segment is at least 1 second - if sas_audio_segment.size(-1) >= sas_sr: - torchaudio.save(sas_absolute_path, sas_audio_segment, sas_sr) - sas_metadata["audio_file"].append(f"wavs/{sas_audio_file_name}") - sas_metadata["text"].append(sas_sentence) - sas_metadata["speaker_name"].append(sas_speaker_name) + # Skip if audio is too short + if sas_audio_segment.size(-1) < sas_sr: + debug_print( + f"Skipping short audio segment: {sas_audio_file_name}", + "SEGMENTS") + return + + torchaudio.save(sas_absolute_path, sas_audio_segment, sas_sr) + sas_metadata["audio_file"].append(f"wavs/{sas_audio_file_name}") + sas_metadata["text"].append(sas_sentence) + sas_metadata["speaker_name"].append(sas_speaker_name) def process_transcription_result( @@ -1780,6 +1766,7 @@ def process_transcription_result( """Helper function to process transcription results and save audio segments""" ptr_i = ptr_segment_idx + 1 ptr_sentence = "" + ptr_sentence_idx = 0 ptr_sentence_start = None ptr_first_word = True ptr_current_words = [] @@ -1801,15 +1788,6 @@ def process_transcription_result( if ptr_first_word: ptr_sentence_start = ptr_start_time - if len(ptr_current_words) == 0: - ptr_sentence_start = max( - ptr_sentence_start - ptr_buffer, 0) - else: - ptr_previous_end = ptr_current_words[-1].get( - "end", 0) if ptr_current_words else 0 - ptr_sentence_start = max( - ptr_sentence_start - ptr_buffer, - (ptr_previous_end + ptr_start_time) / 2) ptr_sentence = ptr_word ptr_first_word = False else: @@ -1829,6 +1807,7 @@ def process_transcription_result( ptr_sentence, ptr_audio_file_name_without_ext, ptr_i, + ptr_sentence_idx, ptr_speaker_name, ptr_audio_folder, ptr_metadata, @@ -1837,13 +1816,13 @@ def process_transcription_result( ptr_too_long_files, ptr_target_language, ) - ptr_i += 1 + ptr_sentence_idx += 1 ptr_first_word = True ptr_current_words = [] ptr_sentence = "" -def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps): +def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps, max_duration=float("inf")): """ Enhanced VAD processing with better end-of-speech detection """ @@ -1859,11 +1838,12 @@ def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps): vad_model, sampling_rate=16000, threshold=0.2, # Lower threshold to be more sensitive to speech + neg_threshold=0.001, # Negative threshold needs to be set explicitly due to a bug in Silero VAD min_speech_duration_ms=200, # Shorter to catch brief utterances - max_speech_duration_s=float("inf"), + max_speech_duration_s= max_duration, min_silence_duration_ms=300, # Shorter silence duration - window_size_samples=1024, # Smaller window for more precise detection - speech_pad_ms=300, # Add padding to end of speech segments + #window_size_samples=1024, # Smaller window for more precise detection # DEPRECATED: does nothing + speech_pad_ms=100, # Add padding to end of speech segments ) # Scale timestamps back to original sample rate @@ -1871,11 +1851,11 @@ def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps): for segment in vad_segments: segment["start"] = int(segment["start"] * scale_factor) # Add extra padding at the end - segment["end"] = int(segment["end"] * scale_factor) + \ - int(0.2 * sr) # Add 200ms padding + segment["end"] = int(segment["end"] * scale_factor) #+ \ + # int(0.2 * sr) # Add 200ms padding merged_segments = merge_short_segments( - vad_segments, min_duration=6.0, max_gap=0.5) + vad_segments, min_duration=6.0, sr=sr, max_gap=0.5) debug_print( f"VAD processing: {len(vad_segments)} original segments, {len(merged_segments)} after merging",