Skip to content

Commit

Permalink
feat(ios): initial work for simple VAD
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Sep 21, 2023
1 parent 9a1e026 commit cdc7e44
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 7 deletions.
52 changes: 52 additions & 0 deletions cpp/rn-whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,56 @@ void rn_whisper_abort_all_transcribe() {
}
}

void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);

float y = data[0];

for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}

bool rn_whisper_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;

if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
printf("not enough samples - assume no speech\n");
return false;
}

if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}

float energy_all = 0.0f;
float energy_last = 0.0f;

for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);

if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}

energy_all /= n_samples;
energy_last /= n_samples_last;

if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}

if (energy_last > vad_thold*energy_all) {
return false;
}

return true;
}

}
3 changes: 2 additions & 1 deletion cpp/rn-whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ void rn_whisper_remove_abort_map(int job_id);
void rn_whisper_abort_transcribe(int job_id);
bool rn_whisper_transcribe_is_aborted(int job_id);
void rn_whisper_abort_all_transcribe();
bool rn_whisper_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose);

#ifdef __cplusplus
}
#endif
#endif
12 changes: 6 additions & 6 deletions example/ios/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -750,16 +750,16 @@ PODS:
- React-perflogger (= 0.71.11)
- RNFS (2.20.0):
- React-Core
- RNZipArchive (6.0.9):
- RNZipArchive (6.1.0):
- React-Core
- RNZipArchive/Core (= 6.0.9)
- RNZipArchive/Core (= 6.1.0)
- SSZipArchive (~> 2.2)
- RNZipArchive/Core (6.0.9):
- RNZipArchive/Core (6.1.0):
- React-Core
- SSZipArchive (~> 2.2)
- SocketRocket (0.6.0)
- SSZipArchive (2.4.3)
- whisper-rn (0.3.5):
- whisper-rn (0.3.6):
- RCT-Folly
- RCTRequired
- RCTTypeSafety
Expand Down Expand Up @@ -994,10 +994,10 @@ SPEC CHECKSUMS:
React-runtimeexecutor: 4817d63dbc9d658f8dc0ec56bd9b83ce531129f0
ReactCommon: 08723d2ed328c5cbcb0de168f231bc7bae7f8aa1
RNFS: 4ac0f0ea233904cb798630b3c077808c06931688
RNZipArchive: 68a0c6db4b1c103f846f1559622050df254a3ade
RNZipArchive: ef9451b849c45a29509bf44e65b788829ab07801
SocketRocket: fccef3f9c5cedea1353a9ef6ada904fde10d6608
SSZipArchive: fe6a26b2a54d5a0890f2567b5cc6de5caa600aef
whisper-rn: 6f293154b175fee138a994fa00d0f414fb1f44e9
whisper-rn: e80c0482f6a632faafd601f98f10da0255c1e1ec
Yoga: f7decafdc5e8c125e6fa0da38a687e35238420fa
YogaKit: f782866e155069a2cca2517aafea43200b01fd5a

Expand Down
2 changes: 2 additions & 0 deletions example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ export default function App() {
realtimeAudioSec: 60,
// Slice audio into 25 (or < 30) sec chunks for better performance
realtimeAudioSliceSec: 25,
//
useVad: true,
})
setStopTranscribe({ stop })
subscribe((evt) => {
Expand Down
18 changes: 18 additions & 0 deletions ios/RNWhisperContext.mm
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#import "RNWhisperContext.h"
#include <vector>

#define NUM_BYTES_PER_BUFFER 16 * 1024

Expand Down Expand Up @@ -142,11 +143,28 @@ void AudioInputCallback(void * inUserData,
for (int i = 0; i < n; i++) {
audioBufferI16[nSamples + i] = ((short*)inBuffer->mAudioData)[i];
}

bool isSpeech = true;
if (state->options[@"useVad"]) {
if (nSamples + n > WHISPER_SAMPLE_RATE * 2) {
int start = nSamples + n - WHISPER_SAMPLE_RATE * 2;
std::vector<float> audioBufferF32Vec(WHISPER_SAMPLE_RATE * 2);
for (int i = 0; i < WHISPER_SAMPLE_RATE * 2; i++) {
audioBufferF32Vec[i] = (float)audioBufferI16[i + start] / 32768.0f;
}
isSpeech = rn_whisper_vad_simple(audioBufferF32Vec, WHISPER_SAMPLE_RATE, 1000, 0.6f, 100.0f, false);
NSLog(@"[RNWhisper] VAD result: %d", isSpeech);
} else {
isSpeech = false;
}
}
nSamples += n;
state->sliceNSamples[state->sliceIndex] = [NSNumber numberWithInt:nSamples];

AudioQueueEnqueueBuffer(state->queue, inBuffer, 0, NULL);

if (!isSpeech) return;

if (!state->isTranscribing) {
state->isTranscribing = true;
dispatch_async([state->mSelf getDispatchQueue], ^{
Expand Down
6 changes: 6 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ export type TranscribeRealtimeOptions = TranscribeOptions & {
* (Default: Equal to `realtimeMaxAudioSec`)
*/
realtimeAudioSliceSec?: number
/**
* Start transcribe on recording when the audio volume is greater than the threshold by using VAD (Voice Activity Detection).
* The first VAD will be triggered after 2 second of recording.
* (Default: false)
*/
useVad?: boolean
}

export type TranscribeRealtimeEvent = {
Expand Down

0 comments on commit cdc7e44

Please sign in to comment.