Skip to content

Commit

Permalink
added scripts for testing pretrained models
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Jan 16, 2024
1 parent d7f284a commit b63576c
Show file tree
Hide file tree
Showing 2 changed files with 680 additions and 0 deletions.
278 changes: 278 additions & 0 deletions egs/aishell/ASR/zipformer_bbpe/jit_pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, exported by `torch.jit.script()`
and uses them to decode waves.
You can use the following command to get the exported models:
./zipformer/export.py \
--exp-dir ./zipformer_bbpe/exp \
--bpe ./data/lang_bbpe_500/bbpe.model \
--epoch 30 \
--avg 9 \
--jit 1
Usage of this script:
./zipformer/jit_pretrained.py \
--nn-model-filename ./zipformer_bbpe/exp/cpu_jit.pt \
--bpe ./data/lang_bbpe_500/bbpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""

import argparse
import logging
import math
from typing import List

import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence

from icefall import smart_byte_decode


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--nn-model-filename",
type=str,
required=True,
help="Path to the torchscript model cpu_jit.pt",
)

parser.add_argument(
"--bpe-model",
type=str,
help="""Path to the bbpe.model.""",
)

parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)

return parser


def read_sound_files(
filenames: List[str], expected_sample_rate: float = 16000
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans


def greedy_search(
model: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)

packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)

device = encoder_out.device
blank_id = model.decoder.blank_id

batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)

assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)

context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(N)]

decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)

decoder_out = model.decoder(
decoder_input,
need_pad=torch.tensor([False]),
).squeeze(1)

offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end

decoder_out = decoder_out[:batch_size]

logits = model.joiner(
current_encoder_out,
decoder_out,
)
# logits'shape (batch_size, vocab_size)

assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=torch.tensor([False]),
)
decoder_out = decoder_out.squeeze(1)

sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])

return ans


@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))

device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)

logging.info(f"device: {device}")

model = torch.jit.load(args.nn_model_filename)

model.eval()

model.to(device)

sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)

logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
opts.mel_opts.high_freq = -400

fbank = kaldifeat.Fbank(opts)

logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
)
waves = [w.to(device) for w in waves]

logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]

features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)

feature_lengths = torch.tensor(feature_lengths, device=device)

encoder_out, encoder_out_lens = model.encoder(
features=features,
feature_lengths=feature_lengths,
)

hyps = greedy_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)

s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = smart_byte_decode(sp.decode(hyp))
s += f"{filename}:\n{words}\n\n"
logging.info(s)

logging.info("Decoding Done")


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)
main()
Loading

0 comments on commit b63576c

Please sign in to comment.