Skip to content

Commit

Permalink
Merge pull request #10 from mmcauliffe/autovot-support
Browse files Browse the repository at this point in the history
Autovot support
  • Loading branch information
mmcauliffe authored Nov 15, 2018
2 parents 6e8c5dc + 91992e2 commit 3664ad8
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 2 deletions.
87 changes: 87 additions & 0 deletions conch/analysis/autovot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from .functions import BaseAnalysisFunction
import wave
import subprocess
import textgrid
import os
import tempfile


def is_autovot_friendly_file(sound_file):
rate = subprocess.run(["soxi", "-r", sound_file], encoding="UTF-8", stdout=subprocess.PIPE).stdout
if int(rate) != 16000:
return False

channels = subprocess.run(["soxi", "-c", sound_file], encoding="UTF-8", stdout=subprocess.PIPE).stdout
if int(channels) != 1:
return False
return True

def resample_for_autovot(soundfile, tmpdir):
output_file = os.path.join(tmpdir, "sound_file.wav")
subprocess.call(["sox", soundfile, "-c", "1", "-r", "16000", output_file])
return output_file


class MeasureVOTPretrained(object):
def __init__(self, classifier_to_use=None, min_vot_length=15, max_vot_length=250, window_max=30, window_min=30, debug=False):
if classifier_to_use is None:
raise ValueError("There must be a classifier to run AutoVOT")
else:
self.classifier_to_use = classifier_to_use
self.min_vot_length = min_vot_length
self.max_vot_length = max_vot_length
self.debug = debug
self.window_max = window_max
self.window_min = window_min

def __call__(self, segment):
file_path = os.path.expanduser(segment["file_path"])
begin = segment["begin"]
end = segment["end"]
vot_marks = sorted(segment["vot_marks"], key=lambda x: x[0])
grid = textgrid.TextGrid(maxTime=end)
vot_tier = textgrid.IntervalTier(name='vot', maxTime=end)
for vot_begin, vot_end, *extra_data in vot_marks:
vot_tier.add(vot_begin, vot_end, 'vot')
grid.append(vot_tier)
with tempfile.TemporaryDirectory() as tmpdirname:
grid_path = "{}/file.TextGrid".format(tmpdirname)
csv_path = "{}/file.csv".format(tmpdirname)
wav_filenames = "{}/wavs.txt".format(tmpdirname)
textgrid_filenames = "{}/textgrids.txt".format(tmpdirname)

if not is_autovot_friendly_file(file_path):
file_path = resample_for_autovot(file_path, tmpdirname)

with open(wav_filenames, 'w') as f:
f.write("{}\n".format(file_path))

with open(textgrid_filenames, 'w') as f:
f.write("{}\n".format(grid_path))

grid.write(grid_path)

if self.debug:
grid.write('/tmp/textgrid_from_conch.csv')
with open('/tmp/alt_wordlist.txt', 'w') as f:
f.write("{}\n".format('/tmp/textgrid_from_conch.csv'))
subprocess.run(["auto_vot_decode.py", wav_filenames, '/tmp/alt_wordlist.txt', self.classifier_to_use, '--vot_tier', 'vot', '--vot_mark', 'vot', "--min_vot_length", str(self.min_vot_length), "--max_vot_length", str(self.max_vot_length), "--window_max", str(self.window_max), "--window_min", str(self.window_min)])
subprocess.run(["auto_vot_decode.py", wav_filenames, textgrid_filenames, self.classifier_to_use, '--vot_tier', 'vot', '--vot_mark', 'vot', '--csv_file', csv_path, "--min_vot_length", str(self.min_vot_length), "--max_vot_length", str(self.max_vot_length), "--window_max", str(self.window_max), "--window_min", str(self.window_min)])

return_list = []
with open(csv_path, "r") as f:
f.readline()
for l, (b, e, *extra_data) in zip(f, vot_marks):
_, time, vot, confidence = l.split(',')
if "neg 0\n" == confidence:
confidence = 0
return_list.append((float(time), float(vot), float(confidence), *extra_data))
return return_list

class AutoVOTAnalysisFunction(BaseAnalysisFunction):
def __init__(self, classifier_to_use=None, min_vot_length=15, max_vot_length=250, window_max=30, window_min=30, debug=False, arguments=None):
super(AutoVOTAnalysisFunction, self).__init__()
self._function = MeasureVOTPretrained(classifier_to_use=classifier_to_use, min_vot_length=min_vot_length, max_vot_length=max_vot_length, window_max=window_max, window_min=window_min, debug=debug)
self.requires_file = True
self.uses_segments = True
self.requires_segment_as_arg = True
3 changes: 3 additions & 0 deletions conch/analysis/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class BaseAnalysisFunction(object):
def __init__(self):
self._function = print
self.requires_file = False
self.requires_segment_as_arg = False
self.uses_segments = False
self.arguments = []

Expand All @@ -39,6 +40,8 @@ def __call__(self, segment):
elif isinstance(segment, str) and not self.requires_file:
signal, sr = librosa.load(safe_path(segment))
return self._function(signal, sr, *self.arguments)
elif isinstance(segment, FileSegment) and self.requires_segment_as_arg:
return self._function(segment, *self.arguments)
elif isinstance(segment, FileSegment) and self.requires_file and not self.uses_segments:
beg, end = segment.begin, segment.end
padding = segment['padding']
Expand Down
2 changes: 2 additions & 0 deletions conch/analysis/segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __eq__(self, other):
return False
if self.channel != other.channel:
return False
if self.properties != other.properties:
return False
return True

def __lt__(self, other):
Expand Down
21 changes: 19 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def praat_script_test_dir(test_dir):
def soundfiles_dir(test_dir):
return os.path.join(test_dir, 'soundfiles')

@pytest.fixture(scope='session')
def autovot_dir(test_dir):
return os.path.join(test_dir, 'autovot')

@pytest.fixture(scope='session')
def tts_dir(test_dir):
Expand All @@ -62,7 +65,6 @@ def y_path(soundfiles_dir):
def acoustic_corpus_path(soundfiles_dir):
return os.path.join(soundfiles_dir, 'acoustic_corpus.wav')


@pytest.fixture(scope='session')
def call_back():
def function(*args):
Expand All @@ -87,6 +89,22 @@ def base_filenames(soundfiles_dir):
if x.endswith('.wav')]
return filenames

@pytest.fixture(scope='session')
def autovot_markings(test_dir):
vot_markings = []
with open(os.path.join(test_dir, "vot_marks"), "r") as f:
for x in f:
vots = x.split(' ')
vot_markings.append((float(vots[0]), float(vots[1])))
return vot_markings

@pytest.fixture(scope='session')
def classifier_path(test_dir):
return os.path.join(test_dir, "vot_model", "sotc_voiceless.classifier")

@pytest.fixture(scope='session')
def autovot_correct_times():
return [(1.593, 0.056, 180.344), (1.828, 0.008, 126.073), (1.909, 0.071, 90.8671), (2.041, 0.005, 45.6481), (2.687, 0.016, 212.67), (2.859, 0.005, 22.646), (2.951, 0.005, 78.2495), (3.351, 0.052, 84.7406), (5.574, 0.02, 96.0191), (6.212, 0.01, 72.1773), (6.736, 0.02, 114.721), (7.02, 0.029, 224.901), (9.255, 0.032, 123.367), (9.498, 0.017, 92.7151), (11.424, 0.056, 85.1062), (13.144, 0.012, 191.111), (13.55, 0.012, 59.8446), (25.125, 0.014, 165.632)]

@pytest.fixture(scope='session')
def praatpath():
Expand Down Expand Up @@ -114,7 +132,6 @@ def formants_func():
window_length=0.025)
return func


@pytest.fixture(scope='session')
def pitch_func():
func = PitchTrackFunction(min_pitch=50, max_pitch=500, time_step=0.01)
Expand Down
18 changes: 18 additions & 0 deletions tests/data/vot_marks
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
1.50346 1.65870
1.85687 1.90566
1.90566 1.98664
2.06736 2.14425
2.65964 2.70424
2.78255 2.85937
2.93352 2.95891
3.32791 3.38926
5.53054 5.60142
6.18190 6.23417
6.76380 6.79755
6.89678 7.06540
9.21705 9.32077
9.44277 9.50559
11.44430 11.50088
13.10701 13.22000
13.51863 13.55111
25.09728 25.14633
1 change: 1 addition & 0 deletions tests/data/vot_model/sotc_voiceless.classifier.neg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
59 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
1 change: 1 addition & 0 deletions tests/data/vot_model/sotc_voiceless.classifier.pos
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
77 -5.19517 0.974839 8.37886 4.43399 -0.291111 -1.03706 -0.851605 -2.99899 -0.988911 0.238412 -6.02394 7.69937 1.91178 12.1159 17.234 3.06435 -4.58158 -6.52092 0.704001 -9.8874 -4.83043 -1.23498 4.27573 -2.88435 -4.91017 -3.34528 -1.34791 11.5056 28.1743 1.4589 0.965129 -11.9287 0.431235 0.497913 -4.02346 -7.29419 -1.02066 12.9648 0.510423 0.892902 -2.05145 -1.73262 -8.13197 14.6691 5.57997 0.225761 5.1901 -7.44362 21.9976 0.278337 1.93388 -36.4286 10.1835 -3.78564 -4.12507 -17.2654 -1.25228 -18.5517 -51.1295 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
20 changes: 20 additions & 0 deletions tests/test_analysis_autovot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from conch.analysis.autovot import AutoVOTAnalysisFunction
import librosa
from statistics import mean
import wave
import pytest
from conch.analysis.segments import SegmentMapping
from conch import analyze_segments


def test_autovot(acoustic_corpus_path, autovot_markings, classifier_path, autovot_correct_times):
mapping = SegmentMapping()
with wave.open(acoustic_corpus_path, 'r') as f:
length = f.getnframes() / float(f.getframerate())
mapping.add_file_segment(acoustic_corpus_path, 0, length, channel=0, vot_marks=autovot_markings)
func = AutoVOTAnalysisFunction(classifier_to_use=classifier_path, window_min=-30, window_max=30, min_vot_length=5, max_vot_length=100)
output = analyze_segments(mapping, func, multiprocessing=False)
output = output[mapping[0]]
for o, truth in zip(output, autovot_correct_times):
assert o == truth

0 comments on commit 3664ad8

Please sign in to comment.