From 097d641869611ac23a6c45b7138562428e034a6b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 13 Nov 2023 12:07:51 +0800 Subject: [PATCH] Resize circular buffer on overflow (#422) --- sherpa-onnx/csrc/circular-buffer.cc | 76 ++++++++++++++++++++- sherpa-onnx/csrc/circular-buffer.h | 2 + sherpa-onnx/csrc/lexicon.cc | 3 +- sherpa-onnx/csrc/lexicon.h | 4 +- sherpa-onnx/csrc/voice-activity-detector.cc | 19 ++++-- 5 files changed, 94 insertions(+), 10 deletions(-) diff --git a/sherpa-onnx/csrc/circular-buffer.cc b/sherpa-onnx/csrc/circular-buffer.cc index b9702ace9..895b5cab1 100644 --- a/sherpa-onnx/csrc/circular-buffer.cc +++ b/sherpa-onnx/csrc/circular-buffer.cc @@ -19,13 +19,83 @@ CircularBuffer::CircularBuffer(int32_t capacity) { buffer_.resize(capacity); } +void CircularBuffer::Resize(int32_t new_capacity) { + int32_t capacity = buffer_.size(); + if (new_capacity <= capacity) { + SHERPA_ONNX_LOGE("new_capacity (%d) <= original capacity (%d). Skip it.", + new_capacity, capacity); + return; + } + + int32_t size = Size(); + if (size == 0) { + buffer_.resize(new_capacity); + return; + } + + std::vector new_buffer(new_capacity); + int32_t start = head_ % capacity; + int32_t dest = head_ % new_capacity; + + if (start + size <= capacity) { + if (dest + size <= new_capacity) { + std::copy(buffer_.begin() + start, buffer_.begin() + start + size, + new_buffer.begin() + dest); + } else { + int32_t part1_size = new_capacity - dest; + + // copy [start, start+part1_size] to new_buffer + std::copy(buffer_.begin() + start, buffer_.begin() + start + part1_size, + new_buffer.begin() + dest); + + // copy [start+part1_size, start+size] to new_buffer + std::copy(buffer_.begin() + start + part1_size, + buffer_.begin() + start + size, new_buffer.begin()); + } + } else { + int32_t part1_size = capacity - start; + int32_t part2_size = size - part1_size; + + // copy [start, start+part1_size] to new_buffer + if (dest + part1_size <= new_capacity) { + std::copy(buffer_.begin() + start, buffer_.begin() + start + part1_size, + new_buffer.begin() + dest); + } else { + int32_t first_part = new_capacity - dest; + int32_t second_part = part1_size - first_part; + std::copy(buffer_.begin() + start, buffer_.begin() + start + first_part, + new_buffer.begin() + dest); + + std::copy(buffer_.begin() + start + first_part, + buffer_.begin() + start + part1_size, new_buffer.begin()); + } + + int32_t new_dest = (dest + part1_size) % new_capacity; + + if (new_dest + part2_size <= new_capacity) { + std::copy(buffer_.begin(), buffer_.begin() + part2_size, + new_buffer.begin() + new_dest); + } else { + int32_t first_part = new_capacity - new_dest; + std::copy(buffer_.begin(), buffer_.begin() + first_part, + new_buffer.begin() + new_dest); + std::copy(buffer_.begin() + first_part, buffer_.begin() + part2_size, + new_buffer.begin()); + } + } + buffer_.swap(new_buffer); +} + void CircularBuffer::Push(const float *p, int32_t n) { int32_t capacity = buffer_.size(); int32_t size = Size(); if (n + size > capacity) { - SHERPA_ONNX_LOGE("Overflow! n: %d, size: %d, n+size: %d, capacity: %d", n, - size, n + size, capacity); - exit(-1); + int32_t new_capacity = std::max(capacity * 2, n + size); + SHERPA_ONNX_LOGE( + "Overflow! n: %d, size: %d, n+size: %d, capacity: %d. Increase " + "capacity to: %d", + n, size, n + size, capacity, new_capacity); + Resize(new_capacity); } int32_t start = tail_ % capacity; diff --git a/sherpa-onnx/csrc/circular-buffer.h b/sherpa-onnx/csrc/circular-buffer.h index 6b0419e39..0f668952d 100644 --- a/sherpa-onnx/csrc/circular-buffer.h +++ b/sherpa-onnx/csrc/circular-buffer.h @@ -47,6 +47,8 @@ class CircularBuffer { tail_ = 0; } + void Resize(int32_t new_capacity); + private: std::vector buffer_; diff --git a/sherpa-onnx/csrc/lexicon.cc b/sherpa-onnx/csrc/lexicon.cc index 5d9ad8ff1..2cf8f8167 100644 --- a/sherpa-onnx/csrc/lexicon.cc +++ b/sherpa-onnx/csrc/lexicon.cc @@ -17,7 +17,8 @@ #include "android/asset_manager_jni.h" #endif -#include +#include +#include // NOLINT #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" diff --git a/sherpa-onnx/csrc/lexicon.h b/sherpa-onnx/csrc/lexicon.h index 4cab72d0d..c79ac11a3 100644 --- a/sherpa-onnx/csrc/lexicon.h +++ b/sherpa-onnx/csrc/lexicon.h @@ -6,8 +6,8 @@ #define SHERPA_ONNX_CSRC_LEXICON_H_ #include -#include -#include +#include +#include // NOLINT #include #include #include diff --git a/sherpa-onnx/csrc/voice-activity-detector.cc b/sherpa-onnx/csrc/voice-activity-detector.cc index 86c0f7e47..af79db58f 100644 --- a/sherpa-onnx/csrc/voice-activity-detector.cc +++ b/sherpa-onnx/csrc/voice-activity-detector.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/csrc/voice-activity-detector.h" +#include #include #include @@ -30,7 +31,7 @@ class VoiceActivityDetector::Impl { void AcceptWaveform(const float *samples, int32_t n) { int32_t window_size = model_->WindowSize(); - // note n is usally window_size and there is no need to use + // note n is usually window_size and there is no need to use // an extra buffer here last_.insert(last_.end(), samples, samples + n); int32_t k = static_cast(last_.size()) / window_size; @@ -39,7 +40,7 @@ class VoiceActivityDetector::Impl { for (int32_t i = 0; i != k; ++i, p += window_size) { buffer_.Push(p, window_size); - is_speech = model_->IsSpeech(p, window_size); + is_speech = is_speech || model_->IsSpeech(p, window_size); } last_ = std::vector( @@ -48,8 +49,9 @@ class VoiceActivityDetector::Impl { if (is_speech) { if (start_ == -1) { // beginning of speech - start_ = buffer_.Tail() - 2 * model_->WindowSize() - - model_->MinSpeechDurationSamples(); + start_ = std::max(buffer_.Tail() - 2 * model_->WindowSize() - + model_->MinSpeechDurationSamples(), + buffer_.Head()); } } else { // non-speech @@ -68,6 +70,15 @@ class VoiceActivityDetector::Impl { buffer_.Pop(end - buffer_.Head()); } + if (start_ == -1) { + int32_t end = buffer_.Tail() - 2 * model_->WindowSize() - + model_->MinSpeechDurationSamples(); + int32_t n = std::max(0, end - buffer_.Head()); + if (n > 0) { + buffer_.Pop(n); + } + } + start_ = -1; } }