Skip to content

Commit

Permalink
feat(android): create createRealtimeTranscribeJob and update vadSimpl…
Browse files Browse the repository at this point in the history
…e jni methods
  • Loading branch information
jhen0409 committed Dec 8, 2023
1 parent 50f8713 commit 2350ede
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 70 deletions.
66 changes: 28 additions & 38 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,8 @@ private void rewind() {
}

private boolean vad(ReadableMap options, short[] shortBuffer, int nSamples, int n) {
boolean isSpeech = true;
if (!isTranscribing && options.hasKey("useVad") && options.getBoolean("useVad")) {
int vadMs = options.hasKey("vadMs") ? options.getInt("vadMs") : 2000;
if (vadMs < 2000) vadMs = 2000;
int sampleSize = (int) (SAMPLE_RATE * vadMs / 1000);
if (nSamples + n > sampleSize) {
int start = nSamples + n - sampleSize;
float[] audioData = new float[sampleSize];
for (int i = 0; i < sampleSize; i++) {
audioData[i] = shortBuffer[i + start] / 32768.0f;
}
float vadThold = options.hasKey("vadThold") ? (float) options.getDouble("vadThold") : 0.6f;
float vadFreqThold = options.hasKey("vadFreqThold") ? (float) options.getDouble("vadFreqThold") : 0.6f;
isSpeech = vadSimple(audioData, sampleSize, vadThold, vadFreqThold);
} else {
isSpeech = false;
}
}
return isSpeech;
if (isTranscribing) return true;
return vadSimple(jobId, shortBuffer, nSamples, n);
}

private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) {
Expand All @@ -112,8 +95,8 @@ private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) {
Log.e(NAME, "Error saving wav file: " + e.getMessage());
}
}

emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap());
removeRealtimeTranscribeJob(jobId, context);
}

public int startRealtimeTranscribe(int jobId, ReadableMap options) {
Expand All @@ -132,6 +115,7 @@ public int startRealtimeTranscribe(int jobId, ReadableMap options) {
rewind();

this.jobId = jobId;
createRealtimeTranscribeJob(jobId, context, options);

int realtimeAudioSec = options.hasKey("realtimeAudioSec") ? options.getInt("realtimeAudioSec") : 0;
final int audioSec = realtimeAudioSec > 0 ? realtimeAudioSec : DEFAULT_MAX_AUDIO_SEC;
Expand Down Expand Up @@ -178,8 +162,7 @@ public void run() {
finishRealtimeTranscribe(options, Arguments.createMap());
} else if (!isTranscribing) {
short[] shortBuffer = shortBufferSlices.get(sliceIndex);
boolean isSpeech = vad(options, shortBuffer, nSamples, 0);
if (!isSpeech) {
if (!vad(options, shortBuffer, nSamples, 0)) {
finishRealtimeTranscribe(options, Arguments.createMap());
break;
}
Expand Down Expand Up @@ -265,7 +248,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe
Log.d(NAME, "Start transcribing realtime: " + nSamplesTranscribing);

int timeStart = (int) System.currentTimeMillis();
int code = full(jobId, options, nSamplesBuffer32, nSamplesTranscribing);
int code = fullWithJob(jobId, context, nSamplesBuffer32, nSamplesTranscribing);
int timeEnd = (int) System.currentTimeMillis();
int timeRecording = (int) (nSamplesTranscribing / SAMPLE_RATE * 1000);

Expand Down Expand Up @@ -383,32 +366,30 @@ public WritableMap transcribeInputStream(int jobId, InputStream inputStream, Rea
this.jobId = jobId;
isTranscribing = true;
float[] audioData = AudioUtils.decodeWaveFile(inputStream);
int code = full(jobId, options, audioData, audioData.length);
isTranscribing = false;
this.jobId = -1;
if (code != 0 && code != 999) {
throw new Exception("Failed to transcribe the file. Code: " + code);
}
WritableMap result = getTextSegments(0, getTextSegmentCount(context));
result.putBoolean("isAborted", isStoppedByAction);
return result;
}

private int full(int jobId, ReadableMap options, float[] audioData, int audioDataLen) {
boolean hasProgressCallback = options.hasKey("onProgress") && options.getBoolean("onProgress");
boolean hasNewSegmentsCallback = options.hasKey("onNewSegments") && options.getBoolean("onNewSegments");
return fullTranscribe(
int code = fullWithNewJob(
jobId,
context,
// float[] audio_data,
audioData,
// jint audio_data_len,
audioDataLen,
audioData.length,
// ReadableMap options,
options,
// Callback callback
hasProgressCallback || hasNewSegmentsCallback ? new Callback(this, hasProgressCallback, hasNewSegmentsCallback) : null
);

isTranscribing = false;
this.jobId = -1;
if (code != 0 && code != 999) {
throw new Exception("Failed to transcribe the file. Code: " + code);
}
WritableMap result = getTextSegments(0, getTextSegmentCount(context));
result.putBoolean("isAborted", isStoppedByAction);
return result;
}

private WritableMap getTextSegments(int start, int count) {
Expand Down Expand Up @@ -531,15 +512,24 @@ private static String cpuInfo() {
protected static native long initContext(String modelPath);
protected static native long initContextWithAsset(AssetManager assetManager, String modelPath);
protected static native long initContextWithInputStream(PushbackInputStream inputStream);
protected static native boolean vadSimple(float[] audio_data, int audio_data_len, float vad_thold, float vad_freq_thold);
protected static native int fullTranscribe(

protected static native int fullWithNewJob(
int job_id,
long context,
float[] audio_data,
int audio_data_len,
ReadableMap options,
Callback Callback
);
protected static native void createRealtimeTranscribeJob(int job_id, long context, ReadableMap options);
protected static native void removeRealtimeTranscribeJob(int job_id, long context);
protected static native boolean vadSimple(int job_id, short[] pcm, int n_samples, int n);
protected static native int fullWithJob(
int job_id,
long context,
float[] audio_data,
int audio_data_len
);
protected static native void abortTranscribe(int jobId);
protected static native void abortAllTranscribe();
protected static native int getTextSegmentCount(long context);
Expand Down
10 changes: 5 additions & 5 deletions android/src/main/jni-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace readablemap {

jboolean hasKey(JNIEnv *env, jobject readableMap, const char *key) {
bool hasKey(JNIEnv *env, jobject readableMap, const char *key) {
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID hasKeyMethod = env->GetMethodID(mapClass, "hasKey", "(Ljava/lang/String;)Z");
jstring jKey = env->NewStringUTF(key);
Expand All @@ -13,7 +13,7 @@ jboolean hasKey(JNIEnv *env, jobject readableMap, const char *key) {
return result;
}

jint getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) {
int getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
Expand All @@ -25,7 +25,7 @@ jint getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue
return result;
}

jboolean getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) {
bool getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
Expand All @@ -37,7 +37,7 @@ jboolean getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean def
return result;
}

jlong getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) {
long getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
Expand All @@ -49,7 +49,7 @@ jlong getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultVa
return result;
}

jfloat getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) {
float getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
Expand Down
95 changes: 73 additions & 22 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,26 +191,6 @@ Java_com_rnwhisper_WhisperContext_initContextWithInputStream(
return reinterpret_cast<jlong>(context);
}

JNIEXPORT jboolean JNICALL
Java_com_rnwhisper_WhisperContext_vadSimple(
JNIEnv *env,
jobject thiz,
jfloatArray audio_data,
jint audio_data_len,
jfloat vad_thold,
jfloat vad_freq_thold
) {
UNUSED(thiz);

std::vector<float> samples(audio_data_len);
jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);
for (int i = 0; i < audio_data_len; i++) {
samples[i] = audio_data_arr[i];
}
bool is_speech = rnwhisper::vad_simple(samples, WHISPER_SAMPLE_RATE, 1000, vad_thold, vad_freq_thold, false);
env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
return is_speech;
}

struct whisper_full_params createFullParams(JNIEnv *env, jobject options) {
struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
Expand Down Expand Up @@ -272,7 +252,7 @@ struct callback_context {
};

JNIEXPORT jint JNICALL
Java_com_rnwhisper_WhisperContext_fullTranscribe(
Java_com_rnwhisper_WhisperContext_fullWithNewJob(
JNIEnv *env,
jobject thiz,
jint job_id,
Expand Down Expand Up @@ -333,7 +313,78 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
return code;
}

// TODO: full for realtimeTranscribe with job_id (need create job first)
JNIEXPORT void JNICALL
Java_com_rnwhisper_WhisperContext_createRealtimeTranscribeJob(
JNIEnv *env,
jobject thiz,
jint job_id,
jlong context_ptr,
jobject options
) {
whisper_full_params params = createFullParams(env, options);
rnwhisper::job* job = rnwhisper::job_new(job_id, params);
rnwhisper::vad_params vad;
vad.use_vad = readablemap::getBool(env, options, "useVad", false);
vad.vad_ms = readablemap::getInt(env, options, "vadMs", 2000);
vad.vad_thold = readablemap::getFloat(env, options, "vadThold", 0.6f);
vad.freq_thold = readablemap::getFloat(env, options, "vadFreqThold", 100.0f);
job->set_vad_params(vad);
}

JNIEXPORT void JNICALL
Java_com_rnwhisper_WhisperContext_removeRealtimeTranscribeJob(
JNIEnv *env,
jobject thiz,
jint job_id,
jlong context_ptr
) {
UNUSED(env);
UNUSED(thiz);
UNUSED(context_ptr);
rnwhisper::job_remove(job_id);
}

JNIEXPORT jboolean JNICALL
Java_com_rnwhisper_WhisperContext_vadSimple(
JNIEnv *env,
jobject thiz,
jint job_id,
jshortArray pcm,
jint n_samples,
jint n
) {
UNUSED(thiz);

jshort *pcm_arr = env->GetShortArrayElements(pcm, nullptr);
rnwhisper::job* job = rnwhisper::job_get(job_id);
bool is_speech = job->vad_simple(pcm_arr, n_samples, n);
env->ReleaseShortArrayElements(pcm, pcm_arr, JNI_ABORT);
return is_speech;
}

JNIEXPORT jint JNICALL
Java_com_rnwhisper_WhisperContext_fullWithJob(
JNIEnv *env,
jobject thiz,
jint job_id,
jlong context_ptr,
jfloatArray audio_data, // TODO: move audio slice to C++
jint audio_data_len
) {
UNUSED(thiz);
struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);

rnwhisper::job* job = rnwhisper::job_get(job_id);

int code = whisper_full(context, job->params, audio_data_arr, audio_data_len);
if (code == 0) {
// whisper_print_timings(context);
}
env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
if (job->is_aborted()) code = -999;
return code;
}

JNIEXPORT void JNICALL
Java_com_rnwhisper_WhisperContext_abortTranscribe(
Expand Down
4 changes: 2 additions & 2 deletions cpp/rn-whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ namespace rnwhisper {

struct vad_params {
bool use_vad = false;
float vad_thold = 0.1;
float freq_thold = 0.1;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
int vad_ms = 2000;
int last_ms = 1000;
bool verbose = false;
Expand Down
5 changes: 2 additions & 3 deletions ios/RNWhisperContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,12 @@ - (OSStatus)transcribeRealtime:(int)jobId
[self prepareRealtime:options];
self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]);

rnwhisper::vad_params vad = {
self->recordState.job->set_vad_params({
.use_vad = options[@"useVad"] != nil ? [options[@"useVad"] boolValue] : false,
.vad_ms = options[@"vadMs"] != nil ? [options[@"vadMs"] intValue] : 2000,
.vad_thold = options[@"vadThold"] != nil ? [options[@"vadThold"] floatValue] : 0.6f,
.freq_thold = options[@"vadFreqThold"] != nil ? [options[@"vadFreqThold"] floatValue] : 100.0f
};
self->recordState.job->set_vad_params(vad);
});

OSStatus status = AudioQueueNewInput(
&self->recordState.dataFormat,
Expand Down

0 comments on commit 2350ede

Please sign in to comment.