Skip to content

Commit

Permalink
add onnx pretrained
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoyang1998 committed Mar 29, 2024
1 parent ff2975d commit 7bd679f
Showing 1 changed file with 250 additions and 0 deletions.
250 changes: 250 additions & 0 deletions egs/audioset/AT/zipformer/onnx_pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2022 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX models and uses them to decode waves.
You can use the following command to get the exported models:
We use the pre-trained model from
https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--causal False
It will generate the following 3 files inside $repo/exp:
- model-epoch-99-avg-1.onnx
3. Run this file
./zipformer/onnx_pretrained.py \
--model-filename $repo/exp/model-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
"""

import argparse
import csv
import logging
import math
from typing import List, Tuple

import k2
import kaldifeat
import onnxruntime as ort
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--model-filename",
type=str,
required=True,
help="Path to the onnx model. ",
)

parser.add_argument(
"--label-dict",
type=str,
help="""class_labels_indices.csv.""",
)

parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)

parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)

return parser


class OnnxModel:
def __init__(
self,
nn_model: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4

self.session_opts = session_opts

self.init_model(nn_model)

def init_model(self, nn_model: str):
self.model = ort.InferenceSession(
nn_model,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
meta = self.model.get_modelmeta().custom_metadata_map
print(meta)


def __call__(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 2-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a Tensor:
- logits, its shape is (N, num_classes)
"""
out = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lens.numpy(),
},
)
return torch.from_numpy(out[0])

def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans


@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
nn_model=args.model_filename,
)

# get the label dictionary
label_dict = {}
with open(args.label_dict, "r") as f:
reader = csv.reader(f, delimiter=",")
for i, row in enumerate(reader):
if i == 0:
continue
label_dict[int(row[0])] = row[2]

logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
opts.mel_opts.high_freq = -400

fbank = kaldifeat.Fbank(opts)

logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)

logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]

features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)

feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
logits = model(features, feature_lengths)

for filename, logit in zip(args.sound_files, logits):
topk_prob, topk_index = logit.sigmoid().topk(5)
topk_labels = [label_dict[index.item()] for index in topk_index]
logging.info(
f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}"
)

logging.info("Decoding Done")


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)
main()

0 comments on commit 7bd679f

Please sign in to comment.