Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Java API for speaker diarization #1416

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/run-java-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ jobs:
make -j4
ls -lh lib

- name: Run java test (speaker diarization)
shell: bash
run: |
cd ./java-api-examples
./run-offline-speaker-diarization.sh
rm -rfv *.onnx *.wav sherpa-onnx-pyannote-*

- name: Run java test (kws)
shell: bash
run: |
Expand Down
99 changes: 99 additions & 0 deletions java-api-examples/OfflineSpeakerDiarizationDemo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright 2024 Xiaomi Corporation

// This file shows how to use sherpa-onnx Java API for speaker diarization,
import com.k2fsa.sherpa.onnx.*;

public class OfflineSpeakerDiarizationDemo {
public static void main(String[] args) {
/* Please use the following commands to download files used in this file
Step 1: Download a speaker segmentation model

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example

wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2

Step 2: Download a speaker embedding extractor model

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example

wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx

Step 3. Download test wave files

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example

wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav

Step 4. Run it
*/

String segmentationModel = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx";
String embeddingModel = "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx";
String waveFilename = "./0-four-speakers-zh.wav";

WaveReader reader = new WaveReader(waveFilename);

OfflineSpeakerSegmentationPyannoteModelConfig pyannote =
OfflineSpeakerSegmentationPyannoteModelConfig.builder().setModel(segmentationModel).build();

OfflineSpeakerSegmentationModelConfig segmentation =
OfflineSpeakerSegmentationModelConfig.builder()
.setPyannote(pyannote)
.setDebug(true)
.build();

SpeakerEmbeddingExtractorConfig embedding =
SpeakerEmbeddingExtractorConfig.builder().setModel(embeddingModel).setDebug(true).build();

// The test wave file ./0-four-speakers-zh.wav contains four speakers, so
// we use numClusters=4 here. If you don't know the number of speakers
// in the test wave file, please set the numClusters to -1 and provide
// threshold for clustering
FastClusteringConfig clustering =
FastClusteringConfig.builder()
.setNumClusters(4) // set it to -1 if you don't know the actual number
.setThreshold(0.5f)
.build();

OfflineSpeakerDiarizationConfig config =
OfflineSpeakerDiarizationConfig.builder()
.setSegmentation(segmentation)
.setEmbedding(embedding)
.setClustering(clustering)
.setMinDurationOn(0.2f)
.setMinDurationOff(0.5f)
.build();

OfflineSpeakerDiarization sd = new OfflineSpeakerDiarization(config);
if (sd.getSampleRate() != reader.getSampleRate()) {
System.out.printf(
"Expected sample rate: %d, given: %d\n", sd.getSampleRate(), reader.getSampleRate());
return;
}

// OfflineSpeakerDiarizationSegment[] segments = sd.process(reader.getSamples());
// without callback is also ok

// or you can use a callback to show the progress
OfflineSpeakerDiarizationSegment[] segments =
sd.processWithCallback(
reader.getSamples(),
(int numProcessedChunks, int numTotalChunks, long arg) -> {
float progress = 100.0f * numProcessedChunks / numTotalChunks;
System.out.printf("Progress: %.2f%%\n", progress);

return 0;
});

for (OfflineSpeakerDiarizationSegment s : segments) {
System.out.printf("%.3f -- %.3f speaker_%02d\n", s.getStart(), s.getEnd(), s.getSpeaker());
}

sd.release();
}
}
6 changes: 6 additions & 0 deletions java-api-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ This directory contains examples for the JAVA API of sherpa-onnx.

# Usage

## Non-streaming speaker diarization

```bash
./run-offline-speaker-diarization.sh
```

## Streaming Speech recognition

```
Expand Down
45 changes: 45 additions & 0 deletions java-api-examples/run-offline-speaker-diarization.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env bash

set -ex

if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
mkdir -p ../build
pushd ../build
cmake \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
..

make -j4
ls -lh lib
popd
fi

if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then
pushd ../sherpa-onnx/java-api
make
popd
fi

if [ ! -f ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
fi

if [ ! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
fi

if [ ! -f ./0-four-speakers-zh.wav ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
fi

java \
-Djava.library.path=$PWD/../build/lib \
-cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \
./OfflineSpeakerDiarizationDemo.java
9 changes: 9 additions & 0 deletions sherpa-onnx/java-api/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ java_files += KeywordSpotterConfig.java
java_files += KeywordSpotterResult.java
java_files += KeywordSpotter.java

java_files += OfflineSpeakerSegmentationPyannoteModelConfig.java
java_files += OfflineSpeakerSegmentationModelConfig.java
java_files += FastClusteringConfig.java
java_files += OfflineSpeakerDiarizationConfig.java
java_files += OfflineSpeakerDiarizationSegment.java
java_files += OfflineSpeakerDiarizationCallback.java
java_files += OfflineSpeakerDiarization.java


class_files := $(java_files:%.java=%.class)

java_files := $(addprefix src/$(package_dir)/,$(java_files))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright 2024 Xiaomi Corporation

package com.k2fsa.sherpa.onnx;

public class FastClusteringConfig {
private final int numClusters;
private final float threshold;

private FastClusteringConfig(Builder builder) {
this.numClusters = builder.numClusters;
this.threshold = builder.threshold;
}

public static Builder builder() {
return new Builder();
}

public int getNumClusters() {
return numClusters;
}

public float getThreshold() {
return threshold;
}

public static class Builder {
private int numClusters = -1;
private float threshold = 0.5f;

public FastClusteringConfig build() {
return new FastClusteringConfig(this);
}

public Builder setNumClusters(int numClusters) {
this.numClusters = numClusters;
return this;
}

public Builder setThreshold(float threshold) {
this.threshold = threshold;
return this;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2024 Xiaomi Corporation

package com.k2fsa.sherpa.onnx;

public class OfflineSpeakerDiarization {
static {
System.loadLibrary("sherpa-onnx-jni");
}

private long ptr = 0;

public OfflineSpeakerDiarization(OfflineSpeakerDiarizationConfig config) {
ptr = newFromFile(config);
}

public int getSampleRate() {
return getSampleRate(ptr);
}

// Only config.clustering is used. All other fields are ignored
public void setConfig(OfflineSpeakerDiarizationConfig config) {
setConfig(ptr, config);
}

public OfflineSpeakerDiarizationSegment[] process(float[] samples) {
return process(ptr, samples);
}

public OfflineSpeakerDiarizationSegment[] processWithCallback(float[] samples, OfflineSpeakerDiarizationCallback callback) {
return processWithCallback(ptr, samples, callback, 0);
}

public OfflineSpeakerDiarizationSegment[] processWithCallback(float[] samples, OfflineSpeakerDiarizationCallback callback, long arg) {
return processWithCallback(ptr, samples, callback, arg);
}

protected void finalize() throws Throwable {
release();
}

// You'd better call it manually if it is not used anymore
public void release() {
if (this.ptr == 0) {
return;
}
delete(this.ptr);
this.ptr = 0;
}

private native int getSampleRate(long ptr);

private native void delete(long ptr);

private native long newFromFile(OfflineSpeakerDiarizationConfig config);

private native void setConfig(long ptr, OfflineSpeakerDiarizationConfig config);

private native OfflineSpeakerDiarizationSegment[] process(long ptr, float[] samples);

private native OfflineSpeakerDiarizationSegment[] processWithCallback(long ptr, float[] samples, OfflineSpeakerDiarizationCallback callback, long arg);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Copyright 2024 Xiaomi Corporation

package com.k2fsa.sherpa.onnx;

@FunctionalInterface
public interface OfflineSpeakerDiarizationCallback {
Integer invoke(int numProcessedChunks, int numTotalCunks, long arg);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package com.k2fsa.sherpa.onnx;

public class OfflineSpeakerDiarizationConfig {
private final OfflineSpeakerSegmentationModelConfig segmentation;
private final SpeakerEmbeddingExtractorConfig embedding;
private final FastClusteringConfig clustering;
private final float minDurationOn;
private final float minDurationOff;

private OfflineSpeakerDiarizationConfig(Builder builder) {
this.segmentation = builder.segmentation;
this.embedding = builder.embedding;
this.clustering = builder.clustering;
this.minDurationOff = builder.minDurationOff;
this.minDurationOn = builder.minDurationOn;
}

public static Builder builder() {
return new Builder();
}

public OfflineSpeakerSegmentationModelConfig getSegmentation() {
return segmentation;
}

public SpeakerEmbeddingExtractorConfig getEmbedding() {
return embedding;
}

public FastClusteringConfig getClustering() {
return clustering;
}

public float getMinDurationOff() {
return minDurationOff;
}

public float getMinDurationOn() {
return minDurationOn;
}

public static class Builder {
private OfflineSpeakerSegmentationModelConfig segmentation = OfflineSpeakerSegmentationModelConfig.builder().build();
private SpeakerEmbeddingExtractorConfig embedding = SpeakerEmbeddingExtractorConfig.builder().build();
private FastClusteringConfig clustering = FastClusteringConfig.builder().build();
private float minDurationOn = 0.2f;
private float minDurationOff = 0.5f;

public OfflineSpeakerDiarizationConfig build() {
return new OfflineSpeakerDiarizationConfig(this);
}

public Builder setSegmentation(OfflineSpeakerSegmentationModelConfig segmentation) {
this.segmentation = segmentation;
return this;
}

public Builder setEmbedding(SpeakerEmbeddingExtractorConfig embedding) {
this.embedding = embedding;
return this;
}

public Builder setClustering(FastClusteringConfig clustering) {
this.clustering = clustering;
return this;
}

public Builder setMinDurationOff(float minDurationOff) {
this.minDurationOff = minDurationOff;
return this;
}

public Builder setMinDurationOn(float minDurationOn) {
this.minDurationOn = minDurationOn;
return this;
}
}

}
Loading