This repository has been archived by the owner on Dec 7, 2023. It is now read-only.
forked from pytorch/audio
-
Notifications
You must be signed in to change notification settings - Fork 1
/
gtzan_test.py
120 lines (101 loc) · 4.25 KB
/
gtzan_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from pathlib import Path
from torchaudio.datasets import gtzan
from torchaudio_unittest.common_utils import get_whitenoise, normalize_wav, save_wav, TempDirMixin, TorchaudioTestCase
def get_mock_dataset(root_dir):
"""
root_dir: directory to the mocked dataset
"""
mocked_samples = []
mocked_training = []
mocked_validation = []
mocked_testing = []
sample_rate = 22050
seed = 0
for genre in gtzan.gtzan_genres:
base_dir = os.path.join(root_dir, "genres", genre)
os.makedirs(base_dir, exist_ok=True)
for i in range(100):
filename = f"{genre}.{i:05d}"
path = os.path.join(base_dir, f"{filename}.wav")
data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype="int16", seed=seed)
save_wav(path, data, sample_rate)
sample = (normalize_wav(data), sample_rate, genre)
mocked_samples.append(sample)
if filename in gtzan.filtered_test:
mocked_testing.append(sample)
if filename in gtzan.filtered_train:
mocked_training.append(sample)
if filename in gtzan.filtered_valid:
mocked_validation.append(sample)
seed += 1
return (mocked_samples, mocked_training, mocked_validation, mocked_testing)
class TestGTZAN(TempDirMixin, TorchaudioTestCase):
backend = "default"
root_dir = None
samples = []
training = []
validation = []
testing = []
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
mocked_data = get_mock_dataset(cls.root_dir)
cls.samples = mocked_data[0]
cls.training = mocked_data[1]
cls.validation = mocked_data[2]
cls.testing = mocked_data[3]
def test_no_subset(self):
dataset = gtzan.GTZAN(self.root_dir)
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[i][1]
assert label == self.samples[i][2]
n_ite += 1
assert n_ite == len(self.samples)
def _test_training(self, dataset):
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.training[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.training[i][1]
assert label == self.training[i][2]
n_ite += 1
assert n_ite == len(self.training)
def _test_validation(self, dataset):
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.validation[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.validation[i][1]
assert label == self.validation[i][2]
n_ite += 1
assert n_ite == len(self.validation)
def _test_testing(self, dataset):
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.testing[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.testing[i][1]
assert label == self.testing[i][2]
n_ite += 1
assert n_ite == len(self.testing)
def test_training_str(self):
train_dataset = gtzan.GTZAN(self.root_dir, subset="training")
self._test_training(train_dataset)
def test_validation_str(self):
val_dataset = gtzan.GTZAN(self.root_dir, subset="validation")
self._test_validation(val_dataset)
def test_testing_str(self):
test_dataset = gtzan.GTZAN(self.root_dir, subset="testing")
self._test_testing(test_dataset)
def test_training_path(self):
root_dir = Path(self.root_dir)
train_dataset = gtzan.GTZAN(root_dir, subset="training")
self._test_training(train_dataset)
def test_validation_path(self):
root_dir = Path(self.root_dir)
val_dataset = gtzan.GTZAN(root_dir, subset="validation")
self._test_validation(val_dataset)
def test_testing_path(self):
root_dir = Path(self.root_dir)
test_dataset = gtzan.GTZAN(root_dir, subset="testing")
self._test_testing(test_dataset)