From 5808121c26c3b0b73f524b154f3a115096213e94 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 9 Apr 2024 17:01:37 +0800 Subject: [PATCH 01/10] Begin to add audio tagging --- nodejs-examples/test-offline-tts-zh.js | 2 +- sherpa-onnx/csrc/CMakeLists.txt | 6 +++ sherpa-onnx/csrc/audio-tagging-impl.cc | 14 ++++++ sherpa-onnx/csrc/audio-tagging-impl.h | 32 +++++++++++++ .../csrc/audio-tagging-model-config.cc | 0 sherpa-onnx/csrc/audio-tagging-model-config.h | 0 .../csrc/audio-tagging-model-zipformer2.cc | 0 .../csrc/audio-tagging-model-zipformer2.h | 0 .../csrc/audio-tagging-zipformer-impl.h | 0 sherpa-onnx/csrc/audio-tagging.cc | 25 ++++++++++ sherpa-onnx/csrc/audio-tagging.h | 46 +++++++++++++++++++ 11 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 sherpa-onnx/csrc/audio-tagging-impl.cc create mode 100644 sherpa-onnx/csrc/audio-tagging-impl.h create mode 100644 sherpa-onnx/csrc/audio-tagging-model-config.cc create mode 100644 sherpa-onnx/csrc/audio-tagging-model-config.h create mode 100644 sherpa-onnx/csrc/audio-tagging-model-zipformer2.cc create mode 100644 sherpa-onnx/csrc/audio-tagging-model-zipformer2.h create mode 100644 sherpa-onnx/csrc/audio-tagging-zipformer-impl.h create mode 100644 sherpa-onnx/csrc/audio-tagging.cc create mode 100644 sherpa-onnx/csrc/audio-tagging.h diff --git a/nodejs-examples/test-offline-tts-zh.js b/nodejs-examples/test-offline-tts-zh.js index a53748c77..d777d490e 100644 --- a/nodejs-examples/test-offline-tts-zh.js +++ b/nodejs-examples/test-offline-tts-zh.js @@ -4,7 +4,7 @@ const sherpa_onnx = require('sherpa-onnx'); function createOfflineTts() { let offlineTtsVitsModelConfig = { - model: './vits-icefall-zh-aishell3/vits-aishell3.onnx', + model: './vits-icefall-zh-aishell3/model.onnx', lexicon: './vits-icefall-zh-aishell3/lexicon.txt', tokens: './vits-icefall-zh-aishell3/tokens.txt', dataDir: '', diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index bedd1ed2a..74dbe66ee 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -111,6 +111,12 @@ list(APPEND sources speaker-embedding-manager.cc ) +# audio tagging +list(APPEND sources + audio-tagging-impl.cc + audio-tagging.cc +) + if(SHERPA_ONNX_ENABLE_TTS) list(APPEND sources lexicon.cc diff --git a/sherpa-onnx/csrc/audio-tagging-impl.cc b/sherpa-onnx/csrc/audio-tagging-impl.cc new file mode 100644 index 000000000..0a8526407 --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging-impl.cc @@ -0,0 +1,14 @@ +// sherpa-onnx/csrc/audio-tagging-impl.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/audio-tagging-impl.h" + +namespace sherpa_onnx { + +std::unique_ptr AudioTaggingImpl::Create( + const AudioTaggingConfig &config) { + return {}; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/audio-tagging-impl.h b/sherpa-onnx/csrc/audio-tagging-impl.h new file mode 100644 index 000000000..0d770cc26 --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging-impl.h @@ -0,0 +1,32 @@ +// sherpa-onnx/csrc/audio-tagging-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_ + +#include "sherpa-onnx/csrc/audio-tagging.h" + +namespace sherpa_onnx { + +class AudioTaggingImpl { + public: + virtual ~AudioTaggingImpl() = default; + + static std::unique_ptr Create( + const AudioTaggingConfig &config); + + virtual std::unique_ptr CreateStream() const { + return nullptr; + } + + virtual std::vector Compute(OfflineStream *s, + int32_t top_k = -1) const { + return {}; + } + + private: +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_ diff --git a/sherpa-onnx/csrc/audio-tagging-model-config.cc b/sherpa-onnx/csrc/audio-tagging-model-config.cc new file mode 100644 index 000000000..e69de29bb diff --git a/sherpa-onnx/csrc/audio-tagging-model-config.h b/sherpa-onnx/csrc/audio-tagging-model-config.h new file mode 100644 index 000000000..e69de29bb diff --git a/sherpa-onnx/csrc/audio-tagging-model-zipformer2.cc b/sherpa-onnx/csrc/audio-tagging-model-zipformer2.cc new file mode 100644 index 000000000..e69de29bb diff --git a/sherpa-onnx/csrc/audio-tagging-model-zipformer2.h b/sherpa-onnx/csrc/audio-tagging-model-zipformer2.h new file mode 100644 index 000000000..e69de29bb diff --git a/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h b/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h new file mode 100644 index 000000000..e69de29bb diff --git a/sherpa-onnx/csrc/audio-tagging.cc b/sherpa-onnx/csrc/audio-tagging.cc new file mode 100644 index 000000000..2c546cace --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging.cc @@ -0,0 +1,25 @@ +// sherpa-onnx/csrc/audio-tagging.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/audio-tagging.h" + +#include "sherpa-onnx/csrc/audio-tagging-impl.h" + +namespace sherpa_onnx { + +AudioTagging::AudioTagging(const AudioTaggingConfig &config) + : impl_(AudioTaggingImpl::Create(config)) {} + +AudioTagging::~AudioTagging() = default; + +std::unique_ptr AudioTagging::CreateStream() const { + return impl_->CreateStream(); +} + +std::vector AudioTagging::Compute(OfflineStream *s, + int32_t top_k /*= -1*/) const { + return impl_->Compute(s, top_k); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/audio-tagging.h b/sherpa-onnx/csrc/audio-tagging.h new file mode 100644 index 000000000..8856a8946 --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging.h @@ -0,0 +1,46 @@ +// sherpa-onnx/csrc/audio-tagging.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_ + +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-stream.h" + +namespace sherpa_onnx { + +struct AudioTaggingConfig { + int32_t top_k = 5; +}; + +struct AudioEvent { + std::string name; // name of the event + float prob; // probability of the event +}; + +class AudioTaggingImpl; + +class AudioTagging { + public: + explicit AudioTagging(const AudioTaggingConfig &config); + + ~AudioTagging(); + + std::unique_ptr CreateStream() const; + + // If top_k is -1, then config.top_k is used. + // Otherwise, config.top_k is ignored + // + // Return top_k AudioEvent. ans[0].prob is the largest of all returned events. + std::vector Compute(OfflineStream *s, int32_t top_k = -1) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_ From e97aaf062ccd42d6b7a8d6d02280e98d4ca94356 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 9 Apr 2024 19:35:02 +0800 Subject: [PATCH 02/10] add model config --- sherpa-onnx/csrc/CMakeLists.txt | 2 + sherpa-onnx/csrc/audio-tagging-impl.h | 3 ++ .../csrc/audio-tagging-model-config.cc | 30 ++++++++++++++ sherpa-onnx/csrc/audio-tagging-model-config.h | 31 ++++++++++++++ .../csrc/audio-tagging-model-zipformer2.cc | 0 .../csrc/audio-tagging-model-zipformer2.h | 0 .../csrc/audio-tagging-zipformer-impl.h | 6 +++ sherpa-onnx/csrc/audio-tagging.cc | 28 +++++++++++++ sherpa-onnx/csrc/audio-tagging.h | 14 +++++++ ...ne-zipformer-audio-tagging-model-config.cc | 40 +++++++++++++++++++ ...ine-zipformer-audio-tagging-model-config.h | 29 ++++++++++++++ 11 files changed, 183 insertions(+) delete mode 100644 sherpa-onnx/csrc/audio-tagging-model-zipformer2.cc delete mode 100644 sherpa-onnx/csrc/audio-tagging-model-zipformer2.h create mode 100644 sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc create mode 100644 sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 74dbe66ee..4e7d4ea6e 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -114,7 +114,9 @@ list(APPEND sources # audio tagging list(APPEND sources audio-tagging-impl.cc + audio-tagging-model-config.cc audio-tagging.cc + offline-zipformer-audio-tagging-model-config.cc ) if(SHERPA_ONNX_ENABLE_TTS) diff --git a/sherpa-onnx/csrc/audio-tagging-impl.h b/sherpa-onnx/csrc/audio-tagging-impl.h index 0d770cc26..4fba4c579 100644 --- a/sherpa-onnx/csrc/audio-tagging-impl.h +++ b/sherpa-onnx/csrc/audio-tagging-impl.h @@ -4,6 +4,9 @@ #ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_ #define SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_ +#include +#include + #include "sherpa-onnx/csrc/audio-tagging.h" namespace sherpa_onnx { diff --git a/sherpa-onnx/csrc/audio-tagging-model-config.cc b/sherpa-onnx/csrc/audio-tagging-model-config.cc index e69de29bb..64fbe1ac5 100644 --- a/sherpa-onnx/csrc/audio-tagging-model-config.cc +++ b/sherpa-onnx/csrc/audio-tagging-model-config.cc @@ -0,0 +1,30 @@ +// sherpa-onnx/csrc/audio-tagging-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/audio-tagging-model-config.h" + +namespace sherpa_onnx { + +void AudioTaggingModelConfig::Register(ParseOptions *po) { + zipformer.Register(po); +} + +bool AudioTaggingModelConfig::Validate() const { + if (!zipformer.model.empty() && !zipformer.Validate()) { + return false; + } + + return true; +} + +std::string AudioTaggingModelConfig::ToString() const { + std::ostringstream os; + + os << "AudioTaggingModelConfig("; + os << "zipformer=" << zipformer.ToString() << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/audio-tagging-model-config.h b/sherpa-onnx/csrc/audio-tagging-model-config.h index e69de29bb..35bfe391b 100644 --- a/sherpa-onnx/csrc/audio-tagging-model-config.h +++ b/sherpa-onnx/csrc/audio-tagging-model-config.h @@ -0,0 +1,31 @@ +// sherpa-onnx/csrc/audio-tagging-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct AudioTaggingModelConfig { + struct OfflineZipformerAudioTaggingModelConfig zipformer; + + AudioTaggingModelConfig() = default; + + explicit AudioTaggingModelConfig( + const OfflineZipformerAudioTaggingModelConfig &zipformer) + : zipformer(zipformer) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/audio-tagging-model-zipformer2.cc b/sherpa-onnx/csrc/audio-tagging-model-zipformer2.cc deleted file mode 100644 index e69de29bb..000000000 diff --git a/sherpa-onnx/csrc/audio-tagging-model-zipformer2.h b/sherpa-onnx/csrc/audio-tagging-model-zipformer2.h deleted file mode 100644 index e69de29bb..000000000 diff --git a/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h b/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h index e69de29bb..c3be51152 100644 --- a/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h +++ b/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h @@ -0,0 +1,6 @@ +// sherpa-onnx/csrc/audio-tagging-zipformer-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ diff --git a/sherpa-onnx/csrc/audio-tagging.cc b/sherpa-onnx/csrc/audio-tagging.cc index 2c546cace..94512170b 100644 --- a/sherpa-onnx/csrc/audio-tagging.cc +++ b/sherpa-onnx/csrc/audio-tagging.cc @@ -5,9 +5,37 @@ #include "sherpa-onnx/csrc/audio-tagging.h" #include "sherpa-onnx/csrc/audio-tagging-impl.h" +#include "sherpa-onnx/csrc/macros.h" namespace sherpa_onnx { +void AudioTaggingConfig::Register(ParseOptions *po) { + model.Register(po); + po->Register("top-k", &top_k, "Top k events to return in the result"); +} + +bool AudioTaggingConfig::Validate() const { + if (!model.Validate()) { + return false; + } + + if (top_k < 1) { + SHERPA_ONNX_LOGE("--top-k should be >= 1. Given: %d", top_k); + return false; + } + + return true; +} +std::string AudioTaggingConfig::ToString() const { + std::ostringstream os; + + os << "AudioTaggingConfig("; + os << "model=" << model.ToString() << ", "; + os << "top_k=" << top_k << ")"; + + return os.str(); +} + AudioTagging::AudioTagging(const AudioTaggingConfig &config) : impl_(AudioTaggingImpl::Create(config)) {} diff --git a/sherpa-onnx/csrc/audio-tagging.h b/sherpa-onnx/csrc/audio-tagging.h index 8856a8946..d3a78d9e4 100644 --- a/sherpa-onnx/csrc/audio-tagging.h +++ b/sherpa-onnx/csrc/audio-tagging.h @@ -8,12 +8,26 @@ #include #include +#include "sherpa-onnx/csrc/audio-tagging-model-config.h" #include "sherpa-onnx/csrc/offline-stream.h" +#include "sherpa-onnx/csrc/parse-options.h" namespace sherpa_onnx { struct AudioTaggingConfig { + AudioTaggingModelConfig model; + int32_t top_k = 5; + + AudioTaggingConfig() = default; + + AudioTaggingConfig(const AudioTaggingModelConfig &model, int32_t top_k) + : model(model), top_k(top_k) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; }; struct AudioEvent { diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc new file mode 100644 index 000000000..3034ff77f --- /dev/null +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc @@ -0,0 +1,40 @@ +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineZipformerAudioTaggingModelConfig::Register(ParseOptions *po) { + po->Register("zipformer-model", &model, + "Path to zipformer model for audio tagging"); +} + +bool OfflineZipformerAudioTaggingModelConfig::Validate() const { + if (model.empty()) { + SHERPA_ONNX_LOGE("Please provide --zipformer-model"); + return false; + } + + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("--zipformer-model: %s does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OfflineZipformerAudioTaggingModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineZipformerAudioTaggingModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h new file mode 100644 index 000000000..4f60e832e --- /dev/null +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineZipformerAudioTaggingModelConfig { + std::string model; + + OfflineZipformerAudioTaggingModelConfig() = default; + + explicit OfflineZipformerAudioTaggingModelConfig(const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_ From 6b6eb7a9675c411baaaebb6f496b35d5a8a91ae4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 9 Apr 2024 19:39:12 +0800 Subject: [PATCH 03/10] add audio zipformer impl --- sherpa-onnx/csrc/audio-tagging-impl.cc | 4 ++- sherpa-onnx/csrc/audio-tagging-impl.h | 10 ++----- .../csrc/audio-tagging-zipformer-impl.h | 26 +++++++++++++++++++ 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/sherpa-onnx/csrc/audio-tagging-impl.cc b/sherpa-onnx/csrc/audio-tagging-impl.cc index 0a8526407..75aa559b6 100644 --- a/sherpa-onnx/csrc/audio-tagging-impl.cc +++ b/sherpa-onnx/csrc/audio-tagging-impl.cc @@ -4,11 +4,13 @@ #include "sherpa-onnx/csrc/audio-tagging-impl.h" +#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h" + namespace sherpa_onnx { std::unique_ptr AudioTaggingImpl::Create( const AudioTaggingConfig &config) { - return {}; + return std::make_unique(config); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/audio-tagging-impl.h b/sherpa-onnx/csrc/audio-tagging-impl.h index 4fba4c579..e5e192457 100644 --- a/sherpa-onnx/csrc/audio-tagging-impl.h +++ b/sherpa-onnx/csrc/audio-tagging-impl.h @@ -18,16 +18,10 @@ class AudioTaggingImpl { static std::unique_ptr Create( const AudioTaggingConfig &config); - virtual std::unique_ptr CreateStream() const { - return nullptr; - } + virtual std::unique_ptr CreateStream() const = 0; virtual std::vector Compute(OfflineStream *s, - int32_t top_k = -1) const { - return {}; - } - - private: + int32_t top_k = -1) const = 0; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h b/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h index c3be51152..7369c9b35 100644 --- a/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h +++ b/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h @@ -3,4 +3,30 @@ // Copyright (c) 2024 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ #define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ + +#include + +#include "sherpa-onnx/csrc/audio-tagging-impl.h" +#include "sherpa-onnx/csrc/audio-tagging.h" + +namespace sherpa_onnx { + +class AudioTaggingZipformerImpl : public AudioTaggingImpl { + public: + explicit AudioTaggingZipformerImpl(const AudioTaggingConfig &config) + : config_(config) {} + + std::unique_ptr CreateStream() const override { return {}; } + + std::vector Compute(OfflineStream *s, + int32_t top_k = -1) const override { + return {}; + } + + private: + AudioTaggingConfig config_; +}; + +} // namespace sherpa_onnx + #endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ From 33ae6820d7e8b05401e4fcfa776919f21a4b0956 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 9 Apr 2024 19:51:12 +0800 Subject: [PATCH 04/10] begin to add audio tagging zipformer model --- .../csrc/audio-tagging-model-config.cc | 5 +- sherpa-onnx/csrc/audio-tagging-model-config.h | 14 +++- .../offline-zipformer-audio-tagging-model.cc | 0 .../offline-zipformer-audio-tagging-model.h | 71 +++++++++++++++++++ sherpa-onnx/csrc/session.cc | 4 ++ sherpa-onnx/csrc/session.h | 3 + 6 files changed, 93 insertions(+), 4 deletions(-) create mode 100644 sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc create mode 100644 sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h diff --git a/sherpa-onnx/csrc/audio-tagging-model-config.cc b/sherpa-onnx/csrc/audio-tagging-model-config.cc index 64fbe1ac5..f95bfadd5 100644 --- a/sherpa-onnx/csrc/audio-tagging-model-config.cc +++ b/sherpa-onnx/csrc/audio-tagging-model-config.cc @@ -22,7 +22,10 @@ std::string AudioTaggingModelConfig::ToString() const { std::ostringstream os; os << "AudioTaggingModelConfig("; - os << "zipformer=" << zipformer.ToString() << ")"; + os << "zipformer=" << zipformer.ToString() << ", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; return os.str(); } diff --git a/sherpa-onnx/csrc/audio-tagging-model-config.h b/sherpa-onnx/csrc/audio-tagging-model-config.h index 35bfe391b..862e9bf9e 100644 --- a/sherpa-onnx/csrc/audio-tagging-model-config.h +++ b/sherpa-onnx/csrc/audio-tagging-model-config.h @@ -14,11 +14,19 @@ namespace sherpa_onnx { struct AudioTaggingModelConfig { struct OfflineZipformerAudioTaggingModelConfig zipformer; + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + AudioTaggingModelConfig() = default; - explicit AudioTaggingModelConfig( - const OfflineZipformerAudioTaggingModelConfig &zipformer) - : zipformer(zipformer) {} + AudioTaggingModelConfig( + const OfflineZipformerAudioTaggingModelConfig &zipformer, + int32_t num_threads, bool debug, const std::string &provider) + : zipformer(zipformer), + num_threads(num_threads), + debug(debug), + provider(provider) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc new file mode 100644 index 000000000..e69de29bb diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h new file mode 100644 index 000000000..dab75e59a --- /dev/null +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h @@ -0,0 +1,71 @@ +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGER_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGER_MODEL_H_ +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-ctc-model.h" +#include "sherpa-onnx/csrc/offline-model-config.h" + +namespace sherpa_onnx { + +/** This class implements the zipformer CTC model of the librispeech recipe + * from icefall. + * + * See + * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py + */ +class OfflineZipformerAudioTaggingModel { + public: + explicit OfflineZipformerAudioTaggingModel( + const AudioTaggingModelConfig &config); + +#if __ANDROID_API__ >= 9 + OfflineZipformerCtcModel(AAssetManager *mgr, + const OfflineModelConfig &config); +#endif + + ~OfflineZipformerCtcModel() override; + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int64_t. + * + * @return Return a vector containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t + */ + std::vector Forward(Ort::Value features, + Ort::Value features_length) override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const override; + + int32_t SubsamplingFactor() const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGER_MODEL_H_ diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index aacd1e158..f08d9adfe 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -154,4 +154,8 @@ Ort::SessionOptions GetSessionOptions( return GetSessionOptionsImpl(config.num_threads, config.provider); } +Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 9bb3e4371..6dacf7da0 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -6,6 +6,7 @@ #define SHERPA_ONNX_CSRC_SESSION_H_ #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/audio-tagging-model-config.h" #include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-model-config.h" #include "sherpa-onnx/csrc/offline-tts-model-config.h" @@ -35,6 +36,8 @@ Ort::SessionOptions GetSessionOptions( Ort::SessionOptions GetSessionOptions( const SpokenLanguageIdentificationConfig &config); +Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_SESSION_H_ From 6e3abb563ba7ceeec68953e9556326ac0bb1478a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 10 Apr 2024 09:52:16 +0800 Subject: [PATCH 05/10] begin to add zipformer model for audio tagging --- sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h | 9 ++++----- sherpa-onnx/csrc/session.cc | 2 ++ sherpa-onnx/csrc/session.h | 7 ++++++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h index dab75e59a..e8ae8753f 100644 --- a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h @@ -14,8 +14,7 @@ #endif #include "onnxruntime_cxx_api.h" // NOLINT -#include "sherpa-onnx/csrc/offline-ctc-model.h" -#include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/csrc/audio-tagging-model-config.h" namespace sherpa_onnx { @@ -31,11 +30,11 @@ class OfflineZipformerAudioTaggingModel { const AudioTaggingModelConfig &config); #if __ANDROID_API__ >= 9 - OfflineZipformerCtcModel(AAssetManager *mgr, - const OfflineModelConfig &config); + OfflineZipformerAudioTaggingModel(AAssetManager *mgr, + const AudioTaggingModelConfig &config); #endif - ~OfflineZipformerCtcModel() override; + ~OfflineZipformerAudioTaggingModel(); /** Run the forward method of the model. * diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index f08d9adfe..d555ed7a7 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -140,9 +140,11 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, config.provider); } +#if SHERPA_ONNX_ENABLE_TTS Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, config.provider); } +#endif Ort::SessionOptions GetSessionOptions( const SpeakerEmbeddingExtractorConfig &config) { diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 6dacf7da0..94f263fd9 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -9,13 +9,16 @@ #include "sherpa-onnx/csrc/audio-tagging-model-config.h" #include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-model-config.h" -#include "sherpa-onnx/csrc/offline-tts-model-config.h" #include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" #include "sherpa-onnx/csrc/spoken-language-identification.h" #include "sherpa-onnx/csrc/vad-model-config.h" +#if SHERPA_ONNX_ENABLE_TTS +#include "sherpa-onnx/csrc/offline-tts-model-config.h" +#endif + namespace sherpa_onnx { Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); @@ -28,7 +31,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); +#if SHERPA_ONNX_ENABLE_TTS Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); +#endif Ort::SessionOptions GetSessionOptions( const SpeakerEmbeddingExtractorConfig &config); From b0f939223159f2a7f7a42ce386f1ab8f5ff7bfdc Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 10 Apr 2024 12:25:32 +0800 Subject: [PATCH 06/10] First working example --- cmake/cmake_extension.py | 1 + go-api-examples/vad-asr-paraformer/.gitignore | 2 + sherpa-onnx/csrc/CMakeLists.txt | 3 + sherpa-onnx/csrc/audio-tagging-impl.cc | 9 +- .../csrc/audio-tagging-zipformer-impl.h | 59 ++++++++- sherpa-onnx/csrc/audio-tagging.cc | 9 ++ sherpa-onnx/csrc/audio-tagging.h | 3 + sherpa-onnx/csrc/math.h | 7 +- sherpa-onnx/csrc/offline-stream.cc | 2 +- sherpa-onnx/csrc/offline-stream.h | 5 +- .../offline-zipformer-audio-tagging-model.cc | 114 ++++++++++++++++++ .../offline-zipformer-audio-tagging-model.h | 16 +-- .../csrc/offline-zipformer-ctc-model.h | 1 - .../csrc/sherpa-onnx-offline-audio-tagging.cc | 85 +++++++++++++ 14 files changed, 292 insertions(+), 24 deletions(-) create mode 100644 go-api-examples/vad-asr-paraformer/.gitignore create mode 100644 sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py index 75b09a5c5..b78129b21 100644 --- a/cmake/cmake_extension.py +++ b/cmake/cmake_extension.py @@ -46,6 +46,7 @@ def enable_alsa(): def get_binaries(): binaries = [ "sherpa-onnx", + "sherpa-onnx-offline-audio-tagging", "sherpa-onnx-keyword-spotter", "sherpa-onnx-microphone", "sherpa-onnx-microphone-offline", diff --git a/go-api-examples/vad-asr-paraformer/.gitignore b/go-api-examples/vad-asr-paraformer/.gitignore new file mode 100644 index 000000000..66786c69b --- /dev/null +++ b/go-api-examples/vad-asr-paraformer/.gitignore @@ -0,0 +1,2 @@ +go.sum +vad-asr-paraformer diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 4e7d4ea6e..e527ea183 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -117,6 +117,7 @@ list(APPEND sources audio-tagging-model-config.cc audio-tagging.cc offline-zipformer-audio-tagging-model-config.cc + offline-zipformer-audio-tagging-model.cc ) if(SHERPA_ONNX_ENABLE_TTS) @@ -201,6 +202,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) + add_executable(sherpa-onnx-offline-audio-tagging sherpa-onnx-offline-audio-tagging.cc) if(SHERPA_ONNX_ENABLE_TTS) add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) @@ -212,6 +214,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) sherpa-onnx-offline sherpa-onnx-offline-parallel sherpa-onnx-offline-language-identification + sherpa-onnx-offline-audio-tagging ) if(SHERPA_ONNX_ENABLE_TTS) list(APPEND main_exes diff --git a/sherpa-onnx/csrc/audio-tagging-impl.cc b/sherpa-onnx/csrc/audio-tagging-impl.cc index 75aa559b6..33e8dbb78 100644 --- a/sherpa-onnx/csrc/audio-tagging-impl.cc +++ b/sherpa-onnx/csrc/audio-tagging-impl.cc @@ -5,12 +5,19 @@ #include "sherpa-onnx/csrc/audio-tagging-impl.h" #include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h" +#include "sherpa-onnx/csrc/macros.h" namespace sherpa_onnx { std::unique_ptr AudioTaggingImpl::Create( const AudioTaggingConfig &config) { - return std::make_unique(config); + if (!config.model.zipformer.model.empty()) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOG( + "Please specify an audio tagging model! Return a null pointer"); + return nullptr; } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h b/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h index 7369c9b35..6680bc066 100644 --- a/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h +++ b/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h @@ -8,23 +8,76 @@ #include "sherpa-onnx/csrc/audio-tagging-impl.h" #include "sherpa-onnx/csrc/audio-tagging.h" +#include "sherpa-onnx/csrc/math.h" +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h" namespace sherpa_onnx { class AudioTaggingZipformerImpl : public AudioTaggingImpl { public: explicit AudioTaggingZipformerImpl(const AudioTaggingConfig &config) - : config_(config) {} + : config_(config), model_(config.model) {} - std::unique_ptr CreateStream() const override { return {}; } + std::unique_ptr CreateStream() const override { + return std::make_unique(); + } std::vector Compute(OfflineStream *s, int32_t top_k = -1) const override { - return {}; + if (top_k < 0) { + top_k = config_.top_k; + } + + int32_t num_event_classes = model_.NumEventClasses(); + + if (top_k > num_event_classes) { + top_k = num_event_classes; + } + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + // WARNING(fangjun): It is fixed to 80 for all models from icefall + int32_t feat_dim = 80; + std::vector f = s->GetFrames(); + + int32_t num_frames = f.size() / feat_dim; + + std::array shape = {1, num_frames, feat_dim}; + + Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), + shape.data(), shape.size()); + + int64_t x_length_scalar = num_frames; + std::array x_length_shape = {1}; + Ort::Value x_length = + Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1, + x_length_shape.data(), x_length_shape.size()); + + Ort::Value probs = model_.Forward(std::move(x), std::move(x_length)); + + const float *p = probs.GetTensorData(); + + std::vector top_k_indexes = TopkIndex(p, num_event_classes, top_k); + + std::vector ans(top_k); + + int32_t i = 0; + + for (int32_t index : top_k_indexes) { + ans[i].index = index; + ans[i].prob = p[index]; + ans[i].name = ""; // TODO(fangjun): fix it + i += 1; + } + + return ans; } private: AudioTaggingConfig config_; + OfflineZipformerAudioTaggingModel model_; + ; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/audio-tagging.cc b/sherpa-onnx/csrc/audio-tagging.cc index 94512170b..0fe340e25 100644 --- a/sherpa-onnx/csrc/audio-tagging.cc +++ b/sherpa-onnx/csrc/audio-tagging.cc @@ -9,6 +9,15 @@ namespace sherpa_onnx { +std::string AudioEvent::ToString() const { + std::ostringstream os; + os << "AudioEvent("; + os << "name=\"" << name << "\", "; + os << "index=" << index << ", "; + os << "prob=" << prob << ")"; + return os.str(); +} + void AudioTaggingConfig::Register(ParseOptions *po) { model.Register(po); po->Register("top-k", &top_k, "Top k events to return in the result"); diff --git a/sherpa-onnx/csrc/audio-tagging.h b/sherpa-onnx/csrc/audio-tagging.h index d3a78d9e4..e57375331 100644 --- a/sherpa-onnx/csrc/audio-tagging.h +++ b/sherpa-onnx/csrc/audio-tagging.h @@ -32,7 +32,10 @@ struct AudioTaggingConfig { struct AudioEvent { std::string name; // name of the event + int32_t index; // index of the event in the label file float prob; // probability of the event + + std::string ToString() const; }; class AudioTaggingImpl; diff --git a/sherpa-onnx/csrc/math.h b/sherpa-onnx/csrc/math.h index ba01835fe..121a05aeb 100644 --- a/sherpa-onnx/csrc/math.h +++ b/sherpa-onnx/csrc/math.h @@ -97,8 +97,8 @@ void LogSoftmax(T *in, int32_t w, int32_t h) { } template -void SubtractBlank(T *in, int32_t w, int32_t h, - int32_t blank_idx, float blank_penalty) { +void SubtractBlank(T *in, int32_t w, int32_t h, int32_t blank_idx, + float blank_penalty) { for (int32_t i = 0; i != h; ++i) { in[blank_idx] -= blank_penalty; in += w; @@ -116,8 +116,7 @@ std::vector TopkIndex(const T *vec, int32_t size, int32_t topk) { }); int32_t k_num = std::min(size, topk); - std::vector index(vec_index.begin(), vec_index.begin() + k_num); - return index; + return {vec_index.begin(), vec_index.begin() + k_num}; } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-stream.cc b/sherpa-onnx/csrc/offline-stream.cc index 08e601363..0eea103c9 100644 --- a/sherpa-onnx/csrc/offline-stream.cc +++ b/sherpa-onnx/csrc/offline-stream.cc @@ -234,7 +234,7 @@ OfflineStream::OfflineStream( : impl_(std::make_unique(config, context_graph)) {} OfflineStream::OfflineStream(WhisperTag tag, - ContextGraphPtr context_graph /*= nullptr*/) + ContextGraphPtr context_graph /*= {}*/) : impl_(std::make_unique(tag, context_graph)) {} OfflineStream::~OfflineStream() = default; diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h index 26b890b60..08ddbd316 100644 --- a/sherpa-onnx/csrc/offline-stream.h +++ b/sherpa-onnx/csrc/offline-stream.h @@ -71,10 +71,9 @@ struct WhisperTag {}; class OfflineStream { public: explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, - ContextGraphPtr context_graph = nullptr); + ContextGraphPtr context_graph = {}); - explicit OfflineStream(WhisperTag tag, - ContextGraphPtr context_graph = nullptr); + explicit OfflineStream(WhisperTag tag, ContextGraphPtr context_graph = {}); ~OfflineStream(); /** diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc index e69de29bb..9cefcf663 100644 --- a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc @@ -0,0 +1,114 @@ +// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h" + +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflineZipformerAudioTaggingModel::Impl { + public: + explicit Impl(const AudioTaggingModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.zipformer.model); + Init(buf.data(), buf.size()); + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.zipformer.model); + Init(buf.data(), buf.size()); + } +#endif + + Ort::Value Forward(Ort::Value features, Ort::Value features_length) { + std::array inputs = {std::move(features), + std::move(features_length)}; + + auto ans = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + return std::move(ans[0]); + } + + int32_t NumEventClasses() const { return num_event_classes_; } + + OrtAllocator *Allocator() const { return allocator_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + // get vocab size from the output[0].shape, which is (N, num_event_classes) + num_event_classes_ = + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1]; + } + + private: + AudioTaggingModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t num_event_classes_ = 0; +}; + +OfflineZipformerAudioTaggingModel::OfflineZipformerAudioTaggingModel( + const AudioTaggingModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OfflineZipformerAudioTaggingModel::OfflineZipformerAudioTaggingModel( + AAssetManager *mgr, const AudioTaggingModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OfflineZipformerAudioTaggingModel::~OfflineZipformerAudioTaggingModel() = + default; + +Ort::Value OfflineZipformerAudioTaggingModel::Forward( + Ort::Value features, Ort::Value features_length) const { + return impl_->Forward(std::move(features), std::move(features_length)); +} + +int32_t OfflineZipformerAudioTaggingModel::NumEventClasses() const { + return impl_->NumEventClasses(); +} + +OrtAllocator *OfflineZipformerAudioTaggingModel::Allocator() const { + return impl_->Allocator(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h index e8ae8753f..ce8220762 100644 --- a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h @@ -4,9 +4,7 @@ #ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGER_MODEL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGER_MODEL_H_ #include -#include #include -#include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" @@ -43,22 +41,18 @@ class OfflineZipformerAudioTaggingModel { * valid frames in `features` before padding. * Its dtype is int64_t. * - * @return Return a vector containing: - * - log_probs: A 3-D tensor of shape (N, T', vocab_size). - * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t + * @return Return a tensor + * - probs: A 2-D tensor of shape (N, num_event_classes). */ - std::vector Forward(Ort::Value features, - Ort::Value features_length) override; + Ort::Value Forward(Ort::Value features, Ort::Value features_length) const; /** Return the vocabulary size of the model */ - int32_t VocabSize() const override; + int32_t NumEventClasses() const; /** Return an allocator for allocating memory */ - OrtAllocator *Allocator() const override; - - int32_t SubsamplingFactor() const override; + OrtAllocator *Allocator() const; private: class Impl; diff --git a/sherpa-onnx/csrc/offline-zipformer-ctc-model.h b/sherpa-onnx/csrc/offline-zipformer-ctc-model.h index e3b9a05ce..c4e835636 100644 --- a/sherpa-onnx/csrc/offline-zipformer-ctc-model.h +++ b/sherpa-onnx/csrc/offline-zipformer-ctc-model.h @@ -4,7 +4,6 @@ #ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ #include -#include #include #include diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc new file mode 100644 index 000000000..567dc59a4 --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc @@ -0,0 +1,85 @@ +// sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include + +#include "sherpa-onnx/csrc/audio-tagging.h" +#include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/wave-reader.h" + +int32_t main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Audio tagging from a file. + +Usage: + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 + +./bin/sherpa-onnx-offline-audio-tagging \ + --zipformer-model=./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx + sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/0.wav + +Input wave files should be of single channel, 16-bit PCM encoded wave file; its +sampling rate can be arbitrary and does not need to be 16kHz. + +Please see +https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models +for more models. +)usage"; + + sherpa_onnx::ParseOptions po(kUsageMessage); + sherpa_onnx::AudioTaggingConfig config; + config.Register(&po); + po.Read(argc, argv); + + if (po.NumArgs() != 1) { + fprintf(stderr, "\nError: Please provide 1 wave file\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + sherpa_onnx::AudioTagging tagger(config); + std::string wav_filename = po.GetArg(1); + + int32_t sampling_rate = -1; + + bool is_ok = false; + const std::vector samples = + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); + + if (!is_ok) { + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); + return -1; + } + + const float duration = samples.size() / static_cast(sampling_rate); + + fprintf(stderr, "Start to compute\n"); + const auto start = std::chrono::steady_clock::now(); + + auto stream = tagger.CreateStream(); + + stream->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + + auto results = tagger.Compute(stream.get()); + const auto end = std::chrono::steady_clock::now(); + fprintf(stderr, "Done\n"); + + int32_t i = 0; + + for (const auto &event : results) { + fprintf(stderr, "%d: %s\n", i, event.ToString().c_str()); + i += 1; + } + + return 0; +} From e17a4658ac4116f962a6d13f1f2ea6ee18f32c5c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 10 Apr 2024 13:09:02 +0800 Subject: [PATCH 07/10] add audio tagging labels --- sherpa-onnx/csrc/CMakeLists.txt | 1 + sherpa-onnx/csrc/audio-tagging-label-file.cc | 70 +++++++++++++++++++ sherpa-onnx/csrc/audio-tagging-label-file.h | 31 ++++++++ .../csrc/audio-tagging-zipformer-impl.h | 16 ++++- sherpa-onnx/csrc/audio-tagging.cc | 12 ++++ sherpa-onnx/csrc/audio-tagging.h | 6 +- .../offline-zipformer-audio-tagging-model.cc | 3 + .../offline-zipformer-audio-tagging-model.h | 6 +- .../csrc/sherpa-onnx-offline-audio-tagging.cc | 13 +++- 9 files changed, 149 insertions(+), 9 deletions(-) create mode 100644 sherpa-onnx/csrc/audio-tagging-label-file.cc create mode 100644 sherpa-onnx/csrc/audio-tagging-label-file.h diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index e527ea183..5b2e5941c 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -114,6 +114,7 @@ list(APPEND sources # audio tagging list(APPEND sources audio-tagging-impl.cc + audio-tagging-label-file.cc audio-tagging-model-config.cc audio-tagging.cc offline-zipformer-audio-tagging-model-config.cc diff --git a/sherpa-onnx/csrc/audio-tagging-label-file.cc b/sherpa-onnx/csrc/audio-tagging-label-file.cc new file mode 100644 index 000000000..24846a174 --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging-label-file.cc @@ -0,0 +1,70 @@ +// sherpa-onnx/csrc/audio-tagging-label-file.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/audio-tagging-label-file.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +AudioTaggingLabels::AudioTaggingLabels(const std::string &filename) { + std::ifstream is(filename); + Init(is); +} + +// Format of a label file +/* +index,mid,display_name +0,/m/09x0r,"Speech" +1,/m/05zppz,"Male speech, man speaking" +*/ +void AudioTaggingLabels::Init(std::istream &is) { + std::string line; + std::getline(is, line); // skip the header + + std::string index; + std::string tmp; + std::string name; + + while (std::getline(is, line)) { + index.clear(); + name.clear(); + std::istringstream input2(line); + + std::getline(input2, index, ','); + std::getline(input2, tmp, ','); + std::getline(input2, name); + + std::size_t pos{}; + int32_t i = std::stoi(index, &pos); + if (index.size() == 0 || pos != index.size()) { + SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str()); + exit(-1); + } + + if (i != names_.size()) { + SHERPA_ONNX_LOGE( + "Index should be sorted and contiguous. Expected index: %d, given: " + "%d.", + static_cast(names_.size()), i); + } + if (name.empty() || name.front() != '"' || name.back() != '"') { + SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str()); + exit(-1); + } + + names_.emplace_back(name.begin() + 1, name.end() - 1); + } +} + +const std::string &AudioTaggingLabels::GetEventName(int32_t index) const { + return names_.at(index); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/audio-tagging-label-file.h b/sherpa-onnx/csrc/audio-tagging-label-file.h new file mode 100644 index 000000000..9e71557f5 --- /dev/null +++ b/sherpa-onnx/csrc/audio-tagging-label-file.h @@ -0,0 +1,31 @@ +// sherpa-onnx/csrc/audio-tagging-label-file.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_ + +#include +#include +#include + +namespace sherpa_onnx { + +class AudioTaggingLabels { + public: + explicit AudioTaggingLabels(const std::string &filename); + + // Return the event name for the given index. + // The returned reference is valid as long as this object is alive + const std::string &GetEventName(int32_t index) const; + int32_t NumEventClasses() const { return names_.size(); } + + private: + void Init(std::istream &is); + + private: + std::vector names_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_ diff --git a/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h b/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h index 6680bc066..639f644c8 100644 --- a/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h +++ b/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h @@ -5,9 +5,13 @@ #define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ #include +#include +#include #include "sherpa-onnx/csrc/audio-tagging-impl.h" +#include "sherpa-onnx/csrc/audio-tagging-label-file.h" #include "sherpa-onnx/csrc/audio-tagging.h" +#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/math.h" #include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h" @@ -16,7 +20,13 @@ namespace sherpa_onnx { class AudioTaggingZipformerImpl : public AudioTaggingImpl { public: explicit AudioTaggingZipformerImpl(const AudioTaggingConfig &config) - : config_(config), model_(config.model) {} + : config_(config), model_(config.model), labels_(config.labels) { + if (model_.NumEventClasses() != labels_.NumEventClasses()) { + SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)", + model_.NumEventClasses(), labels_.NumEventClasses()); + exit(-1); + } + } std::unique_ptr CreateStream() const override { return std::make_unique(); @@ -65,9 +75,9 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl { int32_t i = 0; for (int32_t index : top_k_indexes) { + ans[i].name = labels_.GetEventName(index); ans[i].index = index; ans[i].prob = p[index]; - ans[i].name = ""; // TODO(fangjun): fix it i += 1; } @@ -77,7 +87,7 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl { private: AudioTaggingConfig config_; OfflineZipformerAudioTaggingModel model_; - ; + AudioTaggingLabels labels_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/audio-tagging.cc b/sherpa-onnx/csrc/audio-tagging.cc index 0fe340e25..ee1e2b3f7 100644 --- a/sherpa-onnx/csrc/audio-tagging.cc +++ b/sherpa-onnx/csrc/audio-tagging.cc @@ -5,6 +5,7 @@ #include "sherpa-onnx/csrc/audio-tagging.h" #include "sherpa-onnx/csrc/audio-tagging-impl.h" +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" namespace sherpa_onnx { @@ -20,6 +21,7 @@ std::string AudioEvent::ToString() const { void AudioTaggingConfig::Register(ParseOptions *po) { model.Register(po); + po->Register("labels", &labels, "Event label file"); po->Register("top-k", &top_k, "Top k events to return in the result"); } @@ -33,6 +35,16 @@ bool AudioTaggingConfig::Validate() const { return false; } + if (labels.empty()) { + SHERPA_ONNX_LOGE("Please provide --labels"); + return false; + } + + if (!FileExists(labels)) { + SHERPA_ONNX_LOGE("--labels %s does not exist", labels.c_str()); + return false; + } + return true; } std::string AudioTaggingConfig::ToString() const { diff --git a/sherpa-onnx/csrc/audio-tagging.h b/sherpa-onnx/csrc/audio-tagging.h index e57375331..50cfea02c 100644 --- a/sherpa-onnx/csrc/audio-tagging.h +++ b/sherpa-onnx/csrc/audio-tagging.h @@ -16,13 +16,15 @@ namespace sherpa_onnx { struct AudioTaggingConfig { AudioTaggingModelConfig model; + std::string labels; int32_t top_k = 5; AudioTaggingConfig() = default; - AudioTaggingConfig(const AudioTaggingModelConfig &model, int32_t top_k) - : model(model), top_k(top_k) {} + AudioTaggingConfig(const AudioTaggingModelConfig &model, + const std::string &labels, int32_t top_k) + : model(model), labels(labels), top_k(top_k) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc index 9cefcf663..519821a03 100644 --- a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc @@ -4,6 +4,9 @@ #include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h" +#include +#include + #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" #include "sherpa-onnx/csrc/text-utils.h" diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h index ce8220762..d2ae6963a 100644 --- a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h @@ -1,8 +1,8 @@ // sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h // // Copyright (c) 2024 Xiaomi Corporation -#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGER_MODEL_H_ -#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGER_MODEL_H_ +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_ #include #include @@ -61,4 +61,4 @@ class OfflineZipformerAudioTaggingModel { } // namespace sherpa_onnx -#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGER_MODEL_H_ +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc index 567dc59a4..0a364cd23 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc @@ -64,7 +64,7 @@ for more models. const float duration = samples.size() / static_cast(sampling_rate); fprintf(stderr, "Start to compute\n"); - const auto start = std::chrono::steady_clock::now(); + const auto begin = std::chrono::steady_clock::now(); auto stream = tagger.CreateStream(); @@ -81,5 +81,16 @@ for more models. i += 1; } + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Num threads: %d\n", config.model.num_threads); + fprintf(stderr, "Wave duration: %.3f\n", duration); + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + return 0; } From 8ddc627a05ac225219ea405ad4a0e6db0a5f907c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 10 Apr 2024 13:17:52 +0800 Subject: [PATCH 08/10] add CI tests for audio tagging --- .github/scripts/test-audio-tagging.sh | 35 +++++++++++++++++++ .github/workflows/linux.yaml | 10 ++++++ .github/workflows/macos.yaml | 10 ++++++ .github/workflows/windows-x64.yaml | 10 ++++++ .github/workflows/windows-x86.yaml | 9 +++++ .../csrc/offline-zipformer-ctc-model.cc | 2 ++ .../csrc/sherpa-onnx-offline-audio-tagging.cc | 3 +- 7 files changed, 78 insertions(+), 1 deletion(-) create mode 100755 .github/scripts/test-audio-tagging.sh diff --git a/.github/scripts/test-audio-tagging.sh b/.github/scripts/test-audio-tagging.sh new file mode 100755 index 000000000..6ef1de465 --- /dev/null +++ b/.github/scripts/test-audio-tagging.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +echo "EXE is $EXE" +echo "PATH: $PATH" + +which $EXE + +log "------------------------------------------------------------" +log "Run zipformer for audio tagging " +log "------------------------------------------------------------" + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +repo=sherpa-onnx-zipformer-audio-tagging-2024-04-09 +ls -lh $repo + +$EXE \ + --zipformer-model=$repo/model.onnx + +for w in 1.wav 2.wav 3.wav 4.wav; do + $EXE \ + --zipformer-model=$repo/model.onnx \ + --labels=$repo/class_labels_indices.csv \ + $repo/test_wavs/$w +done +rm -rf $repo diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index b32362a3d..ae0aec470 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -15,6 +15,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -32,6 +33,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -124,6 +126,14 @@ jobs: name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} path: build/bin/* + - name: Test Audio tagging + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-audio-tagging + + .github/scripts/test-audio-tagging.sh + - name: Test online CTC shell: bash run: | diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index 0d0980619..9dfcb7c9d 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -15,6 +15,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-online-ctc.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -31,6 +32,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-online-ctc.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -103,6 +105,14 @@ jobs: otool -L build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx + - name: Test Audio tagging + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-audio-tagging + + .github/scripts/test-audio-tagging.sh + - name: Test C API shell: bash run: | diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index ea7cf7458..8f1715591 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -14,6 +14,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -28,6 +29,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -70,6 +72,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + - name: Test Audio tagging + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx-offline-audio-tagging.exe + + .github/scripts/test-audio-tagging.sh + - name: Test C API shell: bash run: | diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index 69ad7cd97..65d1bea62 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -14,6 +14,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-online-ctc.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -28,6 +29,7 @@ on: - '.github/scripts/test-offline-ctc.sh' - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-online-ctc.sh' + - '.github/scripts/test-audio-tagging.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -85,6 +87,13 @@ jobs: # export EXE=sherpa-onnx-offline-language-identification.exe # # .github/scripts/test-spoken-language-identification.sh + - name: Test Audio tagging + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx-offline-audio-tagging.exe + + .github/scripts/test-audio-tagging.sh - name: Test online CTC shell: bash diff --git a/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc b/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc index a82ef6255..8db9439e4 100644 --- a/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc @@ -4,6 +4,8 @@ #include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h" +#include + #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc index 0a364cd23..862818f5c 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc @@ -18,7 +18,8 @@ tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 ./bin/sherpa-onnx-offline-audio-tagging \ - --zipformer-model=./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx + --zipformer-model=./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx \ + --labels=./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv \ sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/0.wav Input wave files should be of single channel, 16-bit PCM encoded wave file; its From 158619116b84c909641916c6bd18190094b834bd Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 10 Apr 2024 14:21:23 +0800 Subject: [PATCH 09/10] fix typos --- .github/scripts/test-audio-tagging.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/scripts/test-audio-tagging.sh b/.github/scripts/test-audio-tagging.sh index 6ef1de465..57e6663fe 100755 --- a/.github/scripts/test-audio-tagging.sh +++ b/.github/scripts/test-audio-tagging.sh @@ -23,9 +23,6 @@ rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 repo=sherpa-onnx-zipformer-audio-tagging-2024-04-09 ls -lh $repo -$EXE \ - --zipformer-model=$repo/model.onnx - for w in 1.wav 2.wav 3.wav 4.wav; do $EXE \ --zipformer-model=$repo/model.onnx \ From 70284c6a5d59d3398ec777db84ac6c9b9554b29c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 10 Apr 2024 14:42:14 +0800 Subject: [PATCH 10/10] fix typos --- sherpa-onnx/csrc/audio-tagging-model-config.cc | 9 +++++++++ sherpa-onnx/csrc/audio-tagging.cc | 1 + .../csrc/offline-zipformer-audio-tagging-model.cc | 3 ++- sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h | 4 ++-- 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/sherpa-onnx/csrc/audio-tagging-model-config.cc b/sherpa-onnx/csrc/audio-tagging-model-config.cc index f95bfadd5..f1f526f80 100644 --- a/sherpa-onnx/csrc/audio-tagging-model-config.cc +++ b/sherpa-onnx/csrc/audio-tagging-model-config.cc @@ -8,6 +8,15 @@ namespace sherpa_onnx { void AudioTaggingModelConfig::Register(ParseOptions *po) { zipformer.Register(po); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); } bool AudioTaggingModelConfig::Validate() const { diff --git a/sherpa-onnx/csrc/audio-tagging.cc b/sherpa-onnx/csrc/audio-tagging.cc index ee1e2b3f7..34d558dd9 100644 --- a/sherpa-onnx/csrc/audio-tagging.cc +++ b/sherpa-onnx/csrc/audio-tagging.cc @@ -52,6 +52,7 @@ std::string AudioTaggingConfig::ToString() const { os << "AudioTaggingConfig("; os << "model=" << model.ToString() << ", "; + os << "labels=\"" << labels << "\", "; os << "top_k=" << top_k << ")"; return os.str(); diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc index 519821a03..8a2e80dc2 100644 --- a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc @@ -66,7 +66,8 @@ class OfflineZipformerAudioTaggingModel::Impl { SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); } - // get vocab size from the output[0].shape, which is (N, num_event_classes) + // get num_event_classes from the output[0].shape, + // which is (N, num_event_classes) num_event_classes_ = sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1]; } diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h index d2ae6963a..282823499 100644 --- a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h @@ -20,7 +20,7 @@ namespace sherpa_onnx { * from icefall. * * See - * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py + * https://github.com/k2-fsa/icefall/blob/master/egs/audioset/AT/zipformer/export-onnx.py */ class OfflineZipformerAudioTaggingModel { public: @@ -46,7 +46,7 @@ class OfflineZipformerAudioTaggingModel { */ Ort::Value Forward(Ort::Value features, Ort::Value features_length) const; - /** Return the vocabulary size of the model + /** Return the number of event classes of the model */ int32_t NumEventClasses() const;