Skip to content

Commit

Permalink
Add vad clear api for better performance
Browse files Browse the repository at this point in the history
  • Loading branch information
yujinqiu committed Oct 16, 2023
1 parent 0df0a73 commit 6f10d1f
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ class MainActivity : AppCompatActivity() {
val samples = FloatArray(ret) { buffer[it] / 32768.0f }

vad.acceptWaveform(samples)
while(!vad.empty()) {vad.pop();}

val isSpeechDetected = vad.isSpeechDetected()
vad.clear()

runOnUiThread {
onVad(isSpeechDetected)
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,11 @@ SHERPA_ONNX_API void SherpaOnnxVoiceActivityDetectorPop(
p->impl->Pop();
}

SHERPA_ONNX_API void SherpaOnnxVoiceActivityDetectorClear(
SherpaOnnxVoiceActivityDetector *p) {
p->impl->Clear();
}

SHERPA_ONNX_API const SherpaOnnxSpeechSegment *
SherpaOnnxVoiceActivityDetectorFront(SherpaOnnxVoiceActivityDetector *p) {
const sherpa_onnx::SpeechSegment &segment = p->impl->Front();
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,10 @@ SherpaOnnxVoiceActivityDetectorDetected(SherpaOnnxVoiceActivityDetector *p);
SHERPA_ONNX_API void SherpaOnnxVoiceActivityDetectorPop(
SherpaOnnxVoiceActivityDetector *p);

// Clear current speech segments.
SHERPA_ONNX_API void SherpaOnnxVoiceActivityDetectorClear(
SherpaOnnxVoiceActivityDetector *p);

// Return the first speech segment.
// The user has to use SherpaOnnxDestroySpeechSegment() to free the returned
// pointer to avoid memory leak.
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/voice-activity-detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class VoiceActivityDetector::Impl {

void Pop() { segments_.pop(); }

void Clear() { std::queue<SpeechSegment>().swap(segments_); }

const SpeechSegment &Front() const { return segments_.front(); }

void Reset() {
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/voice-activity-detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class VoiceActivityDetector {
void AcceptWaveform(const float *samples, int32_t n);
bool Empty() const;
void Pop();
void Clear();
const SpeechSegment &Front() const;

bool IsSpeechDetected() const;
Expand Down
10 changes: 10 additions & 0 deletions sherpa-onnx/jni/jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class SherpaOnnxVad {

void Pop() { vad_.Pop(); }

void Clear() { vad_.Clear();}

const SpeechSegment &Front() const { return vad_.Front(); }

bool IsSpeechDetected() const { return vad_.IsSpeechDetected(); }
Expand Down Expand Up @@ -556,6 +558,14 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_pop(JNIEnv *env,
model->Pop();
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_clear(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
model->Clear();
}

// see
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
static jobject NewInteger(JNIEnv *env, int32_t value) {
Expand Down
4 changes: 4 additions & 0 deletions swift-api-examples/SherpaOnnx.swift
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,10 @@ class SherpaOnnxVoiceActivityDetectorWrapper {
SherpaOnnxVoiceActivityDetectorPop(vad)
}

func clear() {
SherpaOnnxVoiceActivityDetectorClear(vad)
}

func front() -> SherpaOnnxSpeechSegmentWrapper {
let p: UnsafePointer<SherpaOnnxSpeechSegment>? = SherpaOnnxVoiceActivityDetectorFront(vad)
return SherpaOnnxSpeechSegmentWrapper(p: p)
Expand Down
35 changes: 17 additions & 18 deletions swift-api-examples/generate-subtitles.swift
Original file line number Diff line number Diff line change
Expand Up @@ -174,32 +174,31 @@ func run() {

var segments: [SpeechSegment] = []

while array.count > windowSize {
// todo(fangjun): avoid extra copies here
vad.acceptWaveform(samples: [Float](array[0..<windowSize]))
array = [Float](array[windowSize..<array.count])

while !vad.isEmpty() {
let s = vad.front()
vad.pop()
let result = recognizer.decode(samples: s.samples)
for offset in stride(from: 0, to: array.count, by: windowSize) {
let end = min(offset + windowSize, array.count)
vad.acceptWaveform(samples: [Float](array[offset ..< end]))
}

segments.append(
SpeechSegment(
start: Float(s.start) / Float(sampleRate),
duration: Float(s.samples.count) / Float(sampleRate),
text: result.text))
var index: Int = 0
while !vad.isEmpty() {
let s = vad.front()
vad.pop()
let result = recognizer.decode(samples: s.samples)

print(segments.last!)
segments.append(
SpeechSegment(
start: Float(s.start) / Float(sampleRate),
duration: Float(s.samples.count) / Float(sampleRate),
text: result.text))

}
print(segments.last!)
}

let srt = zip(segments.indices, segments).map { (index, element) in
let srt: String = zip(segments.indices, segments).map { (index, element) in
return "\(index+1)\n\(element)"
}.joined(separator: "\n\n")

let srtFilename = filePath.stringByDeletingPathExtension + ".srt"
let srtFilename: String = filePath.stringByDeletingPathExtension + ".srt"
do {
try srt.write(to: srtFilename.fileURL, atomically: true, encoding: .utf8)
} catch {
Expand Down

0 comments on commit 6f10d1f

Please sign in to comment.