Skip to content

Commit

Permalink
First draft of improving the dataset generation
Browse files Browse the repository at this point in the history
  • Loading branch information
Yohrog committed Nov 24, 2024
1 parent bb55313 commit 06d7948
Showing 1 changed file with 45 additions and 65 deletions.
110 changes: 45 additions & 65 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,11 +1164,11 @@ def format_audio_list(
debug_print("Processing with VAD", "AUDIO")
# Get VAD segments with resampling
vad_segments = process_audio_with_vad(
wav, sr, vad_model, get_speech_timestamps)
wav, sr, vad_model, get_speech_timestamps, max_duration=max_duration)

# Group short segments that are close together
merged_segments = merge_short_segments(
vad_segments, min_duration, max_gap=0.3)
vad_segments, min_duration, sr=sr, max_gap=0.3)
debug_print(
f"Merged {len(vad_segments)-len(merged_segments)} short segments",
"SEGMENTS")
Expand Down Expand Up @@ -1515,7 +1515,7 @@ def _create_bpe_tokenizer(bpe_whisper_words, bpe_out_path, bpe_base_path):
raise


def merge_short_segments(segments, min_duration, max_gap=0.5):
def merge_short_segments(segments, min_duration, sr, max_gap=0.5):
"""
More aggressive merge strategy for short segments
- Increases max_gap to 0.5s (from 0.3s)
Expand All @@ -1527,7 +1527,7 @@ def merge_short_segments(segments, min_duration, max_gap=0.5):

merged = []
current_group = []
target_duration = (min_duration + 10.0) / 2 # Target middle of range
target_duration = sr * (min_duration + 10.0) / 2 # Target middle of range

for i, segment in enumerate(segments):
current_duration = sum(s["end"] - s["start"]
Expand Down Expand Up @@ -1566,16 +1566,15 @@ def merge_short_segments(segments, min_duration, max_gap=0.5):
merged.append(merged_segment)

debug_print(
f"Merged {len(segments) - len(merged)} segments into {len(merged)} segments with mid-range preference",
f"Merged {len(segments) - len(merged)} segments with mid-range preference, for a new total of {len(merged)}",
"SEGMENTS"
)
return merged


def extend_segment(wav, start, end, sr, min_duration, context_window=1.0):
def extend_segment(wav, start, end, sr, min_duration):
"""
Improved segment extension with better context handling
- Adds context_window parameter for smoother extensions
- More balanced extension on both sides
- Checks audio content when extending
"""
Expand All @@ -1585,22 +1584,24 @@ def extend_segment(wav, start, end, sr, min_duration, context_window=1.0):

samples_needed = int((min_duration - current_duration) * sr)

# Check if we have enough samples in the file
if (samples_needed + current_duration * sr) > wav.size(-1):
# Pad the file symmetrically with zeroes
padding_amount = (samples_needed + current_duration * sr - wav.size(-1))
padding = torch.zeros(padding_amount // 2 + 1, dtype=wav.dtype)
return torch.cat([padding, wav, padding], dim=-1)

# Try to extend equally on both sides
extend_left = samples_needed // 2
extend_right = samples_needed - extend_left
new_start = max(0, start - extend_left)

# Add some context window
context_samples = int(context_window * sr)
new_start = max(0, start - extend_left - context_samples)
new_end = min(wav.size(-1), end + extend_right + context_samples)
# If there weren't enough samples on the left, extend more on the right side
extend_right = samples_needed - (start - new_start)
new_end = min(wav.size(-1), end + extend_right)

# Check if we got enough duration
if (new_end - new_start) / sr < min_duration:
# If still too short, try to compensate from the other side
if new_start == 0:
new_end = min(wav.size(-1), end + samples_needed + context_samples)
elif new_end == wav.size(-1):
new_start = max(0, start - samples_needed - context_samples)
# If there weren't enough samples on the right, extend more on the left side
extend_left = samples_needed - (new_end - end)
new_start = max(0, start - extend_left)

debug_print(
f"Extended segment from {current_duration:.2f}s to {(new_end - new_start) / sr:.2f}s",
Expand Down Expand Up @@ -1706,6 +1707,7 @@ def save_audio_segment(
sas_sentence,
sas_audio_file_name_without_ext,
sas_segment_idx,
sas_sentence_idx,
sas_speaker_name,
sas_audio_folder,
sas_metadata,
Expand All @@ -1717,48 +1719,32 @@ def save_audio_segment(
"""Helper function to save audio segments and update metadata"""
sas_sentence = sas_sentence.strip()
sas_sentence = multilingual_cleaners(sas_sentence, sas_target_language)
sas_audio_file_name = f"{sas_audio_file_name_without_ext}_{str(sas_segment_idx).zfill(8)}.wav"
sas_audio_file_name = f"{sas_audio_file_name_without_ext}_{str(sas_segment_idx).zfill(8)}_{str(sas_sentence_idx).zfill(8)}.wav"

sas_absolute_path = os.path.join(sas_audio_folder, sas_audio_file_name)
os.makedirs(os.path.dirname(sas_absolute_path), exist_ok=True)

# Extract audio segment
sas_audio_start = int(sas_sr * sas_start_time)
sas_audio_end = int(sas_sr * sas_end_time)
sas_audio_end = int(sas_sr * (sas_end_time + _sas_buffer))
sas_audio_segment = sas_audio[sas_audio_start:sas_audio_end].unsqueeze(0)

# Handle long audio segments
if sas_audio_segment.size(-1) > sas_max_duration * sas_sr:
sas_too_long_files.append(
(sas_audio_file_name, sas_audio_segment.size(-1) / sas_sr))

while sas_audio_segment.size(-1) > sas_max_duration * sas_sr:
sas_split_audio = sas_audio_segment[:, : int(
sas_max_duration * sas_sr)]
sas_audio_segment = sas_audio_segment[:, int(
sas_max_duration * sas_sr):]
sas_split_file_name = f"{sas_audio_file_name_without_ext}_{str(sas_segment_idx).zfill(8)}.wav"
sas_split_relative_path = os.path.join(sas_split_file_name)
sas_split_absolute_path = os.path.normpath(
os.path.join(sas_audio_folder, sas_split_relative_path))

os.makedirs(
os.path.dirname(sas_split_absolute_path),
exist_ok=True)
torchaudio.save(sas_split_absolute_path, sas_split_audio, sas_sr)

sas_metadata["audio_file"].append(
f"wavs/{sas_split_relative_path}")
sas_metadata["text"].append(sas_sentence)
sas_metadata["speaker_name"].append(sas_speaker_name)
sas_segment_idx += 1

# Only save if segment is at least 1 second
if sas_audio_segment.size(-1) >= sas_sr:
torchaudio.save(sas_absolute_path, sas_audio_segment, sas_sr)
sas_metadata["audio_file"].append(f"wavs/{sas_audio_file_name}")
sas_metadata["text"].append(sas_sentence)
sas_metadata["speaker_name"].append(sas_speaker_name)
# Skip if audio is too short
if sas_audio_segment.size(-1) < sas_sr:
debug_print(
f"Skipping short audio segment: {sas_audio_file_name}",
"SEGMENTS")
return

torchaudio.save(sas_absolute_path, sas_audio_segment, sas_sr)
sas_metadata["audio_file"].append(f"wavs/{sas_audio_file_name}")
sas_metadata["text"].append(sas_sentence)
sas_metadata["speaker_name"].append(sas_speaker_name)


def process_transcription_result(
Expand All @@ -1780,6 +1766,7 @@ def process_transcription_result(
"""Helper function to process transcription results and save audio segments"""
ptr_i = ptr_segment_idx + 1
ptr_sentence = ""
ptr_sentence_idx = 0
ptr_sentence_start = None
ptr_first_word = True
ptr_current_words = []
Expand All @@ -1801,15 +1788,6 @@ def process_transcription_result(

if ptr_first_word:
ptr_sentence_start = ptr_start_time
if len(ptr_current_words) == 0:
ptr_sentence_start = max(
ptr_sentence_start - ptr_buffer, 0)
else:
ptr_previous_end = ptr_current_words[-1].get(
"end", 0) if ptr_current_words else 0
ptr_sentence_start = max(
ptr_sentence_start - ptr_buffer,
(ptr_previous_end + ptr_start_time) / 2)
ptr_sentence = ptr_word
ptr_first_word = False
else:
Expand All @@ -1829,6 +1807,7 @@ def process_transcription_result(
ptr_sentence,
ptr_audio_file_name_without_ext,
ptr_i,
ptr_sentence_idx,
ptr_speaker_name,
ptr_audio_folder,
ptr_metadata,
Expand All @@ -1837,13 +1816,13 @@ def process_transcription_result(
ptr_too_long_files,
ptr_target_language,
)
ptr_i += 1
ptr_sentence_idx += 1
ptr_first_word = True
ptr_current_words = []
ptr_sentence = ""


def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps):
def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps, max_duration=float("inf")):
"""
Enhanced VAD processing with better end-of-speech detection
"""
Expand All @@ -1859,23 +1838,24 @@ def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps):
vad_model,
sampling_rate=16000,
threshold=0.2, # Lower threshold to be more sensitive to speech
neg_threshold=0.001, # Negative threshold needs to be set explicitly due to a bug in Silero VAD
min_speech_duration_ms=200, # Shorter to catch brief utterances
max_speech_duration_s=float("inf"),
max_speech_duration_s= max_duration,
min_silence_duration_ms=300, # Shorter silence duration
window_size_samples=1024, # Smaller window for more precise detection
speech_pad_ms=300, # Add padding to end of speech segments
#window_size_samples=1024, # Smaller window for more precise detection # DEPRECATED: does nothing
speech_pad_ms=100, # Add padding to end of speech segments
)

# Scale timestamps back to original sample rate
scale_factor = sr / 16000
for segment in vad_segments:
segment["start"] = int(segment["start"] * scale_factor)
# Add extra padding at the end
segment["end"] = int(segment["end"] * scale_factor) + \
int(0.2 * sr) # Add 200ms padding
segment["end"] = int(segment["end"] * scale_factor) #+ \
# int(0.2 * sr) # Add 200ms padding

merged_segments = merge_short_segments(
vad_segments, min_duration=6.0, max_gap=0.5)
vad_segments, min_duration=6.0, sr=sr, max_gap=0.5)

debug_print(
f"VAD processing: {len(vad_segments)} original segments, {len(merged_segments)} after merging",
Expand Down

0 comments on commit 06d7948

Please sign in to comment.