forked from Gastron/sb-fin-parl-2015-2020-kevat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test-xent.py
69 lines (61 loc) · 2.5 KB
/
test-xent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env/python3
"""Finnish Parliament ASR"""
import os
import sys
import torch
import logging
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
import kaldi_io
import tqdm
from types import SimpleNamespace
def setup(hparams, run_opts):
""" Kind of mimics what Brain does """
if "device" in run_opts:
device = run_opts["device"]
elif "device" in hparams:
device = hparams["device"]
else:
device = "cpu"
print("Device is:", device)
if "cuda" in device:
torch.cuda.set_device(int(device[-1]))
modules = torch.nn.ModuleDict(hparams["modules"]).to(device)
hparams = SimpleNamespace(**hparams)
if hasattr(hparams, "checkpointer"):
if hasattr(hparams, "test_max_key"):
ckpt = hparams.checkpointer.find_checkpoint(max_key=hparams.test_max_key)
elif hasattr(hparams, "test_min_key"):
ckpt = hparams.checkpointer.find_checkpoint(min_key=hparams.test_min_key)
else:
ckpt = hparams.checkpointer.find_checkpoint()
hparams.checkpointer.load_checkpoint(ckpt)
epoch = hparams.epoch_counter.current
print("Loaded checkpoint from epoch", epoch, "at path", ckpt.path)
return modules, hparams, device
def count_scp_lines(scpfile):
lines = 0
with open(scpfile) as fin:
for _ in fin:
lines += 1
return lines
def run_test(modules, hparams, device):
prior = torch.load(hparams.prior_file).to(device)
num_utts = count_scp_lines(hparams.test_feats)
with open(hparams.test_probs_out, 'wb') as fo:
with torch.no_grad():
for uttid, feats in tqdm.tqdm(kaldi_io.read_mat_scp(hparams.test_feats), total=num_utts):
feats = torch.from_numpy(feats).to(device).unsqueeze(0)
normalized = modules.normalize(feats, lengths=torch.tensor([1.]), epoch=1000)
encoded = modules.encoder(normalized)
out = modules.lin_out(encoded)
normalized_predictions = hparams.log_softmax(out) - prior
kaldi_io.write_mat(fo, normalized_predictions.squeeze(0).cpu().numpy(), key=uttid)
if __name__ == "__main__":
# Reading command line arguments
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
# Load hyperparameters file with command-line overrides
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)
modules, hparams, device = setup(hparams, run_opts)
run_test(modules, hparams, device)