Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C++ API for speaker diarization #1396

Merged
merged 17 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions .github/scripts/test-speaker-diarization.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env bash

set -ex

log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

echo "EXE is $EXE"
echo "PATH: $PATH"

which $EXE

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav

log "specify number of clusters"
$EXE \
--clustering.num-clusters=4 \
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
./0-four-speakers-zh.wav

log "specify threshold for clustering"

$EXE \
--clustering.cluster-threshold=0.90 \
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
./0-four-speakers-zh.wav

rm -rf sherpa-onnx-pyannote-*
rm -fv *.onnx
rm -fv *.wav
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Install pyannote
shell: bash
run: |
pip install pyannote.audio onnx onnxruntime
pip install pyannote.audio onnx==1.15.0 onnxruntime==1.16.3

- name: Run
shell: bash
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand All @@ -38,6 +39,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand Down Expand Up @@ -143,6 +145,15 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/*

- name: Test offline speaker diarization
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization

.github/scripts/test-speaker-diarization.sh

- name: Test offline transducer
shell: bash
run: |
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/macos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand All @@ -37,6 +38,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand Down Expand Up @@ -115,6 +117,15 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx

- name: Test offline speaker diarization
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization

.github/scripts/test-speaker-diarization.sh

- name: Test offline transducer
shell: bash
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/speaker-diarization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin

test_wavs=(
0-two-speakers-zh.wav
0-four-speakers-zh.wav
1-two-speakers-en.wav
2-two-speakers-en.wav
3-two-speakers-en.wav
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/windows-x64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand All @@ -34,6 +35,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand Down Expand Up @@ -87,6 +89,15 @@ jobs:
name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path: build/install/*

- name: Test offline speaker diarization
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization.exe

.github/scripts/test-speaker-diarization.sh

- name: Test online punctuation
shell: bash
run: |
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/windows-x86.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand All @@ -34,6 +35,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand Down Expand Up @@ -87,6 +89,15 @@ jobs:
name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path: build/install/*

- name: Test offline speaker diarization
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization.exe

.github/scripts/test-speaker-diarization.sh

- name: Test online punctuation
shell: bash
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf(stderr, "Memory error\n");
return -1;
}
size_t read_bytes = fread(*buffer_out, 1, size, file);
size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
if (read_bytes != size) {
printf("Errors occured in reading the file %s\n", filename);
free((void *)*buffer_out);
Expand Down
2 changes: 1 addition & 1 deletion c-api-examples/streaming-ctc-buffered-tokens-c-api.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf(stderr, "Memory error\n");
return -1;
}
size_t read_bytes = fread(*buffer_out, 1, size, file);
size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
if (read_bytes != size) {
printf("Errors occured in reading the file %s\n", filename);
free((void *)*buffer_out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf(stderr, "Memory error\n");
return -1;
}
size_t read_bytes = fread(*buffer_out, 1, size, file);
size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
if (read_bytes != size) {
printf("Errors occured in reading the file %s\n", filename);
free((void *)*buffer_out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf(stderr, "Memory error\n");
return -1;
}
size_t read_bytes = fread(*buffer_out, 1, size, file);
size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
if (read_bytes != size) {
printf("Errors occured in reading the file %s\n", filename);
free((void *)*buffer_out);
Expand Down
1 change: 1 addition & 0 deletions cmake/cmake_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def get_binaries():
"sherpa-onnx-offline-audio-tagging",
"sherpa-onnx-offline-language-identification",
"sherpa-onnx-offline-punctuation",
"sherpa-onnx-offline-speaker-diarization",
"sherpa-onnx-offline-tts",
"sherpa-onnx-offline-tts-play",
"sherpa-onnx-offline-websocket-server",
Expand Down
9 changes: 3 additions & 6 deletions scripts/pyannote/segmentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
Please download test wave files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models

## 0-two-speakers-zh.wav
## 0-four-speakers-zh.wav

This file is from
https://www.modelscope.cn/models/iic/speech_campplus_speaker-diarization_common/file/view/master?fileName=examples%252F2speakers_example.wav&status=0

Note that we have renamed it from `2speakers_example.wav` to `0-two-speakers-zh.wav`.
It is recorded by @csukuangfj

## 1-two-speakers-en.wav

Expand Down Expand Up @@ -40,5 +37,5 @@ commands to convert it to `3-two-speakers-en.wav`


```bash
sox ML16091-Audio.mp3 3-two-speakers-en.wav
sox ML16091-Audio.mp3 -r 16k 3-two-speakers-en.wav
```
2 changes: 1 addition & 1 deletion scripts/pyannote/segmentation/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main():
model.receptive_field.duration * 16000
)

opset_version = 18
opset_version = 13

filename = "model.onnx"
torch.onnx.export(
Expand Down
16 changes: 16 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND sources
fast-clustering-config.cc
fast-clustering.cc
offline-speaker-diarization-impl.cc
offline-speaker-diarization-result.cc
offline-speaker-diarization.cc
offline-speaker-segmentation-model-config.cc
offline-speaker-segmentation-pyannote-model-config.cc
offline-speaker-segmentation-pyannote-model.cc
)
endif()

Expand Down Expand Up @@ -260,6 +266,10 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
endif()

if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
add_executable(sherpa-onnx-offline-speaker-diarization sherpa-onnx-offline-speaker-diarization.cc)
endif()

set(main_exes
sherpa-onnx
sherpa-onnx-keyword-spotter
Expand All @@ -276,6 +286,12 @@ if(SHERPA_ONNX_ENABLE_BINARY)
)
endif()

if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND main_exes
sherpa-onnx-offline-speaker-diarization
)
endif()

foreach(exe IN LISTS main_exes)
target_link_libraries(${exe} sherpa-onnx-core)
endforeach()
Expand Down
22 changes: 10 additions & 12 deletions sherpa-onnx/csrc/fast-clustering-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,16 @@ std::string FastClusteringConfig::ToString() const {
}

void FastClusteringConfig::Register(ParseOptions *po) {
std::string prefix = "ctc";
ParseOptions p(prefix, po);

p.Register("num-clusters", &num_clusters,
"Number of cluster. If greater than 0, then --cluster-thresold is "
"ignored. Please provide it if you know the actual number of "
"clusters in advance.");

p.Register("cluster-threshold", &threshold,
"If --num-clusters is not specified, then it specifies the "
"distance threshold for clustering. smaller value -> more "
"clusters. larger value -> fewer clusters");
po->Register(
"num-clusters", &num_clusters,
"Number of cluster. If greater than 0, then cluster threshold is "
"ignored. Please provide it if you know the actual number of "
"clusters in advance.");

po->Register("cluster-threshold", &threshold,
"If num_clusters is not specified, then it specifies the "
"distance threshold for clustering. smaller value -> more "
"clusters. larger value -> fewer clusters");
}

bool FastClusteringConfig::Validate() const {
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifndef SHERPA_ONNX_CSRC_MACROS_H_
#define SHERPA_ONNX_CSRC_MACROS_H_
#include <stdio.h>
#include <stdlib.h>

#if __ANDROID_API__ >= 8
#include "android/log.h"
Expand Down Expand Up @@ -169,4 +170,6 @@
} \
} while (0)

#define SHERPA_ONNX_EXIT(code) exit(code)

#endif // SHERPA_ONNX_CSRC_MACROS_H_
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/offline-sense-voice-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <utility>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"

Expand Down
26 changes: 26 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"

#include <memory>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h"

namespace sherpa_onnx {

std::unique_ptr<OfflineSpeakerDiarizationImpl>
OfflineSpeakerDiarizationImpl::Create(
const OfflineSpeakerDiarizationConfig &config) {
if (!config.segmentation.pyannote.model.empty()) {
return std::make_unique<OfflineSpeakerDiarizationPyannoteImpl>(config);
}

SHERPA_ONNX_LOGE("Please specify a speaker segmentation model.");

return nullptr;
}

} // namespace sherpa_onnx
Loading
Loading