Skip to content

Commit

Permalink
C++ API for speaker diarization (#1396)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 9, 2024
1 parent 70165cb commit 59407ed
Show file tree
Hide file tree
Showing 39 changed files with 1,652 additions and 108 deletions.
41 changes: 41 additions & 0 deletions .github/scripts/test-speaker-diarization.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env bash

set -ex

log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

echo "EXE is $EXE"
echo "PATH: $PATH"

which $EXE

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav

log "specify number of clusters"
$EXE \
--clustering.num-clusters=4 \
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
./0-four-speakers-zh.wav

log "specify threshold for clustering"

$EXE \
--clustering.cluster-threshold=0.90 \
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
./0-four-speakers-zh.wav

rm -rf sherpa-onnx-pyannote-*
rm -fv *.onnx
rm -fv *.wav
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Install pyannote
shell: bash
run: |
pip install pyannote.audio onnx onnxruntime
pip install pyannote.audio onnx==1.15.0 onnxruntime==1.16.3
- name: Run
shell: bash
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand All @@ -38,6 +39,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand Down Expand Up @@ -143,6 +145,15 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/*

- name: Test offline speaker diarization
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization
.github/scripts/test-speaker-diarization.sh
- name: Test offline transducer
shell: bash
run: |
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/macos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand All @@ -37,6 +38,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand Down Expand Up @@ -115,6 +117,15 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test offline speaker diarization
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization
.github/scripts/test-speaker-diarization.sh
- name: Test offline transducer
shell: bash
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/speaker-diarization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin
test_wavs=(
0-two-speakers-zh.wav
0-four-speakers-zh.wav
1-two-speakers-en.wav
2-two-speakers-en.wav
3-two-speakers-en.wav
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/windows-x64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand All @@ -34,6 +35,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand Down Expand Up @@ -87,6 +89,15 @@ jobs:
name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path: build/install/*

- name: Test offline speaker diarization
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization.exe
.github/scripts/test-speaker-diarization.sh
- name: Test online punctuation
shell: bash
run: |
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/windows-x86.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand All @@ -34,6 +35,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand Down Expand Up @@ -87,6 +89,15 @@ jobs:
name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path: build/install/*

- name: Test offline speaker diarization
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization.exe
.github/scripts/test-speaker-diarization.sh
- name: Test online punctuation
shell: bash
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf(stderr, "Memory error\n");
return -1;
}
size_t read_bytes = fread(*buffer_out, 1, size, file);
size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
if (read_bytes != size) {
printf("Errors occured in reading the file %s\n", filename);
free((void *)*buffer_out);
Expand Down
2 changes: 1 addition & 1 deletion c-api-examples/streaming-ctc-buffered-tokens-c-api.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf(stderr, "Memory error\n");
return -1;
}
size_t read_bytes = fread(*buffer_out, 1, size, file);
size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
if (read_bytes != size) {
printf("Errors occured in reading the file %s\n", filename);
free((void *)*buffer_out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf(stderr, "Memory error\n");
return -1;
}
size_t read_bytes = fread(*buffer_out, 1, size, file);
size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
if (read_bytes != size) {
printf("Errors occured in reading the file %s\n", filename);
free((void *)*buffer_out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf(stderr, "Memory error\n");
return -1;
}
size_t read_bytes = fread(*buffer_out, 1, size, file);
size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
if (read_bytes != size) {
printf("Errors occured in reading the file %s\n", filename);
free((void *)*buffer_out);
Expand Down
1 change: 1 addition & 0 deletions cmake/cmake_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def get_binaries():
"sherpa-onnx-offline-audio-tagging",
"sherpa-onnx-offline-language-identification",
"sherpa-onnx-offline-punctuation",
"sherpa-onnx-offline-speaker-diarization",
"sherpa-onnx-offline-tts",
"sherpa-onnx-offline-tts-play",
"sherpa-onnx-offline-websocket-server",
Expand Down
9 changes: 3 additions & 6 deletions scripts/pyannote/segmentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
Please download test wave files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models

## 0-two-speakers-zh.wav
## 0-four-speakers-zh.wav

This file is from
https://www.modelscope.cn/models/iic/speech_campplus_speaker-diarization_common/file/view/master?fileName=examples%252F2speakers_example.wav&status=0

Note that we have renamed it from `2speakers_example.wav` to `0-two-speakers-zh.wav`.
It is recorded by @csukuangfj

## 1-two-speakers-en.wav

Expand Down Expand Up @@ -40,5 +37,5 @@ commands to convert it to `3-two-speakers-en.wav`


```bash
sox ML16091-Audio.mp3 3-two-speakers-en.wav
sox ML16091-Audio.mp3 -r 16k 3-two-speakers-en.wav
```
2 changes: 1 addition & 1 deletion scripts/pyannote/segmentation/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main():
model.receptive_field.duration * 16000
)

opset_version = 18
opset_version = 13

filename = "model.onnx"
torch.onnx.export(
Expand Down
16 changes: 16 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND sources
fast-clustering-config.cc
fast-clustering.cc
offline-speaker-diarization-impl.cc
offline-speaker-diarization-result.cc
offline-speaker-diarization.cc
offline-speaker-segmentation-model-config.cc
offline-speaker-segmentation-pyannote-model-config.cc
offline-speaker-segmentation-pyannote-model.cc
)
endif()

Expand Down Expand Up @@ -260,6 +266,10 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
endif()

if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
add_executable(sherpa-onnx-offline-speaker-diarization sherpa-onnx-offline-speaker-diarization.cc)
endif()

set(main_exes
sherpa-onnx
sherpa-onnx-keyword-spotter
Expand All @@ -276,6 +286,12 @@ if(SHERPA_ONNX_ENABLE_BINARY)
)
endif()

if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND main_exes
sherpa-onnx-offline-speaker-diarization
)
endif()

foreach(exe IN LISTS main_exes)
target_link_libraries(${exe} sherpa-onnx-core)
endforeach()
Expand Down
22 changes: 10 additions & 12 deletions sherpa-onnx/csrc/fast-clustering-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,16 @@ std::string FastClusteringConfig::ToString() const {
}

void FastClusteringConfig::Register(ParseOptions *po) {
std::string prefix = "ctc";
ParseOptions p(prefix, po);

p.Register("num-clusters", &num_clusters,
"Number of cluster. If greater than 0, then --cluster-thresold is "
"ignored. Please provide it if you know the actual number of "
"clusters in advance.");

p.Register("cluster-threshold", &threshold,
"If --num-clusters is not specified, then it specifies the "
"distance threshold for clustering. smaller value -> more "
"clusters. larger value -> fewer clusters");
po->Register(
"num-clusters", &num_clusters,
"Number of cluster. If greater than 0, then cluster threshold is "
"ignored. Please provide it if you know the actual number of "
"clusters in advance.");

po->Register("cluster-threshold", &threshold,
"If num_clusters is not specified, then it specifies the "
"distance threshold for clustering. smaller value -> more "
"clusters. larger value -> fewer clusters");
}

bool FastClusteringConfig::Validate() const {
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifndef SHERPA_ONNX_CSRC_MACROS_H_
#define SHERPA_ONNX_CSRC_MACROS_H_
#include <stdio.h>
#include <stdlib.h>

#if __ANDROID_API__ >= 8
#include "android/log.h"
Expand Down Expand Up @@ -169,4 +170,6 @@
} \
} while (0)

#define SHERPA_ONNX_EXIT(code) exit(code)

#endif // SHERPA_ONNX_CSRC_MACROS_H_
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/offline-sense-voice-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <utility>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"

Expand Down
26 changes: 26 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"

#include <memory>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h"

namespace sherpa_onnx {

std::unique_ptr<OfflineSpeakerDiarizationImpl>
OfflineSpeakerDiarizationImpl::Create(
const OfflineSpeakerDiarizationConfig &config) {
if (!config.segmentation.pyannote.model.empty()) {
return std::make_unique<OfflineSpeakerDiarizationPyannoteImpl>(config);
}

SHERPA_ONNX_LOGE("Please specify a speaker segmentation model.");

return nullptr;
}

} // namespace sherpa_onnx
Loading

0 comments on commit 59407ed

Please sign in to comment.