diff --git a/.github/scripts/test-online-ctc.sh b/.github/scripts/test-online-ctc.sh new file mode 100755 index 000000000..f74ee3c3e --- /dev/null +++ b/.github/scripts/test-online-ctc.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +echo "EXE is $EXE" +echo "PATH: $PATH" + +which $EXE + +log "------------------------------------------------------------" +log "Run streaming Conformer CTC from WeNet" +log "------------------------------------------------------------" +wenet_models=( +sherpa-onnx-zh-wenet-aishell +sherpa-onnx-zh-wenet-aishell2 +sherpa-onnx-zh-wenet-wenetspeech +sherpa-onnx-zh-wenet-multi-cn +sherpa-onnx-en-wenet-librispeech +sherpa-onnx-en-wenet-gigaspeech +) +for name in ${wenet_models[@]}; do + repo_url=https://huggingface.co/csukuangfj/$name + log "Start testing ${repo_url}" + repo=$(basename $repo_url) + log "Download pretrained model and test-data from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + pushd $repo + git lfs pull --include "*.onnx" + ls -lh *.onnx + popd + + log "test float32 models" + time $EXE \ + --tokens=$repo/tokens.txt \ + --wenet-ctc-model=$repo/model-streaming.onnx \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + + log "test int8 models" + time $EXE \ + --tokens=$repo/tokens.txt \ + --wenet-ctc-model=$repo/model-streaming.int8.onnx \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + + rm -rf $repo +done diff --git a/.github/workflows/linux-gpu.yaml b/.github/workflows/linux-gpu.yaml index ba7b4a004..6a7d2589e 100644 --- a/.github/workflows/linux-gpu.yaml +++ b/.github/workflows/linux-gpu.yaml @@ -12,6 +12,7 @@ on: - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' + - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' - 'CMakeLists.txt' - 'cmake/**' @@ -27,6 +28,8 @@ on: - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' + - '.github/scripts/test-online-ctc.sh' + - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' - 'CMakeLists.txt' - 'cmake/**' @@ -88,6 +91,14 @@ jobs: file build/bin/sherpa-onnx readelf -d build/bin/sherpa-onnx + - name: Test online CTC + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx + + .github/scripts/test-online-ctc.sh + - name: Test offline TTS shell: bash run: | diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index 88c82c3eb..cbe4dac97 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -12,6 +12,7 @@ on: - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' + - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' - 'CMakeLists.txt' - 'cmake/**' @@ -27,6 +28,7 @@ on: - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' + - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' - 'CMakeLists.txt' - 'cmake/**' @@ -89,6 +91,14 @@ jobs: file build/bin/sherpa-onnx readelf -d build/bin/sherpa-onnx + - name: Test online CTC + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx + + .github/scripts/test-online-ctc.sh + - name: Test offline CTC shell: bash run: | diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index e952b38a8..b73b0a504 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -13,6 +13,7 @@ on: - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-offline-tts.sh' + - '.github/scripts/test-online-ctc.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -26,6 +27,7 @@ on: - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-offline-tts.sh' + - '.github/scripts/test-online-ctc.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -96,6 +98,15 @@ jobs: otool -L build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx + - name: Test online CTC + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx + + .github/scripts/test-online-ctc.sh + + - name: Test offline TTS shell: bash run: | diff --git a/.github/workflows/windows-x64-cuda.yaml b/.github/workflows/windows-x64-cuda.yaml index e89f0ddd4..936990fea 100644 --- a/.github/workflows/windows-x64-cuda.yaml +++ b/.github/workflows/windows-x64-cuda.yaml @@ -12,6 +12,7 @@ on: - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' + - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' - 'CMakeLists.txt' - 'cmake/**' @@ -25,6 +26,7 @@ on: - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' + - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' - 'CMakeLists.txt' - 'cmake/**' @@ -66,6 +68,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + - name: Test online CTC + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx.exe + + .github/scripts/test-online-ctc.sh + - name: Test offline TTS shell: bash run: | diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index e929d1d01..c9a9c0f68 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -12,6 +12,7 @@ on: - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' + - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' - 'CMakeLists.txt' - 'cmake/**' @@ -25,6 +26,7 @@ on: - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' + - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' - 'CMakeLists.txt' - 'cmake/**' @@ -67,6 +69,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + - name: Test online CTC + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx.exe + + .github/scripts/test-online-ctc.sh + - name: Test offline TTS shell: bash run: | diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index 1117ea257..0a1cf9222 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -13,6 +13,7 @@ on: - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-offline-tts.sh' + - '.github/scripts/test-online-ctc.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -26,6 +27,7 @@ on: - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-offline-tts.sh' + - '.github/scripts/test-online-ctc.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -67,6 +69,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + - name: Test online CTC + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx.exe + + .github/scripts/test-online-ctc.sh + - name: Test offline TTS shell: bash run: | diff --git a/scripts/wenet/export-onnx-streaming.py b/scripts/wenet/export-onnx-streaming.py index 27d8afde5..8ef7b2172 100755 --- a/scripts/wenet/export-onnx-streaming.py +++ b/scripts/wenet/export-onnx-streaming.py @@ -164,6 +164,7 @@ def main(): dynamic_axes={ "x": {0: "N", 1: "T"}, "attn_cache": {2: "T"}, + "attn_mask": {2: "T"}, "log_probs": {0: "N"}, "new_attn_cache": {2: "T"}, }, diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 0e870b301..403931028 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -49,6 +49,8 @@ set(sources offline-zipformer-ctc-model-config.cc offline-zipformer-ctc-model.cc online-conformer-transducer-model.cc + online-ctc-greedy-search-decoder.cc + online-ctc-model.cc online-lm-config.cc online-lm.cc online-lstm-transducer-model.cc @@ -64,6 +66,8 @@ set(sources online-transducer-model-config.cc online-transducer-model.cc online-transducer-modified-beam-search-decoder.cc + online-wenet-ctc-model-config.cc + online-wenet-ctc-model.cc online-zipformer-transducer-model.cc online-zipformer2-transducer-model.cc onnx-utils.cc diff --git a/sherpa-onnx/csrc/online-ctc-decoder.h b/sherpa-onnx/csrc/online-ctc-decoder.h new file mode 100644 index 000000000..3e701bb37 --- /dev/null +++ b/sherpa-onnx/csrc/online-ctc-decoder.h @@ -0,0 +1,42 @@ +// sherpa-onnx/csrc/online-ctc-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +struct OnlineCtcDecoderResult { + /// The decoded token IDs + std::vector tokens; + + /// timestamps[i] contains the output frame index where tokens[i] is decoded. + /// Note: The index is after subsampling + std::vector timestamps; + + int32_t num_trailing_blanks = 0; +}; + +class OnlineCtcDecoder { + public: + virtual ~OnlineCtcDecoder() = default; + + /** Run streaming CTC decoding given the output from the encoder model. + * + * @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing + * lob_probs. + * + * @param results Input & Output parameters.. + */ + virtual void Decode(Ort::Value log_probs, + std::vector *results) = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ diff --git a/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc new file mode 100644 index 000000000..8a5a606aa --- /dev/null +++ b/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc @@ -0,0 +1,60 @@ +// sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OnlineCtcGreedySearchDecoder::Decode( + Ort::Value log_probs, std::vector *results) { + std::vector log_probs_shape = + log_probs.GetTensorTypeAndShapeInfo().GetShape(); + + if (log_probs_shape[0] != results->size()) { + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", + static_cast(log_probs_shape[0]), + static_cast(results->size())); + exit(-1); + } + + int32_t batch_size = static_cast(log_probs_shape[0]); + int32_t num_frames = static_cast(log_probs_shape[1]); + int32_t vocab_size = static_cast(log_probs_shape[2]); + + const float *p = log_probs.GetTensorData(); + + for (int32_t b = 0; b != batch_size; ++b) { + auto &r = (*results)[b]; + + int32_t prev_id = -1; + + for (int32_t t = 0; t != num_frames; ++t, p += vocab_size) { + int32_t y = static_cast(std::distance( + static_cast(p), + std::max_element(static_cast(p), + static_cast(p) + vocab_size))); + + if (y == blank_id_) { + r.num_trailing_blanks += 1; + } else { + r.num_trailing_blanks = 0; + } + + if (y != blank_id_ && y != prev_id) { + r.tokens.push_back(y); + r.timestamps.push_back(t); + } + + prev_id = y; + } // for (int32_t t = 0; t != num_frames; ++t) { + } // for (int32_t b = 0; b != batch_size; ++b) +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h b/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h new file mode 100644 index 000000000..fc724f2c3 --- /dev/null +++ b/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CTC_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/online-ctc-decoder.h" + +namespace sherpa_onnx { + +class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder { + public: + explicit OnlineCtcGreedySearchDecoder(int32_t blank_id) + : blank_id_(blank_id) {} + + void Decode(Ort::Value log_probs, + std::vector *results) override; + + private: + int32_t blank_id_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_GREEDY_SEARCH_DECODER_H_ diff --git a/sherpa-onnx/csrc/online-ctc-model.cc b/sherpa-onnx/csrc/online-ctc-model.cc new file mode 100644 index 000000000..4ec094e04 --- /dev/null +++ b/sherpa-onnx/csrc/online-ctc-model.cc @@ -0,0 +1,138 @@ +// sherpa-onnx/csrc/online-ctc-model.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-ctc-model.h" + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-wenet-ctc-model.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace { + +enum class ModelType { + kZipformerCtc, + kWenetCtc, + kUnkown, +}; + +} // namespace + +namespace sherpa_onnx { + +static ModelType GetModelType(char *model_data, size_t model_data_length, + bool debug) { + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); + Ort::SessionOptions sess_opts; + + auto sess = std::make_unique(env, model_data, model_data_length, + sess_opts); + + Ort::ModelMetadata meta_data = sess->GetModelMetadata(); + if (debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; + auto model_type = + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); + if (!model_type) { + SHERPA_ONNX_LOGE( + "No model_type in the metadata!\n" + "If you are using models from WeNet, please refer to\n" + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/" + "run.sh\n" + "\n" + "for how to add metadta to model.onnx\n"); + return ModelType::kUnkown; + } + + if (model_type.get() == std::string("zipformer2")) { + return ModelType::kZipformerCtc; + } else if (model_type.get() == std::string("wenet_ctc")) { + return ModelType::kWenetCtc; + } else { + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); + return ModelType::kUnkown; + } +} + +std::unique_ptr OnlineCtcModel::Create( + const OnlineModelConfig &config) { + ModelType model_type = ModelType::kUnkown; + + std::string filename; + if (!config.wenet_ctc.model.empty()) { + filename = config.wenet_ctc.model; + } else { + SHERPA_ONNX_LOGE("Please specify a CTC model"); + exit(-1); + } + + { + auto buffer = ReadFile(filename); + + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + } + + switch (model_type) { + case ModelType::kZipformerCtc: + return nullptr; + // return std::make_unique(config); + break; + case ModelType::kWenetCtc: + return std::make_unique(config); + break; + case ModelType::kUnkown: + SHERPA_ONNX_LOGE("Unknown model type in online CTC!"); + return nullptr; + } + + return nullptr; +} + +#if __ANDROID_API__ >= 9 + +std::unique_ptr OnlineCtcModel::Create( + AAssetManager *mgr, const OnlineModelConfig &config) { + ModelType model_type = ModelType::kUnkown; + + std::string filename; + if (!config.wenet_ctc.model.empty()) { + filename = config.wenet_ctc.model; + } else { + SHERPA_ONNX_LOGE("Please specify a CTC model"); + exit(-1); + } + + { + auto buffer = ReadFile(mgr, filename); + + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + } + + switch (model_type) { + case ModelType::kZipformerCtc: + return nullptr; + // return std::make_unique(mgr, config); + break; + case ModelType::kWenetCtc: + return std::make_unique(mgr, config); + break; + case ModelType::kUnkown: + SHERPA_ONNX_LOGE("Unknown model type in online CTC!"); + return nullptr; + } + + return nullptr; +} +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-ctc-model.h b/sherpa-onnx/csrc/online-ctc-model.h new file mode 100644 index 000000000..c89cf6054 --- /dev/null +++ b/sherpa-onnx/csrc/online-ctc-model.h @@ -0,0 +1,67 @@ +// sherpa-onnx/csrc/online-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CTC_MODEL_H_ + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-model-config.h" + +namespace sherpa_onnx { + +class OnlineCtcModel { + public: + virtual ~OnlineCtcModel() = default; + + static std::unique_ptr Create( + const OnlineModelConfig &config); + +#if __ANDROID_API__ >= 9 + static std::unique_ptr Create( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + + // Return a list of tensors containing the initial states + virtual std::vector GetInitStates() const = 0; + + /** + * + * @param x A 3-D tensor of shape (N, T, C). N has to be 1. + * @param states It is from GetInitStates() or returned from this method. + * + * @return Return a list of tensors + * - ans[0] contains log_probs, of shape (N, T, C) + * - ans[1:] contains next_states + */ + virtual std::vector Forward( + Ort::Value x, std::vector states) const = 0; + + /** Return the vocabulary size of the model + */ + virtual int32_t VocabSize() const = 0; + + /** Return an allocator for allocating memory + */ + virtual OrtAllocator *Allocator() const = 0; + + // The model accepts this number of frames before subsampling as input + virtual int32_t ChunkLength() const = 0; + + // Similar to frame_shift in feature extractor, after processing + // ChunkLength() frames, we advance by ChunkShift() frames + // before we process the next chunk. + virtual int32_t ChunkShift() const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_MODEL_H_ diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index 9c1f8c493..a81ce375d 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -13,6 +13,7 @@ namespace sherpa_onnx { void OnlineModelConfig::Register(ParseOptions *po) { transducer.Register(po); paraformer.Register(po); + wenet_ctc.Register(po); po->Register("tokens", &tokens, "Path to tokens.txt"); @@ -46,6 +47,10 @@ bool OnlineModelConfig::Validate() const { return paraformer.Validate(); } + if (!wenet_ctc.model.empty()) { + return wenet_ctc.Validate(); + } + return transducer.Validate(); } @@ -55,6 +60,7 @@ std::string OnlineModelConfig::ToString() const { os << "OnlineModelConfig("; os << "transducer=" << transducer.ToString() << ", "; os << "paraformer=" << paraformer.ToString() << ", "; + os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; os << "debug=" << (debug ? "True" : "False") << ", "; diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index 2afd66176..34369b959 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -8,12 +8,14 @@ #include "sherpa-onnx/csrc/online-paraformer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" namespace sherpa_onnx { struct OnlineModelConfig { OnlineTransducerModelConfig transducer; OnlineParaformerModelConfig paraformer; + OnlineWenetCtcModelConfig wenet_ctc; std::string tokens; int32_t num_threads = 1; bool debug = false; @@ -31,10 +33,12 @@ struct OnlineModelConfig { OnlineModelConfig() = default; OnlineModelConfig(const OnlineTransducerModelConfig &transducer, const OnlineParaformerModelConfig ¶former, + const OnlineWenetCtcModelConfig &wenet_ctc, const std::string &tokens, int32_t num_threads, bool debug, const std::string &provider, const std::string &model_type) : transducer(transducer), paraformer(paraformer), + wenet_ctc(wenet_ctc), tokens(tokens), num_threads(num_threads), debug(debug), diff --git a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h new file mode 100644 index 000000000..16b8ca48a --- /dev/null +++ b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h @@ -0,0 +1,226 @@ +// sherpa-onnx/csrc/online-recognizer-ctc-impl.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_ + +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-ctc-decoder.h" +#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h" +#include "sherpa-onnx/csrc/online-ctc-model.h" +#include "sherpa-onnx/csrc/online-recognizer-impl.h" +#include "sherpa-onnx/csrc/symbol-table.h" + +namespace sherpa_onnx { + +static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, + const SymbolTable &sym_table, + float frame_shift_ms, + int32_t subsampling_factor, + int32_t segment, + int32_t frames_since_start) { + OnlineRecognizerResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.tokens.size()); + + for (auto i : src.tokens) { + auto sym = sym_table[i]; + + r.text.append(sym); + r.tokens.push_back(std::move(sym)); + } + + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + for (auto t : src.timestamps) { + float time = frame_shift_s * t; + r.timestamps.push_back(time); + } + + r.segment = segment; + r.start_time = frames_since_start * frame_shift_ms / 1000.; + + return r; +} + +class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { + public: + explicit OnlineRecognizerCtcImpl(const OnlineRecognizerConfig &config) + : config_(config), + model_(OnlineCtcModel::Create(config.model_config)), + sym_(config.model_config.tokens), + endpoint_(config_.endpoint_config) { + if (!config.model_config.wenet_ctc.model.empty()) { + // WeNet CTC models assume input samples are in the range + // [-32768, 32767], so we set normalize_samples to false + config_.feat_config.normalize_samples = false; + } + + InitDecoder(); + } + +#if __ANDROID_API__ >= 9 + explicit OnlineRecognizerCtcImpl(AAssetManager *mgr, + const OnlineRecognizerConfig &config) + : config_(config), + model_(OnlineCtcModel::Create(mgr, config.model_config)), + sym_(mgr, config.model_config.tokens), + endpoint_(config_.endpoint_config) { + if (!config.model_config.wenet_ctc.model.empty()) { + // WeNet CTC models assume input samples are in the range + // [-32768, 32767], so we set normalize_samples to false + config_.feat_config.normalize_samples = false; + } + + InitDecoder(); + } +#endif + + std::unique_ptr CreateStream() const override { + auto stream = std::make_unique(config_.feat_config); + stream->SetStates(model_->GetInitStates()); + + return stream; + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() + model_->ChunkLength() < + s->NumFramesReady(); + } + + void DecodeStreams(OnlineStream **ss, int32_t n) const override { + for (int32_t i = 0; i != n; ++i) { + DecodeStream(ss[i]); + } + } + + OnlineRecognizerResult GetResult(OnlineStream *s) const override { + OnlineCtcDecoderResult decoder_result = s->GetCtcResult(); + + // TODO(fangjun): Remember to change these constants if needed + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = 4; + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, + s->GetCurrentSegment(), s->GetNumFramesSinceStart()); + } + + bool IsEndpoint(OnlineStream *s) const override { + if (!config_.enable_endpoint) { + return false; + } + + int32_t num_processed_frames = s->GetNumProcessedFrames(); + + // frame shift is 10 milliseconds + float frame_shift_in_seconds = 0.01; + + // subsampling factor is 4 + int32_t trailing_silence_frames = s->GetCtcResult().num_trailing_blanks * 4; + + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, + frame_shift_in_seconds); + } + + void Reset(OnlineStream *s) const override { + // segment is incremented only when the last + // result is not empty + const auto &r = s->GetCtcResult(); + if (!r.tokens.empty()) { + s->GetCurrentSegment() += 1; + } + + // clear result + s->SetCtcResult({}); + + // clear states + s->SetStates(model_->GetInitStates()); + + // Note: We only update counters. The underlying audio samples + // are not discarded. + s->Reset(); + } + + private: + void InitDecoder() { + if (config_.decoding_method == "greedy_search") { + if (!sym_.contains("") && !sym_.contains("") && + !sym_.contains("")) { + SHERPA_ONNX_LOGE( + "We expect that tokens.txt contains " + "the symbol or or and its ID."); + exit(-1); + } + + int32_t blank_id = 0; + if (sym_.contains("")) { + blank_id = sym_[""]; + } else if (sym_.contains("")) { + // for tdnn models of the yesno recipe from icefall + blank_id = sym_[""]; + } else if (sym_.contains("")) { + // for WeNet CTC models + blank_id = sym_[""]; + } + + decoder_ = std::make_unique(blank_id); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config_.decoding_method.c_str()); + exit(-1); + } + } + + void DecodeStream(OnlineStream *s) const { + int32_t chunk_length = model_->ChunkLength(); + int32_t chunk_shift = model_->ChunkShift(); + + int32_t feat_dim = s->FeatureDim(); + + const auto num_processed_frames = s->GetNumProcessedFrames(); + std::vector frames = + s->GetFrames(num_processed_frames, chunk_length); + s->GetNumProcessedFrames() += chunk_shift; + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape{1, chunk_length, feat_dim}; + Ort::Value x = + Ort::Value::CreateTensor(memory_info, frames.data(), frames.size(), + x_shape.data(), x_shape.size()); + auto out = model_->Forward(std::move(x), std::move(s->GetStates())); + int32_t num_states = static_cast(out.size()) - 1; + + std::vector states; + states.reserve(num_states); + + for (int32_t i = 0; i != num_states; ++i) { + states.push_back(std::move(out[i + 1])); + } + s->SetStates(std::move(states)); + + std::vector results(1); + results[0] = std::move(s->GetCtcResult()); + + decoder_->Decode(std::move(out[0]), &results); + s->SetCtcResult(results[0]); + } + + private: + OnlineRecognizerConfig config_; + std::unique_ptr model_; + std::unique_ptr decoder_; + SymbolTable sym_; + Endpoint endpoint_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_ diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index 1eb16c031..59d8658b8 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/csrc/online-recognizer-impl.h" +#include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h" #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" @@ -19,6 +20,10 @@ std::unique_ptr OnlineRecognizerImpl::Create( return std::make_unique(config); } + if (!config.model_config.wenet_ctc.model.empty()) { + return std::make_unique(config); + } + SHERPA_ONNX_LOGE("Please specify a model"); exit(-1); } @@ -34,6 +39,10 @@ std::unique_ptr OnlineRecognizerImpl::Create( return std::make_unique(mgr, config); } + if (!config.model_config.wenet_ctc.model.empty()) { + return std::make_unique(mgr, config); + } + SHERPA_ONNX_LOGE("Please specify a model"); exit(-1); } diff --git a/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h index d5034d135..fbbeb2ada 100644 --- a/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h @@ -120,11 +120,7 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { model_(mgr, config.model_config), sym_(mgr, config.model_config.tokens), endpoint_(config_.endpoint_config) { - if (config.decoding_method == "greedy_search") { - // add greedy search decoder - // SHERPA_ONNX_LOGE("to be implemented"); - // exit(-1); - } else { + if (config.decoding_method != "greedy_search") { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config.decoding_method.c_str()); exit(-1); diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index 8cec219b5..e165189d6 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -51,6 +51,10 @@ class OnlineStream::Impl { OnlineTransducerDecoderResult &GetResult() { return result_; } + OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; } + + void SetCtcResult(const OnlineCtcDecoderResult &r) { ctc_result_ = r; } + void SetParaformerResult(const OnlineParaformerDecoderResult &r) { paraformer_result_ = r; } @@ -89,7 +93,8 @@ class OnlineStream::Impl { int32_t start_frame_index_ = 0; // never reset int32_t segment_ = 0; OnlineTransducerDecoderResult result_; - std::vector states_; + OnlineCtcDecoderResult ctc_result_; + std::vector states_; // states for transducer or ctc models std::vector paraformer_feat_cache_; std::vector paraformer_encoder_out_cache_; std::vector paraformer_alpha_cache_; @@ -144,6 +149,14 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() { return impl_->GetResult(); } +OnlineCtcDecoderResult &OnlineStream::GetCtcResult() { + return impl_->GetCtcResult(); +} + +void OnlineStream::SetCtcResult(const OnlineCtcDecoderResult &r) { + impl_->SetCtcResult(r); +} + void OnlineStream::SetParaformerResult(const OnlineParaformerDecoderResult &r) { impl_->SetParaformerResult(r); } diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index b9214b1dc..09be33151 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -11,6 +11,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/context-graph.h" #include "sherpa-onnx/csrc/features.h" +#include "sherpa-onnx/csrc/online-ctc-decoder.h" #include "sherpa-onnx/csrc/online-paraformer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" @@ -75,6 +76,9 @@ class OnlineStream { void SetResult(const OnlineTransducerDecoderResult &r); OnlineTransducerDecoderResult &GetResult(); + void SetCtcResult(const OnlineCtcDecoderResult &r); + OnlineCtcDecoderResult &GetCtcResult(); + void SetParaformerResult(const OnlineParaformerDecoderResult &r); OnlineParaformerDecoderResult &GetParaformerResult(); diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model-config.cc b/sherpa-onnx/csrc/online-wenet-ctc-model-config.cc new file mode 100644 index 000000000..6098be626 --- /dev/null +++ b/sherpa-onnx/csrc/online-wenet-ctc-model-config.cc @@ -0,0 +1,59 @@ +// sherpa-onnx/csrc/online-wenet-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OnlineWenetCtcModelConfig::Register(ParseOptions *po) { + po->Register("wenet-ctc-model", &model, + "Path to CTC model.onnx from WeNet. Please see " + "https://github.com/k2-fsa/sherpa-onnx/pull/425"); + po->Register("wenet-ctc-chunk-size", &chunk_size, + "Chunk size after subsampling used for decoding."); + po->Register("wenet-ctc-num-left-chunks", &num_left_chunks, + "Number of left chunks after subsampling used for decoding."); +} + +bool OnlineWenetCtcModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("WeNet CTC model %s does not exist", model.c_str()); + return false; + } + + if (chunk_size <= 0) { + SHERPA_ONNX_LOGE( + "Please specify a positive value for --wenet-ctc-chunk-size. Currently " + "given: %d", + chunk_size); + return false; + } + + if (num_left_chunks <= 0) { + SHERPA_ONNX_LOGE( + "Please specify a positive value for --wenet-ctc-num-left-chunks. " + "Currently given: %d. Note that if you want to use -1, please consider " + "using a non-streaming model.", + num_left_chunks); + return false; + } + + return true; +} + +std::string OnlineWenetCtcModelConfig::ToString() const { + std::ostringstream os; + + os << "OnlineWenetCtcModelConfig("; + os << "model=\"" << model << "\", "; + os << "chunk_size=" << chunk_size << ", "; + os << "num_left_chunks=" << num_left_chunks << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model-config.h b/sherpa-onnx/csrc/online-wenet-ctc-model-config.h new file mode 100644 index 000000000..b90295ea6 --- /dev/null +++ b/sherpa-onnx/csrc/online-wenet-ctc-model-config.h @@ -0,0 +1,38 @@ +// sherpa-onnx/csrc/online-wenet-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OnlineWenetCtcModelConfig { + std::string model; + + // --chunk_size from wenet + int32_t chunk_size = 16; + + // --num_left_chunks from wenet + int32_t num_left_chunks = 4; + + OnlineWenetCtcModelConfig() = default; + + OnlineWenetCtcModelConfig(const std::string &model, int32_t chunk_size, + int32_t num_left_chunks) + : model(model), + chunk_size(chunk_size), + num_left_chunks(num_left_chunks) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model.cc b/sherpa-onnx/csrc/online-wenet-ctc-model.cc new file mode 100644 index 000000000..5d7e90964 --- /dev/null +++ b/sherpa-onnx/csrc/online-wenet-ctc-model.cc @@ -0,0 +1,242 @@ +// sherpa-onnx/csrc/online-paraformer-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-wenet-ctc-model.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OnlineWenetCtcModel::Impl { + public: + explicit Impl(const OnlineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.wenet_ctc.model); + Init(buf.data(), buf.size()); + } + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OnlineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_WARNING), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.wenet_ctc.model); + Init(buf.data(), buf.size()); + } + } +#endif + + std::vector Forward(Ort::Value x, + std::vector states) { + Ort::Value &attn_cache = states[0]; + Ort::Value &conv_cache = states[1]; + Ort::Value &offset = states[2]; + + int32_t chunk_size = config_.wenet_ctc.chunk_size; + int32_t left_chunks = config_.wenet_ctc.num_left_chunks; + // build attn_mask + std::array attn_mask_shape{1, 1, + required_cache_size_ + chunk_size}; + Ort::Value attn_mask = Ort::Value::CreateTensor( + allocator_, attn_mask_shape.data(), attn_mask_shape.size()); + bool *p = attn_mask.GetTensorMutableData(); + int32_t chunk_idx = + offset.GetTensorData()[0] / chunk_size - left_chunks; + if (chunk_idx < left_chunks) { + std::fill(p, p + required_cache_size_ - chunk_idx * chunk_size, 0); + std::fill(p + required_cache_size_ - chunk_idx * chunk_size, + p + attn_mask_shape[2], 1); + } else { + std::fill(p, p + attn_mask_shape[2], 1); + } + + std::array inputs = {std::move(x), + View(&offset), + View(&required_cache_size_tensor_), + std::move(attn_cache), + std::move(conv_cache), + std::move(attn_mask)}; + + auto out = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + + offset.GetTensorMutableData()[0] += + out[0].GetTensorTypeAndShapeInfo().GetShape()[1]; + out.push_back(std::move(offset)); + + return out; + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t ChunkLength() const { + // When chunk_size is 16, subsampling_factor_ is 4, right_context_ is 6, + // the returned value is (16 - 1)*4 + 6 + 1 = 67 + return (config_.wenet_ctc.chunk_size - 1) * subsampling_factor_ + + right_context_ + 1; + } + + int32_t ChunkShift() const { return required_cache_size_; } + + OrtAllocator *Allocator() const { return allocator_; } + + // Return a vector containing 3 tensors + // - attn_cache + // - conv_cache + // - offset + std::vector GetInitStates() const { + std::vector ans; + ans.reserve(3); + ans.push_back(Clone(Allocator(), &attn_cache_)); + ans.push_back(Clone(Allocator(), &conv_cache_)); + + int64_t offset_shape = 1; + + Ort::Value offset = + Ort::Value::CreateTensor(allocator_, &offset_shape, 1); + + offset.GetTensorMutableData()[0] = required_cache_size_; + + ans.push_back(std::move(offset)); + + return ans; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(head_, "head"); + SHERPA_ONNX_READ_META_DATA(num_blocks_, "num_blocks"); + SHERPA_ONNX_READ_META_DATA(output_size_, "output_size"); + SHERPA_ONNX_READ_META_DATA(cnn_module_kernel_, "cnn_module_kernel"); + SHERPA_ONNX_READ_META_DATA(right_context_, "right_context"); + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + + required_cache_size_ = + config_.wenet_ctc.chunk_size * config_.wenet_ctc.num_left_chunks; + + InitStates(); + } + + void InitStates() { + std::array attn_cache_shape{ + num_blocks_, head_, required_cache_size_, output_size_ / head_ * 2}; + attn_cache_ = Ort::Value::CreateTensor( + allocator_, attn_cache_shape.data(), attn_cache_shape.size()); + + Fill(&attn_cache_, 0); + + std::array conv_cache_shape{num_blocks_, 1, output_size_, + cnn_module_kernel_ - 1}; + conv_cache_ = Ort::Value::CreateTensor( + allocator_, conv_cache_shape.data(), conv_cache_shape.size()); + + Fill(&conv_cache_, 0); + + int64_t shape = 1; + required_cache_size_tensor_ = + Ort::Value::CreateTensor(allocator_, &shape, 1); + + required_cache_size_tensor_.GetTensorMutableData()[0] = + required_cache_size_; + } + + private: + OnlineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t head_; + int32_t num_blocks_; + int32_t output_size_; + int32_t cnn_module_kernel_; + int32_t right_context_; + int32_t subsampling_factor_; + int32_t vocab_size_; + + int32_t required_cache_size_; + + Ort::Value attn_cache_{nullptr}; + Ort::Value conv_cache_{nullptr}; + Ort::Value required_cache_size_tensor_{nullptr}; +}; + +OnlineWenetCtcModel::OnlineWenetCtcModel(const OnlineModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OnlineWenetCtcModel::OnlineWenetCtcModel(AAssetManager *mgr, + const OnlineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OnlineWenetCtcModel::~OnlineWenetCtcModel() = default; + +std::vector OnlineWenetCtcModel::Forward( + Ort::Value x, std::vector states) const { + return impl_->Forward(std::move(x), std::move(states)); +} + +int32_t OnlineWenetCtcModel::VocabSize() const { return impl_->VocabSize(); } + +int32_t OnlineWenetCtcModel::ChunkLength() const { + return impl_->ChunkLength(); +} + +int32_t OnlineWenetCtcModel::ChunkShift() const { return impl_->ChunkShift(); } + +OrtAllocator *OnlineWenetCtcModel::Allocator() const { + return impl_->Allocator(); +} + +std::vector OnlineWenetCtcModel::GetInitStates() const { + return impl_->GetInitStates(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model.h b/sherpa-onnx/csrc/online-wenet-ctc-model.h new file mode 100644 index 000000000..ccfd378b6 --- /dev/null +++ b/sherpa-onnx/csrc/online-wenet-ctc-model.h @@ -0,0 +1,73 @@ +// sherpa-onnx/csrc/online-wenet-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_H_ + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-ctc-model.h" +#include "sherpa-onnx/csrc/online-model-config.h" + +namespace sherpa_onnx { + +class OnlineWenetCtcModel : public OnlineCtcModel { + public: + explicit OnlineWenetCtcModel(const OnlineModelConfig &config); + +#if __ANDROID_API__ >= 9 + OnlineWenetCtcModel(AAssetManager *mgr, const OnlineModelConfig &config); +#endif + + ~OnlineWenetCtcModel() override; + + // A list of 3 tensors: + // - attn_cache + // - conv_cache + // - offset + std::vector GetInitStates() const override; + + /** + * + * @param x A 3-D tensor of shape (N, T, C). N has to be 1. + * @param states It is from GetInitStates() or returned from this method. + * + * @return Return a list of tensors + * - ans[0] contains log_probs, of shape (N, T, C) + * - ans[1:] contains next_states + */ + std::vector Forward( + Ort::Value x, std::vector states) const override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const override; + + // The model accepts this number of frames before subsampling as input + int32_t ChunkLength() const override; + + // Similar to frame_shift in feature extractor, after processing + // ChunkLength() frames, we advance by ChunkShift() frames + // before we process the next chunk. + int32_t ChunkShift() const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_H_ diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 80f01e49c..2d5c20823 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -125,6 +125,34 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { } } +Ort::Value View(Ort::Value *v) { + auto type_and_shape = v->GetTensorTypeAndShapeInfo(); + std::vector shape = type_and_shape.GetShape(); + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + switch (type_and_shape.GetElementType()) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return Ort::Value::CreateTensor( + memory_info, v->GetTensorMutableData(), + type_and_shape.GetElementCount(), shape.data(), shape.size()); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return Ort::Value::CreateTensor( + memory_info, v->GetTensorMutableData(), + type_and_shape.GetElementCount(), shape.data(), shape.size()); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return Ort::Value::CreateTensor( + memory_info, v->GetTensorMutableData(), + type_and_shape.GetElementCount(), shape.data(), shape.size()); + default: + fprintf(stderr, "Unsupported type: %d\n", + static_cast(type_and_shape.GetElementType())); + exit(-1); + // unreachable code + return Ort::Value{nullptr}; + } +} + void Print1D(Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); const float *d = v->GetTensorData(); diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index 3dc0e0fc5..0b7fcc750 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -65,6 +65,9 @@ void PrintModelMetadata(std::ostream &os, // Return a deep copy of v Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); +// Return a shallow copy +Ort::Value View(Ort::Value *v); + // Print a 1-D tensor to stderr void Print1D(Ort::Value *v); diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 26c44db07..e346922e5 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -26,6 +26,7 @@ pybind11_add_module(_sherpa_onnx online-recognizer.cc online-stream.cc online-transducer-model-config.cc + online-wenet-ctc-model-config.cc sherpa-onnx.cc silero-vad-model-config.cc vad-model-config.cc diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index 8699e56d0..bd4c6798f 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -11,24 +11,29 @@ #include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" namespace sherpa_onnx { void PybindOnlineModelConfig(py::module *m) { PybindOnlineTransducerModelConfig(m); PybindOnlineParaformerModelConfig(m); + PybindOnlineWenetCtcModelConfig(m); using PyClass = OnlineModelConfig; py::class_(*m, "OnlineModelConfig") .def(py::init(), py::arg("transducer") = OnlineTransducerModelConfig(), py::arg("paraformer") = OnlineParaformerModelConfig(), + py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, py::arg("provider") = "cpu", py::arg("model_type") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) + .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) diff --git a/sherpa-onnx/python/csrc/online-wenet-ctc-model-config.cc b/sherpa-onnx/python/csrc/online-wenet-ctc-model-config.cc new file mode 100644 index 000000000..a05a782ae --- /dev/null +++ b/sherpa-onnx/python/csrc/online-wenet-ctc-model-config.cc @@ -0,0 +1,25 @@ +// sherpa-onnx/python/csrc/online-wenet-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" + +#include +#include + +#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" + +namespace sherpa_onnx { + +void PybindOnlineWenetCtcModelConfig(py::module *m) { + using PyClass = OnlineWenetCtcModelConfig; + py::class_(*m, "OnlineWenetCtcModelConfig") + .def(py::init(), py::arg("model"), + py::arg("chunk_size") = 16, py::arg("num_left_chunks") = 4) + .def_readwrite("model", &PyClass::model) + .def_readwrite("chunk_size", &PyClass::chunk_size) + .def_readwrite("num_left_chunks", &PyClass::num_left_chunks) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h b/sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h new file mode 100644 index 000000000..2026537a9 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOnlineWenetCtcModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_