Skip to content

Commit

Permalink
begin to add pyannote segmentation model
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Oct 7, 2024
1 parent dd65064 commit 23fd7ec
Show file tree
Hide file tree
Showing 19 changed files with 241 additions and 85 deletions.
2 changes: 1 addition & 1 deletion scripts/pyannote/segmentation/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main():
model.receptive_field.duration * 16000
)

opset_version = 18
opset_version = 15

filename = "model.onnx"
torch.onnx.export(
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
offline-speaker-diarization.cc
offline-speaker-segmentation-model-config.cc
offline-speaker-segmentation-pyannote-model-config.cc
offline-speaker-segmentation-pyannote-model.cc
)
endif()

Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/offline-sense-voice-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <utility>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"

Expand Down
9 changes: 9 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,20 @@

#include <memory>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h"

namespace sherpa_onnx {

std::unique_ptr<OfflineSpeakerDiarizationImpl>
OfflineSpeakerDiarizationImpl::Create(
const OfflineSpeakerDiarizationConfig &config) {
if (!config.segmentation.pyannote.model.empty()) {
return std::make_unique<OfflineSpeakerDiarizationPyannoteImpl>(config);
}

SHERPA_ONNX_LOGE("Please specify a speaker segmentation model.");

return nullptr;
}

Expand Down
32 changes: 31 additions & 1 deletion sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,33 @@
// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.cc
// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_

#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h"

namespace sherpa_onnx {
class OfflineSpeakerDiarizationPyannoteImpl
: public OfflineSpeakerDiarizationImpl {
public:
~OfflineSpeakerDiarizationPyannoteImpl() override = default;

explicit OfflineSpeakerDiarizationPyannoteImpl(
const OfflineSpeakerDiarizationConfig &config)
: config_(config), segmentation_model_(config_.segmentation) {}

OfflineSpeakerDiarizationResult Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback =
nullptr) const override {
return {};
}

private:
OfflineSpeakerDiarizationConfig config_;
OfflineSpeakerSegmentationPyannoteModel segmentation_model_;
};

} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
5 changes: 4 additions & 1 deletion sherpa-onnx/csrc/offline-speaker-diarization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "sherpa-onnx/csrc/offline-speaker-diarization.h"

#include <string>

#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"

namespace sherpa_onnx {
Expand Down Expand Up @@ -39,7 +41,8 @@ std::string OfflineSpeakerDiarizationConfig::ToString() const {
}

OfflineSpeakerDiarization::OfflineSpeakerDiarization(
const OfflineSpeakerDiarizationConfig &config) {}
const OfflineSpeakerDiarizationConfig &config)
: impl_(OfflineSpeakerDiarizationImpl::Create(config)) {}

OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default;

Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <functional>
#include <memory>
#include <string>

#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"

#include <sstream>
#include <string>

#include "sherpa-onnx/csrc/macros.h"

Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_

#include <string>

#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"

#include <sstream>
#include <string>

#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ struct OfflineSpeakerSegmentationPyannoteModelConfig {

OfflineSpeakerSegmentationPyannoteModelConfig() = default;

OfflineSpeakerSegmentationPyannoteModelConfig(const std::string &model)
: model(model){};
explicit OfflineSpeakerSegmentationPyannoteModelConfig(
const std::string &model)
: model(model) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_

#include <cstdint>
#include <string>

namespace sherpa_onnx {

// If you are not sure what each field means, please
// have a look of the Python file in the model directory that
// you have downloaded.
struct OfflineSpeakerSegmentationPyannoteModelMetaData {
int32_t sample_rate = 0;
int32_t window_size = 0; // in samples
int32_t window_shift = 0; // in samples
int32_t receptive_field_size = 0; // in samples
int32_t receptive_field_shift = 0; // in samples
int32_t num_speakers = 0;
int32_t powerset_max_classes = 0;
int32_t num_classes = 0;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_
92 changes: 92 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h"

#include <string>
#include <utility>
#include <vector>

#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"

namespace sherpa_onnx {

class OfflineSpeakerSegmentationPyannoteModel::Impl {
public:
explicit Impl(const OfflineSpeakerSegmentationModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config_.pyannote.model);
Init(buf.data(), buf.size());
}

Ort::Value Forward(Ort::Value x) { return Ort::Value(nullptr); }

private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);

GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);

GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);

// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}

Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "window_size");

meta_data_.window_shift =
static_cast<int32_t>(0.1 * meta_data_.window_size);

SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_size,
"receptive_field_size");
SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_shift,
"receptive_field_shift");
SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "num_speakers");
SHERPA_ONNX_READ_META_DATA(meta_data_.powerset_max_classes,
"powerset_max_classes");
SHERPA_ONNX_READ_META_DATA(meta_data_.num_classes, "num_classes");
}

private:
OfflineSpeakerSegmentationModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;

std::unique_ptr<Ort::Session> sess_;

std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;

std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;

OfflineSpeakerSegmentationPyannoteModelMetaData meta_data_;
};

OfflineSpeakerSegmentationPyannoteModel::
OfflineSpeakerSegmentationPyannoteModel(
const OfflineSpeakerSegmentationModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}

OfflineSpeakerSegmentationPyannoteModel::
~OfflineSpeakerSegmentationPyannoteModel() = default;

Ort::Value OfflineSpeakerSegmentationPyannoteModel::Forward(Ort::Value x) {
return impl_->Forward(std::move(x));
}

} // namespace sherpa_onnx
39 changes: 39 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_

#include <memory>

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h"

namespace sherpa_onnx {

class OfflineSpeakerSegmentationPyannoteModel {
public:
explicit OfflineSpeakerSegmentationPyannoteModel(
const OfflineSpeakerSegmentationModelConfig &config);

~OfflineSpeakerSegmentationPyannoteModel();

const OfflineSpeakerSegmentationPyannoteModelMetaData &GetMetaData() const;

/**
* @param x A 3-D float tensor of shape (batch_size, 1, num_samples)
* @return Return a float tensor of
* shape (batch_size, num_frames, num_speakers). Note that
* num_speakers here uses powerset encoding.
*/
Ort::Value Forward(Ort::Value x);

private:
class Impl;
std::unique_ptr<Impl> impl_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_
6 changes: 4 additions & 2 deletions sherpa-onnx/csrc/provider-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ void TensorrtConfig::Register(ParseOptions *po) {

bool TensorrtConfig::Validate() const {
if (trt_max_workspace_size < 0) {
SHERPA_ONNX_LOGE("trt_max_workspace_size: %ld is not valid.",
trt_max_workspace_size);
std::ostringstream os;
os << "trt_max_workspace_size: " << trt_max_workspace_size
<< " is not valid.";
SHERPA_ONNX_LOGE("%s", os.str().c_str());
return false;
}
if (trt_max_partition_iterations < 0) {
Expand Down
42 changes: 2 additions & 40 deletions sherpa-onnx/csrc/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) {
api.ReleaseStatus(status);
}

static Ort::SessionOptions GetSessionOptionsImpl(
Ort::SessionOptions GetSessionOptionsImpl(
int32_t num_threads, const std::string &provider_str,
const ProviderConfig *provider_config = nullptr) {
const ProviderConfig *provider_config /*= nullptr*/) {
Provider p = StringToProvider(provider_str);

Ort::SessionOptions sess_opts;
Expand Down Expand Up @@ -259,10 +259,6 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
&config.provider_config);
}

Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}

Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) {
return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider);
}
Expand All @@ -271,38 +267,4 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) {
return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider);
}

Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}

#if SHERPA_ONNX_ENABLE_TTS
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
#endif

Ort::SessionOptions GetSessionOptions(
const SpeakerEmbeddingExtractorConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}

Ort::SessionOptions GetSessionOptions(
const SpokenLanguageIdentificationConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}

Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}

Ort::SessionOptions GetSessionOptions(
const OfflinePunctuationModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}

Ort::SessionOptions GetSessionOptions(
const OnlinePunctuationModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}

} // namespace sherpa_onnx
Loading

0 comments on commit 23fd7ec

Please sign in to comment.