From a30ae88257849510aa8c491c4e4a29a828ee03de Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 7 Oct 2024 10:54:27 +0800 Subject: [PATCH 01/17] Begin to add C++ implementation for offline speaker diarization. --- sherpa-onnx/csrc/CMakeLists.txt | 3 ++ sherpa-onnx/csrc/macros.h | 2 + .../csrc/offline-speaker-diarization-impl.cc | 17 +++++++ .../csrc/offline-speaker-diarization-impl.h | 27 ++++++++++ .../offline-speaker-diarization-result.cc | 23 +++++++++ .../csrc/offline-speaker-diarization-result.h | 51 +++++++++++++++++++ .../csrc/offline-speaker-diarization.cc | 22 ++++++++ .../csrc/offline-speaker-diarization.h | 39 ++++++++++++++ 8 files changed, 184 insertions(+) create mode 100644 sherpa-onnx/csrc/offline-speaker-diarization-impl.cc create mode 100644 sherpa-onnx/csrc/offline-speaker-diarization-impl.h create mode 100644 sherpa-onnx/csrc/offline-speaker-diarization-result.cc create mode 100644 sherpa-onnx/csrc/offline-speaker-diarization-result.h create mode 100644 sherpa-onnx/csrc/offline-speaker-diarization.cc create mode 100644 sherpa-onnx/csrc/offline-speaker-diarization.h diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index e49fdeed4..bd52e5c35 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -164,6 +164,9 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) list(APPEND sources fast-clustering-config.cc fast-clustering.cc + offline-speaker-diarization-impl.cc + offline-speaker-diarization-result.cc + offline-speaker-diarization.cc ) endif() diff --git a/sherpa-onnx/csrc/macros.h b/sherpa-onnx/csrc/macros.h index b5dfb99e3..e5966fe41 100644 --- a/sherpa-onnx/csrc/macros.h +++ b/sherpa-onnx/csrc/macros.h @@ -169,4 +169,6 @@ } \ } while (0) +#define SHERPA_ONNX_EXIT(code) exit(code) + #endif // SHERPA_ONNX_CSRC_MACROS_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc b/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc new file mode 100644 index 000000000..cefd30a7d --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc @@ -0,0 +1,17 @@ +// sherpa-onnx/csrc/offline-speaker-diarization-impl.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" + +#include + +namespace sherpa_onnx { + +std::unique_ptr +OfflineSpeakerDiarizationImpl::Create( + const OfflineSpeakerDiarizationConfig &config) { + return nullptr; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-impl.h new file mode 100644 index 000000000..6e7ac47c3 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-diarization-impl.h @@ -0,0 +1,27 @@ +// sherpa-onnx/csrc/offline-speaker-diarization-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ + +#include + +#include "sherpa-onnx/csrc/offline-speaker-diarization.h" +namespace sherpa_onnx { + +class OfflineSpeakerDiarizationImpl { + public: + static std::unique_ptr Create( + const OfflineSpeakerDiarizationConfig &config); + + virtual ~OfflineSpeakerDiarizationImpl() = default; + + virtual OfflineSpeakerDiarizationResult Process( + const float *audio, int32_t n, + OfflineSpeakerDiarizationProgressCallback callback = nullptr) const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-result.cc b/sherpa-onnx/csrc/offline-speaker-diarization-result.cc new file mode 100644 index 000000000..c36d793ac --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-diarization-result.cc @@ -0,0 +1,23 @@ +// sherpa-onnx/csrc/offline-speaker-diarization-result.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" + +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +OfflineSpeakerDiarizationSegment::OfflineSpeakerDiarizationSegment( + float start, float end, int32_t speaker) { + if (start > end) { + SHERPA_ONNX_LOGE("start %.3f should be less than end %.3f", start, end); + SHERPA_ONNX_EXIT(-1); + } + + start_ = start; + end_ = end; + speaker_ = speaker; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-result.h b/sherpa-onnx/csrc/offline-speaker-diarization-result.h new file mode 100644 index 000000000..942ecf7e5 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-diarization-result.h @@ -0,0 +1,51 @@ +// sherpa-onnx/csrc/offline-speaker-diarization-result.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ + +#include +#include +#include + +namespace sherpa_onnx { + +class OfflineSpeakerDiarizationSegment { + public: + OfflineSpeakerDiarizationSegment(float start, float end, int32_t speaker); + + // If the gap between the two segments is less than the given gap, then we + // merge them and return a new segment. Otherwise, it returns null. + std::optional Merge( + const OfflineSpeakerDiarizationSegment &other, float gap) const; + + private: + float start_; // in seconds + float end_; // in seconds + int32_t speaker_; // ID of the speaker, starting from 0 +}; + +class OfflineSpeakerDiarizationResult { + public: + // Add a new segment + void Add(const OfflineSpeakerDiarizationSegment &segment); + + // Number of distinct speakers contained in this object at this point + int32_t NumSpeakers() const; + + // Return a list of segments sorted by segment.start time + std::vector SortByStartTime() const; + + // ans.size() == NumSpeakers(). + // ans[i] is for speaker_i and is sorted by start time + std::vector> SortBySpeaker() + const; + + private: + std::vector segments_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.cc b/sherpa-onnx/csrc/offline-speaker-diarization.cc new file mode 100644 index 000000000..9739222c7 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-diarization.cc @@ -0,0 +1,22 @@ +// sherpa-onnx/csrc/offline-speaker-diarization.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speaker-diarization.h" + +#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" + +namespace sherpa_onnx { + +OfflineSpeakerDiarization::OfflineSpeakerDiarization( + const OfflineSpeakerDiarizationConfig &config) {} + +OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default; + +OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process( + const float *audio, int32_t n, + OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/) const { + return impl_->Process(audio, n, callback); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.h b/sherpa-onnx/csrc/offline-speaker-diarization.h new file mode 100644 index 000000000..b072b16a6 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-diarization.h @@ -0,0 +1,39 @@ +// sherpa-onnx/csrc/offline-speaker-diarization.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ + +#include +#include + +#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" + +namespace sherpa_onnx { + +struct OfflineSpeakerDiarizationConfig {}; + +class OfflineSpeakerDiarizationImpl; + +using OfflineSpeakerDiarizationProgressCallback = std::function; + +class OfflineSpeakerDiarization { + public: + explicit OfflineSpeakerDiarization( + const OfflineSpeakerDiarizationConfig &config); + + ~OfflineSpeakerDiarization(); + + OfflineSpeakerDiarizationResult Process( + const float *audio, int32_t n, + OfflineSpeakerDiarizationProgressCallback callback = nullptr) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ From dd65064d1ee8845e8908d165ea282b6f3f262706 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 7 Oct 2024 11:45:57 +0800 Subject: [PATCH 02/17] add config for speaker diarization --- cmake/cmake_extension.py | 1 + sherpa-onnx/csrc/CMakeLists.txt | 12 ++++ ...ffline-speaker-diarization-pyannote-impl.h | 3 + .../csrc/offline-speaker-diarization.cc | 30 ++++++++++ .../csrc/offline-speaker-diarization.h | 16 +++++- ...fline-speaker-segmentation-model-config.cc | 56 +++++++++++++++++++ ...ffline-speaker-segmentation-model-config.h | 38 +++++++++++++ ...aker-segmentation-pyannote-model-config.cc | 37 ++++++++++++ ...eaker-segmentation-pyannote-model-config.h | 29 ++++++++++ ...sherpa-onnx-offline-speaker-diarization.cc | 19 +++++++ sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc | 7 ++- 11 files changed, 244 insertions(+), 4 deletions(-) create mode 100644 sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h create mode 100644 sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc create mode 100644 sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h create mode 100644 sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc create mode 100644 sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h create mode 100644 sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py index 672e3d17a..c49c32555 100644 --- a/cmake/cmake_extension.py +++ b/cmake/cmake_extension.py @@ -55,6 +55,7 @@ def get_binaries(): "sherpa-onnx-offline-audio-tagging", "sherpa-onnx-offline-language-identification", "sherpa-onnx-offline-punctuation", + "sherpa-onnx-offline-speaker-diarization", "sherpa-onnx-offline-tts", "sherpa-onnx-offline-tts-play", "sherpa-onnx-offline-websocket-server", diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index bd52e5c35..222dbf151 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -167,6 +167,8 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) offline-speaker-diarization-impl.cc offline-speaker-diarization-result.cc offline-speaker-diarization.cc + offline-speaker-segmentation-model-config.cc + offline-speaker-segmentation-pyannote-model-config.cc ) endif() @@ -263,6 +265,10 @@ if(SHERPA_ONNX_ENABLE_BINARY) add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) endif() + if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) + add_executable(sherpa-onnx-offline-speaker-diarization sherpa-onnx-offline-speaker-diarization.cc) + endif() + set(main_exes sherpa-onnx sherpa-onnx-keyword-spotter @@ -279,6 +285,12 @@ if(SHERPA_ONNX_ENABLE_BINARY) ) endif() + if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) + list(APPEND main_exes + sherpa-onnx-offline-speaker-diarization + ) + endif() + foreach(exe IN LISTS main_exes) target_link_libraries(${exe} sherpa-onnx-core) endforeach() diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h new file mode 100644 index 000000000..f4be51495 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -0,0 +1,3 @@ +// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.cc +// +// Copyright (c) 2024 Xiaomi Corporation diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.cc b/sherpa-onnx/csrc/offline-speaker-diarization.cc index 9739222c7..4eaf85498 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization.cc @@ -8,6 +8,36 @@ namespace sherpa_onnx { +void OfflineSpeakerDiarizationConfig::Register(ParseOptions *po) { + ParseOptions po_segmentation("segmentation", po); + segmentation.Register(&po_segmentation); + + ParseOptions po_embedding("embedding", po); + embedding.Register(&po_embedding); +} + +bool OfflineSpeakerDiarizationConfig::Validate() const { + if (!segmentation.Validate()) { + return false; + } + + if (!embedding.Validate()) { + return false; + } + + return true; +} + +std::string OfflineSpeakerDiarizationConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeakerDiarizationConfig("; + os << "segmentation=" << segmentation.ToString() << ", "; + os << "embedding=" << embedding.ToString() << ")"; + + return os.str(); +} + OfflineSpeakerDiarization::OfflineSpeakerDiarization( const OfflineSpeakerDiarizationConfig &config) {} diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.h b/sherpa-onnx/csrc/offline-speaker-diarization.h index b072b16a6..641886db4 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization.h @@ -9,10 +9,24 @@ #include #include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" namespace sherpa_onnx { -struct OfflineSpeakerDiarizationConfig {}; +struct OfflineSpeakerDiarizationConfig { + OfflineSpeakerSegmentationModelConfig segmentation; + SpeakerEmbeddingExtractorConfig embedding; + + OfflineSpeakerDiarizationConfig() = default; + OfflineSpeakerDiarizationConfig( + const OfflineSpeakerSegmentationModelConfig &segmentation, + const SpeakerEmbeddingExtractorConfig &embedding); + + void Register(ParseOptions *po); + bool Validate() const; + std::string ToString() const; +}; class OfflineSpeakerDiarizationImpl; diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc new file mode 100644 index 000000000..214892f05 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc @@ -0,0 +1,56 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" + +#include + +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineSpeakerSegmentationModelConfig::Register(ParseOptions *po) { + pyannote.Register(po); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool OfflineSpeakerSegmentationModelConfig::Validate() const { + if (num_threads < 1) { + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); + return false; + } + + if (!pyannote.model.empty()) { + return pyannote.Validate(); + } + + if (pyannote.model.empty()) { + SHERPA_ONNX_LOGE( + "You have to provide at least one speaker segmentation model"); + return false; + } + + return true; +} + +std::string OfflineSpeakerSegmentationModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeakerSegmentationModelConfig("; + os << "pyannote=" << pyannote.ToString() << ", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h new file mode 100644 index 000000000..c3ba7d7de --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h @@ -0,0 +1,38 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ + +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineSpeakerSegmentationModelConfig { + OfflineSpeakerSegmentationPyannoteModelConfig pyannote; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + OfflineSpeakerSegmentationModelConfig() = default; + + explicit OfflineSpeakerSegmentationModelConfig( + const OfflineSpeakerSegmentationPyannoteModelConfig &pyannote, + int32_t num_threads, bool debug, const std::string &provider) + : pyannote(pyannote), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc new file mode 100644 index 000000000..848524b1a --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc @@ -0,0 +1,37 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineSpeakerSegmentationPyannoteModelConfig::Register(ParseOptions *po) { + po->Register("pyannote-model", &model, + "Path to model.onnx of the Pyannote segmentation model."); +} + +bool OfflineSpeakerSegmentationPyannoteModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("Pyannote segmentation model: '%s' does not exist", + model.c_str()); + return false; + } + + return true; +} + +std::string OfflineSpeakerSegmentationPyannoteModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeakerSegmentationPyannoteModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h new file mode 100644 index 000000000..a0467c53d --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineSpeakerSegmentationPyannoteModelConfig { + std::string model; + + OfflineSpeakerSegmentationPyannoteModelConfig() = default; + + OfflineSpeakerSegmentationPyannoteModelConfig(const std::string &model) + : model(model){}; + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc new file mode 100644 index 000000000..b084a4b1f --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc @@ -0,0 +1,19 @@ +// sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speaker-diarization.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Offline/Non-streaming speaker diarization with sherpa-onnx +Usage example: + + )usage"; + sherpa_onnx::OfflineSpeakerDiarizationConfig config; + sherpa_onnx::ParseOptions po(kUsageMessage); + config.Register(&po); + po.PrintUsage(); + po.Read(argc, argv); + std::cout << config.ToString() << "\n"; +} diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc index 442ec1813..1ab8b68de 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc @@ -9,14 +9,15 @@ #include "sherpa-onnx/csrc/parse-options.h" #include "sherpa-onnx/csrc/wave-writer.h" -int32_t audioCallback(const float * /*samples*/, int32_t n, float progress) { +static int32_t AudioCallback(const float * /*samples*/, int32_t n, + float progress) { printf("sample=%d, progress=%f\n", n, progress); return 1; } int main(int32_t argc, char *argv[]) { const char *kUsageMessage = R"usage( -Offline text-to-speech with sherpa-onnx +Offline/Non-streaming text-to-speech with sherpa-onnx Usage example: @@ -79,7 +80,7 @@ or details. sherpa_onnx::OfflineTts tts(config); const auto begin = std::chrono::steady_clock::now(); - auto audio = tts.Generate(po.GetArg(1), sid, 1.0, audioCallback); + auto audio = tts.Generate(po.GetArg(1), sid, 1.0, AudioCallback); const auto end = std::chrono::steady_clock::now(); if (audio.samples.empty()) { From 23fd7ec4b19f74d34f96d2269f22eb7c81307481 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 7 Oct 2024 23:42:37 +0800 Subject: [PATCH 03/17] begin to add pyannote segmentation model --- scripts/pyannote/segmentation/export-onnx.py | 2 +- sherpa-onnx/csrc/CMakeLists.txt | 1 + sherpa-onnx/csrc/offline-sense-voice-model.cc | 1 + .../csrc/offline-speaker-diarization-impl.cc | 9 ++ ...ffline-speaker-diarization-pyannote-impl.h | 32 ++++++- .../csrc/offline-speaker-diarization.cc | 5 +- .../csrc/offline-speaker-diarization.h | 1 + ...fline-speaker-segmentation-model-config.cc | 1 + ...ffline-speaker-segmentation-model-config.h | 2 + ...aker-segmentation-pyannote-model-config.cc | 1 + ...eaker-segmentation-pyannote-model-config.h | 5 +- ...er-segmentation-pyannote-model-meta-data.h | 29 ++++++ ...ine-speaker-segmentation-pyannote-model.cc | 92 +++++++++++++++++++ ...line-speaker-segmentation-pyannote-model.h | 39 ++++++++ sherpa-onnx/csrc/provider-config.cc | 6 +- sherpa-onnx/csrc/session.cc | 42 +-------- sherpa-onnx/csrc/session.h | 45 ++------- ...sherpa-onnx-offline-speaker-diarization.cc | 9 +- .../csrc/speaker-embedding-extractor.cc | 4 +- 19 files changed, 241 insertions(+), 85 deletions(-) create mode 100644 sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h create mode 100644 sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc create mode 100644 sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h diff --git a/scripts/pyannote/segmentation/export-onnx.py b/scripts/pyannote/segmentation/export-onnx.py index 5f6e79c7e..7ebcae960 100755 --- a/scripts/pyannote/segmentation/export-onnx.py +++ b/scripts/pyannote/segmentation/export-onnx.py @@ -72,7 +72,7 @@ def main(): model.receptive_field.duration * 16000 ) - opset_version = 18 + opset_version = 15 filename = "model.onnx" torch.onnx.export( diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 222dbf151..3e6526563 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -169,6 +169,7 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) offline-speaker-diarization.cc offline-speaker-segmentation-model-config.cc offline-speaker-segmentation-pyannote-model-config.cc + offline-speaker-segmentation-pyannote-model.cc ) endif() diff --git a/sherpa-onnx/csrc/offline-sense-voice-model.cc b/sherpa-onnx/csrc/offline-sense-voice-model.cc index 1d2a14ef5..24903a41a 100644 --- a/sherpa-onnx/csrc/offline-sense-voice-model.cc +++ b/sherpa-onnx/csrc/offline-sense-voice-model.cc @@ -9,6 +9,7 @@ #include #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" #include "sherpa-onnx/csrc/text-utils.h" diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc b/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc index cefd30a7d..e41a7767a 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc @@ -6,11 +6,20 @@ #include +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h" + namespace sherpa_onnx { std::unique_ptr OfflineSpeakerDiarizationImpl::Create( const OfflineSpeakerDiarizationConfig &config) { + if (!config.segmentation.pyannote.model.empty()) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE("Please specify a speaker segmentation model."); + return nullptr; } diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index f4be51495..186f6ea1d 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -1,3 +1,33 @@ -// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.cc +// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h // // Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ + +#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h" + +namespace sherpa_onnx { +class OfflineSpeakerDiarizationPyannoteImpl + : public OfflineSpeakerDiarizationImpl { + public: + ~OfflineSpeakerDiarizationPyannoteImpl() override = default; + + explicit OfflineSpeakerDiarizationPyannoteImpl( + const OfflineSpeakerDiarizationConfig &config) + : config_(config), segmentation_model_(config_.segmentation) {} + + OfflineSpeakerDiarizationResult Process( + const float *audio, int32_t n, + OfflineSpeakerDiarizationProgressCallback callback = + nullptr) const override { + return {}; + } + + private: + OfflineSpeakerDiarizationConfig config_; + OfflineSpeakerSegmentationPyannoteModel segmentation_model_; +}; + +} // namespace sherpa_onnx +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.cc b/sherpa-onnx/csrc/offline-speaker-diarization.cc index 4eaf85498..bc992a295 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization.cc @@ -4,6 +4,8 @@ #include "sherpa-onnx/csrc/offline-speaker-diarization.h" +#include + #include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" namespace sherpa_onnx { @@ -39,7 +41,8 @@ std::string OfflineSpeakerDiarizationConfig::ToString() const { } OfflineSpeakerDiarization::OfflineSpeakerDiarization( - const OfflineSpeakerDiarizationConfig &config) {} + const OfflineSpeakerDiarizationConfig &config) + : impl_(OfflineSpeakerDiarizationImpl::Create(config)) {} OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default; diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.h b/sherpa-onnx/csrc/offline-speaker-diarization.h index 641886db4..1c4b88a69 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization.h @@ -7,6 +7,7 @@ #include #include +#include #include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc index 214892f05..f1c9f7d4a 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" #include +#include #include "sherpa-onnx/csrc/macros.h" diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h index c3ba7d7de..8e9e4a96e 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h @@ -4,6 +4,8 @@ #ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ #define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ +#include + #include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h" #include "sherpa-onnx/csrc/parse-options.h" diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc index 848524b1a..f7417ea83 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h" #include +#include #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h index a0467c53d..fb5ca4a48 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h @@ -15,8 +15,9 @@ struct OfflineSpeakerSegmentationPyannoteModelConfig { OfflineSpeakerSegmentationPyannoteModelConfig() = default; - OfflineSpeakerSegmentationPyannoteModelConfig(const std::string &model) - : model(model){}; + explicit OfflineSpeakerSegmentationPyannoteModelConfig( + const std::string &model) + : model(model) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h new file mode 100644 index 000000000..728ed7ff4 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ + +#include +#include + +namespace sherpa_onnx { + +// If you are not sure what each field means, please +// have a look of the Python file in the model directory that +// you have downloaded. +struct OfflineSpeakerSegmentationPyannoteModelMetaData { + int32_t sample_rate = 0; + int32_t window_size = 0; // in samples + int32_t window_shift = 0; // in samples + int32_t receptive_field_size = 0; // in samples + int32_t receptive_field_shift = 0; // in samples + int32_t num_speakers = 0; + int32_t powerset_max_classes = 0; + int32_t num_classes = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc new file mode 100644 index 000000000..620f22b8e --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc @@ -0,0 +1,92 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" + +namespace sherpa_onnx { + +class OfflineSpeakerSegmentationPyannoteModel::Impl { + public: + explicit Impl(const OfflineSpeakerSegmentationModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.pyannote.model); + Init(buf.data(), buf.size()); + } + + Ort::Value Forward(Ort::Value x) { return Ort::Value(nullptr); } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "window_size"); + + meta_data_.window_shift = + static_cast(0.1 * meta_data_.window_size); + + SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_size, + "receptive_field_size"); + SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_shift, + "receptive_field_shift"); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "num_speakers"); + SHERPA_ONNX_READ_META_DATA(meta_data_.powerset_max_classes, + "powerset_max_classes"); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_classes, "num_classes"); + } + + private: + OfflineSpeakerSegmentationModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + OfflineSpeakerSegmentationPyannoteModelMetaData meta_data_; +}; + +OfflineSpeakerSegmentationPyannoteModel:: + OfflineSpeakerSegmentationPyannoteModel( + const OfflineSpeakerSegmentationModelConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineSpeakerSegmentationPyannoteModel:: + ~OfflineSpeakerSegmentationPyannoteModel() = default; + +Ort::Value OfflineSpeakerSegmentationPyannoteModel::Forward(Ort::Value x) { + return impl_->Forward(std::move(x)); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h new file mode 100644 index 000000000..7ea8f44c5 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h @@ -0,0 +1,39 @@ +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h" + +namespace sherpa_onnx { + +class OfflineSpeakerSegmentationPyannoteModel { + public: + explicit OfflineSpeakerSegmentationPyannoteModel( + const OfflineSpeakerSegmentationModelConfig &config); + + ~OfflineSpeakerSegmentationPyannoteModel(); + + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetMetaData() const; + + /** + * @param x A 3-D float tensor of shape (batch_size, 1, num_samples) + * @return Return a float tensor of + * shape (batch_size, num_frames, num_speakers). Note that + * num_speakers here uses powerset encoding. + */ + Ort::Value Forward(Ort::Value x); + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index 1db62aa6b..165e2d9a2 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -61,8 +61,10 @@ void TensorrtConfig::Register(ParseOptions *po) { bool TensorrtConfig::Validate() const { if (trt_max_workspace_size < 0) { - SHERPA_ONNX_LOGE("trt_max_workspace_size: %ld is not valid.", - trt_max_workspace_size); + std::ostringstream os; + os << "trt_max_workspace_size: " << trt_max_workspace_size + << " is not valid."; + SHERPA_ONNX_LOGE("%s", os.str().c_str()); return false; } if (trt_max_partition_iterations < 0) { diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 7f6f685e0..9c5eb2b1a 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -35,9 +35,9 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { api.ReleaseStatus(status); } -static Ort::SessionOptions GetSessionOptionsImpl( +Ort::SessionOptions GetSessionOptionsImpl( int32_t num_threads, const std::string &provider_str, - const ProviderConfig *provider_config = nullptr) { + const ProviderConfig *provider_config /*= nullptr*/) { Provider p = StringToProvider(provider_str); Ort::SessionOptions sess_opts; @@ -259,10 +259,6 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, &config.provider_config); } -Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) { return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); } @@ -271,38 +267,4 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) { return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); } -Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -#if SHERPA_ONNX_ENABLE_TTS -Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} -#endif - -Ort::SessionOptions GetSessionOptions( - const SpeakerEmbeddingExtractorConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -Ort::SessionOptions GetSessionOptions( - const SpokenLanguageIdentificationConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -Ort::SessionOptions GetSessionOptions( - const OfflinePunctuationModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - -Ort::SessionOptions GetSessionOptions( - const OnlinePunctuationModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); -} - } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 1e8beb114..e19db6c20 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -8,53 +8,28 @@ #include #include "onnxruntime_cxx_api.h" // NOLINT -#include "sherpa-onnx/csrc/audio-tagging-model-config.h" #include "sherpa-onnx/csrc/offline-lm-config.h" -#include "sherpa-onnx/csrc/offline-model-config.h" -#include "sherpa-onnx/csrc/offline-punctuation-model-config.h" -#include "sherpa-onnx/csrc/online-punctuation-model-config.h" #include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-model-config.h" -#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" -#include "sherpa-onnx/csrc/spoken-language-identification.h" -#include "sherpa-onnx/csrc/vad-model-config.h" - -#if SHERPA_ONNX_ENABLE_TTS -#include "sherpa-onnx/csrc/offline-tts-model-config.h" -#endif namespace sherpa_onnx { -Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); - -Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, - const std::string &model_type); - -Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); +Ort::SessionOptions GetSessionOptionsImpl( + int32_t num_threads, const std::string &provider_str, + const ProviderConfig *provider_config = nullptr); Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); - Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); -Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); - -#if SHERPA_ONNX_ENABLE_TTS -Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); -#endif - -Ort::SessionOptions GetSessionOptions( - const SpeakerEmbeddingExtractorConfig &config); - -Ort::SessionOptions GetSessionOptions( - const SpokenLanguageIdentificationConfig &config); - -Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config); +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); -Ort::SessionOptions GetSessionOptions( - const OfflinePunctuationModelConfig &config); +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, + const std::string &model_type); -Ort::SessionOptions GetSessionOptions( - const OnlinePunctuationModelConfig &config); +template +Ort::SessionOptions GetSessionOptions(const T &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc index b084a4b1f..db921000c 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc @@ -13,7 +13,14 @@ Usage example: sherpa_onnx::OfflineSpeakerDiarizationConfig config; sherpa_onnx::ParseOptions po(kUsageMessage); config.Register(&po); - po.PrintUsage(); po.Read(argc, argv); + std::cout << config.ToString() << "\n"; + + if (!config.Validate()) { + po.PrintUsage(); + std::cerr << "Errors in config!\n"; + exit(-1); + } + sherpa_onnx::OfflineSpeakerDiarization sd(config); } diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor.cc b/sherpa-onnx/csrc/speaker-embedding-extractor.cc index 1c99de1a0..d90b0b1e0 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor.cc @@ -26,12 +26,12 @@ void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) { bool SpeakerEmbeddingExtractorConfig::Validate() const { if (model.empty()) { - SHERPA_ONNX_LOGE("Please provide --model"); + SHERPA_ONNX_LOGE("Please provide a speaker embedding extractor model"); return false; } if (!FileExists(model)) { - SHERPA_ONNX_LOGE("--speaker-embedding-model: '%s' does not exist", + SHERPA_ONNX_LOGE("speaker embedding extractor model: '%s' does not exist", model.c_str()); return false; } From d0e0f9993b012c545ae29e379b4931ccf65fe017 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 8 Oct 2024 11:49:41 +0800 Subject: [PATCH 04/17] Run the speaker segmentation model --- sherpa-onnx/csrc/macros.h | 1 + .../csrc/offline-speaker-diarization-impl.h | 4 +- ...ffline-speaker-diarization-pyannote-impl.h | 124 +++++++++++++++++- .../csrc/offline-speaker-diarization.cc | 5 +- .../csrc/offline-speaker-diarization.h | 3 +- ...ine-speaker-segmentation-pyannote-model.cc | 20 ++- ...line-speaker-segmentation-pyannote-model.h | 5 +- ...sherpa-onnx-offline-speaker-diarization.cc | 56 +++++++- 8 files changed, 207 insertions(+), 11 deletions(-) diff --git a/sherpa-onnx/csrc/macros.h b/sherpa-onnx/csrc/macros.h index e5966fe41..6bd6f62a6 100644 --- a/sherpa-onnx/csrc/macros.h +++ b/sherpa-onnx/csrc/macros.h @@ -5,6 +5,7 @@ #ifndef SHERPA_ONNX_CSRC_MACROS_H_ #define SHERPA_ONNX_CSRC_MACROS_H_ #include +#include #if __ANDROID_API__ >= 8 #include "android/log.h" diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-impl.h index 6e7ac47c3..5a4579efb 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-impl.h @@ -6,6 +6,7 @@ #define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ #include +#include #include "sherpa-onnx/csrc/offline-speaker-diarization.h" namespace sherpa_onnx { @@ -19,7 +20,8 @@ class OfflineSpeakerDiarizationImpl { virtual OfflineSpeakerDiarizationResult Process( const float *audio, int32_t n, - OfflineSpeakerDiarizationProgressCallback callback = nullptr) const = 0; + OfflineSpeakerDiarizationProgressCallback callback = nullptr, + void *callback_arg = nullptr) const = 0; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index 186f6ea1d..c08c62ad1 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -4,10 +4,19 @@ #ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ +#include +#include +#include + +#include "Eigen/Dense" #include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" #include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h" namespace sherpa_onnx { + +using Matrix2D = + Eigen::Matrix; + class OfflineSpeakerDiarizationPyannoteImpl : public OfflineSpeakerDiarizationImpl { public: @@ -19,11 +28,122 @@ class OfflineSpeakerDiarizationPyannoteImpl OfflineSpeakerDiarizationResult Process( const float *audio, int32_t n, - OfflineSpeakerDiarizationProgressCallback callback = - nullptr) const override { + OfflineSpeakerDiarizationProgressCallback callback = nullptr, + void *callback_arg = nullptr) const override { + std::vector segmentations = + RunSpeakerSegmentationModel(audio, n, callback, callback_arg); + // segmentations[i] is for chunk_i + // Each matrix is of shape (num_frames, num_powerset_classes) + if (segmentations.empty()) { + return {}; + } + + std::cout << "segmentations.size() " << segmentations.size() << "\n"; + for (const auto &m : segmentations) { + std::cout << m.rows() << ", " << m.cols() << "\n"; + } + return {}; } + private: + std::vector RunSpeakerSegmentationModel( + const float *audio, int32_t n, + OfflineSpeakerDiarizationProgressCallback callback, + void *callback_arg) const { + std::vector ans; + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + + if (n <= 0) { + SHERPA_ONNX_LOGE( + "number of audio samples is %d (<= 0). Please provide a positive " + "number", + n); + return {}; + } + + if (n <= window_size) { + std::vector buf(window_size); + std::copy(audio, audio + n, buf.data()); + + std::array shape = {1, 1, window_size}; + + Ort::Value x = Ort::Value::CreateTensor( + memory_info, buf.data(), buf.size(), shape.data(), shape.size()); + + Ort::Value out = segmentation_model_.Forward(std::move(x)); + std::vector out_shape = + out.GetTensorTypeAndShapeInfo().GetShape(); + Matrix2D m(out_shape[1], out_shape[2]); + std::copy(out.GetTensorData(), + out.GetTensorData() + m.size(), &m(0, 0)); + + ans.push_back(std::move(m)); + + if (callback) { + callback(1, 1, callback_arg); + } + + return ans; + } + + int32_t num_chunks = (n - window_size) / window_shift + 1; + bool has_last_chunk = (n - window_size) % window_shift > 0; + + ans.reserve(num_chunks + has_last_chunk); + + std::array shape = {1, 1, window_size}; + + const float *p = audio; + for (int32_t i = 0; i != num_chunks; ++i, p += window_shift) { + Ort::Value x = + Ort::Value::CreateTensor(memory_info, const_cast(p), + window_size, shape.data(), shape.size()); + + Ort::Value out = segmentation_model_.Forward(std::move(x)); + std::vector out_shape = + out.GetTensorTypeAndShapeInfo().GetShape(); + + Matrix2D m(out_shape[1], out_shape[2]); + std::copy(out.GetTensorData(), + out.GetTensorData() + m.size(), &m(0, 0)); + + ans.push_back(std::move(m)); + + if (callback) { + callback(i + 1, num_chunks + has_last_chunk, callback_arg); + } + } + + if (has_last_chunk) { + std::vector buf(window_size); + std::copy(p, audio + n, buf.data()); + + Ort::Value x = Ort::Value::CreateTensor( + memory_info, buf.data(), buf.size(), shape.data(), shape.size()); + + Ort::Value out = segmentation_model_.Forward(std::move(x)); + std::vector out_shape = + out.GetTensorTypeAndShapeInfo().GetShape(); + Matrix2D m(out_shape[1], out_shape[2]); + std::copy(out.GetTensorData(), + out.GetTensorData() + m.size(), &m(0, 0)); + + ans.push_back(std::move(m)); + if (callback) { + callback(num_chunks + 1, num_chunks + 1, callback_arg); + } + } + + return ans; + } + private: OfflineSpeakerDiarizationConfig config_; OfflineSpeakerSegmentationPyannoteModel segmentation_model_; diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.cc b/sherpa-onnx/csrc/offline-speaker-diarization.cc index bc992a295..f7595e482 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization.cc @@ -48,8 +48,9 @@ OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default; OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process( const float *audio, int32_t n, - OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/) const { - return impl_->Process(audio, n, callback); + OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/, + void *callback_arg /*= nullptr*/) const { + return impl_->Process(audio, n, callback, callback_arg); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.h b/sherpa-onnx/csrc/offline-speaker-diarization.h index 1c4b88a69..c7ff685a5 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization.h @@ -43,7 +43,8 @@ class OfflineSpeakerDiarization { OfflineSpeakerDiarizationResult Process( const float *audio, int32_t n, - OfflineSpeakerDiarizationProgressCallback callback = nullptr) const; + OfflineSpeakerDiarizationProgressCallback callback = nullptr, + void *callback_arg = nullptr) const; private: std::unique_ptr impl_; diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc index 620f22b8e..3f3323698 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc @@ -24,7 +24,17 @@ class OfflineSpeakerSegmentationPyannoteModel::Impl { Init(buf.data(), buf.size()); } - Ort::Value Forward(Ort::Value x) { return Ort::Value(nullptr); } + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() + const { + return meta_data_; + } + + Ort::Value Forward(Ort::Value x) { + auto out = sess_->Run({}, input_names_ptr_.data(), &x, 1, + output_names_ptr_.data(), output_names_ptr_.size()); + + return std::move(out[0]); + } private: void Init(void *model_data, size_t model_data_length) { @@ -85,7 +95,13 @@ OfflineSpeakerSegmentationPyannoteModel:: OfflineSpeakerSegmentationPyannoteModel:: ~OfflineSpeakerSegmentationPyannoteModel() = default; -Ort::Value OfflineSpeakerSegmentationPyannoteModel::Forward(Ort::Value x) { +const OfflineSpeakerSegmentationPyannoteModelMetaData & +OfflineSpeakerSegmentationPyannoteModel::GetModelMetaData() const { + return impl_->GetModelMetaData(); +} + +Ort::Value OfflineSpeakerSegmentationPyannoteModel::Forward( + Ort::Value x) const { return impl_->Forward(std::move(x)); } diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h index 7ea8f44c5..b504c373f 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h @@ -19,7 +19,8 @@ class OfflineSpeakerSegmentationPyannoteModel { ~OfflineSpeakerSegmentationPyannoteModel(); - const OfflineSpeakerSegmentationPyannoteModelMetaData &GetMetaData() const; + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() + const; /** * @param x A 3-D float tensor of shape (batch_size, 1, num_samples) @@ -27,7 +28,7 @@ class OfflineSpeakerSegmentationPyannoteModel { * shape (batch_size, num_frames, num_speakers). Note that * num_speakers here uses powerset encoding. */ - Ort::Value Forward(Ort::Value x); + Ort::Value Forward(Ort::Value x) const; private: class Impl; diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc index db921000c..af95ea002 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc @@ -3,6 +3,17 @@ // Copyright (c) 2024 Xiaomi Corporation #include "sherpa-onnx/csrc/offline-speaker-diarization.h" +#include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/wave-reader.h" + +static int32_t ProgressCallback(int32_t processed_chunks, int32_t num_chunks, + void *arg) { + float progress = 100.0 * processed_chunks / num_chunks; + fprintf(stderr, "progress %.2f%%\n", progress); + + // the return value is currently ignored + return 0; +} int main(int32_t argc, char *argv[]) { const char *kUsageMessage = R"usage( @@ -20,7 +31,50 @@ Usage example: if (!config.Validate()) { po.PrintUsage(); std::cerr << "Errors in config!\n"; - exit(-1); + return -1; + } + + if (po.NumArgs() != 1) { + std::cerr << "Error: Please provide exactly 1 wave file.\n\n"; + po.PrintUsage(); + return -1; } + sherpa_onnx::OfflineSpeakerDiarization sd(config); + + std::cout << "Started\n"; + const auto begin = std::chrono::steady_clock::now(); + const std::string wav_filename = po.GetArg(1); + int32_t sample_rate = -1; + bool is_ok = false; + const std::vector samples = + sherpa_onnx::ReadWave(wav_filename, &sample_rate, &is_ok); + if (!is_ok) { + std::cerr << "Failed to read " << wav_filename.c_str() << "\n"; + return -1; + } + + if (sample_rate != 16000) { + std::cerr << "Expect sample rate 16000. Given: " << sample_rate << "\n"; + return -1; + } + + float duration = samples.size() / static_cast(sample_rate); + + // sd.Process(samples.data(), samples.size() < 160000 ? samples.size() : + // 160000); + sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr); + + const auto end = std::chrono::steady_clock::now(); + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + + return 0; } From b9bcbfaf996af5e4b7cefa7480ca9b396632c886 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 8 Oct 2024 12:00:05 +0800 Subject: [PATCH 05/17] refactoring --- ...s-spotter-buffered-tokens-keywords-c-api.c | 2 +- .../streaming-ctc-buffered-tokens-c-api.c | 2 +- ...reaming-paraformer-buffered-tokens-c-api.c | 2 +- ...zipformer-buffered-tokens-hotwords-c-api.c | 2 +- scripts/pyannote/segmentation/README.md | 2 +- ...ffline-speaker-diarization-pyannote-impl.h | 63 ++++++++----------- 6 files changed, 32 insertions(+), 41 deletions(-) diff --git a/c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c b/c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c index ec8be3b07..45a0bb87a 100644 --- a/c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c +++ b/c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { fprintf(stderr, "Memory error\n"); return -1; } - size_t read_bytes = fread(*buffer_out, 1, size, file); + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); if (read_bytes != size) { printf("Errors occured in reading the file %s\n", filename); free((void *)*buffer_out); diff --git a/c-api-examples/streaming-ctc-buffered-tokens-c-api.c b/c-api-examples/streaming-ctc-buffered-tokens-c-api.c index 3223772a8..33690e008 100644 --- a/c-api-examples/streaming-ctc-buffered-tokens-c-api.c +++ b/c-api-examples/streaming-ctc-buffered-tokens-c-api.c @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { fprintf(stderr, "Memory error\n"); return -1; } - size_t read_bytes = fread(*buffer_out, 1, size, file); + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); if (read_bytes != size) { printf("Errors occured in reading the file %s\n", filename); free((void *)*buffer_out); diff --git a/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c b/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c index cd87177b5..a597374df 100644 --- a/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c +++ b/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { fprintf(stderr, "Memory error\n"); return -1; } - size_t read_bytes = fread(*buffer_out, 1, size, file); + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); if (read_bytes != size) { printf("Errors occured in reading the file %s\n", filename); free((void *)*buffer_out); diff --git a/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c b/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c index d5092c5cc..c991d4999 100644 --- a/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c +++ b/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { fprintf(stderr, "Memory error\n"); return -1; } - size_t read_bytes = fread(*buffer_out, 1, size, file); + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); if (read_bytes != size) { printf("Errors occured in reading the file %s\n", filename); free((void *)*buffer_out); diff --git a/scripts/pyannote/segmentation/README.md b/scripts/pyannote/segmentation/README.md index a2e35b2de..689b7a9be 100644 --- a/scripts/pyannote/segmentation/README.md +++ b/scripts/pyannote/segmentation/README.md @@ -40,5 +40,5 @@ commands to convert it to `3-two-speakers-en.wav` ```bash -sox ML16091-Audio.mp3 3-two-speakers-en.wav +sox ML16091-Audio.mp3 -r 16k 3-two-speakers-en.wav ``` diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index c08c62ad1..6818e2113 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -53,9 +53,6 @@ class OfflineSpeakerDiarizationPyannoteImpl void *callback_arg) const { std::vector ans; - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - const auto &meta_data = segmentation_model_.GetModelMetaData(); int32_t window_size = meta_data.window_size; int32_t window_shift = meta_data.window_shift; @@ -70,19 +67,11 @@ class OfflineSpeakerDiarizationPyannoteImpl if (n <= window_size) { std::vector buf(window_size); - std::copy(audio, audio + n, buf.data()); - - std::array shape = {1, 1, window_size}; + // NOTE: buf is zero initialized by default - Ort::Value x = Ort::Value::CreateTensor( - memory_info, buf.data(), buf.size(), shape.data(), shape.size()); + std::copy(audio, audio + n, buf.data()); - Ort::Value out = segmentation_model_.Forward(std::move(x)); - std::vector out_shape = - out.GetTensorTypeAndShapeInfo().GetShape(); - Matrix2D m(out_shape[1], out_shape[2]); - std::copy(out.GetTensorData(), - out.GetTensorData() + m.size(), &m(0, 0)); + Matrix2D m = ProcessChunk(buf.data()); ans.push_back(std::move(m)); @@ -98,21 +87,10 @@ class OfflineSpeakerDiarizationPyannoteImpl ans.reserve(num_chunks + has_last_chunk); - std::array shape = {1, 1, window_size}; - const float *p = audio; - for (int32_t i = 0; i != num_chunks; ++i, p += window_shift) { - Ort::Value x = - Ort::Value::CreateTensor(memory_info, const_cast(p), - window_size, shape.data(), shape.size()); - Ort::Value out = segmentation_model_.Forward(std::move(x)); - std::vector out_shape = - out.GetTensorTypeAndShapeInfo().GetShape(); - - Matrix2D m(out_shape[1], out_shape[2]); - std::copy(out.GetTensorData(), - out.GetTensorData() + m.size(), &m(0, 0)); + for (int32_t i = 0; i != num_chunks; ++i, p += window_shift) { + Matrix2D m = ProcessChunk(p); ans.push_back(std::move(m)); @@ -125,15 +103,7 @@ class OfflineSpeakerDiarizationPyannoteImpl std::vector buf(window_size); std::copy(p, audio + n, buf.data()); - Ort::Value x = Ort::Value::CreateTensor( - memory_info, buf.data(), buf.size(), shape.data(), shape.size()); - - Ort::Value out = segmentation_model_.Forward(std::move(x)); - std::vector out_shape = - out.GetTensorTypeAndShapeInfo().GetShape(); - Matrix2D m(out_shape[1], out_shape[2]); - std::copy(out.GetTensorData(), - out.GetTensorData() + m.size(), &m(0, 0)); + Matrix2D m = ProcessChunk(buf.data()); ans.push_back(std::move(m)); if (callback) { @@ -144,6 +114,27 @@ class OfflineSpeakerDiarizationPyannoteImpl return ans; } + Matrix2D ProcessChunk(const float *p) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array shape = {1, 1, window_size}; + + Ort::Value x = + Ort::Value::CreateTensor(memory_info, const_cast(p), + window_size, shape.data(), shape.size()); + + Ort::Value out = segmentation_model_.Forward(std::move(x)); + std::vector out_shape = out.GetTensorTypeAndShapeInfo().GetShape(); + Matrix2D m(out_shape[1], out_shape[2]); + std::copy(out.GetTensorData(), out.GetTensorData() + m.size(), + &m(0, 0)); + return m; + } + private: OfflineSpeakerDiarizationConfig config_; OfflineSpeakerSegmentationPyannoteModel segmentation_model_; From 1cf26d461e51e251ad448eb46a9f00e1981bc56e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 8 Oct 2024 14:55:55 +0800 Subject: [PATCH 06/17] Convert segmentations to multi-labels. --- ...ffline-speaker-diarization-pyannote-impl.h | 71 ++++++++++++++++++- 1 file changed, 68 insertions(+), 3 deletions(-) diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index 6818e2113..8c6aed3e7 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -17,6 +17,9 @@ namespace sherpa_onnx { using Matrix2D = Eigen::Matrix; +using Matrix2DInt32 = + Eigen::Matrix; + class OfflineSpeakerDiarizationPyannoteImpl : public OfflineSpeakerDiarizationImpl { public: @@ -24,7 +27,9 @@ class OfflineSpeakerDiarizationPyannoteImpl explicit OfflineSpeakerDiarizationPyannoteImpl( const OfflineSpeakerDiarizationConfig &config) - : config_(config), segmentation_model_(config_.segmentation) {} + : config_(config), segmentation_model_(config_.segmentation) { + Init(); + } OfflineSpeakerDiarizationResult Process( const float *audio, int32_t n, @@ -38,15 +43,59 @@ class OfflineSpeakerDiarizationPyannoteImpl return {}; } - std::cout << "segmentations.size() " << segmentations.size() << "\n"; + std::cout << "segmentations.size() " << segmentations.size() << "---" + << segmentations[0].rows() << ", " << segmentations[1].cols() + << "\n"; + + std::vector labels; + labels.reserve(segmentations.size()); + for (const auto &m : segmentations) { - std::cout << m.rows() << ", " << m.cols() << "\n"; + labels.push_back(ToMultiLabel(m)); } + segmentations.clear(); + + // labels[i] is a 0-1 matrix of shape (num_frames, num_speakers) + return {}; } private: + void Init() { InitPowersetMapping(); } + + // see also + // https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/utils/powerset.py#L68 + void InitPowersetMapping() { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t num_classes = meta_data.num_classes; + int32_t powerset_max_classes = meta_data.powerset_max_classes; + int32_t num_speakers = meta_data.num_speakers; + + powerset_mapping_ = Matrix2DInt32(num_classes, num_speakers); + powerset_mapping_.setZero(); + + int32_t k = 1; + for (int32_t i = 1; i <= powerset_max_classes; ++i) { + if (i == 1) { + for (int32_t j = 0; j != num_speakers; ++j, ++k) { + powerset_mapping_(k, j) = 1; + } + } else if (i == 2) { + for (int32_t j = 0; j != num_speakers; ++j) { + for (int32_t m = j + 1; m < num_speakers; ++m, ++k) { + powerset_mapping_(k, j) = 1; + powerset_mapping_(k, m) = 1; + } + } + } else { + SHERPA_ONNX_LOGE( + "powerset_max_classes = %d is currently not supported!", i); + SHERPA_ONNX_EXIT(-1); + } + } + } + std::vector RunSpeakerSegmentationModel( const float *audio, int32_t n, OfflineSpeakerDiarizationProgressCallback callback, @@ -135,9 +184,25 @@ class OfflineSpeakerDiarizationPyannoteImpl return m; } + Matrix2DInt32 ToMultiLabel(const Matrix2D &m) const { + int32_t num_rows = m.rows(); + Matrix2DInt32 ans(num_rows, powerset_mapping_.cols()); + + std::ptrdiff_t col_id; + + for (int32_t i = 0; i != num_rows; ++i) { + m.row(i).maxCoeff(&col_id); + ans.row(i) = powerset_mapping_.row(col_id); + } + + std::cout << "sum labels: " << ans.colwise().sum() << "\n"; + return ans; + } + private: OfflineSpeakerDiarizationConfig config_; OfflineSpeakerSegmentationPyannoteModel segmentation_model_; + Matrix2DInt32 powerset_mapping_; }; } // namespace sherpa_onnx From 42e1eff6d061f3a213c9690c965c9e56707c19f7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 8 Oct 2024 16:00:32 +0800 Subject: [PATCH 07/17] Compute speaker count per frame --- ...ffline-speaker-diarization-pyannote-impl.h | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index 8c6aed3e7..3dd66bdc3 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -20,6 +20,9 @@ using Matrix2D = using Matrix2DInt32 = Eigen::Matrix; +using FloatRowVector = Eigen::Matrix; +using Int32RowVector = Eigen::Matrix; + class OfflineSpeakerDiarizationPyannoteImpl : public OfflineSpeakerDiarizationImpl { public: @@ -58,6 +61,16 @@ class OfflineSpeakerDiarizationPyannoteImpl // labels[i] is a 0-1 matrix of shape (num_frames, num_speakers) + // speaker count per frame + Int32RowVector speaker_count = ComputeSpeakerCount(labels); + std::cout << "speaker count: " << speaker_count.cast().sum() << ", " + << speaker_count.cast().mean() << "\n"; + + if (speaker_count.maxCoeff() == 0) { + SHERPA_ONNX_LOGE("No speakers found in the audio samples"); + return {}; + } + return {}; } @@ -199,6 +212,39 @@ class OfflineSpeakerDiarizationPyannoteImpl return ans; } + // See also + // https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/pipelines/utils/diarization.py#L122 + Int32RowVector ComputeSpeakerCount( + const std::vector &labels) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + + int32_t num_chunks = labels.size(); + + int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) / + receptive_field_shift + + 1; + + FloatRowVector count(num_frames); + FloatRowVector weight(num_frames); + count.setZero(); + weight.setZero(); + + for (int32_t i = 0; i != num_chunks; ++i) { + int32_t start = float(i) * window_shift / receptive_field_shift + 0.5; + + auto seq = Eigen::seqN(start, labels[i].rows()); + + count(seq).array() += labels[i].rowwise().sum().array().cast(); + + weight(seq).array() += 1; + } + + return ((count.array() / (weight.array() + 1e-12f)) + 0.5).cast(); + } + private: OfflineSpeakerDiarizationConfig config_; OfflineSpeakerSegmentationPyannoteModel segmentation_model_; From 9b4f6913176648f00e55b5e145740cca64bcc377 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 8 Oct 2024 16:59:36 +0800 Subject: [PATCH 08/17] Get chunk-speaker pair and its sample indexes --- ...ffline-speaker-diarization-pyannote-impl.h | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index 3dd66bdc3..558010b8f 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -23,6 +23,8 @@ using Matrix2DInt32 = using FloatRowVector = Eigen::Matrix; using Int32RowVector = Eigen::Matrix; +using Int32Pair = std::pair; + class OfflineSpeakerDiarizationPyannoteImpl : public OfflineSpeakerDiarizationImpl { public: @@ -71,6 +73,18 @@ class OfflineSpeakerDiarizationPyannoteImpl return {}; } + auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels); + std::cout << "pair size: " << chunk_speaker_samples_list_pair.first.size() + << "\n"; + int32_t kk = 0; + for (const auto &p : chunk_speaker_samples_list_pair.first) { + std::cout << p.first << ", " << p.second << "\n"; + for (const auto &pp : chunk_speaker_samples_list_pair.second[kk]) { + std::cout << " " << pp.first << ", " << pp.second << "\n"; + } + kk += 1; + } + return {}; } @@ -245,6 +259,104 @@ class OfflineSpeakerDiarizationPyannoteImpl return ((count.array() / (weight.array() + 1e-12f)) + 0.5).cast(); } + // ans.first: a list of (chunk_id, speaker_id) + // ans.second: a list of list of (start_sample_index, end_sample_index) + // + // ans.first[i] corresponds to ans.second[i] + std::pair, std::vector>> + GetChunkSpeakerSampleIndexes(const std::vector &labels) const { + auto new_labels = ExcludeOverlap(labels); + + std::vector chunk_speaker_list; + std::vector> samples_index_list; + + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + int32_t num_speakers = meta_data.num_speakers; + + int32_t chunk_index = 0; + for (const auto &label : new_labels) { + Matrix2DInt32 tmp = label.transpose(); + // tmp: (num_speakers, num_frames) + int32_t num_frames = tmp.cols(); + + int32_t sample_offset = chunk_index * window_shift; + + for (int32_t speaker_index = 0; speaker_index != num_speakers; + ++speaker_index) { + auto d = tmp.row(speaker_index); + if (d.sum() < 10) { + // skip segments less than 10 frames + continue; + } + + Int32Pair this_chunk_speaker = {chunk_index, speaker_index}; + std::vector this_speaker_samples; + + bool started = false; + int32_t start_index; + + for (int32_t k = 0; k != num_frames; ++k) { + if (d[k] != 0) { + if (!started) { + started = true; + start_index = k; + } + } else if (started) { + started = false; + + int32_t start_samples = + float(start_index) / num_frames * window_size + sample_offset; + int32_t end_samples = + float(k) / num_frames * window_size + sample_offset; + + this_speaker_samples.emplace_back(start_samples, end_samples); + } + } + + if (started) { + int32_t start_samples = + float(start_index) / num_frames * window_size + sample_offset; + int32_t end_samples = + float(num_frames - 1) / num_frames * window_size + sample_offset; + this_speaker_samples.emplace_back(start_samples, end_samples); + } + + chunk_speaker_list.push_back(std::move(this_chunk_speaker)); + samples_index_list.push_back(std::move(this_speaker_samples)); + } // for (int32_t speaker_index = 0; + chunk_index += 1; + } // for (const auto &label : new_labels) + + return {chunk_speaker_list, samples_index_list}; + } + + // If there are multiple speakers at a frame, then this frame is excluded. + std::vector ExcludeOverlap( + const std::vector &labels) const { + int32_t num_chunks = labels.size(); + std::vector ans; + ans.reserve(num_chunks); + + for (const auto &label : labels) { + Matrix2DInt32 new_label(label.rows(), label.cols()); + new_label.setZero(); + Int32RowVector v = label.rowwise().sum(); + + for (int32_t i = 0; i != v.cols(); ++i) { + if (v[i] < 2) { + new_label.row(i) = label.row(i); + } + } + + ans.push_back(std::move(new_label)); + } + + return ans; + } + private: OfflineSpeakerDiarizationConfig config_; OfflineSpeakerSegmentationPyannoteModel segmentation_model_; From b91a4be2b60d257087baa252d90eb637be69a415 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 8 Oct 2024 17:43:43 +0800 Subject: [PATCH 09/17] compute embeddings --- ...ffline-speaker-diarization-pyannote-impl.h | 108 ++++++++++++------ 1 file changed, 75 insertions(+), 33 deletions(-) diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index 558010b8f..70788e00b 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -11,6 +11,7 @@ #include "Eigen/Dense" #include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" #include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" namespace sherpa_onnx { @@ -32,7 +33,9 @@ class OfflineSpeakerDiarizationPyannoteImpl explicit OfflineSpeakerDiarizationPyannoteImpl( const OfflineSpeakerDiarizationConfig &config) - : config_(config), segmentation_model_(config_.segmentation) { + : config_(config), + segmentation_model_(config_.segmentation), + embedding_extractor_(config_.embedding) { Init(); } @@ -40,8 +43,7 @@ class OfflineSpeakerDiarizationPyannoteImpl const float *audio, int32_t n, OfflineSpeakerDiarizationProgressCallback callback = nullptr, void *callback_arg = nullptr) const override { - std::vector segmentations = - RunSpeakerSegmentationModel(audio, n, callback, callback_arg); + std::vector segmentations = RunSpeakerSegmentationModel(audio, n); // segmentations[i] is for chunk_i // Each matrix is of shape (num_frames, num_powerset_classes) if (segmentations.empty()) { @@ -74,16 +76,13 @@ class OfflineSpeakerDiarizationPyannoteImpl } auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels); - std::cout << "pair size: " << chunk_speaker_samples_list_pair.first.size() - << "\n"; - int32_t kk = 0; - for (const auto &p : chunk_speaker_samples_list_pair.first) { - std::cout << p.first << ", " << p.second << "\n"; - for (const auto &pp : chunk_speaker_samples_list_pair.second[kk]) { - std::cout << " " << pp.first << ", " << pp.second << "\n"; - } - kk += 1; - } + Matrix2D embeddings = + ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second, + callback, callback_arg); + + std::cout << "embeddings.shape " << embeddings.rows() << ", " + << embeddings.cols() << "\n" + << embeddings.rowwise().sum() << "\n"; return {}; } @@ -123,10 +122,8 @@ class OfflineSpeakerDiarizationPyannoteImpl } } - std::vector RunSpeakerSegmentationModel( - const float *audio, int32_t n, - OfflineSpeakerDiarizationProgressCallback callback, - void *callback_arg) const { + std::vector RunSpeakerSegmentationModel(const float *audio, + int32_t n) const { std::vector ans; const auto &meta_data = segmentation_model_.GetModelMetaData(); @@ -151,10 +148,6 @@ class OfflineSpeakerDiarizationPyannoteImpl ans.push_back(std::move(m)); - if (callback) { - callback(1, 1, callback_arg); - } - return ans; } @@ -169,10 +162,6 @@ class OfflineSpeakerDiarizationPyannoteImpl Matrix2D m = ProcessChunk(p); ans.push_back(std::move(m)); - - if (callback) { - callback(i + 1, num_chunks + has_last_chunk, callback_arg); - } } if (has_last_chunk) { @@ -182,9 +171,6 @@ class OfflineSpeakerDiarizationPyannoteImpl Matrix2D m = ProcessChunk(buf.data()); ans.push_back(std::move(m)); - if (callback) { - callback(num_chunks + 1, num_chunks + 1, callback_arg); - } } return ans; @@ -247,7 +233,8 @@ class OfflineSpeakerDiarizationPyannoteImpl weight.setZero(); for (int32_t i = 0; i != num_chunks; ++i) { - int32_t start = float(i) * window_shift / receptive_field_shift + 0.5; + int32_t start = + static_cast(i) * window_shift / receptive_field_shift + 0.5; auto seq = Eigen::seqN(start, labels[i].rows()); @@ -308,9 +295,11 @@ class OfflineSpeakerDiarizationPyannoteImpl started = false; int32_t start_samples = - float(start_index) / num_frames * window_size + sample_offset; + static_cast(start_index) / num_frames * window_size + + sample_offset; int32_t end_samples = - float(k) / num_frames * window_size + sample_offset; + static_cast(k) / num_frames * window_size + + sample_offset; this_speaker_samples.emplace_back(start_samples, end_samples); } @@ -318,9 +307,11 @@ class OfflineSpeakerDiarizationPyannoteImpl if (started) { int32_t start_samples = - float(start_index) / num_frames * window_size + sample_offset; + static_cast(start_index) / num_frames * window_size + + sample_offset; int32_t end_samples = - float(num_frames - 1) / num_frames * window_size + sample_offset; + static_cast(num_frames - 1) / num_frames * window_size + + sample_offset; this_speaker_samples.emplace_back(start_samples, end_samples); } @@ -357,9 +348,60 @@ class OfflineSpeakerDiarizationPyannoteImpl return ans; } + /** + * @param sample_indexes[i] contains the sample segment start and end indexes + * for the i-th (chunk, speaker) pair + * @return Return a matrix of shape (sample_indexes.size(), embedding_dim) + * where ans.row[i] contains the embedding for the + * i-th (chunk, speaker) pair + */ + Matrix2D ComputeEmbeddings( + const float *audio, int32_t n, + const std::vector> &sample_indexes, + OfflineSpeakerDiarizationProgressCallback callback, + void *callback_arg) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t sample_rate = meta_data.sample_rate; + Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim()); + + int32_t k = 0; + for (const auto &v : sample_indexes) { + auto stream = embedding_extractor_.CreateStream(); + for (const auto &p : v) { + int32_t end = (p.second <= n) ? p.second : n; + int32_t num_samples = end - p.first; + + if (num_samples > 0) { + stream->AcceptWaveform(sample_rate, audio + p.first, num_samples); + } + } + + stream->InputFinished(); + if (!embedding_extractor_.IsReady(stream.get())) { + SHERPA_ONNX_LOGE( + "This segment is too short, which should not happen since we have " + "already filtered short segments"); + SHERPA_ONNX_EXIT(-1); + } + + std::vector embedding = embedding_extractor_.Compute(stream.get()); + + std::copy(embedding.begin(), embedding.end(), &ans(k, 0)); + + k += 1; + + if (callback) { + callback(k, ans.rows(), callback_arg); + } + } + + return ans; + } + private: OfflineSpeakerDiarizationConfig config_; OfflineSpeakerSegmentationPyannoteModel segmentation_model_; + SpeakerEmbeddingExtractor embedding_extractor_; Matrix2DInt32 powerset_mapping_; }; From 9931a26383a709582b8be1263895c4870afe0160 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 8 Oct 2024 18:04:54 +0800 Subject: [PATCH 10/17] Add clustering --- sherpa-onnx/csrc/fast-clustering-config.cc | 22 +++++++++---------- ...ffline-speaker-diarization-pyannote-impl.h | 14 ++++++++---- .../csrc/offline-speaker-diarization.cc | 10 ++++++++- .../csrc/offline-speaker-diarization.h | 9 +++++++- 4 files changed, 37 insertions(+), 18 deletions(-) diff --git a/sherpa-onnx/csrc/fast-clustering-config.cc b/sherpa-onnx/csrc/fast-clustering-config.cc index e8382e598..e4f64fbbb 100644 --- a/sherpa-onnx/csrc/fast-clustering-config.cc +++ b/sherpa-onnx/csrc/fast-clustering-config.cc @@ -21,18 +21,16 @@ std::string FastClusteringConfig::ToString() const { } void FastClusteringConfig::Register(ParseOptions *po) { - std::string prefix = "ctc"; - ParseOptions p(prefix, po); - - p.Register("num-clusters", &num_clusters, - "Number of cluster. If greater than 0, then --cluster-thresold is " - "ignored. Please provide it if you know the actual number of " - "clusters in advance."); - - p.Register("cluster-threshold", &threshold, - "If --num-clusters is not specified, then it specifies the " - "distance threshold for clustering. smaller value -> more " - "clusters. larger value -> fewer clusters"); + po->Register( + "num-clusters", &num_clusters, + "Number of cluster. If greater than 0, then cluster threshold is " + "ignored. Please provide it if you know the actual number of " + "clusters in advance."); + + po->Register("cluster-threshold", &threshold, + "If num_clusters is not specified, then it specifies the " + "distance threshold for clustering. smaller value -> more " + "clusters. larger value -> fewer clusters"); } bool FastClusteringConfig::Validate() const { diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index 70788e00b..fc63f7f26 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -9,6 +9,7 @@ #include #include "Eigen/Dense" +#include "sherpa-onnx/csrc/fast-clustering.h" #include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" #include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h" #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" @@ -35,7 +36,8 @@ class OfflineSpeakerDiarizationPyannoteImpl const OfflineSpeakerDiarizationConfig &config) : config_(config), segmentation_model_(config_.segmentation), - embedding_extractor_(config_.embedding) { + embedding_extractor_(config_.embedding), + clustering_(config_.clustering) { Init(); } @@ -80,9 +82,12 @@ class OfflineSpeakerDiarizationPyannoteImpl ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second, callback, callback_arg); - std::cout << "embeddings.shape " << embeddings.rows() << ", " - << embeddings.cols() << "\n" - << embeddings.rowwise().sum() << "\n"; + std::vector cluster_labels = clustering_.Cluster( + &embeddings(0, 0), embeddings.rows(), embeddings.cols()); + + for (int32_t i = 0; i != cluster_labels.size(); ++i) { + std::cout << i << "->" << cluster_labels[i] << "\n"; + } return {}; } @@ -402,6 +407,7 @@ class OfflineSpeakerDiarizationPyannoteImpl OfflineSpeakerDiarizationConfig config_; OfflineSpeakerSegmentationPyannoteModel segmentation_model_; SpeakerEmbeddingExtractor embedding_extractor_; + FastClustering clustering_; Matrix2DInt32 powerset_mapping_; }; diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.cc b/sherpa-onnx/csrc/offline-speaker-diarization.cc index f7595e482..df2426a10 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization.cc @@ -16,6 +16,9 @@ void OfflineSpeakerDiarizationConfig::Register(ParseOptions *po) { ParseOptions po_embedding("embedding", po); embedding.Register(&po_embedding); + + ParseOptions po_clustering("clustering", po); + clustering.Register(&po_clustering); } bool OfflineSpeakerDiarizationConfig::Validate() const { @@ -27,6 +30,10 @@ bool OfflineSpeakerDiarizationConfig::Validate() const { return false; } + if (!clustering.Validate()) { + return false; + } + return true; } @@ -35,7 +42,8 @@ std::string OfflineSpeakerDiarizationConfig::ToString() const { os << "OfflineSpeakerDiarizationConfig("; os << "segmentation=" << segmentation.ToString() << ", "; - os << "embedding=" << embedding.ToString() << ")"; + os << "embedding=" << embedding.ToString() << ", "; + os << "clustering=" << clustering.ToString() << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.h b/sherpa-onnx/csrc/offline-speaker-diarization.h index c7ff685a5..41968fe99 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization.h @@ -9,6 +9,7 @@ #include #include +#include "sherpa-onnx/csrc/fast-clustering-config.h" #include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" @@ -18,11 +19,17 @@ namespace sherpa_onnx { struct OfflineSpeakerDiarizationConfig { OfflineSpeakerSegmentationModelConfig segmentation; SpeakerEmbeddingExtractorConfig embedding; + FastClusteringConfig clustering; OfflineSpeakerDiarizationConfig() = default; + OfflineSpeakerDiarizationConfig( const OfflineSpeakerSegmentationModelConfig &segmentation, - const SpeakerEmbeddingExtractorConfig &embedding); + const SpeakerEmbeddingExtractorConfig &embedding, + const FastClusteringConfig &clustering) + : segmentation(segmentation), + embedding(embedding), + clustering(clustering) {} void Register(ParseOptions *po); bool Validate() const; From 06c0d45ad72894ab29e0bbb22d666b7cc776e0e0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 8 Oct 2024 20:31:40 +0800 Subject: [PATCH 11/17] first working version --- ...ffline-speaker-diarization-pyannote-impl.h | 263 +++++++++++++++++- .../offline-speaker-diarization-result.cc | 39 +++ .../csrc/offline-speaker-diarization-result.h | 9 +- .../csrc/offline-speaker-diarization.cc | 13 +- .../csrc/offline-speaker-diarization.h | 8 + 5 files changed, 317 insertions(+), 15 deletions(-) diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index fc63f7f26..8d94e4d21 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -5,17 +5,40 @@ #define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ #include +#include #include #include #include "Eigen/Dense" #include "sherpa-onnx/csrc/fast-clustering.h" +#include "sherpa-onnx/csrc/math.h" #include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" #include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h" #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" namespace sherpa_onnx { +namespace { + +// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L41 +template +inline void hash_combine(std::size_t *seed, const T &v) { // NOLINT + std::hash hasher; + *seed ^= hasher(v) + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2); // NOLINT +} + +// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L47 +struct PairHash { + template + std::size_t operator()(const std::pair &pair) const { + std::size_t result = 0; + hash_combine(&result, pair.first); + hash_combine(&result, pair.second); + return result; + } +}; +} // namespace + using Matrix2D = Eigen::Matrix; @@ -68,11 +91,11 @@ class OfflineSpeakerDiarizationPyannoteImpl // labels[i] is a 0-1 matrix of shape (num_frames, num_speakers) // speaker count per frame - Int32RowVector speaker_count = ComputeSpeakerCount(labels); - std::cout << "speaker count: " << speaker_count.cast().sum() << ", " - << speaker_count.cast().mean() << "\n"; + Int32RowVector speakers_per_frame = ComputeSpeakersPerFrame(labels); + std::cout << "speaker count: " << speakers_per_frame.cast().sum() + << ", " << speakers_per_frame.cast().mean() << "\n"; - if (speaker_count.maxCoeff() == 0) { + if (speakers_per_frame.maxCoeff() == 0) { SHERPA_ONNX_LOGE("No speakers found in the audio samples"); return {}; } @@ -85,8 +108,24 @@ class OfflineSpeakerDiarizationPyannoteImpl std::vector cluster_labels = clustering_.Cluster( &embeddings(0, 0), embeddings.rows(), embeddings.cols()); - for (int32_t i = 0; i != cluster_labels.size(); ++i) { - std::cout << i << "->" << cluster_labels[i] << "\n"; + int32_t max_cluster_index = + *std::max_element(cluster_labels.begin(), cluster_labels.end()); + + auto chunk_speaker_to_cluster = ConvertChunkSpeakerToCluster( + chunk_speaker_samples_list_pair.first, cluster_labels); + + auto new_labels = + ReLabel(labels, max_cluster_index, chunk_speaker_to_cluster); + + Matrix2DInt32 speaker_count = ComputeSpeakerCount(new_labels, n); + + Matrix2DInt32 final_labels = + FinalizeLabels(speaker_count, speakers_per_frame); + + auto result = ComputeResult(final_labels); + + for (const auto &r : result.segments_) { + std::cout << r.ToString() << "\n"; } return {}; @@ -219,7 +258,7 @@ class OfflineSpeakerDiarizationPyannoteImpl // See also // https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/pipelines/utils/diarization.py#L122 - Int32RowVector ComputeSpeakerCount( + Int32RowVector ComputeSpeakersPerFrame( const std::vector &labels) const { const auto &meta_data = segmentation_model_.GetModelMetaData(); int32_t window_size = meta_data.window_size; @@ -287,17 +326,17 @@ class OfflineSpeakerDiarizationPyannoteImpl Int32Pair this_chunk_speaker = {chunk_index, speaker_index}; std::vector this_speaker_samples; - bool started = false; + bool is_active = false; int32_t start_index; for (int32_t k = 0; k != num_frames; ++k) { if (d[k] != 0) { - if (!started) { - started = true; + if (!is_active) { + is_active = true; start_index = k; } - } else if (started) { - started = false; + } else if (is_active) { + is_active = false; int32_t start_samples = static_cast(start_index) / num_frames * window_size + @@ -310,7 +349,7 @@ class OfflineSpeakerDiarizationPyannoteImpl } } - if (started) { + if (is_active) { int32_t start_samples = static_cast(start_index) / num_frames * window_size + sample_offset; @@ -403,6 +442,204 @@ class OfflineSpeakerDiarizationPyannoteImpl return ans; } + std::unordered_map ConvertChunkSpeakerToCluster( + const std::vector &chunk_speaker_pair, + const std::vector &cluster_labels) const { + std::unordered_map ans; + + int32_t k = 0; + for (const auto &p : chunk_speaker_pair) { + ans[p] = cluster_labels[k]; + k += 1; + } + + return ans; + } + + std::vector ReLabel( + const std::vector &labels, int32_t max_cluster_index, + std::unordered_map chunk_speaker_to_cluster) + const { + std::vector new_labels; + new_labels.reserve(labels.size()); + + int32_t chunk_index = 0; + for (const auto &label : labels) { + Matrix2DInt32 new_label(label.rows(), max_cluster_index + 1); + new_label.setZero(); + + Matrix2DInt32 t = label.transpose(); + // t: (num_speakers, num_frames) + + for (int32_t speaker_index = 0; speaker_index != t.rows(); + ++speaker_index) { + if (chunk_speaker_to_cluster.count({chunk_index, speaker_index}) == 0) { + continue; + } + + int32_t new_speaker_index = + chunk_speaker_to_cluster.at({chunk_index, speaker_index}); + + for (int32_t k = 0; k != t.cols(); ++k) { + if (t(speaker_index, k) == 1) { + new_label(k, new_speaker_index) = 1; + } + } + } + + std::cout << "chunk " << chunk_index << ", " << new_label.colwise().sum() + << "\n"; + + new_labels.push_back(std::move(new_label)); + + chunk_index += 1; + } + + return new_labels; + } + + Matrix2DInt32 ComputeSpeakerCount(const std::vector &labels, + int32_t num_samples) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + + int32_t num_chunks = labels.size(); + + int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) / + receptive_field_shift + + 1; + + Matrix2DInt32 count(num_frames, labels[0].cols()); + count.setZero(); + + for (int32_t i = 0; i != num_chunks; ++i) { + int32_t start = + static_cast(i) * window_shift / receptive_field_shift + 0.5; + + auto seq = Eigen::seqN(start, labels[i].rows()); + + count(seq, Eigen::all).array() += labels[i].array(); + } + + bool has_last_chunk = (num_samples - window_size) % window_shift > 0; + + if (has_last_chunk) { + return count; + } + + int32_t last_frame = num_samples / receptive_field_shift; + return count(Eigen::seq(0, last_frame), Eigen::all); + } + + Matrix2DInt32 FinalizeLabels(const Matrix2DInt32 &count, + const Int32RowVector &speakers_per_frame) const { + int32_t num_rows = count.rows(); + int32_t num_cols = count.cols(); + + Matrix2DInt32 ans(num_rows, num_cols); + ans.setZero(); + + for (int32_t i = 0; i != num_rows; ++i) { + int32_t k = speakers_per_frame[i]; + if (k == 0) { + continue; + } + auto top_k = TopkIndex(&count(i, 0), num_cols, k); + + for (int32_t m : top_k) { + ans(i, m) = 1; + } + } + + return ans; + } + + OfflineSpeakerDiarizationResult ComputeResult( + const Matrix2DInt32 &final_labels) const { + Matrix2DInt32 final_labels_t = final_labels.transpose(); + int32_t num_speakers = final_labels_t.rows(); + int32_t num_frames = final_labels_t.cols(); + + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + int32_t receptive_field_size = meta_data.receptive_field_size; + int32_t sample_rate = meta_data.sample_rate; + + float scale = static_cast(receptive_field_shift) / sample_rate; + float scale_offset = 0.5 * receptive_field_size / sample_rate; + + OfflineSpeakerDiarizationResult ans; + + for (int32_t speaker_index = 0; speaker_index != num_speakers; + ++speaker_index) { + std::vector this_speaker; + + bool is_active = final_labels_t(speaker_index, 0) > 0; + int32_t start_index = is_active ? 0 : -1; + + for (int32_t frame_index = 1; frame_index != num_frames; ++frame_index) { + if (is_active) { + if (final_labels_t(speaker_index, frame_index) == 0) { + float start_time = start_index * scale + scale_offset; + float end_time = frame_index * scale + scale_offset; + + OfflineSpeakerDiarizationSegment segment(start_time, end_time, + speaker_index); + this_speaker.push_back(segment); + + is_active = false; + } + } else if (final_labels_t(speaker_index, frame_index) == 1) { + is_active = true; + start_index = frame_index; + } + } + + if (is_active) { + float start_time = start_index * scale + scale_offset; + float end_time = (num_frames - 1) * scale + scale_offset; + + OfflineSpeakerDiarizationSegment segment(start_time, end_time, + speaker_index); + this_speaker.push_back(segment); + } + + // merge segments if the gap between them is less than min_duration_off + MergeSegments(&this_speaker); + + for (const auto &seg : this_speaker) { + if (seg.Duration() > config_.min_duration_on) { + ans.Add(seg); + } + } + } // for (int32_t speaker_index = 0; speaker_index != num_speakers; + + return ans; + } + + void MergeSegments( + std::vector *segments) const { + float min_duration_off = config_.min_duration_off; + bool changed = true; + while (changed) { + changed = false; + for (int32_t i = 0; i < static_cast(segments->size()) - 1; ++i) { + auto s = (*segments)[i].Merge((*segments)[i + 1], min_duration_off); + if (s) { + (*segments)[i] = s.value(); + segments->erase(segments->begin() + i + 1); + + changed = true; + break; + } + } + } + } + private: OfflineSpeakerDiarizationConfig config_; OfflineSpeakerSegmentationPyannoteModel segmentation_model_; diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-result.cc b/sherpa-onnx/csrc/offline-speaker-diarization-result.cc index c36d793ac..46399e451 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-result.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization-result.cc @@ -4,6 +4,8 @@ #include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" +#include + #include "sherpa-onnx/csrc/macros.h" namespace sherpa_onnx { @@ -20,4 +22,41 @@ OfflineSpeakerDiarizationSegment::OfflineSpeakerDiarizationSegment( speaker_ = speaker; } +std::optional +OfflineSpeakerDiarizationSegment::Merge( + const OfflineSpeakerDiarizationSegment &other, float gap) const { + if (other.speaker_ != speaker_) { + SHERPA_ONNX_LOGE( + "The two segments should have the same speaker. this->speaker: %d, " + "other.speaker: %d", + speaker_, other.speaker_); + return std::nullopt; + } + + if (end_ < other.start_ && end_ + gap >= other.start_) { + return OfflineSpeakerDiarizationSegment(start_, other.end_, speaker_); + } else if (other.end_ < start_ && other.end_ + gap >= start_) { + return OfflineSpeakerDiarizationSegment(other.start_, end_, speaker_); + } else { + return std::nullopt; + } +} + +std::string OfflineSpeakerDiarizationSegment::ToString() const { + char s[128]; + int32_t n = snprintf(s, sizeof(s), "%.3f -- %.3f speaker_%02d", start_, end_, + speaker_); + + return {&s[0]}; +} + +void OfflineSpeakerDiarizationResult::Add( + const OfflineSpeakerDiarizationSegment &segment) { + segments_.push_back(segment); +} + +int32_t OfflineSpeakerDiarizationResult::NumSegments() const { + return segments_.size(); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-result.h b/sherpa-onnx/csrc/offline-speaker-diarization-result.h index 942ecf7e5..1037bf640 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-result.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-result.h @@ -7,6 +7,7 @@ #include #include +#include #include namespace sherpa_onnx { @@ -20,6 +21,10 @@ class OfflineSpeakerDiarizationSegment { std::optional Merge( const OfflineSpeakerDiarizationSegment &other, float gap) const; + float Duration() const { return end_ - start_; } + + std::string ToString() const; + private: float start_; // in seconds float end_; // in seconds @@ -34,6 +39,8 @@ class OfflineSpeakerDiarizationResult { // Number of distinct speakers contained in this object at this point int32_t NumSpeakers() const; + int32_t NumSegments() const; + // Return a list of segments sorted by segment.start time std::vector SortByStartTime() const; @@ -42,7 +49,7 @@ class OfflineSpeakerDiarizationResult { std::vector> SortBySpeaker() const; - private: + public: std::vector segments_; }; diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.cc b/sherpa-onnx/csrc/offline-speaker-diarization.cc index df2426a10..8a4a54cc8 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization.cc @@ -19,6 +19,15 @@ void OfflineSpeakerDiarizationConfig::Register(ParseOptions *po) { ParseOptions po_clustering("clustering", po); clustering.Register(&po_clustering); + + po->Register("min-duration-on", &min_duration_on, + "if a segment is less than this value, then it is discarded. " + "Set it to 0 so that no segment is discarded"); + + po->Register("min-duration-off", &min_duration_off, + "if the gap between to segments of the same speaker is less " + "than this value, then these two segments are merged into a " + "single segment. We do it recursively."); } bool OfflineSpeakerDiarizationConfig::Validate() const { @@ -43,7 +52,9 @@ std::string OfflineSpeakerDiarizationConfig::ToString() const { os << "OfflineSpeakerDiarizationConfig("; os << "segmentation=" << segmentation.ToString() << ", "; os << "embedding=" << embedding.ToString() << ", "; - os << "clustering=" << clustering.ToString() << ")"; + os << "clustering=" << clustering.ToString() << ", "; + os << "min_duration_on=" << min_duration_on << ", "; + os << "min_duration_off=" << min_duration_off << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.h b/sherpa-onnx/csrc/offline-speaker-diarization.h index 41968fe99..8c52fa269 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization.h @@ -21,6 +21,14 @@ struct OfflineSpeakerDiarizationConfig { SpeakerEmbeddingExtractorConfig embedding; FastClusteringConfig clustering; + // if a segment is less than this value, then it is discarded + float min_duration_on = 0.3; // in seconds + + // if the gap between to segments of the same speaker is less than this value, + // then these two segments are merged into a single segment. + // We do this recursively. + float min_duration_off = 0.5; // in seconds + OfflineSpeakerDiarizationConfig() = default; OfflineSpeakerDiarizationConfig( From ee263e620c581b165e56bfc10514a45b5910eca0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 8 Oct 2024 20:35:10 +0800 Subject: [PATCH 12/17] refactor --- .../offline-speaker-diarization-pyannote-impl.h | 16 +--------------- .../sherpa-onnx-offline-speaker-diarization.cc | 8 +++++++- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index 8d94e4d21..62af2e6ec 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -75,10 +75,6 @@ class OfflineSpeakerDiarizationPyannoteImpl return {}; } - std::cout << "segmentations.size() " << segmentations.size() << "---" - << segmentations[0].rows() << ", " << segmentations[1].cols() - << "\n"; - std::vector labels; labels.reserve(segmentations.size()); @@ -92,8 +88,6 @@ class OfflineSpeakerDiarizationPyannoteImpl // speaker count per frame Int32RowVector speakers_per_frame = ComputeSpeakersPerFrame(labels); - std::cout << "speaker count: " << speakers_per_frame.cast().sum() - << ", " << speakers_per_frame.cast().mean() << "\n"; if (speakers_per_frame.maxCoeff() == 0) { SHERPA_ONNX_LOGE("No speakers found in the audio samples"); @@ -124,11 +118,7 @@ class OfflineSpeakerDiarizationPyannoteImpl auto result = ComputeResult(final_labels); - for (const auto &r : result.segments_) { - std::cout << r.ToString() << "\n"; - } - - return {}; + return result; } private: @@ -252,7 +242,6 @@ class OfflineSpeakerDiarizationPyannoteImpl ans.row(i) = powerset_mapping_.row(col_id); } - std::cout << "sum labels: " << ans.colwise().sum() << "\n"; return ans; } @@ -487,9 +476,6 @@ class OfflineSpeakerDiarizationPyannoteImpl } } - std::cout << "chunk " << chunk_index << ", " << new_label.colwise().sum() - << "\n"; - new_labels.push_back(std::move(new_label)); chunk_index += 1; diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc index af95ea002..55e6f9a48 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc @@ -63,7 +63,12 @@ Usage example: // sd.Process(samples.data(), samples.size() < 160000 ? samples.size() : // 160000); - sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr); + auto result = + sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr); + + for (const auto &r : result.segments_) { + std::cout << r.ToString() << "\n"; + } const auto end = std::chrono::steady_clock::now(); float elapsed_seconds = @@ -71,6 +76,7 @@ Usage example: .count() / 1000.; + fprintf(stderr, "Duration : %.3f s\n", duration); fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); float rtf = elapsed_seconds / duration; fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", From e972b2e05f7a6a42d8e2bc355aa2adc81e344aec Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 8 Oct 2024 20:46:48 +0800 Subject: [PATCH 13/17] change optset to 13 --- scripts/pyannote/segmentation/export-onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/pyannote/segmentation/export-onnx.py b/scripts/pyannote/segmentation/export-onnx.py index 7ebcae960..feb241a26 100755 --- a/scripts/pyannote/segmentation/export-onnx.py +++ b/scripts/pyannote/segmentation/export-onnx.py @@ -72,7 +72,7 @@ def main(): model.receptive_field.duration * 16000 ) - opset_version = 15 + opset_version = 13 filename = "model.onnx" torch.onnx.export( From 24ec150d4429bb673256d10cf08b997c2651e8d9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 8 Oct 2024 20:51:36 +0800 Subject: [PATCH 14/17] use onnxruntime 1.16.3 and onnx 1.15.0 to export pyannote models --- .github/workflows/export-pyannote-segmentation-to-onnx.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/export-pyannote-segmentation-to-onnx.yaml b/.github/workflows/export-pyannote-segmentation-to-onnx.yaml index 300aca500..ece0ffa28 100644 --- a/.github/workflows/export-pyannote-segmentation-to-onnx.yaml +++ b/.github/workflows/export-pyannote-segmentation-to-onnx.yaml @@ -29,7 +29,7 @@ jobs: - name: Install pyannote shell: bash run: | - pip install pyannote.audio onnx onnxruntime + pip install pyannote.audio onnx==1.15.0 onnxruntime==1.16.3 - name: Run shell: bash From 613df17b0a863b9b8f7e3036a89dacdc8b6c5a5f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 8 Oct 2024 21:01:59 +0800 Subject: [PATCH 15/17] first working version --- ...sherpa-onnx-offline-speaker-diarization.cc | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc index 55e6f9a48..5faec2677 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc @@ -20,6 +20,55 @@ int main(int32_t argc, char *argv[]) { Offline/Non-streaming speaker diarization with sherpa-onnx Usage example: +Step 1: Download a speaker segmentation model + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + +Step 2: Download a speaker embedding extractor model + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + +Step 3. Download test wave files + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +for a list of available test wave files. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-two-speakers-zh.wav + +Step 4. Build sherpa-onnx + +Step 5. Run it + + ./bin/sherpa-onnx-offline-speaker-diarization \ + --clustering.num-clusters=2 \ + --segmentation.debug=0 \ + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ + --embedding.model=../3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ + ./0-two-speakers-zh.wav + +Since we know that there are two speakers in the test wave file, we use +--clustering.num-clusters=2 in the above example. + +If we don't know number of speakers in the given wave file, we can use +the argument --clustering.cluster-threshold. The following is an example: + + ./bin/sherpa-onnx-offline-speaker-diarization \ + --clustering.cluster-threshold=0.75 \ + --segmentation.debug=0 \ + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ + --embedding.model=../3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ + ./0-two-speakers-zh.wav + +A larger threshold leads to few clusters, i.e., few speakers; +a smaller threshold leads to more clusters, i.e., more speakers )usage"; sherpa_onnx::OfflineSpeakerDiarizationConfig config; sherpa_onnx::ParseOptions po(kUsageMessage); From b6dea7e3f9aedcc083bbb9c41f8e9c1c1ff2d01f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 9 Oct 2024 11:30:41 +0800 Subject: [PATCH 16/17] Add CI tests --- .github/scripts/test-speaker-diarization.sh | 41 ++++++++++++++ .github/workflows/linux.yaml | 11 ++++ .github/workflows/macos.yaml | 11 ++++ .github/workflows/speaker-diarization.yaml | 2 +- .github/workflows/windows-x64.yaml | 11 ++++ .github/workflows/windows-x86.yaml | 11 ++++ scripts/pyannote/segmentation/README.md | 7 +-- .../csrc/offline-speaker-diarization-impl.h | 2 + ...ffline-speaker-diarization-pyannote-impl.h | 8 ++- .../offline-speaker-diarization-result.cc | 56 +++++++++++++++++-- .../csrc/offline-speaker-diarization-result.h | 15 +++-- .../csrc/offline-speaker-diarization.cc | 4 ++ .../csrc/offline-speaker-diarization.h | 3 + ...sherpa-onnx-offline-speaker-diarization.cc | 32 +++++------ 14 files changed, 182 insertions(+), 32 deletions(-) create mode 100755 .github/scripts/test-speaker-diarization.sh diff --git a/.github/scripts/test-speaker-diarization.sh b/.github/scripts/test-speaker-diarization.sh new file mode 100755 index 000000000..6d7b2effd --- /dev/null +++ b/.github/scripts/test-speaker-diarization.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +echo "EXE is $EXE" +echo "PATH: $PATH" + +which $EXE + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 +tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 +rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav + +log "specify number of clusters" +$EXE \ + --clustering.num-clusters=4 \ + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ + ./0-four-speakers-zh.wav + +log "specify threshold for clustering" + +$EXE \ + --clustering.cluster-threshold=0.90 \ + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ + ./0-four-speakers-zh.wav + +rm -rf sherpa-onnx-pyannote-* +rm -fv *.onnx +rm -fv *.wav diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index 0e1eca099..1d3e8dc7b 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -18,6 +18,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -38,6 +39,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -143,6 +145,15 @@ jobs: name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} path: install/* + - name: Test offline speaker diarization + shell: bash + run: | + du -h -d1 . + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-speaker-diarization + + .github/scripts/test-speaker-diarization.sh + - name: Test offline transducer shell: bash run: | diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index 084531e4a..f3d70f583 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -18,6 +18,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -37,6 +38,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -115,6 +117,15 @@ jobs: otool -L build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx + - name: Test offline speaker diarization + shell: bash + run: | + du -h -d1 . + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-speaker-diarization + + .github/scripts/test-speaker-diarization.sh + - name: Test offline transducer shell: bash run: | diff --git a/.github/workflows/speaker-diarization.yaml b/.github/workflows/speaker-diarization.yaml index 0bd6a575c..ab2a4f090 100644 --- a/.github/workflows/speaker-diarization.yaml +++ b/.github/workflows/speaker-diarization.yaml @@ -67,7 +67,7 @@ jobs: curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin test_wavs=( - 0-two-speakers-zh.wav + 0-four-speakers-zh.wav 1-two-speakers-en.wav 2-two-speakers-en.wav 3-two-speakers-en.wav diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index 2d2811c31..c67f3e0b5 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -17,6 +17,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -34,6 +35,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -87,6 +89,15 @@ jobs: name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }} path: build/install/* + - name: Test offline speaker diarization + shell: bash + run: | + du -h -d1 . + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-speaker-diarization.exe + + .github/scripts/test-speaker-diarization.sh + - name: Test online punctuation shell: bash run: | diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index 316cef626..30394e90e 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -17,6 +17,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -34,6 +35,7 @@ on: - '.github/scripts/test-audio-tagging.sh' - '.github/scripts/test-offline-punctuation.sh' - '.github/scripts/test-online-punctuation.sh' + - '.github/scripts/test-speaker-diarization.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -87,6 +89,15 @@ jobs: name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }} path: build/install/* + - name: Test offline speaker diarization + shell: bash + run: | + du -h -d1 . + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-speaker-diarization.exe + + .github/scripts/test-speaker-diarization.sh + - name: Test online punctuation shell: bash run: | diff --git a/scripts/pyannote/segmentation/README.md b/scripts/pyannote/segmentation/README.md index 689b7a9be..a9c5230d1 100644 --- a/scripts/pyannote/segmentation/README.md +++ b/scripts/pyannote/segmentation/README.md @@ -3,12 +3,9 @@ Please download test wave files from https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models -## 0-two-speakers-zh.wav +## 0-four-speakers-zh.wav -This file is from -https://www.modelscope.cn/models/iic/speech_campplus_speaker-diarization_common/file/view/master?fileName=examples%252F2speakers_example.wav&status=0 - -Note that we have renamed it from `2speakers_example.wav` to `0-two-speakers-zh.wav`. +It is recorded by @csukuangfj ## 1-two-speakers-en.wav diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-impl.h index 5a4579efb..f7fe39499 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-impl.h @@ -18,6 +18,8 @@ class OfflineSpeakerDiarizationImpl { virtual ~OfflineSpeakerDiarizationImpl() = default; + virtual int32_t SampleRate() const = 0; + virtual OfflineSpeakerDiarizationResult Process( const float *audio, int32_t n, OfflineSpeakerDiarizationProgressCallback callback = nullptr, diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index 62af2e6ec..bcd0c93a4 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -18,7 +18,7 @@ namespace sherpa_onnx { -namespace { +namespace { // NOLINT // copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L41 template @@ -64,6 +64,12 @@ class OfflineSpeakerDiarizationPyannoteImpl Init(); } + int32_t SampleRate() const override { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + + return meta_data.sample_rate; + } + OfflineSpeakerDiarizationResult Process( const float *audio, int32_t n, OfflineSpeakerDiarizationProgressCallback callback = nullptr, diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-result.cc b/sherpa-onnx/csrc/offline-speaker-diarization-result.cc index 46399e451..8bf83f5d9 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-result.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization-result.cc @@ -4,14 +4,18 @@ #include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" +#include +#include #include +#include +#include #include "sherpa-onnx/csrc/macros.h" namespace sherpa_onnx { OfflineSpeakerDiarizationSegment::OfflineSpeakerDiarizationSegment( - float start, float end, int32_t speaker) { + float start, float end, int32_t speaker, const std::string &text /*= {}*/) { if (start > end) { SHERPA_ONNX_LOGE("start %.3f should be less than end %.3f", start, end); SHERPA_ONNX_EXIT(-1); @@ -20,6 +24,7 @@ OfflineSpeakerDiarizationSegment::OfflineSpeakerDiarizationSegment( start_ = start; end_ = end; speaker_ = speaker; + text_ = text; } std::optional @@ -44,10 +49,16 @@ OfflineSpeakerDiarizationSegment::Merge( std::string OfflineSpeakerDiarizationSegment::ToString() const { char s[128]; - int32_t n = snprintf(s, sizeof(s), "%.3f -- %.3f speaker_%02d", start_, end_, - speaker_); + snprintf(s, sizeof(s), "%.3f -- %.3f speaker_%02d", start_, end_, speaker_); - return {&s[0]}; + std::ostringstream os; + os << s; + + if (!text_.empty()) { + os << " " << text_; + } + + return os.str(); } void OfflineSpeakerDiarizationResult::Add( @@ -55,8 +66,45 @@ void OfflineSpeakerDiarizationResult::Add( segments_.push_back(segment); } +int32_t OfflineSpeakerDiarizationResult::NumSpeakers() const { + std::unordered_set count; + for (const auto &s : segments_) { + count.insert(s.Speaker()); + } + + return count.size(); +} + int32_t OfflineSpeakerDiarizationResult::NumSegments() const { return segments_.size(); } +// Return a list of segments sorted by segment.start time +std::vector +OfflineSpeakerDiarizationResult::SortByStartTime() const { + auto ans = segments_; + std::sort(ans.begin(), ans.end(), [](const auto &a, const auto &b) { + return (a.Start() < b.Start()) || + ((a.Start() == b.Start()) && (a.Speaker() < b.Speaker())); + }); + + return ans; +} + +std::vector> +OfflineSpeakerDiarizationResult::SortBySpeaker() const { + auto tmp = segments_; + std::sort(tmp.begin(), tmp.end(), [](const auto &a, const auto &b) { + return (a.Speaker() < b.Speaker()) || + ((a.Speaker() == b.Speaker()) && (a.Start() < b.Start())); + }); + + std::vector> ans(NumSpeakers()); + for (auto &s : tmp) { + ans[s.Speaker()].push_back(std::move(s)); + } + + return ans; +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-result.h b/sherpa-onnx/csrc/offline-speaker-diarization-result.h index 1037bf640..e71d054e5 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-result.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-result.h @@ -14,21 +14,28 @@ namespace sherpa_onnx { class OfflineSpeakerDiarizationSegment { public: - OfflineSpeakerDiarizationSegment(float start, float end, int32_t speaker); + OfflineSpeakerDiarizationSegment(float start, float end, int32_t speaker, + const std::string &text = {}); // If the gap between the two segments is less than the given gap, then we // merge them and return a new segment. Otherwise, it returns null. std::optional Merge( const OfflineSpeakerDiarizationSegment &other, float gap) const; + float Start() const { return start_; } + float End() const { return end_; } + int32_t Speaker() const { return speaker_; } + const std::string &Text() const { return text_; } float Duration() const { return end_ - start_; } std::string ToString() const; private: - float start_; // in seconds - float end_; // in seconds - int32_t speaker_; // ID of the speaker, starting from 0 + float start_; // in seconds + float end_; // in seconds + int32_t speaker_; // ID of the speaker, starting from 0 + std::string text_; // If not empty, it contains the speech recognition result + // of this segment }; class OfflineSpeakerDiarizationResult { diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.cc b/sherpa-onnx/csrc/offline-speaker-diarization.cc index 8a4a54cc8..aeff9b42d 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization.cc @@ -65,6 +65,10 @@ OfflineSpeakerDiarization::OfflineSpeakerDiarization( OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default; +int32_t OfflineSpeakerDiarization::SampleRate() const { + return impl_->SampleRate(); +} + OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process( const float *audio, int32_t n, OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/, diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.h b/sherpa-onnx/csrc/offline-speaker-diarization.h index 8c52fa269..ab9a440aa 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization.h @@ -56,6 +56,9 @@ class OfflineSpeakerDiarization { ~OfflineSpeakerDiarization(); + // Expected sample rate of the input audio samples + int32_t SampleRate() const; + OfflineSpeakerDiarizationResult Process( const float *audio, int32_t n, OfflineSpeakerDiarizationProgressCallback callback = nullptr, diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc index 5faec2677..170973114 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc @@ -41,31 +41,29 @@ Step 3. Download test wave files Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models for a list of available test wave files. The following is an example - wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-two-speakers-zh.wav + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav Step 4. Build sherpa-onnx Step 5. Run it ./bin/sherpa-onnx-offline-speaker-diarization \ - --clustering.num-clusters=2 \ - --segmentation.debug=0 \ + --clustering.num-clusters=4 \ --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ - --embedding.model=../3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ - ./0-two-speakers-zh.wav + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ + ./0-four-speakers-zh.wav -Since we know that there are two speakers in the test wave file, we use ---clustering.num-clusters=2 in the above example. +Since we know that there are four speakers in the test wave file, we use +--clustering.num-clusters=4 in the above example. If we don't know number of speakers in the given wave file, we can use the argument --clustering.cluster-threshold. The following is an example: ./bin/sherpa-onnx-offline-speaker-diarization \ - --clustering.cluster-threshold=0.75 \ - --segmentation.debug=0 \ + --clustering.cluster-threshold=0.90 \ --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ - --embedding.model=../3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ - ./0-two-speakers-zh.wav + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ + ./0-four-speakers-zh.wav A larger threshold leads to few clusters, i.e., few speakers; a smaller threshold leads to more clusters, i.e., more speakers @@ -103,19 +101,19 @@ a smaller threshold leads to more clusters, i.e., more speakers return -1; } - if (sample_rate != 16000) { - std::cerr << "Expect sample rate 16000. Given: " << sample_rate << "\n"; + if (sample_rate != sd.SampleRate()) { + std::cerr << "Expect sample rate " << sd.SampleRate() + << ". Given: " << sample_rate << "\n"; return -1; } float duration = samples.size() / static_cast(sample_rate); - // sd.Process(samples.data(), samples.size() < 160000 ? samples.size() : - // 160000); auto result = - sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr); + sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr) + .SortByStartTime(); - for (const auto &r : result.segments_) { + for (const auto &r : result) { std::cout << r.ToString() << "\n"; } From 1eb1a501f46809f792072edf7db8cccbb5177930 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 9 Oct 2024 11:47:51 +0800 Subject: [PATCH 17/17] fix typos --- sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc b/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc index ea83cfaaf..faca83b98 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc @@ -19,7 +19,7 @@ The input text can contain English words. Usage: Please download the model from: -https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 ./bin/Release/sherpa-onnx-online-punctuation \ --cnn-bilstm=/path/to/model.onnx \