-
Notifications
You must be signed in to change notification settings - Fork 302
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added scripts for testing pretrained models
- Loading branch information
Showing
2 changed files
with
680 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,278 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, | ||
# Zengwei Yao, | ||
# Zengrui Jin,) | ||
# | ||
# See ../../../../LICENSE for clarification regarding multiple authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
This script loads torchscript models, exported by `torch.jit.script()` | ||
and uses them to decode waves. | ||
You can use the following command to get the exported models: | ||
./zipformer/export.py \ | ||
--exp-dir ./zipformer_bbpe/exp \ | ||
--bpe ./data/lang_bbpe_500/bbpe.model \ | ||
--epoch 30 \ | ||
--avg 9 \ | ||
--jit 1 | ||
Usage of this script: | ||
./zipformer/jit_pretrained.py \ | ||
--nn-model-filename ./zipformer_bbpe/exp/cpu_jit.pt \ | ||
--bpe ./data/lang_bbpe_500/bbpe.model \ | ||
/path/to/foo.wav \ | ||
/path/to/bar.wav | ||
""" | ||
|
||
import argparse | ||
import logging | ||
import math | ||
from typing import List | ||
|
||
import kaldifeat | ||
import sentencepiece as spm | ||
import torch | ||
import torchaudio | ||
from torch.nn.utils.rnn import pad_sequence | ||
|
||
from icefall import smart_byte_decode | ||
|
||
|
||
def get_parser(): | ||
parser = argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
) | ||
|
||
parser.add_argument( | ||
"--nn-model-filename", | ||
type=str, | ||
required=True, | ||
help="Path to the torchscript model cpu_jit.pt", | ||
) | ||
|
||
parser.add_argument( | ||
"--bpe-model", | ||
type=str, | ||
help="""Path to the bbpe.model.""", | ||
) | ||
|
||
parser.add_argument( | ||
"sound_files", | ||
type=str, | ||
nargs="+", | ||
help="The input sound file(s) to transcribe. " | ||
"Supported formats are those supported by torchaudio.load(). " | ||
"For example, wav and flac are supported. " | ||
"The sample rate has to be 16kHz.", | ||
) | ||
|
||
return parser | ||
|
||
|
||
def read_sound_files( | ||
filenames: List[str], expected_sample_rate: float = 16000 | ||
) -> List[torch.Tensor]: | ||
"""Read a list of sound files into a list 1-D float32 torch tensors. | ||
Args: | ||
filenames: | ||
A list of sound filenames. | ||
expected_sample_rate: | ||
The expected sample rate of the sound files. | ||
Returns: | ||
Return a list of 1-D float32 torch tensors. | ||
""" | ||
ans = [] | ||
for f in filenames: | ||
wave, sample_rate = torchaudio.load(f) | ||
assert ( | ||
sample_rate == expected_sample_rate | ||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" | ||
# We use only the first channel | ||
ans.append(wave[0].contiguous()) | ||
return ans | ||
|
||
|
||
def greedy_search( | ||
model: torch.jit.ScriptModule, | ||
encoder_out: torch.Tensor, | ||
encoder_out_lens: torch.Tensor, | ||
) -> List[List[int]]: | ||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. | ||
Args: | ||
model: | ||
The transducer model. | ||
encoder_out: | ||
A 3-D tensor of shape (N, T, C) | ||
encoder_out_lens: | ||
A 1-D tensor of shape (N,). | ||
Returns: | ||
Return the decoded results for each utterance. | ||
""" | ||
assert encoder_out.ndim == 3 | ||
assert encoder_out.size(0) >= 1, encoder_out.size(0) | ||
|
||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( | ||
input=encoder_out, | ||
lengths=encoder_out_lens.cpu(), | ||
batch_first=True, | ||
enforce_sorted=False, | ||
) | ||
|
||
device = encoder_out.device | ||
blank_id = model.decoder.blank_id | ||
|
||
batch_size_list = packed_encoder_out.batch_sizes.tolist() | ||
N = encoder_out.size(0) | ||
|
||
assert torch.all(encoder_out_lens > 0), encoder_out_lens | ||
assert N == batch_size_list[0], (N, batch_size_list) | ||
|
||
context_size = model.decoder.context_size | ||
hyps = [[blank_id] * context_size for _ in range(N)] | ||
|
||
decoder_input = torch.tensor( | ||
hyps, | ||
device=device, | ||
dtype=torch.int64, | ||
) # (N, context_size) | ||
|
||
decoder_out = model.decoder( | ||
decoder_input, | ||
need_pad=torch.tensor([False]), | ||
).squeeze(1) | ||
|
||
offset = 0 | ||
for batch_size in batch_size_list: | ||
start = offset | ||
end = offset + batch_size | ||
current_encoder_out = packed_encoder_out.data[start:end] | ||
current_encoder_out = current_encoder_out | ||
# current_encoder_out's shape: (batch_size, encoder_out_dim) | ||
offset = end | ||
|
||
decoder_out = decoder_out[:batch_size] | ||
|
||
logits = model.joiner( | ||
current_encoder_out, | ||
decoder_out, | ||
) | ||
# logits'shape (batch_size, vocab_size) | ||
|
||
assert logits.ndim == 2, logits.shape | ||
y = logits.argmax(dim=1).tolist() | ||
emitted = False | ||
for i, v in enumerate(y): | ||
if v != blank_id: | ||
hyps[i].append(v) | ||
emitted = True | ||
if emitted: | ||
# update decoder output | ||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]] | ||
decoder_input = torch.tensor( | ||
decoder_input, | ||
device=device, | ||
dtype=torch.int64, | ||
) | ||
decoder_out = model.decoder( | ||
decoder_input, | ||
need_pad=torch.tensor([False]), | ||
) | ||
decoder_out = decoder_out.squeeze(1) | ||
|
||
sorted_ans = [h[context_size:] for h in hyps] | ||
ans = [] | ||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist() | ||
for i in range(N): | ||
ans.append(sorted_ans[unsorted_indices[i]]) | ||
|
||
return ans | ||
|
||
|
||
@torch.no_grad() | ||
def main(): | ||
parser = get_parser() | ||
args = parser.parse_args() | ||
logging.info(vars(args)) | ||
|
||
device = torch.device("cpu") | ||
if torch.cuda.is_available(): | ||
device = torch.device("cuda", 0) | ||
|
||
logging.info(f"device: {device}") | ||
|
||
model = torch.jit.load(args.nn_model_filename) | ||
|
||
model.eval() | ||
|
||
model.to(device) | ||
|
||
sp = spm.SentencePieceProcessor() | ||
sp.load(args.bpe_model) | ||
|
||
logging.info("Constructing Fbank computer") | ||
opts = kaldifeat.FbankOptions() | ||
opts.device = device | ||
opts.frame_opts.dither = 0 | ||
opts.frame_opts.snip_edges = False | ||
opts.frame_opts.samp_freq = 16000 | ||
opts.mel_opts.num_bins = 80 | ||
opts.mel_opts.high_freq = -400 | ||
|
||
fbank = kaldifeat.Fbank(opts) | ||
|
||
logging.info(f"Reading sound files: {args.sound_files}") | ||
waves = read_sound_files( | ||
filenames=args.sound_files, | ||
) | ||
waves = [w.to(device) for w in waves] | ||
|
||
logging.info("Decoding started") | ||
features = fbank(waves) | ||
feature_lengths = [f.size(0) for f in features] | ||
|
||
features = pad_sequence( | ||
features, | ||
batch_first=True, | ||
padding_value=math.log(1e-10), | ||
) | ||
|
||
feature_lengths = torch.tensor(feature_lengths, device=device) | ||
|
||
encoder_out, encoder_out_lens = model.encoder( | ||
features=features, | ||
feature_lengths=feature_lengths, | ||
) | ||
|
||
hyps = greedy_search( | ||
model=model, | ||
encoder_out=encoder_out, | ||
encoder_out_lens=encoder_out_lens, | ||
) | ||
|
||
s = "\n" | ||
for filename, hyp in zip(args.sound_files, hyps): | ||
words = smart_byte_decode(sp.decode(hyp)) | ||
s += f"{filename}:\n{words}\n\n" | ||
logging.info(s) | ||
|
||
logging.info("Decoding Done") | ||
|
||
|
||
if __name__ == "__main__": | ||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | ||
|
||
logging.basicConfig(format=formatter, level=logging.INFO) | ||
main() |
Oops, something went wrong.