forked from buriburisuri/speech-to-text-wavenet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
executable file
·90 lines (63 loc) · 2.12 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import sugartensor as tf
from data import SpeechCorpus, voca_size
from model import *
import numpy as np
from tqdm import tqdm
__author__ = '[email protected]'
# set log level to debug
tf.sg_verbosity(10)
# command line argument for set_name
tf.sg_arg_def(set=('valid', "'train', 'valid', or 'test'. The default is 'valid'"))
tf.sg_arg_def(frac=(1.0, "test fraction ratio to whole data set. The default is 1.0(=whole set)"))
#
# hyper parameters
#
# batch size
batch_size = 16
#
# inputs
#
# corpus input tensor ( with QueueRunner )
data = SpeechCorpus(batch_size=batch_size, set_name=tf.sg_arg().set)
# mfcc feature of audio
x = data.mfcc
# target sentence label
y = data.label
# sequence length except zero-padding
seq_len = tf.not_equal(x.sg_sum(axis=2), 0.).sg_int().sg_sum(axis=1)
#
# Testing Graph
#
# encode audio feature
logit = get_logit(x, voca_size=voca_size)
# CTC loss
loss = logit.sg_ctc(target=y, seq_len=seq_len)
#
# run network
#
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
# init variables
tf.sg_init(sess)
# restore parameters
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint('asset/train'))
# logging
tf.sg_info('Testing started on %s set at global step[%08d].' %
(tf.sg_arg().set.upper(), sess.run(tf.sg_global_step())))
with tf.sg_queue_context():
# create progress bar
iterator = tqdm(range(0, int(data.num_batch * tf.sg_arg().frac)), total=int(data.num_batch * tf.sg_arg().frac),
initial=0, desc='test', ncols=70, unit='b', leave=False)
# batch loop
loss_avg = 0.
for _ in iterator:
# run session
batch_loss = sess.run(loss)
# loss history update
if batch_loss is not None and \
not np.isnan(batch_loss.all()) and not np.isinf(batch_loss.all()):
loss_avg += np.mean(batch_loss)
# final average
loss_avg /= data.num_batch * tf.sg_arg().frac
# logging
tf.sg_info('Testing finished on %s.(CTC loss=%f)' % (tf.sg_arg().set.upper(), loss_avg))