Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement VAD on realtime transcription #129

Merged
merged 7 commits into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,27 @@ private void saveWavFile(byte[] rawData, String audioOutputFile) throws IOExcept
}
}

private boolean vad(ReadableMap options, short[] shortBuffer, int nSamples, int n) {
boolean isSpeech = true;
if (!isTranscribing && options.hasKey("useVad") && options.getBoolean("useVad")) {
int vadSec = options.hasKey("vadMs") ? options.getInt("vadMs") / 1000 : 2;
int sampleSize = vadSec * SAMPLE_RATE;
if (nSamples + n > sampleSize) {
int start = nSamples + n - sampleSize;
float[] audioData = new float[sampleSize];
for (int i = 0; i < sampleSize; i++) {
audioData[i] = shortBuffer[i + start] / 32768.0f;
}
float vadThold = options.hasKey("vadThold") ? (float) options.getDouble("vadThold") : 0.6f;
float vadFreqThold = options.hasKey("vadFreqThold") ? (float) options.getDouble("vadFreqThold") : 0.6f;
isSpeech = vadSimple(audioData, sampleSize, vadThold, vadFreqThold);
} else {
isSpeech = false;
}
}
return isSpeech;
}

public int startRealtimeTranscribe(int jobId, ReadableMap options) {
if (isCapturing || isTranscribing) {
return -100;
Expand Down Expand Up @@ -223,6 +244,12 @@ public void run() {
) {
emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap());
} else if (!isTranscribing) {
short[] shortBuffer = shortBufferSlices.get(sliceIndex);
boolean isSpeech = vad(options, shortBuffer, nSamples, 0);
if (!isSpeech) {
emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap());
break;
}
isTranscribing = true;
fullTranscribeSamples(options, true);
}
Expand All @@ -244,9 +271,14 @@ public void run() {
for (int i = 0; i < n; i++) {
shortBuffer[nSamples + i] = buffer[i];
}

boolean isSpeech = vad(options, shortBuffer, nSamples, n);

nSamples += n;
sliceNSamples.set(sliceIndex, nSamples);

if (!isSpeech) continue;

if (!isTranscribing && nSamples > SAMPLE_RATE / 2) {
isTranscribing = true;
fullHandler = new Thread(new Runnable() {
Expand Down Expand Up @@ -593,6 +625,7 @@ private static String cpuInfo() {
protected static native long initContext(String modelPath);
protected static native long initContextWithAsset(AssetManager assetManager, String modelPath);
protected static native long initContextWithInputStream(PushbackInputStream inputStream);
protected static native boolean vadSimple(float[] audio_data, int audio_data_len, float vad_thold, float vad_freq_thold);
protected static native int fullTranscribe(
int job_id,
long context,
Expand Down
22 changes: 22 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <sys/sysinfo.h>
#include <string>
#include <thread>
#include <vector>
#include "whisper.h"
#include "rn-whisper.h"
#include "ggml.h"
Expand Down Expand Up @@ -184,6 +185,27 @@ Java_com_rnwhisper_WhisperContext_initContextWithInputStream(
return reinterpret_cast<jlong>(context);
}

JNIEXPORT jboolean JNICALL
Java_com_rnwhisper_WhisperContext_vadSimple(
JNIEnv *env,
jobject thiz,
jfloatArray audio_data,
jint audio_data_len,
jfloat vad_thold,
jfloat vad_freq_thold
) {
UNUSED(thiz);

std::vector<float> samples(audio_data_len);
jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);
for (int i = 0; i < audio_data_len; i++) {
samples[i] = audio_data_arr[i];
}
bool is_speech = rn_whisper_vad_simple(samples, WHISPER_SAMPLE_RATE, 1000, vad_thold, vad_freq_thold, false);
env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
return is_speech;
}

struct progress_callback_context {
JNIEnv *env;
jobject progress_callback_instance;
Expand Down
51 changes: 51 additions & 0 deletions cpp/rn-whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,55 @@ void rn_whisper_abort_all_transcribe() {
}
}

void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);

float y = data[0];

for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}

bool rn_whisper_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;

if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}

if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}

float energy_all = 0.0f;
float energy_last = 0.0f;

for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);

if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}

energy_all /= n_samples;
energy_last /= n_samples_last;

if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}

if (energy_last > vad_thold*energy_all) {
return false;
}

return true;
}

}
3 changes: 2 additions & 1 deletion cpp/rn-whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ void rn_whisper_remove_abort_map(int job_id);
void rn_whisper_abort_transcribe(int job_id);
bool rn_whisper_transcribe_is_aborted(int job_id);
void rn_whisper_abort_all_transcribe();
bool rn_whisper_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose);

#ifdef __cplusplus
}
#endif
#endif
12 changes: 6 additions & 6 deletions example/ios/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -750,16 +750,16 @@ PODS:
- React-perflogger (= 0.71.11)
- RNFS (2.20.0):
- React-Core
- RNZipArchive (6.0.9):
- RNZipArchive (6.1.0):
- React-Core
- RNZipArchive/Core (= 6.0.9)
- RNZipArchive/Core (= 6.1.0)
- SSZipArchive (~> 2.2)
- RNZipArchive/Core (6.0.9):
- RNZipArchive/Core (6.1.0):
- React-Core
- SSZipArchive (~> 2.2)
- SocketRocket (0.6.0)
- SSZipArchive (2.4.3)
- whisper-rn (0.3.5):
- whisper-rn (0.3.6):
- RCT-Folly
- RCTRequired
- RCTTypeSafety
Expand Down Expand Up @@ -994,10 +994,10 @@ SPEC CHECKSUMS:
React-runtimeexecutor: 4817d63dbc9d658f8dc0ec56bd9b83ce531129f0
ReactCommon: 08723d2ed328c5cbcb0de168f231bc7bae7f8aa1
RNFS: 4ac0f0ea233904cb798630b3c077808c06931688
RNZipArchive: 68a0c6db4b1c103f846f1559622050df254a3ade
RNZipArchive: ef9451b849c45a29509bf44e65b788829ab07801
SocketRocket: fccef3f9c5cedea1353a9ef6ada904fde10d6608
SSZipArchive: fe6a26b2a54d5a0890f2567b5cc6de5caa600aef
whisper-rn: 6f293154b175fee138a994fa00d0f414fb1f44e9
whisper-rn: e80c0482f6a632faafd601f98f10da0255c1e1ec
Yoga: f7decafdc5e8c125e6fa0da38a687e35238420fa
YogaKit: f782866e155069a2cca2517aafea43200b01fd5a

Expand Down
2 changes: 2 additions & 0 deletions example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ export default function App() {
realtimeAudioSec: 60,
// Slice audio into 25 (or < 30) sec chunks for better performance
realtimeAudioSliceSec: 25,
// Voice Activity Detection - Start transcribing when speech is detected
// useVad: true,
})
setStopTranscribe({ stop })
subscribe((evt) => {
Expand Down
33 changes: 33 additions & 0 deletions ios/RNWhisperContext.mm
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#import "RNWhisperContext.h"
#include <vector>

#define NUM_BYTES_PER_BUFFER 16 * 1024

Expand Down Expand Up @@ -77,6 +78,29 @@ - (void)freeBufferIfNeeded {
}
}

bool vad(RNWhisperContextRecordState *state, int16_t* audioBufferI16, int nSamples, int n)
{
bool isSpeech = true;
if (!state->isTranscribing && state->options[@"useVad"]) {
int vadSec = state->options[@"vadMs"] != nil ? [state->options[@"vadMs"] intValue] / 1000 : 2;
int sampleSize = vadSec * WHISPER_SAMPLE_RATE;
if (nSamples + n > sampleSize) {
int start = nSamples + n - sampleSize;
std::vector<float> audioBufferF32Vec(sampleSize);
for (int i = 0; i < sampleSize; i++) {
audioBufferF32Vec[i] = (float)audioBufferI16[i + start] / 32768.0f;
}
float vadThold = state->options[@"vadThold"] != nil ? [state->options[@"vadThold"] floatValue] : 0.6f;
float vadFreqThold = state->options[@"vadFreqThold"] != nil ? [state->options[@"vadFreqThold"] floatValue] : 100.0f;
isSpeech = rn_whisper_vad_simple(audioBufferF32Vec, WHISPER_SAMPLE_RATE, 1000, vadThold, vadFreqThold, false);
NSLog(@"[RNWhisper] VAD result: %d", isSpeech);
} else {
isSpeech = false;
}
}
return isSpeech;
}

void AudioInputCallback(void * inUserData,
AudioQueueRef inAQ,
AudioQueueBufferRef inBuffer,
Expand Down Expand Up @@ -117,6 +141,11 @@ void AudioInputCallback(void * inUserData,
!state->isTranscribing &&
nSamples != state->nSamplesTranscribing
) {
int16_t* audioBufferI16 = (int16_t*) [state->shortBufferSlices[state->sliceIndex] pointerValue];
if (!vad(state, audioBufferI16, nSamples, 0)) {
state->transcribeHandler(state->jobId, @"end", @{});
return;
}
state->isTranscribing = true;
dispatch_async([state->mSelf getDispatchQueue], ^{
[state->mSelf fullTranscribeSamples:state];
Expand All @@ -142,11 +171,15 @@ void AudioInputCallback(void * inUserData,
for (int i = 0; i < n; i++) {
audioBufferI16[nSamples + i] = ((short*)inBuffer->mAudioData)[i];
}

bool isSpeech = vad(state, audioBufferI16, nSamples, n);
nSamples += n;
state->sliceNSamples[state->sliceIndex] = [NSNumber numberWithInt:nSamples];

AudioQueueEnqueueBuffer(state->queue, inBuffer, 0, NULL);

if (!isSpeech) return;

if (!state->isTranscribing) {
state->isTranscribing = true;
dispatch_async([state->mSelf getDispatchQueue], ^{
Expand Down
18 changes: 18 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ export type TranscribeRealtimeOptions = TranscribeOptions & {
* (Default: Undefined)
*/
audioOutputPath?: string
/**
* Start transcribe on recording when the audio volume is greater than the threshold by using VAD (Voice Activity Detection).
* The first VAD will be triggered after 2 second of recording.
* (Default: false)
*/
useVad?: boolean
/**
* The length of the collected audio is used for VAD. (ms) (Default: 2000)
*/
vadMs?: number
/**
* VAD threshold. (Default: 0.6)
*/
vadThold?: number
/**
* Frequency to apply High-pass filter in VAD. (Default: 100.0)
*/
vadFreqThold?: number
}

export type TranscribeRealtimeEvent = {
Expand Down