-
Notifications
You must be signed in to change notification settings - Fork 35
/
extract_wav2vec2_tdnn.py
217 lines (169 loc) · 7.18 KB
/
extract_wav2vec2_tdnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#!/usr/bin/env python
# from fairseq.checkpoint_utils import load_model_ensemble_and_task, load_checkpoint_to_cpu
from __future__ import division
import onmt
import onmt.markdown
import torch
import argparse
import math
import numpy
import sys
import h5py as h5
import numpy as np
from onmt.inference.fast_translator import FastTranslator
from onmt.inference.stream_translator import StreamTranslator
from torch.cuda.amp import autocast
parser = argparse.ArgumentParser(description='translate.py')
onmt.markdown.add_md_help_argument(parser)
parser.add_argument('-model', required=True,
help='Path to model .pt file')
parser.add_argument('-lm', required=False,
help='Path to language model .pt file. Used for cold fusion')
parser.add_argument('-vocab_list', default="",
help='A Vocabulary list (1 word per line). Only are these words generated during translation.')
parser.add_argument('-autoencoder', required=False,
help='Path to autoencoder .pt file')
parser.add_argument('-input_type', default="word",
help="Input type: word/char")
parser.add_argument('-src', required=True,
help='Source sequence to decode (one line per sequence)')
parser.add_argument('-attributes', default="",
help='Attributes for the decoder. Split them by | ')
parser.add_argument('-ensemble_weight', default="",
help='Weight for ensembles. Default as uniform. Split them by | and they will be normalized later')
parser.add_argument('-sub_ensemble_weight', default="",
help='Weight for ensembles. Default as uniform. Split them by | and they will be normalized later')
parser.add_argument('-stride', type=int, default=1,
help="Stride on input features")
parser.add_argument('-concat', type=str, default="1",
help="Concate sequential audio features to decrease sequence length")
parser.add_argument('-asr_format', default="h5", required=False,
help="Format of asr data h5 or scp")
parser.add_argument('-encoder_type', default='text',
help="Type of encoder to use. Options are [text|img|audio].")
parser.add_argument('-previous_context', type=int, default=0,
help="Number of previous sentence for context")
parser.add_argument('-max_memory_size', type=int, default=512,
help="Number of memory states stored in the buffer for XL models")
parser.add_argument('-tgt',
help='True target sequence (optional)')
parser.add_argument('-scp_output', default='output.scp',
help="""Path to output the feature paths""")
parser.add_argument('-ark_output', default='output.ark',
help="""Path to output the features""")
parser.add_argument('-batch_size', type=int, default=30,
help='Batch size (in audio samples)')
parser.add_argument('-gpu', type=int, default=-1,
help="Device to run on")
parser.add_argument('-fp16', action='store_true',
help='To use floating point 16 in decoding')
def _is_oversized(batch, new_sent_size, batch_size):
"""
Function to see if adding new sentence will make the current batch
:param batch:
:param new_sent_size:
:param batch_size_words:
:return:
"""
# Always return False if empty
if len(batch) == 0:
return False
current_max_length = max([sent.size(0) for sent in batch])
# Because adding a new sentence will potential enlarge the area of the rectangle, we need to check
if max(current_max_length, new_sent_size) * (len(batch) + 1) > batch_size:
return True
return False
def write_ark(utts, features, padding_mask, out_ark, out_scp, opt):
# cache_wav = ''
features = features.cpu()
bsz, seq_len, feat_size = features.size()
lengths = (1 - padding_mask).sum(dim=1)
assert len(utts) == bsz
for i in range(bsz):
feature_ = features[i, 0:lengths[i]]
feature_ = feature_.numpy()
# if opt.fp16:
# feature_ = feature_.astype(np.float16)
seg_name = utts[i]
dic = {seg_name: feature_}
from onmt.data.kaldiio.io import write_ark_file
write_ark_file(out_ark, out_scp, dic)
def build_data(src_sents):
from onmt.data.wav_dataset import WavDataset
src_data = src_sents
data_type = 'wav'
tgt_data = None
src_lang_data = [torch.Tensor([0])]
tgt_lang_data = None
return onmt.Dataset(src_data, tgt_data,
src_langs=src_lang_data, tgt_langs=tgt_lang_data,
batch_size_words=sys.maxsize,
max_src_len=sys.maxsize,
data_type=data_type,
batch_size_sents=sys.maxsize,
src_align_right=False,
past_src_data=None)
if __name__ == '__main__':
opt = parser.parse_args()
opt.cuda = opt.gpu > -1
if opt.cuda:
torch.cuda.set_device(opt.gpu)
from onmt.models.speech_recognizer.wav2vec2 import FairseqWav2VecExtractor
model = FairseqWav2VecExtractor(opt.model)
# if opt.fp16:
# model = model.half()
if opt.cuda:
model = model.cuda()
model.eval()
ark_out = open(opt.ark_output, 'wb')
scp_out = open(opt.scp_output, 'w')
audio_data = open(opt.src)
from onmt.utils import safe_readaudio
i = 0
n_models = len(opt.model.split("|"))
src_batch = list()
src_utts = list()
while True:
try:
line = next(audio_data).strip().split()
utt = line[0]
if len(line) == 2:
wav_path = line[1]
start = 0
end = 0
else:
wav_path, start, end = line[1], float(line[2]), float(line[3])
line = safe_readaudio(wav_path, start=start, end=end, sample_rate=16000)
except StopIteration:
break
src_length = line.size(0)
"""
Read features output from wav2vec model and write into scp/ark file just like Kaldi w/ logmel features
"""
if _is_oversized(src_batch, src_length, opt.batch_size):
# If adding a new sentence will make the batch oversized
# Then do translation now, and then free the list
print("Batch sizes :", len(src_batch))
dataset = build_data(src_batch)
batch = dataset.get_batch(0)
batch.cuda()
with autocast(enabled=opt.fp16):
features, padding_mask = model(batch)
write_ark(src_utts, features, padding_mask, ark_out, scp_out, opt)
src_batch = []
src_utts = []
src_batch.append(line)
src_utts.append(utt)
# catch the last batch
if len(src_batch) != 0:
print("Batch sizes :", len(src_batch), )
dataset = build_data(src_batch)
batch = dataset.get_batch(0)
batch.cuda()
with autocast(enabled=opt.fp16):
features, padding_mask = model(batch)
write_ark(src_utts, features, padding_mask, ark_out, scp_out, opt)
src_batch = []
src_utts = []
ark_out.close()
scp_out.close()