From 0b18ccfbb2ec51b3986c0b6bcef6ea9911ccc1e0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 30 Jan 2024 11:21:43 +0800 Subject: [PATCH] C++ API demo for speaker identification with portaudio. (#561) --- cmake/cmake_extension.py | 1 + setup.py | 1 + sherpa-onnx/csrc/CMakeLists.txt | 8 +- ...crophone-offline-speaker-identification.cc | 310 ++++++++++++++++++ .../csrc/speaker-embedding-extractor.cc | 2 +- 5 files changed, 320 insertions(+), 2 deletions(-) create mode 100644 sherpa-onnx/csrc/sherpa-onnx-microphone-offline-speaker-identification.cc diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py index e62706e61..638bcd414 100644 --- a/cmake/cmake_extension.py +++ b/cmake/cmake_extension.py @@ -155,6 +155,7 @@ def build_extension(self, ext: setuptools.extension.Extension): binaries += ["sherpa-onnx-offline"] binaries += ["sherpa-onnx-microphone"] binaries += ["sherpa-onnx-microphone-offline"] + binaries += ["sherpa-onnx-microphone-offline-speaker-identification"] binaries += ["sherpa-onnx-online-websocket-server"] binaries += ["sherpa-onnx-offline-websocket-server"] binaries += ["sherpa-onnx-online-websocket-client"] diff --git a/setup.py b/setup.py index 42ec86f14..cc8199d3c 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ def get_binaries_to_install(): binaries += ["sherpa-onnx-offline"] binaries += ["sherpa-onnx-microphone"] binaries += ["sherpa-onnx-microphone-offline"] + binaries += ["sherpa-onnx-microphone-offline-speaker-identification"] binaries += ["sherpa-onnx-online-websocket-server"] binaries += ["sherpa-onnx-offline-websocket-server"] binaries += ["sherpa-onnx-online-websocket-client"] diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 033174d93..766b4608e 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -287,6 +287,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO) microphone.cc ) + add_executable(sherpa-onnx-microphone-offline-speaker-identification + sherpa-onnx-microphone-offline-speaker-identification.cc + microphone.cc + ) + if(BUILD_SHARED_LIBS) set(PA_LIB portaudio) else() @@ -294,9 +299,10 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO) endif() set(exes - sherpa-onnx-offline-tts-play sherpa-onnx-microphone sherpa-onnx-microphone-offline + sherpa-onnx-microphone-offline-speaker-identification + sherpa-onnx-offline-tts-play sherpa-onnx-vad-microphone sherpa-onnx-vad-microphone-offline-asr ) diff --git a/sherpa-onnx/csrc/sherpa-onnx-microphone-offline-speaker-identification.cc b/sherpa-onnx/csrc/sherpa-onnx-microphone-offline-speaker-identification.cc new file mode 100644 index 000000000..f525f5223 --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-microphone-offline-speaker-identification.cc @@ -0,0 +1,310 @@ +// sherpa-onnx/csrc/sherpa-onnx-microphone-offline-speaker-identification.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include +#include +#include + +#include +#include +#include // NOLINT +#include +#include // NOLINT + +#include "portaudio.h" // NOLINT +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/microphone.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" +#include "sherpa-onnx/csrc/speaker-embedding-manager.h" +#include "sherpa-onnx/csrc/wave-reader.h" + +enum class State { + kIdle, + kRecording, + kComputing, +}; + +State state = State::kIdle; + +// true to stop the program and exit +bool stop = false; + +std::vector samples; +std::mutex samples_mutex; + +static void DetectKeyPress() { + SHERPA_ONNX_LOGE("\nPress Enter to start"); + int32_t key; + while (!stop && (key = getchar())) { + if (key != 0x0a) { + continue; + } + + switch (state) { + case State::kIdle: + SHERPA_ONNX_LOGE("\nStart recording. Press Enter to stop recording"); + state = State::kRecording; + { + std::lock_guard lock(samples_mutex); + samples.clear(); + } + break; + case State::kRecording: + SHERPA_ONNX_LOGE("\nStop recording. Computing ..."); + state = State::kComputing; + break; + case State::kComputing: + break; + } + } +} + +static int32_t RecordCallback(const void *input_buffer, + void * /*output_buffer*/, + unsigned long frames_per_buffer, // NOLINT + const PaStreamCallbackTimeInfo * /*time_info*/, + PaStreamCallbackFlags /*status_flags*/, + void *user_data) { + std::lock_guard lock(samples_mutex); + + auto p = reinterpret_cast(input_buffer); + samples.insert(samples.end(), p, p + frames_per_buffer); + + return stop ? paComplete : paContinue; +} + +static void Handler(int32_t sig) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Press Enter to exit\n"); +} + +static std::vector> ComputeEmbeddings( + const std::vector &filenames, + sherpa_onnx::SpeakerEmbeddingExtractor *extractor) { + std::vector> embedding_list; + embedding_list.reserve(filenames.size()); + + for (const auto &f : filenames) { + int32_t sampling_rate = -1; + + bool is_ok = false; + const std::vector samples = + sherpa_onnx::ReadWave(f, &sampling_rate, &is_ok); + + if (!is_ok) { + fprintf(stderr, "Failed to read %s\n", f.c_str()); + exit(-1); + } + + auto s = extractor->CreateStream(); + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + s->InputFinished(); + auto embedding = extractor->Compute(s.get()); + embedding_list.push_back(embedding); + } + return embedding_list; +} + +static std::unordered_map> +ReadSpeakerFile(const std::string &filename) { + std::unordered_map> ans; + + std::ifstream is(filename); + if (!is) { + fprintf(stderr, "Failed to open %s", filename.c_str()); + exit(0); + } + + std::string line; + std::string name; + std::string path; + + while (std::getline(is, line)) { + std::istringstream iss(line); + name.clear(); + path.clear(); + + iss >> name >> path; + if (!iss || !iss.eof() || name.empty() || path.empty()) { + fprintf(stderr, "Invalid line: %s\n", line.c_str()); + exit(-1); + } + ans[name].push_back(path); + } + + return ans; +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +This program shows how to use non-streaming speaker identification. +Usage: + +(1) Prepare a text file containing speaker related files. + +Each line in the text file contains two columns. The first column is the +speaker name, while the second column contains the wave file of the speaker. + +If the text file contains multiple wave files for the same speaker, then the +embeddings of these files are averaged. + +An example text file is given below: + + foo /path/to/a.wav + bar /path/to/b.wav + foo /path/to/c.wav + foobar /path/to/d.wav + +Each wave file should contain only a single channel; the sample format +should be int16_t; the sample rate can be arbitrary. + +(2) Download a model for computing speaker embeddings + +Please visit +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +to download a model. An example is given below: + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/wespeaker_zh_cnceleb_resnet34.onnx + +Note that `zh` means Chinese, while `en` means English. + +(3) Run it ! + + ./bin/sherpa-onnx-microphone-offline-speaker-identification \ + --model=/path/to/your-model.onnx \ + --speaker-file=/path/to/speaker.txt +)usage"; + + sherpa_onnx::ParseOptions po(kUsageMessage); + float threshold = 0.5; + std::string speaker_file; + + po.Register("threshold", &threshold, + "Threshold for comparing embedding scores."); + + po.Register("speaker-file", &speaker_file, "Path to speaker.txt"); + + sherpa_onnx::SpeakerEmbeddingExtractorConfig config; + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 0) { + fprintf(stderr, + "This program does not support any positional arguments.\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config! Please use --help to view the usage.\n"); + return -1; + } + + SHERPA_ONNX_LOGE("\nCreating extractor ..."); + sherpa_onnx::SpeakerEmbeddingExtractor extractor(config); + SHERPA_ONNX_LOGE("\nextractor created!"); + + sherpa_onnx::SpeakerEmbeddingManager manager(extractor.Dim()); + + auto name2files = ReadSpeakerFile(speaker_file); + for (const auto &p : name2files) { + SHERPA_ONNX_LOGE("\nProcessing speaker %s", p.first.c_str()); + auto embedding_list = ComputeEmbeddings(p.second, &extractor); + manager.Add(p.first, embedding_list); + } + + sherpa_onnx::Microphone mic; + + PaDeviceIndex num_devices = Pa_GetDeviceCount(); + fprintf(stderr, "Num devices: %d\n", num_devices); + + PaStreamParameters param; + + param.device = Pa_GetDefaultInputDevice(); + if (param.device == paNoDevice) { + fprintf(stderr, "No default input device found\n"); + exit(EXIT_FAILURE); + } + fprintf(stderr, "Use default device: %d\n", param.device); + + const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device); + fprintf(stderr, " Name: %s\n", info->name); + fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels); + + param.channelCount = 1; + param.sampleFormat = paFloat32; + + param.suggestedLatency = info->defaultLowInputLatency; + param.hostApiSpecificStreamInfo = nullptr; + float sample_rate = 16000; + + PaStream *stream; + PaError err = + Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */ + sample_rate, + 0, // frames per buffer + paClipOff, // we won't output out of range samples + // so don't bother clipping them + RecordCallback, nullptr); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + err = Pa_StartStream(stream); + fprintf(stderr, "Started\n"); + + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + std::thread t(DetectKeyPress); + while (!stop) { + switch (state) { + case State::kIdle: + break; + case State::kRecording: + break; + case State::kComputing: { + std::vector buf; + { + std::lock_guard lock(samples_mutex); + buf = std::move(samples); + } + + auto s = extractor.CreateStream(); + s->AcceptWaveform(sample_rate, buf.data(), buf.size()); + s->InputFinished(); + auto embedding = extractor.Compute(s.get()); + auto name = manager.Search(embedding.data(), threshold); + + if (name.empty()) { + name = "--Unknown--"; + } + + SHERPA_ONNX_LOGE("\nDone!\nDetected speaker is: %s", name.c_str()); + + state = State::kIdle; + SHERPA_ONNX_LOGE("\nPress Enter to start"); + break; + } + } + + Pa_Sleep(20); // sleep for 20ms + } + t.join(); + + err = Pa_CloseStream(stream); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + return 0; +} diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor.cc b/sherpa-onnx/csrc/speaker-embedding-extractor.cc index e4e04f621..35bfc297b 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor.cc @@ -26,7 +26,7 @@ void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) { bool SpeakerEmbeddingExtractorConfig::Validate() const { if (model.empty()) { - SHERPA_ONNX_LOGE("Please provide --speaker-embedding-model"); + SHERPA_ONNX_LOGE("Please provide --model"); return false; }