Skip to content

Commit

Permalink
added mock testing and local model
Browse files Browse the repository at this point in the history
  • Loading branch information
jim-gyas committed Jan 3, 2025
1 parent 0b51e2a commit dfca747
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 32 deletions.
52 changes: 45 additions & 7 deletions src/stt_data_with_llm/audio_parser.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import io
import json
import logging
import os

import librosa
import requests
import torchaudio
from dotenv import load_dotenv
from pyannote.audio import Pipeline
from pydub import AudioSegment

Expand All @@ -13,10 +15,13 @@
AUDIO_SEG_UPPER_LIMIT,
HEADERS,
HYPER_PARAMETERS,
USE_AUTH_TOKEN,
)
from stt_data_with_llm.util import setup_logging

# load the evnironment variable
load_dotenv()

USE_AUTH_TOKEN = os.getenv("use_auth_token")
# Call the setup_logging function at the beginning of your script
setup_logging("audio_parser.log")

Expand Down Expand Up @@ -62,15 +67,21 @@ def sec_to_frame(sec, sr):
def initialize_vad_pipeline():
"""
Initializes the Voice Activity Detection (VAD) pipeline using Pyannote.
Returns:
Pipeline: Initialized VAD pipeline
"""
logging.info("Initializing Voice Activity Detection pipeline...")
vad_pipeline = Pipeline.from_pretrained(
"pyannote/voice-activity-detection",
use_auth_token=USE_AUTH_TOKEN,
)
try:
vad_pipeline = Pipeline.from_pretrained(
"pyannote/voice-activity-detection",
use_auth_token=USE_AUTH_TOKEN,
)
except Exception as e:
logging.warning(f"Failed to load online model: {e}. Using local model.")
vad_pipeline = Pipeline.from_pretrained(
"tests/data/pyannote_vad_model",
use_auth_token=False,
)
vad_pipeline.instantiate(HYPER_PARAMETERS)
logging.info("VAD pipeline initialized successfully.")
return vad_pipeline
Expand Down Expand Up @@ -287,6 +298,26 @@ def process_non_mute_segments(
return counter


def generate_vad_output(audio_file, output_json):
"""Generate VAD output for a given audio file and save it to a JSON file.
Args:
audio_file (_type_): _description_
output_json (_type_): _description_
"""
pipeline = initialize_vad_pipeline()
vad = pipeline(audio_file)
vad_output = {
"timeline": [
{"start": segment.start, "end": segment.end}
for segment in vad.get_timeline().support()
]
}

with open(output_json, "w", encoding="utf-8") as file:
json.dump(vad_output, file, ensure_ascii=False, indent=2)


def get_split_audio(
audio_data,
full_audio_id,
Expand Down Expand Up @@ -315,10 +346,17 @@ def get_split_audio(

if not os.path.exists(output_folder):
os.makedirs(output_folder)

vad_output_folder = "tests/data/vad_output"
if not os.path.exists(vad_output_folder):
os.makedirs(vad_output_folder)

# initialize vad pipeline
pipeline = initialize_vad_pipeline()
vad = pipeline(temp_audio_file)

generate_vad_output(
temp_audio_file, f"{vad_output_folder}/{full_audio_id}_vad_output.json"
)
original_audio_segment = AudioSegment.from_file(temp_audio_file)
original_audio_ndarray, sampling_rate = torchaudio.load(temp_audio_file)
original_audio_ndarray = original_audio_ndarray[0]
Expand Down
2 changes: 0 additions & 2 deletions src/stt_data_with_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,3 @@
"upgrade-insecure-requests": "1",
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36 Edg/129.0.0.0", # noqa: E501
}

USE_AUTH_TOKEN = "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"
106 changes: 83 additions & 23 deletions tests/test_audio_parser.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,92 @@
import json
import logging
from unittest import TestCase, mock

from stt_data_with_llm.audio_parser import get_audio, get_split_audio
from stt_data_with_llm.config import AUDIO_SEG_LOWER_LIMIT, AUDIO_SEG_UPPER_LIMIT


def test_get_split_audio():
"""
Test function for the get_split_audio functionality.
"""
audio_urls = {
"NW_001": "https://www.rfa.org/tibetan/sargyur/golok-china-religious-restriction-08202024054225.html/@@stream", # noqa
"NW_002": "https://vot.org/wp-content/uploads/2024/03/tc88888888888888.mp3",
"NW_003": "https://voa-audio-ns.akamaized.net/vti/2024/04/13/01000000-0aff-0242-a7bb-08dc5bc45613.mp3",
}
num_of_seg_in_audios = {}
for seg_id, audio_url in audio_urls.items():

audio_data = get_audio(audio_url)
split_audio_data = get_split_audio(
audio_data, seg_id, AUDIO_SEG_LOWER_LIMIT, AUDIO_SEG_UPPER_LIMIT
)
num_split = len(split_audio_data)
num_of_seg_in_audios[seg_id] = num_split
expected_num_of_seg_in_audios = "tests/data/expected_audio_data.json"
with open(expected_num_of_seg_in_audios, encoding="utf-8") as file:
expected_num_split = json.load(file)
assert num_of_seg_in_audios == expected_num_split
class TestGetSplitAudio(TestCase):
@mock.patch("stt_data_with_llm.audio_parser.initialize_vad_pipeline")
@mock.patch("stt_data_with_llm.audio_parser.Pipeline")
def test_get_split_audio(self, mock_pipeline, mock_initialize_vad):
"""
Test function for the get_split_audio functionality.
"""
# Define mock VAD outputs for each audio file
vad_outputs = {
"NW_001": "tests/data/vad_output/NW_001_vad_output.json",
"NW_002": "tests/data/vad_output/NW_002_vad_output.json",
"NW_003": "tests/data/vad_output/NW_003_vad_output.json",
}
# Load all VAD outputs dynamically
mock_vad_results = {}
for seg_id, vad_path in vad_outputs.items():
with open(vad_path, encoding="utf-8") as file:
mock_vad_results[seg_id] = json.load(file)

class MockVADPipeline:
def __init__(self, seg_id):
self.seg_id = seg_id

def __call__(self, audio_file):
return MockVADResult(self.seg_id)

class MockVADResult:
def __init__(self, seg_id):
self.vad_output = mock_vad_results[seg_id]

def get_timeline(self):
class MockTimeline:
def __init__(self, timeline):
self.timeline = timeline

def support(self):
return [
type(
"Segment",
(),
{"start": seg["start"], "end": seg["end"]},
)
for seg in self.timeline
]

return MockTimeline(self.vad_output["timeline"])

# Setup mock behavior
def mock_initialize_pipeline(seg_id):
try:
return MockVADPipeline(seg_id)
except Exception as e:
logging.warning(
f"Mocking failed: {e}. Falling back to actual function."
)
return None

audio_urls = {
"NW_001": "https://www.rfa.org/tibetan/sargyur/golok-china-religious-restriction-08202024054225.html/@@stream", # noqa
"NW_002": "https://vot.org/wp-content/uploads/2024/03/tc88888888888888.mp3",
"NW_003": "https://voa-audio-ns.akamaized.net/vti/2024/04/13/01000000-0aff-0242-a7bb-08dc5bc45613.mp3",
}
num_of_seg_in_audios = {}
for seg_id, audio_url in audio_urls.items():
mock_pipeline = mock_initialize_pipeline(seg_id)
if mock_pipeline:
mock_initialize_vad.return_value = mock_pipeline
else:
mock_initialize_vad.side_effect = None # Disable the mock for fallback

audio_data = get_audio(audio_url)
split_audio_data = get_split_audio(
audio_data, seg_id, AUDIO_SEG_LOWER_LIMIT, AUDIO_SEG_UPPER_LIMIT
)
num_split = len(split_audio_data)
num_of_seg_in_audios[seg_id] = num_split
expected_num_of_seg_in_audios = "tests/data/expected_audio_data.json"
with open(expected_num_of_seg_in_audios, encoding="utf-8") as file:
expected_num_split = json.load(file)
assert num_of_seg_in_audios == expected_num_split


if __name__ == "__main__":
test_get_split_audio()
TestGetSplitAudio().test_get_split_audio()

0 comments on commit dfca747

Please sign in to comment.