From 2d412b1190778bc35f337ef1feeb12292b5c9f92 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 11 Oct 2024 14:41:53 +0800 Subject: [PATCH] Kotlin API for speaker diarization (#1415) --- .../OfflineSpeakerDiarization.kt | 1 + kotlin-api-examples/run.sh | 31 +++ .../test_offline_speaker_diarization.kt | 53 +++++ .../csrc/offline-speaker-diarization-result.h | 2 +- sherpa-onnx/jni/CMakeLists.txt | 6 + .../jni/offline-speaker-diarization.cc | 219 ++++++++++++++++++ .../kotlin-api/OfflineSpeakerDiarization.kt | 101 ++++++++ 7 files changed, 412 insertions(+), 1 deletion(-) create mode 120000 kotlin-api-examples/OfflineSpeakerDiarization.kt create mode 100644 kotlin-api-examples/test_offline_speaker_diarization.kt create mode 100644 sherpa-onnx/jni/offline-speaker-diarization.cc create mode 100644 sherpa-onnx/kotlin-api/OfflineSpeakerDiarization.kt diff --git a/kotlin-api-examples/OfflineSpeakerDiarization.kt b/kotlin-api-examples/OfflineSpeakerDiarization.kt new file mode 120000 index 000000000..870612b4c --- /dev/null +++ b/kotlin-api-examples/OfflineSpeakerDiarization.kt @@ -0,0 +1 @@ +../sherpa-onnx/kotlin-api/OfflineSpeakerDiarization.kt \ No newline at end of file diff --git a/kotlin-api-examples/run.sh b/kotlin-api-examples/run.sh index 23e86886e..50e7816f1 100755 --- a/kotlin-api-examples/run.sh +++ b/kotlin-api-examples/run.sh @@ -285,6 +285,37 @@ function testPunctuation() { java -Djava.library.path=../build/lib -jar $out_filename } +function testOfflineSpeakerDiarization() { + if [ ! -f ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + fi + + if [ ! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + fi + + if [ ! -f ./0-four-speakers-zh.wav ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav + fi + + out_filename=test_offline_speaker_diarization.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_offline_speaker_diarization.kt \ + OfflineSpeakerDiarization.kt \ + Speaker.kt \ + OnlineStream.kt \ + WaveReader.kt \ + faked-asset-manager.kt \ + faked-log.kt + + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename +} + +testOfflineSpeakerDiarization testSpeakerEmbeddingExtractor testOnlineAsr testTts diff --git a/kotlin-api-examples/test_offline_speaker_diarization.kt b/kotlin-api-examples/test_offline_speaker_diarization.kt new file mode 100644 index 000000000..96c33f062 --- /dev/null +++ b/kotlin-api-examples/test_offline_speaker_diarization.kt @@ -0,0 +1,53 @@ +package com.k2fsa.sherpa.onnx + +fun main() { + testOfflineSpeakerDiarization() +} + +fun callback(numProcessedChunks: Int, numTotalChunks: Int, arg: Long): Int { + val progress = numProcessedChunks.toFloat() / numTotalChunks * 100 + val s = "%.2f".format(progress) + println("Progress: ${s}%"); + + return 0 +} + +fun testOfflineSpeakerDiarization() { + var config = OfflineSpeakerDiarizationConfig( + segmentation=OfflineSpeakerSegmentationModelConfig( + pyannote=OfflineSpeakerSegmentationPyannoteModelConfig("./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"), + ), + embedding=SpeakerEmbeddingExtractorConfig( + model="./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx", + ), + + // The test wave file ./0-four-speakers-zh.wav contains four speakers, so + // we use numClusters=4 here. If you don't know the number of speakers + // in the test wave file, please set the threshold like below. + // + // clustering=FastClusteringConfig(threshold=0.5), + // + // WARNING: You need to tune threshold by yourself. + // A larger threshold leads to fewer clusters, i.e., few speakers. + // A smaller threshold leads to more clusters, i.e., more speakers. + // + clustering=FastClusteringConfig(numClusters=4), + ) + + val sd = OfflineSpeakerDiarization(config=config) + + val waveData = WaveReader.readWave( + filename = "./0-four-speakers-zh.wav", + ) + + if (sd.sampleRate() != waveData.sampleRate) { + println("Expected sample rate: ${sd.sampleRate()}, given: ${waveData.sampleRate}") + return + } + + // val segments = sd.process(waveData.samples) // this one is also ok + val segments = sd.processWithCallback(waveData.samples, callback=::callback) + for (segment in segments) { + println("${segment.start} -- ${segment.end} speaker_${segment.speaker}") + } +} diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-result.h b/sherpa-onnx/csrc/offline-speaker-diarization-result.h index 5fb144f5c..6298a87c7 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-result.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-result.h @@ -58,7 +58,7 @@ class OfflineSpeakerDiarizationResult { std::vector> SortBySpeaker() const; - public: + private: std::vector segments_; }; diff --git a/sherpa-onnx/jni/CMakeLists.txt b/sherpa-onnx/jni/CMakeLists.txt index 998379084..23544c177 100644 --- a/sherpa-onnx/jni/CMakeLists.txt +++ b/sherpa-onnx/jni/CMakeLists.txt @@ -33,6 +33,12 @@ if(SHERPA_ONNX_ENABLE_TTS) ) endif() +if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) + list(APPEND sources + offline-speaker-diarization.cc + ) +endif() + add_library(sherpa-onnx-jni ${sources}) target_compile_definitions(sherpa-onnx-jni PRIVATE SHERPA_ONNX_BUILD_SHARED_LIBS=1) diff --git a/sherpa-onnx/jni/offline-speaker-diarization.cc b/sherpa-onnx/jni/offline-speaker-diarization.cc new file mode 100644 index 000000000..a0eef8b9c --- /dev/null +++ b/sherpa-onnx/jni/offline-speaker-diarization.cc @@ -0,0 +1,219 @@ +// sherpa-onnx/jni/offline-speaker-diarization.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speaker-diarization.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/jni/common.h" + +namespace sherpa_onnx { + +static OfflineSpeakerDiarizationConfig GetOfflineSpeakerDiarizationConfig( + JNIEnv *env, jobject config) { + OfflineSpeakerDiarizationConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + //---------- segmentation ---------- + fid = env->GetFieldID( + cls, "segmentation", + "Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationModelConfig;"); + jobject segmentation_config = env->GetObjectField(config, fid); + jclass segmentation_config_cls = env->GetObjectClass(segmentation_config); + + fid = env->GetFieldID( + segmentation_config_cls, "pyannote", + "Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationPyannoteModelConfig;"); + jobject pyannote_config = env->GetObjectField(segmentation_config, fid); + jclass pyannote_config_cls = env->GetObjectClass(pyannote_config); + + fid = env->GetFieldID(pyannote_config_cls, "model", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(pyannote_config, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.segmentation.pyannote.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(segmentation_config_cls, "numThreads", "I"); + ans.segmentation.num_threads = env->GetIntField(segmentation_config, fid); + + fid = env->GetFieldID(segmentation_config_cls, "debug", "Z"); + ans.segmentation.debug = env->GetBooleanField(segmentation_config, fid); + + fid = env->GetFieldID(segmentation_config_cls, "provider", + "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(segmentation_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.segmentation.provider = p; + env->ReleaseStringUTFChars(s, p); + + //---------- embedding ---------- + fid = env->GetFieldID( + cls, "embedding", + "Lcom/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractorConfig;"); + jobject embedding_config = env->GetObjectField(config, fid); + jclass embedding_config_cls = env->GetObjectClass(embedding_config); + + fid = env->GetFieldID(embedding_config_cls, "model", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(embedding_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.embedding.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(embedding_config_cls, "numThreads", "I"); + ans.embedding.num_threads = env->GetIntField(embedding_config, fid); + + fid = env->GetFieldID(embedding_config_cls, "debug", "Z"); + ans.embedding.debug = env->GetBooleanField(embedding_config, fid); + + fid = env->GetFieldID(embedding_config_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(embedding_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.embedding.provider = p; + env->ReleaseStringUTFChars(s, p); + + //---------- clustering ---------- + fid = env->GetFieldID(cls, "clustering", + "Lcom/k2fsa/sherpa/onnx/FastClusteringConfig;"); + jobject clustering_config = env->GetObjectField(config, fid); + jclass clustering_config_cls = env->GetObjectClass(clustering_config); + + fid = env->GetFieldID(clustering_config_cls, "numClusters", "I"); + ans.clustering.num_clusters = env->GetIntField(clustering_config, fid); + + fid = env->GetFieldID(clustering_config_cls, "threshold", "F"); + ans.clustering.threshold = env->GetFloatField(clustering_config, fid); + + // its own fields + fid = env->GetFieldID(cls, "minDurationOn", "F"); + ans.min_duration_on = env->GetFloatField(config, fid); + + fid = env->GetFieldID(cls, "minDurationOff", "F"); + ans.min_duration_off = env->GetFloatField(config, fid); + + return ans; +} + +} // namespace sherpa_onnx + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset( + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { + return 0; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto sd = new sherpa_onnx::OfflineSpeakerDiarization(config); + + return (jlong)sd; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_setConfig( + JNIEnv *env, jobject /*obj*/, jlong ptr, jobject _config) { + auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto sd = reinterpret_cast(ptr); + sd->SetConfig(config); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_delete(JNIEnv * /*env*/, + jobject /*obj*/, + jlong ptr) { + delete reinterpret_cast(ptr); +} + +static jobjectArray ProcessImpl( + JNIEnv *env, + const std::vector + &segments) { + jclass cls = + env->FindClass("com/k2fsa/sherpa/onnx/OfflineSpeakerDiarizationSegment"); + + jobjectArray obj_arr = + (jobjectArray)env->NewObjectArray(segments.size(), cls, nullptr); + + jmethodID constructor = env->GetMethodID(cls, "", "(FFI)V"); + + for (int32_t i = 0; i != segments.size(); ++i) { + const auto &s = segments[i]; + jobject segment = + env->NewObject(cls, constructor, s.Start(), s.End(), s.Speaker()); + env->SetObjectArrayElement(obj_arr, i, segment); + } + + return obj_arr; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_process( + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) { + auto sd = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + auto segments = sd->Process(p, n).SortByStartTime(); + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); + + return ProcessImpl(env, segments); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_processWithCallback( + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, + jobject callback, jlong arg) { + std::function callback_wrapper = + [env, callback](int32_t num_processed_chunks, int32_t num_total_chunks, + void *data) -> int { + jclass cls = env->GetObjectClass(callback); + + jmethodID mid = env->GetMethodID(cls, "invoke", "(IIJ)Ljava/lang/Integer;"); + if (mid == nullptr) { + SHERPA_ONNX_LOGE("Failed to get the callback. Ignore it."); + return 0; + } + + jobject ret = env->CallObjectMethod(callback, mid, num_processed_chunks, + num_total_chunks, (jlong)data); + jclass jklass = env->GetObjectClass(ret); + jmethodID int_value_mid = env->GetMethodID(jklass, "intValue", "()I"); + return env->CallIntMethod(ret, int_value_mid); + }; + + auto sd = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + auto segments = + sd->Process(p, n, callback_wrapper, (void *)arg).SortByStartTime(); + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); + + return ProcessImpl(env, segments); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_getSampleRate( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + return reinterpret_cast(ptr) + ->SampleRate(); +} diff --git a/sherpa-onnx/kotlin-api/OfflineSpeakerDiarization.kt b/sherpa-onnx/kotlin-api/OfflineSpeakerDiarization.kt new file mode 100644 index 000000000..de0a9dffd --- /dev/null +++ b/sherpa-onnx/kotlin-api/OfflineSpeakerDiarization.kt @@ -0,0 +1,101 @@ +package com.k2fsa.sherpa.onnx + +import android.content.res.AssetManager + +data class OfflineSpeakerSegmentationPyannoteModelConfig( + var model: String, +) + +data class OfflineSpeakerSegmentationModelConfig( + var pyannote: OfflineSpeakerSegmentationPyannoteModelConfig, + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) + +data class FastClusteringConfig( + var numClusters: Int = -1, + var threshold: Float = 0.5f, +) + +data class OfflineSpeakerDiarizationConfig( + var segmentation: OfflineSpeakerSegmentationModelConfig, + var embedding: SpeakerEmbeddingExtractorConfig, + var clustering: FastClusteringConfig, + var minDurationOn: Float = 0.2f, + var minDurationOff: Float = 0.5f, +) + +data class OfflineSpeakerDiarizationSegment( + val start: Float, // in seconds + val end: Float, // in seconds + val speaker: Int, // ID of the speaker; count from 0 +) + +class OfflineSpeakerDiarization( + assetManager: AssetManager? = null, + config: OfflineSpeakerDiarizationConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + // Only config.clustering is used. All other fields in config + // are ignored + fun setConfig(config: OfflineSpeakerDiarizationConfig) = setConfig(ptr, config) + + fun sampleRate() = getSampleRate(ptr) + + fun process(samples: FloatArray) = process(ptr, samples) + + fun processWithCallback( + samples: FloatArray, + callback: (numProcessedChunks: Int, numTotalChunks: Int, arg: Long) -> Int, + arg: Long = 0, + ) = processWithCallback(ptr, samples, callback, arg) + + private external fun delete(ptr: Long) + + private external fun newFromAsset( + assetManager: AssetManager, + config: OfflineSpeakerDiarizationConfig, + ): Long + + private external fun newFromFile( + config: OfflineSpeakerDiarizationConfig, + ): Long + + private external fun setConfig(ptr: Long, config: OfflineSpeakerDiarizationConfig) + + private external fun getSampleRate(ptr: Long): Int + + private external fun process(ptr: Long, samples: FloatArray): Array + + private external fun processWithCallback( + ptr: Long, + samples: FloatArray, + callback: (numProcessedChunks: Int, numTotalChunks: Int, arg: Long) -> Int, + arg: Long, + ): Array + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +}