Skip to content

Commit

Permalink
Support streaming conformer CTC models from wenet (#427)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Nov 16, 2023
1 parent b83b3e3 commit fac4f6b
Show file tree
Hide file tree
Showing 31 changed files with 1,212 additions and 7 deletions.
55 changes: 55 additions & 0 deletions .github/scripts/test-online-ctc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env bash

set -e

log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

echo "EXE is $EXE"
echo "PATH: $PATH"

which $EXE

log "------------------------------------------------------------"
log "Run streaming Conformer CTC from WeNet"
log "------------------------------------------------------------"
wenet_models=(
sherpa-onnx-zh-wenet-aishell
sherpa-onnx-zh-wenet-aishell2
sherpa-onnx-zh-wenet-wenetspeech
sherpa-onnx-zh-wenet-multi-cn
sherpa-onnx-en-wenet-librispeech
sherpa-onnx-en-wenet-gigaspeech
)
for name in ${wenet_models[@]}; do
repo_url=https://huggingface.co/csukuangfj/$name
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd

log "test float32 models"
time $EXE \
--tokens=$repo/tokens.txt \
--wenet-ctc-model=$repo/model-streaming.onnx \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav

log "test int8 models"
time $EXE \
--tokens=$repo/tokens.txt \
--wenet-ctc-model=$repo/model-streaming.int8.onnx \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav

rm -rf $repo
done
11 changes: 11 additions & 0 deletions .github/workflows/linux-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ on:
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- 'CMakeLists.txt'
- 'cmake/**'
Expand All @@ -27,6 +28,8 @@ on:
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- 'CMakeLists.txt'
- 'cmake/**'
Expand Down Expand Up @@ -88,6 +91,14 @@ jobs:
file build/bin/sherpa-onnx
readelf -d build/bin/sherpa-onnx
- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
.github/scripts/test-online-ctc.sh
- name: Test offline TTS
shell: bash
run: |
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ on:
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- 'CMakeLists.txt'
- 'cmake/**'
Expand All @@ -27,6 +28,7 @@ on:
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- 'CMakeLists.txt'
- 'cmake/**'
Expand Down Expand Up @@ -89,6 +91,14 @@ jobs:
file build/bin/sherpa-onnx
readelf -d build/bin/sherpa-onnx
- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
.github/scripts/test-online-ctc.sh
- name: Test offline CTC
shell: bash
run: |
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/macos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ on:
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-online-ctc.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand All @@ -26,6 +27,7 @@ on:
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-online-ctc.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand Down Expand Up @@ -96,6 +98,15 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx
.github/scripts/test-online-ctc.sh
- name: Test offline TTS
shell: bash
run: |
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/windows-x64-cuda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ on:
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- 'CMakeLists.txt'
- 'cmake/**'
Expand All @@ -25,6 +26,7 @@ on:
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- 'CMakeLists.txt'
- 'cmake/**'
Expand Down Expand Up @@ -66,6 +68,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx.exe
.github/scripts/test-online-ctc.sh
- name: Test offline TTS
shell: bash
run: |
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/windows-x64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ on:
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- 'CMakeLists.txt'
- 'cmake/**'
Expand All @@ -25,6 +26,7 @@ on:
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- 'CMakeLists.txt'
- 'cmake/**'
Expand Down Expand Up @@ -67,6 +69,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx.exe
.github/scripts/test-online-ctc.sh
- name: Test offline TTS
shell: bash
run: |
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/windows-x86.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ on:
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-online-ctc.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand All @@ -26,6 +27,7 @@ on:
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-online-ctc.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand Down Expand Up @@ -67,6 +69,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx.exe
.github/scripts/test-online-ctc.sh
- name: Test offline TTS
shell: bash
run: |
Expand Down
1 change: 1 addition & 0 deletions scripts/wenet/export-onnx-streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def main():
dynamic_axes={
"x": {0: "N", 1: "T"},
"attn_cache": {2: "T"},
"attn_mask": {2: "T"},
"log_probs": {0: "N"},
"new_attn_cache": {2: "T"},
},
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ set(sources
offline-zipformer-ctc-model-config.cc
offline-zipformer-ctc-model.cc
online-conformer-transducer-model.cc
online-ctc-greedy-search-decoder.cc
online-ctc-model.cc
online-lm-config.cc
online-lm.cc
online-lstm-transducer-model.cc
Expand All @@ -64,6 +66,8 @@ set(sources
online-transducer-model-config.cc
online-transducer-model.cc
online-transducer-modified-beam-search-decoder.cc
online-wenet-ctc-model-config.cc
online-wenet-ctc-model.cc
online-zipformer-transducer-model.cc
online-zipformer2-transducer-model.cc
onnx-utils.cc
Expand Down
42 changes: 42 additions & 0 deletions sherpa-onnx/csrc/online-ctc-decoder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// sherpa-onnx/csrc/online-ctc-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
#define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_

#include <vector>

#include "onnxruntime_cxx_api.h" // NOLINT

namespace sherpa_onnx {

struct OnlineCtcDecoderResult {
/// The decoded token IDs
std::vector<int64_t> tokens;

/// timestamps[i] contains the output frame index where tokens[i] is decoded.
/// Note: The index is after subsampling
std::vector<int32_t> timestamps;

int32_t num_trailing_blanks = 0;
};

class OnlineCtcDecoder {
public:
virtual ~OnlineCtcDecoder() = default;

/** Run streaming CTC decoding given the output from the encoder model.
*
* @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing
* lob_probs.
*
* @param results Input & Output parameters..
*/
virtual void Decode(Ort::Value log_probs,
std::vector<OnlineCtcDecoderResult> *results) = 0;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
60 changes: 60 additions & 0 deletions sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation

#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"

#include <algorithm>
#include <utility>
#include <vector>

#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {

void OnlineCtcGreedySearchDecoder::Decode(
Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results) {
std::vector<int64_t> log_probs_shape =
log_probs.GetTensorTypeAndShapeInfo().GetShape();

if (log_probs_shape[0] != results->size()) {
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
static_cast<int32_t>(log_probs_shape[0]),
static_cast<int32_t>(results->size()));
exit(-1);
}

int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);
int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);
int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);

const float *p = log_probs.GetTensorData<float>();

for (int32_t b = 0; b != batch_size; ++b) {
auto &r = (*results)[b];

int32_t prev_id = -1;

for (int32_t t = 0; t != num_frames; ++t, p += vocab_size) {
int32_t y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p),
std::max_element(static_cast<const float *>(p),
static_cast<const float *>(p) + vocab_size)));

if (y == blank_id_) {
r.num_trailing_blanks += 1;
} else {
r.num_trailing_blanks = 0;
}

if (y != blank_id_ && y != prev_id) {
r.tokens.push_back(y);
r.timestamps.push_back(t);
}

prev_id = y;
} // for (int32_t t = 0; t != num_frames; ++t) {
} // for (int32_t b = 0; b != batch_size; ++b)
}

} // namespace sherpa_onnx
Loading

0 comments on commit fac4f6b

Please sign in to comment.