This repository has been archived by the owner on Dec 7, 2023. It is now read-only.
forked from pytorch/audio
-
Notifications
You must be signed in to change notification settings - Fork 1
/
ljspeech_test.py
77 lines (64 loc) · 2.87 KB
/
ljspeech_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import csv
import os
from pathlib import Path
from torchaudio.datasets import ljspeech
from torchaudio_unittest.common_utils import get_whitenoise, normalize_wav, save_wav, TempDirMixin, TorchaudioTestCase
_TRANSCRIPTS = [
"Test transcript 1",
"Test transcript 2",
"Test transcript 3",
"In 1465 Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome,",
]
_NORMALIZED_TRANSCRIPT = [
"Test transcript one",
"Test transcript two",
"Test transcript three",
"In fourteen sixty-five Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome,",
]
def get_mock_dataset(root_dir):
"""
root_dir: path to the mocked dataset
"""
mocked_data = []
base_dir = os.path.join(root_dir, "LJSpeech-1.1")
archive_dir = os.path.join(base_dir, "wavs")
os.makedirs(archive_dir, exist_ok=True)
metadata_path = os.path.join(base_dir, "metadata.csv")
sample_rate = 22050
with open(metadata_path, mode="w", newline="") as metadata_file:
metadata_writer = csv.writer(metadata_file, delimiter="|", quoting=csv.QUOTE_NONE)
for i, (transcript, normalized_transcript) in enumerate(zip(_TRANSCRIPTS, _NORMALIZED_TRANSCRIPT)):
fileid = f"LJ001-{i:04d}"
metadata_writer.writerow([fileid, transcript, normalized_transcript])
filename = fileid + ".wav"
path = os.path.join(archive_dir, filename)
data = get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=1, dtype="int16", seed=i)
save_wav(path, data, sample_rate)
mocked_data.append(normalize_wav(data))
return mocked_data, _TRANSCRIPTS, _NORMALIZED_TRANSCRIPT
class TestLJSpeech(TempDirMixin, TorchaudioTestCase):
backend = "default"
root_dir = None
data, _transcripts, _normalized_transcript = [], [], []
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
cls.data, cls._transcripts, cls._normalized_transcript = get_mock_dataset(cls.root_dir)
def _test_ljspeech(self, dataset):
n_ite = 0
for i, (waveform, sample_rate, transcript, normalized_transcript) in enumerate(dataset):
expected_transcript = self._transcripts[i]
expected_normalized_transcript = self._normalized_transcript[i]
expected_data = self.data[i]
self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
assert sample_rate == sample_rate
assert transcript == expected_transcript
assert normalized_transcript == expected_normalized_transcript
n_ite += 1
assert n_ite == len(self.data)
def test_ljspeech_str(self):
dataset = ljspeech.LJSPEECH(self.root_dir)
self._test_ljspeech(dataset)
def test_ljspeech_path(self):
dataset = ljspeech.LJSPEECH(Path(self.root_dir))
self._test_ljspeech(dataset)