Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Android JNI support for speaker diarization #1421

Merged
merged 1 commit into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,18 @@ OfflineSpeakerDiarizationImpl::Create(
return nullptr;
}

#if __ANDROID_API__ >= 9
std::unique_ptr<OfflineSpeakerDiarizationImpl>
OfflineSpeakerDiarizationImpl::Create(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config) {
if (!config.segmentation.pyannote.model.empty()) {
return std::make_unique<OfflineSpeakerDiarizationPyannoteImpl>(mgr, config);
}

SHERPA_ONNX_LOGE("Please specify a speaker segmentation model.");

return nullptr;
}
#endif

} // namespace sherpa_onnx
10 changes: 10 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
#include <functional>
#include <memory>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
namespace sherpa_onnx {

Expand All @@ -16,6 +21,11 @@ class OfflineSpeakerDiarizationImpl {
static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create(
const OfflineSpeakerDiarizationConfig &config);

#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config);
#endif

virtual ~OfflineSpeakerDiarizationImpl() = default;

virtual int32_t SampleRate() const = 0;
Expand Down
16 changes: 16 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "Eigen/Dense"
#include "sherpa-onnx/csrc/fast-clustering.h"
#include "sherpa-onnx/csrc/math.h"
Expand Down Expand Up @@ -65,6 +70,17 @@ class OfflineSpeakerDiarizationPyannoteImpl
Init();
}

#if __ANDROID_API__ >= 9
OfflineSpeakerDiarizationPyannoteImpl(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config)
: config_(config),
segmentation_model_(mgr, config_.segmentation),
embedding_extractor_(mgr, config_.embedding),
clustering_(std::make_unique<FastClustering>(config_.clustering)) {
Init();
}
#endif

int32_t SampleRate() const override {
const auto &meta_data = segmentation_model_.GetModelMetaData();

Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ OfflineSpeakerDiarization::OfflineSpeakerDiarization(
const OfflineSpeakerDiarizationConfig &config)
: impl_(OfflineSpeakerDiarizationImpl::Create(config)) {}

#if __ANDROID_API__ >= 9
OfflineSpeakerDiarization::OfflineSpeakerDiarization(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config)
: impl_(OfflineSpeakerDiarizationImpl::Create(mgr, config)) {}
#endif

OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default;

int32_t OfflineSpeakerDiarization::SampleRate() const {
Expand Down
10 changes: 10 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
#include <memory>
#include <string>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/fast-clustering-config.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
Expand Down Expand Up @@ -57,6 +62,11 @@ class OfflineSpeakerDiarization {
explicit OfflineSpeakerDiarization(
const OfflineSpeakerDiarizationConfig &config);

#if __ANDROID_API__ >= 9
OfflineSpeakerDiarization(AAssetManager *mgr,
const OfflineSpeakerDiarizationConfig &config);
#endif

~OfflineSpeakerDiarization();

// Expected sample rate of the input audio samples
Expand Down
18 changes: 18 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ class OfflineSpeakerSegmentationPyannoteModel::Impl {
Init(buf.data(), buf.size());
}

#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.pyannote.model);
Init(buf.data(), buf.size());
}
#endif

const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
const {
return meta_data_;
Expand Down Expand Up @@ -92,6 +103,13 @@ OfflineSpeakerSegmentationPyannoteModel::
const OfflineSpeakerSegmentationModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}

#if __ANDROID_API__ >= 9
OfflineSpeakerSegmentationPyannoteModel::
OfflineSpeakerSegmentationPyannoteModel(
AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif

OfflineSpeakerSegmentationPyannoteModel::
~OfflineSpeakerSegmentationPyannoteModel() = default;

Expand Down
10 changes: 10 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

#include <memory>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h"
Expand All @@ -17,6 +22,11 @@ class OfflineSpeakerSegmentationPyannoteModel {
explicit OfflineSpeakerSegmentationPyannoteModel(
const OfflineSpeakerSegmentationModelConfig &config);

#if __ANDROID_API__ >= 9
OfflineSpeakerSegmentationPyannoteModel(
AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config);
#endif

~OfflineSpeakerSegmentationPyannoteModel();

const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/sherpa-onnx-vad-microphone-offline-asr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ to download models for offline ASR.
}

while (!vad->Empty()) {
auto &segment = vad->Front();
const auto &segment = vad->Front();
auto s = recognizer.CreateStream();
s->AcceptWaveform(sample_rate, segment.samples.data(),
segment.samples.size());
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/jni/audio-tagging.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif

Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/jni/keyword-spotter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetKwsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

auto kws = new sherpa_onnx::KeywordSpotter(
#if __ANDROID_API__ >= 9
mgr,
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/jni/offline-punctuation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ Java_com_k2fsa_sherpa_onnx_OfflinePunctuation_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOfflinePunctuationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

auto model = new sherpa_onnx::OfflinePunctuation(
#if __ANDROID_API__ >= 9
mgr,
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/jni/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromAsset(JNIEnv *env,
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOfflineConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

auto model = new sherpa_onnx::OfflineRecognizer(
#if __ANDROID_API__ >= 9
mgr,
Expand Down
19 changes: 18 additions & 1 deletion sherpa-onnx/jni/offline-speaker-diarization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,24 @@ SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
return 0;
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif

auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

auto sd = new sherpa_onnx::OfflineSpeakerDiarization(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);

return (jlong)sd;
}

SHERPA_ONNX_EXTERN_C
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/jni/offline-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/jni/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_newFromAsset(JNIEnv *env,
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetConfig(env, _config);
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/jni/speaker-embedding-extractor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/jni/spoken-language-identification.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif

Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/jni/voice-activity-detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetVadModelConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

auto model = new sherpa_onnx::VoiceActivityDetector(
#if __ANDROID_API__ >= 9
mgr,
Expand Down
Loading