diff --git a/.github/scripts/test-audio-tagging.sh b/.github/scripts/test-audio-tagging.sh new file mode 100755 index 000000000..57e6663fe --- /dev/null +++ b/.github/scripts/test-audio-tagging.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +echo "EXE is $EXE" +echo "PATH: $PATH" + +which $EXE + +log "------------------------------------------------------------" +log "Run zipformer for audio tagging " +log "------------------------------------------------------------" + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +repo=sherpa-onnx-zipformer-audio-tagging-2024-04-09 +ls -lh $repo + +for w in 1.wav 2.wav 3.wav 4.wav; do + $EXE \ + --zipformer-model=$repo/model.onnx \ + --labels=$repo/class_labels_indices.csv \ + $repo/test_wavs/$w +done +rm -rf $repo diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index b32362a3d..ae0aec470 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -15,6 +15,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -32,6 +33,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -124,6 +126,14 @@ jobs: name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} path: build/bin/* + - name: Test Audio tagging + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-audio-tagging + + .github/scripts/test-audio-tagging.sh + - name: Test online CTC shell: bash run: | diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index 0d0980619..9dfcb7c9d 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -15,6 +15,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-online-ctc.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -31,6 +32,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-online-ctc.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -103,6 +105,14 @@ jobs: otool -L build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx + - name: Test Audio tagging + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-audio-tagging + + .github/scripts/test-audio-tagging.sh + - name: Test C API shell: bash run: | diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index ea7cf7458..8f1715591 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -14,6 +14,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -28,6 +29,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -70,6 +72,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + - name: Test Audio tagging + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx-offline-audio-tagging.exe + + .github/scripts/test-audio-tagging.sh + - name: Test C API shell: bash run: | diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index 69ad7cd97..65d1bea62 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -14,6 +14,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-online-ctc.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -28,6 +29,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-online-ctc.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -85,6 +87,13 @@ jobs: # export EXE=sherpa-onnx-offline-language-identification.exe # # .github/scripts/test-spoken-language-identification.sh + - name: Test Audio tagging + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx-offline-audio-tagging.exe + + .github/scripts/test-audio-tagging.sh - name: Test online CTC shell: bash diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py index 75b09a5c5..b78129b21 100644 --- a/cmake/cmake_extension.py +++ b/cmake/cmake_extension.py @@ -46,6 +46,7 @@ def enable_alsa(): def get_binaries(): binaries = [ "sherpa-onnx", + "sherpa-onnx-offline-audio-tagging", "sherpa-onnx-keyword-spotter", "sherpa-onnx-microphone", "sherpa-onnx-microphone-offline", diff --git a/go-api-examples/vad-asr-paraformer/.gitignore b/go-api-examples/vad-asr-paraformer/.gitignore new file mode 100644 index 000000000..66786c69b --- /dev/null +++ b/go-api-examples/vad-asr-paraformer/.gitignore @@ -0,0 +1,2 @@ +go.sum +vad-asr-paraformer diff --git a/nodejs-examples/test-offline-tts-zh.js b/nodejs-examples/test-offline-tts-zh.js index a53748c77..d777d490e 100644 --- a/nodejs-examples/test-offline-tts-zh.js +++ b/nodejs-examples/test-offline-tts-zh.js @@ -4,7 +4,7 @@ const sherpa_onnx = require('sherpa-onnx'); function createOfflineTts() { let offlineTtsVitsModelConfig = { - model: './vits-icefall-zh-aishell3/vits-aishell3.onnx', + model: './vits-icefall-zh-aishell3/model.onnx', lexicon: './vits-icefall-zh-aishell3/lexicon.txt', tokens: './vits-icefall-zh-aishell3/tokens.txt', dataDir: '', diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index bedd1ed2a..5b2e5941c 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -111,6 +111,16 @@ list(APPEND sources speaker-embedding-manager.cc ) +# audio tagging +list(APPEND sources + audio-tagging-impl.cc + audio-tagging-label-file.cc + audio-tagging-model-config.cc + audio-tagging.cc + offline-zipformer-audio-tagging-model-config.cc + offline-zipformer-audio-tagging-model.cc +) + if(SHERPA_ONNX_ENABLE_TTS) list(APPEND sources lexicon.cc @@ -193,6 +203,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) + add_executable(sherpa-onnx-offline-audio-tagging sherpa-onnx-offline-audio-tagging.cc) if(SHERPA_ONNX_ENABLE_TTS) add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) @@ -204,6 +215,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) sherpa-onnx-offline sherpa-onnx-offline-parallel sherpa-onnx-offline-language-identification + sherpa-onnx-offline-audio-tagging ) if(SHERPA_ONNX_ENABLE_TTS) list(APPEND main_exes diff --git a/sherpa-onnx/csrc/audio-tagging-impl.cc b/sherpa-onnx/csrc/audio-tagging-impl.cc new file mode 100644 index 000000000..33e8dbb78 --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging-impl.cc @@ -0,0 +1,23 @@ +// sherpa-onnx/csrc/audio-tagging-impl.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/audio-tagging-impl.h" + +#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +std::unique_ptr AudioTaggingImpl::Create( + const AudioTaggingConfig &config) { + if (!config.model.zipformer.model.empty()) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOG( + "Please specify an audio tagging model! Return a null pointer"); + return nullptr; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/audio-tagging-impl.h b/sherpa-onnx/csrc/audio-tagging-impl.h new file mode 100644 index 000000000..e5e192457 --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging-impl.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/audio-tagging-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_ + +#include +#include + +#include "sherpa-onnx/csrc/audio-tagging.h" + +namespace sherpa_onnx { + +class AudioTaggingImpl { + public: + virtual ~AudioTaggingImpl() = default; + + static std::unique_ptr Create( + const AudioTaggingConfig &config); + + virtual std::unique_ptr CreateStream() const = 0; + + virtual std::vector Compute(OfflineStream *s, + int32_t top_k = -1) const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_ diff --git a/sherpa-onnx/csrc/audio-tagging-label-file.cc b/sherpa-onnx/csrc/audio-tagging-label-file.cc new file mode 100644 index 000000000..24846a174 --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging-label-file.cc @@ -0,0 +1,70 @@ +// sherpa-onnx/csrc/audio-tagging-label-file.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/audio-tagging-label-file.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +AudioTaggingLabels::AudioTaggingLabels(const std::string &filename) { + std::ifstream is(filename); + Init(is); +} + +// Format of a label file +/* +index,mid,display_name +0,/m/09x0r,"Speech" +1,/m/05zppz,"Male speech, man speaking" +*/ +void AudioTaggingLabels::Init(std::istream &is) { + std::string line; + std::getline(is, line); // skip the header + + std::string index; + std::string tmp; + std::string name; + + while (std::getline(is, line)) { + index.clear(); + name.clear(); + std::istringstream input2(line); + + std::getline(input2, index, ','); + std::getline(input2, tmp, ','); + std::getline(input2, name); + + std::size_t pos{}; + int32_t i = std::stoi(index, &pos); + if (index.size() == 0 || pos != index.size()) { + SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str()); + exit(-1); + } + + if (i != names_.size()) { + SHERPA_ONNX_LOGE( + "Index should be sorted and contiguous. Expected index: %d, given: " + "%d.", + static_cast(names_.size()), i); + } + if (name.empty() || name.front() != '"' || name.back() != '"') { + SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str()); + exit(-1); + } + + names_.emplace_back(name.begin() + 1, name.end() - 1); + } +} + +const std::string &AudioTaggingLabels::GetEventName(int32_t index) const { + return names_.at(index); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/audio-tagging-label-file.h b/sherpa-onnx/csrc/audio-tagging-label-file.h new file mode 100644 index 000000000..9e71557f5 --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging-label-file.h @@ -0,0 +1,31 @@ +// sherpa-onnx/csrc/audio-tagging-label-file.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_ + +#include +#include +#include + +namespace sherpa_onnx { + +class AudioTaggingLabels { + public: + explicit AudioTaggingLabels(const std::string &filename); + + // Return the event name for the given index. + // The returned reference is valid as long as this object is alive + const std::string &GetEventName(int32_t index) const; + int32_t NumEventClasses() const { return names_.size(); } + + private: + void Init(std::istream &is); + + private: + std::vector names_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_ diff --git a/sherpa-onnx/csrc/audio-tagging-model-config.cc b/sherpa-onnx/csrc/audio-tagging-model-config.cc new file mode 100644 index 000000000..f1f526f80 --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging-model-config.cc @@ -0,0 +1,42 @@ +// sherpa-onnx/csrc/audio-tagging-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/audio-tagging-model-config.h" + +namespace sherpa_onnx { + +void AudioTaggingModelConfig::Register(ParseOptions *po) { + zipformer.Register(po); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool AudioTaggingModelConfig::Validate() const { + if (!zipformer.model.empty() && !zipformer.Validate()) { + return false; + } + + return true; +} + +std::string AudioTaggingModelConfig::ToString() const { + std::ostringstream os; + + os << "AudioTaggingModelConfig("; + os << "zipformer=" << zipformer.ToString() << ", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/audio-tagging-model-config.h b/sherpa-onnx/csrc/audio-tagging-model-config.h new file mode 100644 index 000000000..862e9bf9e --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging-model-config.h @@ -0,0 +1,39 @@ +// sherpa-onnx/csrc/audio-tagging-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct AudioTaggingModelConfig { + struct OfflineZipformerAudioTaggingModelConfig zipformer; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + AudioTaggingModelConfig() = default; + + AudioTaggingModelConfig( + const OfflineZipformerAudioTaggingModelConfig &zipformer, + int32_t num_threads, bool debug, const std::string &provider) + : zipformer(zipformer), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h b/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h new file mode 100644 index 000000000..639f644c8 --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h @@ -0,0 +1,95 @@ +// sherpa-onnx/csrc/audio-tagging-zipformer-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ + +#include +#include +#include + +#include "sherpa-onnx/csrc/audio-tagging-impl.h" +#include "sherpa-onnx/csrc/audio-tagging-label-file.h" +#include "sherpa-onnx/csrc/audio-tagging.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/math.h" +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h" + +namespace sherpa_onnx { + +class AudioTaggingZipformerImpl : public AudioTaggingImpl { + public: + explicit AudioTaggingZipformerImpl(const AudioTaggingConfig &config) + : config_(config), model_(config.model), labels_(config.labels) { + if (model_.NumEventClasses() != labels_.NumEventClasses()) { + SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)", + model_.NumEventClasses(), labels_.NumEventClasses()); + exit(-1); + } + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(); + } + + std::vector Compute(OfflineStream *s, + int32_t top_k = -1) const override { + if (top_k < 0) { + top_k = config_.top_k; + } + + int32_t num_event_classes = model_.NumEventClasses(); + + if (top_k > num_event_classes) { + top_k = num_event_classes; + } + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + // WARNING(fangjun): It is fixed to 80 for all models from icefall + int32_t feat_dim = 80; + std::vector f = s->GetFrames(); + + int32_t num_frames = f.size() / feat_dim; + + std::array shape = {1, num_frames, feat_dim}; + + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), + shape.data(), shape.size()); + + int64_t x_length_scalar = num_frames; + std::array x_length_shape = {1}; + Ort::Value x_length = + Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1, + x_length_shape.data(), x_length_shape.size()); + + Ort::Value probs = model_.Forward(std::move(x), std::move(x_length)); + + const float *p = probs.GetTensorData(); + + std::vector top_k_indexes = TopkIndex(p, num_event_classes, top_k); + + std::vector ans(top_k); + + int32_t i = 0; + + for (int32_t index : top_k_indexes) { + ans[i].name = labels_.GetEventName(index); + ans[i].index = index; + ans[i].prob = p[index]; + i += 1; + } + + return ans; + } + + private: + AudioTaggingConfig config_; + OfflineZipformerAudioTaggingModel model_; + AudioTaggingLabels labels_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ diff --git a/sherpa-onnx/csrc/audio-tagging.cc b/sherpa-onnx/csrc/audio-tagging.cc new file mode 100644 index 000000000..34d558dd9 --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging.cc @@ -0,0 +1,75 @@ +// sherpa-onnx/csrc/audio-tagging.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/audio-tagging.h" + +#include "sherpa-onnx/csrc/audio-tagging-impl.h" +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +std::string AudioEvent::ToString() const { + std::ostringstream os; + os << "AudioEvent("; + os << "name=\"" << name << "\", "; + os << "index=" << index << ", "; + os << "prob=" << prob << ")"; + return os.str(); +} + +void AudioTaggingConfig::Register(ParseOptions *po) { + model.Register(po); + po->Register("labels", &labels, "Event label file"); + po->Register("top-k", &top_k, "Top k events to return in the result"); +} + +bool AudioTaggingConfig::Validate() const { + if (!model.Validate()) { + return false; + } + + if (top_k < 1) { + SHERPA_ONNX_LOGE("--top-k should be >= 1. Given: %d", top_k); + return false; + } + + if (labels.empty()) { + SHERPA_ONNX_LOGE("Please provide --labels"); + return false; + } + + if (!FileExists(labels)) { + SHERPA_ONNX_LOGE("--labels %s does not exist", labels.c_str()); + return false; + } + + return true; +} +std::string AudioTaggingConfig::ToString() const { + std::ostringstream os; + + os << "AudioTaggingConfig("; + os << "model=" << model.ToString() << ", "; + os << "labels=\"" << labels << "\", "; + os << "top_k=" << top_k << ")"; + + return os.str(); +} + +AudioTagging::AudioTagging(const AudioTaggingConfig &config) + : impl_(AudioTaggingImpl::Create(config)) {} + +AudioTagging::~AudioTagging() = default; + +std::unique_ptr AudioTagging::CreateStream() const { + return impl_->CreateStream(); +} + +std::vector AudioTagging::Compute(OfflineStream *s, + int32_t top_k /*= -1*/) const { + return impl_->Compute(s, top_k); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/audio-tagging.h b/sherpa-onnx/csrc/audio-tagging.h new file mode 100644 index 000000000..50cfea02c --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging.h @@ -0,0 +1,65 @@ +// sherpa-onnx/csrc/audio-tagging.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_ + +#include +#include +#include + +#include "sherpa-onnx/csrc/audio-tagging-model-config.h" +#include "sherpa-onnx/csrc/offline-stream.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct AudioTaggingConfig { + AudioTaggingModelConfig model; + std::string labels; + + int32_t top_k = 5; + + AudioTaggingConfig() = default; + + AudioTaggingConfig(const AudioTaggingModelConfig &model, + const std::string &labels, int32_t top_k) + : model(model), labels(labels), top_k(top_k) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +struct AudioEvent { + std::string name; // name of the event + int32_t index; // index of the event in the label file + float prob; // probability of the event + + std::string ToString() const; +}; + +class AudioTaggingImpl; + +class AudioTagging { + public: + explicit AudioTagging(const AudioTaggingConfig &config); + + ~AudioTagging(); + + std::unique_ptr CreateStream() const; + + // If top_k is -1, then config.top_k is used. + // Otherwise, config.top_k is ignored + // + // Return top_k AudioEvent. ans[0].prob is the largest of all returned events. + std::vector Compute(OfflineStream *s, int32_t top_k = -1) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_ diff --git a/sherpa-onnx/csrc/math.h b/sherpa-onnx/csrc/math.h index ba01835fe..121a05aeb 100644 --- a/sherpa-onnx/csrc/math.h +++ b/sherpa-onnx/csrc/math.h @@ -97,8 +97,8 @@ void LogSoftmax(T *in, int32_t w, int32_t h) { } template -void SubtractBlank(T *in, int32_t w, int32_t h, - int32_t blank_idx, float blank_penalty) { +void SubtractBlank(T *in, int32_t w, int32_t h, int32_t blank_idx, + float blank_penalty) { for (int32_t i = 0; i != h; ++i) { in[blank_idx] -= blank_penalty; in += w; @@ -116,8 +116,7 @@ std::vector TopkIndex(const T *vec, int32_t size, int32_t topk) { }); int32_t k_num = std::min(size, topk); - std::vector index(vec_index.begin(), vec_index.begin() + k_num); - return index; + return {vec_index.begin(), vec_index.begin() + k_num}; } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-stream.cc b/sherpa-onnx/csrc/offline-stream.cc index 08e601363..0eea103c9 100644 --- a/sherpa-onnx/csrc/offline-stream.cc +++ b/sherpa-onnx/csrc/offline-stream.cc @@ -234,7 +234,7 @@ OfflineStream::OfflineStream( : impl_(std::make_unique(config, context_graph)) {} OfflineStream::OfflineStream(WhisperTag tag, - ContextGraphPtr context_graph /*= nullptr*/) + ContextGraphPtr context_graph /*= {}*/) : impl_(std::make_unique(tag, context_graph)) {} OfflineStream::~OfflineStream() = default; diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h index 26b890b60..08ddbd316 100644 --- a/sherpa-onnx/csrc/offline-stream.h +++ b/sherpa-onnx/csrc/offline-stream.h @@ -71,10 +71,9 @@ struct WhisperTag {}; class OfflineStream { public: explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, - ContextGraphPtr context_graph = nullptr); + ContextGraphPtr context_graph = {}); - explicit OfflineStream(WhisperTag tag, - ContextGraphPtr context_graph = nullptr); + explicit OfflineStream(WhisperTag tag, ContextGraphPtr context_graph = {}); ~OfflineStream(); /** diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc new file mode 100644 index 000000000..3034ff77f --- /dev/null +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc @@ -0,0 +1,40 @@ +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineZipformerAudioTaggingModelConfig::Register(ParseOptions *po) { + po->Register("zipformer-model", &model, + "Path to zipformer model for audio tagging"); +} + +bool OfflineZipformerAudioTaggingModelConfig::Validate() const { + if (model.empty()) { + SHERPA_ONNX_LOGE("Please provide --zipformer-model"); + return false; + } + + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("--zipformer-model: %s does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OfflineZipformerAudioTaggingModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineZipformerAudioTaggingModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h new file mode 100644 index 000000000..4f60e832e --- /dev/null +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineZipformerAudioTaggingModelConfig { + std::string model; + + OfflineZipformerAudioTaggingModelConfig() = default; + + explicit OfflineZipformerAudioTaggingModelConfig(const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc new file mode 100644 index 000000000..8a2e80dc2 --- /dev/null +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc @@ -0,0 +1,118 @@ +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h" + +#include +#include + +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflineZipformerAudioTaggingModel::Impl { + public: + explicit Impl(const AudioTaggingModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.zipformer.model); + Init(buf.data(), buf.size()); + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.zipformer.model); + Init(buf.data(), buf.size()); + } +#endif + + Ort::Value Forward(Ort::Value features, Ort::Value features_length) { + std::array inputs = {std::move(features), + std::move(features_length)}; + + auto ans = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + return std::move(ans[0]); + } + + int32_t NumEventClasses() const { return num_event_classes_; } + + OrtAllocator *Allocator() const { return allocator_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + // get num_event_classes from the output[0].shape, + // which is (N, num_event_classes) + num_event_classes_ = + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1]; + } + + private: + AudioTaggingModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t num_event_classes_ = 0; +}; + +OfflineZipformerAudioTaggingModel::OfflineZipformerAudioTaggingModel( + const AudioTaggingModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OfflineZipformerAudioTaggingModel::OfflineZipformerAudioTaggingModel( + AAssetManager *mgr, const AudioTaggingModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OfflineZipformerAudioTaggingModel::~OfflineZipformerAudioTaggingModel() = + default; + +Ort::Value OfflineZipformerAudioTaggingModel::Forward( + Ort::Value features, Ort::Value features_length) const { + return impl_->Forward(std::move(features), std::move(features_length)); +} + +int32_t OfflineZipformerAudioTaggingModel::NumEventClasses() const { + return impl_->NumEventClasses(); +} + +OrtAllocator *OfflineZipformerAudioTaggingModel::Allocator() const { + return impl_->Allocator(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h new file mode 100644 index 000000000..282823499 --- /dev/null +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h @@ -0,0 +1,64 @@ +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_ +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/audio-tagging-model-config.h" + +namespace sherpa_onnx { + +/** This class implements the zipformer CTC model of the librispeech recipe + * from icefall. + * + * See + * https://github.com/k2-fsa/icefall/blob/master/egs/audioset/AT/zipformer/export-onnx.py + */ +class OfflineZipformerAudioTaggingModel { + public: + explicit OfflineZipformerAudioTaggingModel( + const AudioTaggingModelConfig &config); + +#if __ANDROID_API__ >= 9 + OfflineZipformerAudioTaggingModel(AAssetManager *mgr, + const AudioTaggingModelConfig &config); +#endif + + ~OfflineZipformerAudioTaggingModel(); + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int64_t. + * + * @return Return a tensor + * - probs: A 2-D tensor of shape (N, num_event_classes). + */ + Ort::Value Forward(Ort::Value features, Ort::Value features_length) const; + + /** Return the number of event classes of the model + */ + int32_t NumEventClasses() const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_ diff --git a/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc b/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc index a82ef6255..8db9439e4 100644 --- a/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc @@ -4,6 +4,8 @@ #include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h" +#include + #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-zipformer-ctc-model.h b/sherpa-onnx/csrc/offline-zipformer-ctc-model.h index e3b9a05ce..c4e835636 100644 --- a/sherpa-onnx/csrc/offline-zipformer-ctc-model.h +++ b/sherpa-onnx/csrc/offline-zipformer-ctc-model.h @@ -4,7 +4,6 @@ #ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ #include -#include #include #include diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index aacd1e158..d555ed7a7 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -140,9 +140,11 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, config.provider); } +#if SHERPA_ONNX_ENABLE_TTS Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, config.provider); } +#endif Ort::SessionOptions GetSessionOptions( const SpeakerEmbeddingExtractorConfig &config) { @@ -154,4 +156,8 @@ Ort::SessionOptions GetSessionOptions( return GetSessionOptionsImpl(config.num_threads, config.provider); } +Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 9bb3e4371..94f263fd9 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -6,15 +6,19 @@ #define SHERPA_ONNX_CSRC_SESSION_H_ #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/audio-tagging-model-config.h" #include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-model-config.h" -#include "sherpa-onnx/csrc/offline-tts-model-config.h" #include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" #include "sherpa-onnx/csrc/spoken-language-identification.h" #include "sherpa-onnx/csrc/vad-model-config.h" +#if SHERPA_ONNX_ENABLE_TTS +#include "sherpa-onnx/csrc/offline-tts-model-config.h" +#endif + namespace sherpa_onnx { Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); @@ -27,7 +31,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); +#if SHERPA_ONNX_ENABLE_TTS Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); +#endif Ort::SessionOptions GetSessionOptions( const SpeakerEmbeddingExtractorConfig &config); @@ -35,6 +41,8 @@ Ort::SessionOptions GetSessionOptions( Ort::SessionOptions GetSessionOptions( const SpokenLanguageIdentificationConfig &config); +Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_SESSION_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc new file mode 100644 index 000000000..862818f5c --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc @@ -0,0 +1,97 @@ +// sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include + +#include "sherpa-onnx/csrc/audio-tagging.h" +#include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/wave-reader.h" + +int32_t main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Audio tagging from a file. + +Usage: + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 + +./bin/sherpa-onnx-offline-audio-tagging \ + --zipformer-model=./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx \ + --labels=./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv \ + sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/0.wav + +Input wave files should be of single channel, 16-bit PCM encoded wave file; its +sampling rate can be arbitrary and does not need to be 16kHz. + +Please see +https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models +for more models. +)usage"; + + sherpa_onnx::ParseOptions po(kUsageMessage); + sherpa_onnx::AudioTaggingConfig config; + config.Register(&po); + po.Read(argc, argv); + + if (po.NumArgs() != 1) { + fprintf(stderr, "\nError: Please provide 1 wave file\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + sherpa_onnx::AudioTagging tagger(config); + std::string wav_filename = po.GetArg(1); + + int32_t sampling_rate = -1; + + bool is_ok = false; + const std::vector samples = + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); + + if (!is_ok) { + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); + return -1; + } + + const float duration = samples.size() / static_cast(sampling_rate); + + fprintf(stderr, "Start to compute\n"); + const auto begin = std::chrono::steady_clock::now(); + + auto stream = tagger.CreateStream(); + + stream->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + + auto results = tagger.Compute(stream.get()); + const auto end = std::chrono::steady_clock::now(); + fprintf(stderr, "Done\n"); + + int32_t i = 0; + + for (const auto &event : results) { + fprintf(stderr, "%d: %s\n", i, event.ToString().c_str()); + i += 1; + } + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Num threads: %d\n", config.model.num_threads); + fprintf(stderr, "Wave duration: %.3f\n", duration); + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + + return 0; +}