This repository has been archived by the owner on Dec 7, 2023. It is now read-only.
forked from pytorch/audio
-
Notifications
You must be signed in to change notification settings - Fork 1
/
cmuarctic_test.py
77 lines (65 loc) · 2.51 KB
/
cmuarctic_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
from pathlib import Path
from torchaudio.datasets import cmuarctic
from torchaudio_unittest.common_utils import get_whitenoise, normalize_wav, save_wav, TempDirMixin, TorchaudioTestCase
def get_mock_dataset(root_dir):
"""
root_dir: directory to the mocked dataset
"""
mocked_data = []
sample_rate = 16000
transcript = "This is a test transcript."
base_dir = os.path.join(root_dir, "ARCTIC", "cmu_us_aew_arctic")
txt_dir = os.path.join(base_dir, "etc")
os.makedirs(txt_dir, exist_ok=True)
txt_file = os.path.join(txt_dir, "txt.done.data")
audio_dir = os.path.join(base_dir, "wav")
os.makedirs(audio_dir, exist_ok=True)
seed = 42
with open(txt_file, "w") as txt:
for c in ["a", "b"]:
for i in range(5):
utterance_id = f"arctic_{c}{i:04d}"
path = os.path.join(audio_dir, f"{utterance_id}.wav")
data = get_whitenoise(
sample_rate=sample_rate,
duration=3,
n_channels=1,
dtype="int16",
seed=seed,
)
save_wav(path, data, sample_rate)
sample = (
normalize_wav(data),
sample_rate,
transcript,
utterance_id.split("_")[1],
)
mocked_data.append(sample)
txt.write(f'( {utterance_id} "{transcript}" )\n')
seed += 1
return mocked_data
class TestCMUARCTIC(TempDirMixin, TorchaudioTestCase):
backend = "default"
root_dir = None
samples = []
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
cls.samples = get_mock_dataset(cls.root_dir)
def _test_cmuarctic(self, dataset):
n_ite = 0
for i, (waveform, sample_rate, transcript, utterance_id) in enumerate(dataset):
expected_sample = self.samples[i]
assert sample_rate == expected_sample[1]
assert transcript == expected_sample[2]
assert utterance_id == expected_sample[3]
self.assertEqual(expected_sample[0], waveform, atol=5e-5, rtol=1e-8)
n_ite += 1
assert n_ite == len(self.samples)
def test_cmuarctic_str(self):
dataset = cmuarctic.CMUARCTIC(self.root_dir)
self._test_cmuarctic(dataset)
def test_cmuarctic_path(self):
dataset = cmuarctic.CMUARCTIC(Path(self.root_dir))
self._test_cmuarctic(dataset)