-
Notifications
You must be signed in to change notification settings - Fork 303
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ff2975d
commit 7bd679f
Showing
1 changed file
with
250 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) | ||
# 2022 Xiaomi Corp. (authors: Xiaoyu Yang) | ||
# | ||
# See ../../../../LICENSE for clarification regarding multiple authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
This script loads ONNX models and uses them to decode waves. | ||
You can use the following command to get the exported models: | ||
We use the pre-trained model from | ||
https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/ | ||
as an example to show how to use this file. | ||
1. Download the pre-trained model | ||
cd egs/librispeech/ASR | ||
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/ | ||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
repo=$(basename $repo_url) | ||
pushd $repo | ||
git lfs pull --include "exp/pretrained.pt" | ||
cd exp | ||
ln -s pretrained.pt epoch-99.pt | ||
popd | ||
2. Export the model to ONNX | ||
./zipformer/export-onnx.py \ | ||
--use-averaged-model 0 \ | ||
--epoch 99 \ | ||
--avg 1 \ | ||
--exp-dir $repo/exp \ | ||
--causal False | ||
It will generate the following 3 files inside $repo/exp: | ||
- model-epoch-99-avg-1.onnx | ||
3. Run this file | ||
./zipformer/onnx_pretrained.py \ | ||
--model-filename $repo/exp/model-epoch-99-avg-1.onnx \ | ||
--tokens $repo/data/lang_bpe_500/tokens.txt \ | ||
$repo/test_wavs/1089-134686-0001.wav \ | ||
$repo/test_wavs/1221-135766-0001.wav \ | ||
$repo/test_wavs/1221-135766-0002.wav | ||
""" | ||
|
||
import argparse | ||
import csv | ||
import logging | ||
import math | ||
from typing import List, Tuple | ||
|
||
import k2 | ||
import kaldifeat | ||
import onnxruntime as ort | ||
import torch | ||
import torchaudio | ||
from torch.nn.utils.rnn import pad_sequence | ||
|
||
|
||
def get_parser(): | ||
parser = argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
) | ||
|
||
parser.add_argument( | ||
"--model-filename", | ||
type=str, | ||
required=True, | ||
help="Path to the onnx model. ", | ||
) | ||
|
||
parser.add_argument( | ||
"--label-dict", | ||
type=str, | ||
help="""class_labels_indices.csv.""", | ||
) | ||
|
||
parser.add_argument( | ||
"sound_files", | ||
type=str, | ||
nargs="+", | ||
help="The input sound file(s) to transcribe. " | ||
"Supported formats are those supported by torchaudio.load(). " | ||
"For example, wav and flac are supported. " | ||
"The sample rate has to be 16kHz.", | ||
) | ||
|
||
parser.add_argument( | ||
"--sample-rate", | ||
type=int, | ||
default=16000, | ||
help="The sample rate of the input sound file", | ||
) | ||
|
||
return parser | ||
|
||
|
||
class OnnxModel: | ||
def __init__( | ||
self, | ||
nn_model: str, | ||
): | ||
session_opts = ort.SessionOptions() | ||
session_opts.inter_op_num_threads = 1 | ||
session_opts.intra_op_num_threads = 4 | ||
|
||
self.session_opts = session_opts | ||
|
||
self.init_model(nn_model) | ||
|
||
def init_model(self, nn_model: str): | ||
self.model = ort.InferenceSession( | ||
nn_model, | ||
sess_options=self.session_opts, | ||
providers=["CPUExecutionProvider"], | ||
) | ||
meta = self.model.get_modelmeta().custom_metadata_map | ||
print(meta) | ||
|
||
|
||
def __call__( | ||
self, | ||
x: torch.Tensor, | ||
x_lens: torch.Tensor, | ||
) -> torch.Tensor: | ||
""" | ||
Args: | ||
x: | ||
A 3-D tensor of shape (N, T, C) | ||
x_lens: | ||
A 2-D tensor of shape (N,). Its dtype is torch.int64 | ||
Returns: | ||
Return a Tensor: | ||
- logits, its shape is (N, num_classes) | ||
""" | ||
out = self.model.run( | ||
[ | ||
self.model.get_outputs()[0].name, | ||
], | ||
{ | ||
self.model.get_inputs()[0].name: x.numpy(), | ||
self.model.get_inputs()[1].name: x_lens.numpy(), | ||
}, | ||
) | ||
return torch.from_numpy(out[0]) | ||
|
||
def read_sound_files( | ||
filenames: List[str], expected_sample_rate: float | ||
) -> List[torch.Tensor]: | ||
"""Read a list of sound files into a list 1-D float32 torch tensors. | ||
Args: | ||
filenames: | ||
A list of sound filenames. | ||
expected_sample_rate: | ||
The expected sample rate of the sound files. | ||
Returns: | ||
Return a list of 1-D float32 torch tensors. | ||
""" | ||
ans = [] | ||
for f in filenames: | ||
wave, sample_rate = torchaudio.load(f) | ||
assert ( | ||
sample_rate == expected_sample_rate | ||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" | ||
# We use only the first channel | ||
ans.append(wave[0]) | ||
return ans | ||
|
||
|
||
@torch.no_grad() | ||
def main(): | ||
parser = get_parser() | ||
args = parser.parse_args() | ||
logging.info(vars(args)) | ||
model = OnnxModel( | ||
nn_model=args.model_filename, | ||
) | ||
|
||
# get the label dictionary | ||
label_dict = {} | ||
with open(args.label_dict, "r") as f: | ||
reader = csv.reader(f, delimiter=",") | ||
for i, row in enumerate(reader): | ||
if i == 0: | ||
continue | ||
label_dict[int(row[0])] = row[2] | ||
|
||
logging.info("Constructing Fbank computer") | ||
opts = kaldifeat.FbankOptions() | ||
opts.device = "cpu" | ||
opts.frame_opts.dither = 0 | ||
opts.frame_opts.snip_edges = False | ||
opts.frame_opts.samp_freq = args.sample_rate | ||
opts.mel_opts.num_bins = 80 | ||
opts.mel_opts.high_freq = -400 | ||
|
||
fbank = kaldifeat.Fbank(opts) | ||
|
||
logging.info(f"Reading sound files: {args.sound_files}") | ||
waves = read_sound_files( | ||
filenames=args.sound_files, | ||
expected_sample_rate=args.sample_rate, | ||
) | ||
|
||
logging.info("Decoding started") | ||
features = fbank(waves) | ||
feature_lengths = [f.size(0) for f in features] | ||
|
||
features = pad_sequence( | ||
features, | ||
batch_first=True, | ||
padding_value=math.log(1e-10), | ||
) | ||
|
||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) | ||
logits = model(features, feature_lengths) | ||
|
||
for filename, logit in zip(args.sound_files, logits): | ||
topk_prob, topk_index = logit.sigmoid().topk(5) | ||
topk_labels = [label_dict[index.item()] for index in topk_index] | ||
logging.info( | ||
f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}" | ||
) | ||
|
||
logging.info("Decoding Done") | ||
|
||
|
||
if __name__ == "__main__": | ||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | ||
|
||
logging.basicConfig(format=formatter, level=logging.INFO) | ||
main() |