-
Notifications
You must be signed in to change notification settings - Fork 445
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
begin to add pyannote segmentation model
- Loading branch information
1 parent
dd65064
commit 23fd7ec
Showing
19 changed files
with
241 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 31 additions & 1 deletion
32
sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,33 @@ | ||
// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.cc | ||
// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ | ||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ | ||
|
||
#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" | ||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h" | ||
|
||
namespace sherpa_onnx { | ||
class OfflineSpeakerDiarizationPyannoteImpl | ||
: public OfflineSpeakerDiarizationImpl { | ||
public: | ||
~OfflineSpeakerDiarizationPyannoteImpl() override = default; | ||
|
||
explicit OfflineSpeakerDiarizationPyannoteImpl( | ||
const OfflineSpeakerDiarizationConfig &config) | ||
: config_(config), segmentation_model_(config_.segmentation) {} | ||
|
||
OfflineSpeakerDiarizationResult Process( | ||
const float *audio, int32_t n, | ||
OfflineSpeakerDiarizationProgressCallback callback = | ||
nullptr) const override { | ||
return {}; | ||
} | ||
|
||
private: | ||
OfflineSpeakerDiarizationConfig config_; | ||
OfflineSpeakerSegmentationPyannoteModel segmentation_model_; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ | ||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ | ||
|
||
#include <cstdint> | ||
#include <string> | ||
|
||
namespace sherpa_onnx { | ||
|
||
// If you are not sure what each field means, please | ||
// have a look of the Python file in the model directory that | ||
// you have downloaded. | ||
struct OfflineSpeakerSegmentationPyannoteModelMetaData { | ||
int32_t sample_rate = 0; | ||
int32_t window_size = 0; // in samples | ||
int32_t window_shift = 0; // in samples | ||
int32_t receptive_field_size = 0; // in samples | ||
int32_t receptive_field_shift = 0; // in samples | ||
int32_t num_speakers = 0; | ||
int32_t powerset_max_classes = 0; | ||
int32_t num_classes = 0; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ |
92 changes: 92 additions & 0 deletions
92
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
|
||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h" | ||
|
||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "sherpa-onnx/csrc/onnx-utils.h" | ||
#include "sherpa-onnx/csrc/session.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
class OfflineSpeakerSegmentationPyannoteModel::Impl { | ||
public: | ||
explicit Impl(const OfflineSpeakerSegmentationModelConfig &config) | ||
: config_(config), | ||
env_(ORT_LOGGING_LEVEL_ERROR), | ||
sess_opts_(GetSessionOptions(config)), | ||
allocator_{} { | ||
auto buf = ReadFile(config_.pyannote.model); | ||
Init(buf.data(), buf.size()); | ||
} | ||
|
||
Ort::Value Forward(Ort::Value x) { return Ort::Value(nullptr); } | ||
|
||
private: | ||
void Init(void *model_data, size_t model_data_length) { | ||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
sess_opts_); | ||
|
||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
|
||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
|
||
// get meta data | ||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
if (config_.debug) { | ||
std::ostringstream os; | ||
PrintModelMetadata(os, meta_data); | ||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
} | ||
|
||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "window_size"); | ||
|
||
meta_data_.window_shift = | ||
static_cast<int32_t>(0.1 * meta_data_.window_size); | ||
|
||
SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_size, | ||
"receptive_field_size"); | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_shift, | ||
"receptive_field_shift"); | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "num_speakers"); | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.powerset_max_classes, | ||
"powerset_max_classes"); | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.num_classes, "num_classes"); | ||
} | ||
|
||
private: | ||
OfflineSpeakerSegmentationModelConfig config_; | ||
Ort::Env env_; | ||
Ort::SessionOptions sess_opts_; | ||
Ort::AllocatorWithDefaultOptions allocator_; | ||
|
||
std::unique_ptr<Ort::Session> sess_; | ||
|
||
std::vector<std::string> input_names_; | ||
std::vector<const char *> input_names_ptr_; | ||
|
||
std::vector<std::string> output_names_; | ||
std::vector<const char *> output_names_ptr_; | ||
|
||
OfflineSpeakerSegmentationPyannoteModelMetaData meta_data_; | ||
}; | ||
|
||
OfflineSpeakerSegmentationPyannoteModel:: | ||
OfflineSpeakerSegmentationPyannoteModel( | ||
const OfflineSpeakerSegmentationModelConfig &config) | ||
: impl_(std::make_unique<Impl>(config)) {} | ||
|
||
OfflineSpeakerSegmentationPyannoteModel:: | ||
~OfflineSpeakerSegmentationPyannoteModel() = default; | ||
|
||
Ort::Value OfflineSpeakerSegmentationPyannoteModel::Forward(Ort::Value x) { | ||
return impl_->Forward(std::move(x)); | ||
} | ||
|
||
} // namespace sherpa_onnx |
39 changes: 39 additions & 0 deletions
39
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ | ||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ | ||
|
||
#include <memory> | ||
|
||
#include "onnxruntime_cxx_api.h" // NOLINT | ||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" | ||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
class OfflineSpeakerSegmentationPyannoteModel { | ||
public: | ||
explicit OfflineSpeakerSegmentationPyannoteModel( | ||
const OfflineSpeakerSegmentationModelConfig &config); | ||
|
||
~OfflineSpeakerSegmentationPyannoteModel(); | ||
|
||
const OfflineSpeakerSegmentationPyannoteModelMetaData &GetMetaData() const; | ||
|
||
/** | ||
* @param x A 3-D float tensor of shape (batch_size, 1, num_samples) | ||
* @return Return a float tensor of | ||
* shape (batch_size, num_frames, num_speakers). Note that | ||
* num_speakers here uses powerset encoding. | ||
*/ | ||
Ort::Value Forward(Ort::Value x); | ||
|
||
private: | ||
class Impl; | ||
std::unique_ptr<Impl> impl_; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.