From 2350ede563cf2ad199600132ae0bb78fb8323bee Mon Sep 17 00:00:00 2001 From: Jhen Date: Fri, 8 Dec 2023 12:39:19 +0800 Subject: [PATCH] feat(android): create createRealtimeTranscribeJob and update vadSimple jni methods --- .../java/com/rnwhisper/WhisperContext.java | 66 ++++++------- android/src/main/jni-utils.h | 10 +- android/src/main/jni.cpp | 95 ++++++++++++++----- cpp/rn-whisper.h | 4 +- ios/RNWhisperContext.mm | 5 +- 5 files changed, 110 insertions(+), 70 deletions(-) diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index bf1c7ee..06a60f9 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -80,25 +80,8 @@ private void rewind() { } private boolean vad(ReadableMap options, short[] shortBuffer, int nSamples, int n) { - boolean isSpeech = true; - if (!isTranscribing && options.hasKey("useVad") && options.getBoolean("useVad")) { - int vadMs = options.hasKey("vadMs") ? options.getInt("vadMs") : 2000; - if (vadMs < 2000) vadMs = 2000; - int sampleSize = (int) (SAMPLE_RATE * vadMs / 1000); - if (nSamples + n > sampleSize) { - int start = nSamples + n - sampleSize; - float[] audioData = new float[sampleSize]; - for (int i = 0; i < sampleSize; i++) { - audioData[i] = shortBuffer[i + start] / 32768.0f; - } - float vadThold = options.hasKey("vadThold") ? (float) options.getDouble("vadThold") : 0.6f; - float vadFreqThold = options.hasKey("vadFreqThold") ? (float) options.getDouble("vadFreqThold") : 0.6f; - isSpeech = vadSimple(audioData, sampleSize, vadThold, vadFreqThold); - } else { - isSpeech = false; - } - } - return isSpeech; + if (isTranscribing) return true; + return vadSimple(jobId, shortBuffer, nSamples, n); } private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) { @@ -112,8 +95,8 @@ private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) { Log.e(NAME, "Error saving wav file: " + e.getMessage()); } } - emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap()); + removeRealtimeTranscribeJob(jobId, context); } public int startRealtimeTranscribe(int jobId, ReadableMap options) { @@ -132,6 +115,7 @@ public int startRealtimeTranscribe(int jobId, ReadableMap options) { rewind(); this.jobId = jobId; + createRealtimeTranscribeJob(jobId, context, options); int realtimeAudioSec = options.hasKey("realtimeAudioSec") ? options.getInt("realtimeAudioSec") : 0; final int audioSec = realtimeAudioSec > 0 ? realtimeAudioSec : DEFAULT_MAX_AUDIO_SEC; @@ -178,8 +162,7 @@ public void run() { finishRealtimeTranscribe(options, Arguments.createMap()); } else if (!isTranscribing) { short[] shortBuffer = shortBufferSlices.get(sliceIndex); - boolean isSpeech = vad(options, shortBuffer, nSamples, 0); - if (!isSpeech) { + if (!vad(options, shortBuffer, nSamples, 0)) { finishRealtimeTranscribe(options, Arguments.createMap()); break; } @@ -265,7 +248,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe Log.d(NAME, "Start transcribing realtime: " + nSamplesTranscribing); int timeStart = (int) System.currentTimeMillis(); - int code = full(jobId, options, nSamplesBuffer32, nSamplesTranscribing); + int code = fullWithJob(jobId, context, nSamplesBuffer32, nSamplesTranscribing); int timeEnd = (int) System.currentTimeMillis(); int timeRecording = (int) (nSamplesTranscribing / SAMPLE_RATE * 1000); @@ -383,32 +366,30 @@ public WritableMap transcribeInputStream(int jobId, InputStream inputStream, Rea this.jobId = jobId; isTranscribing = true; float[] audioData = AudioUtils.decodeWaveFile(inputStream); - int code = full(jobId, options, audioData, audioData.length); - isTranscribing = false; - this.jobId = -1; - if (code != 0 && code != 999) { - throw new Exception("Failed to transcribe the file. Code: " + code); - } - WritableMap result = getTextSegments(0, getTextSegmentCount(context)); - result.putBoolean("isAborted", isStoppedByAction); - return result; - } - private int full(int jobId, ReadableMap options, float[] audioData, int audioDataLen) { boolean hasProgressCallback = options.hasKey("onProgress") && options.getBoolean("onProgress"); boolean hasNewSegmentsCallback = options.hasKey("onNewSegments") && options.getBoolean("onNewSegments"); - return fullTranscribe( + int code = fullWithNewJob( jobId, context, // float[] audio_data, audioData, // jint audio_data_len, - audioDataLen, + audioData.length, // ReadableMap options, options, // Callback callback hasProgressCallback || hasNewSegmentsCallback ? new Callback(this, hasProgressCallback, hasNewSegmentsCallback) : null ); + + isTranscribing = false; + this.jobId = -1; + if (code != 0 && code != 999) { + throw new Exception("Failed to transcribe the file. Code: " + code); + } + WritableMap result = getTextSegments(0, getTextSegmentCount(context)); + result.putBoolean("isAborted", isStoppedByAction); + return result; } private WritableMap getTextSegments(int start, int count) { @@ -531,8 +512,8 @@ private static String cpuInfo() { protected static native long initContext(String modelPath); protected static native long initContextWithAsset(AssetManager assetManager, String modelPath); protected static native long initContextWithInputStream(PushbackInputStream inputStream); - protected static native boolean vadSimple(float[] audio_data, int audio_data_len, float vad_thold, float vad_freq_thold); - protected static native int fullTranscribe( + + protected static native int fullWithNewJob( int job_id, long context, float[] audio_data, @@ -540,6 +521,15 @@ protected static native int fullTranscribe( ReadableMap options, Callback Callback ); + protected static native void createRealtimeTranscribeJob(int job_id, long context, ReadableMap options); + protected static native void removeRealtimeTranscribeJob(int job_id, long context); + protected static native boolean vadSimple(int job_id, short[] pcm, int n_samples, int n); + protected static native int fullWithJob( + int job_id, + long context, + float[] audio_data, + int audio_data_len + ); protected static native void abortTranscribe(int jobId); protected static native void abortAllTranscribe(); protected static native int getTextSegmentCount(long context); diff --git a/android/src/main/jni-utils.h b/android/src/main/jni-utils.h index 419ce34..f4cf1a9 100644 --- a/android/src/main/jni-utils.h +++ b/android/src/main/jni-utils.h @@ -4,7 +4,7 @@ namespace readablemap { -jboolean hasKey(JNIEnv *env, jobject readableMap, const char *key) { +bool hasKey(JNIEnv *env, jobject readableMap, const char *key) { jclass mapClass = env->GetObjectClass(readableMap); jmethodID hasKeyMethod = env->GetMethodID(mapClass, "hasKey", "(Ljava/lang/String;)Z"); jstring jKey = env->NewStringUTF(key); @@ -13,7 +13,7 @@ jboolean hasKey(JNIEnv *env, jobject readableMap, const char *key) { return result; } -jint getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) { +int getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) { if (!hasKey(env, readableMap, key)) { return defaultValue; } @@ -25,7 +25,7 @@ jint getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue return result; } -jboolean getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) { +bool getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) { if (!hasKey(env, readableMap, key)) { return defaultValue; } @@ -37,7 +37,7 @@ jboolean getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean def return result; } -jlong getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) { +long getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) { if (!hasKey(env, readableMap, key)) { return defaultValue; } @@ -49,7 +49,7 @@ jlong getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultVa return result; } -jfloat getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) { +float getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) { if (!hasKey(env, readableMap, key)) { return defaultValue; } diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 27115d8..5b7ddb6 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -191,26 +191,6 @@ Java_com_rnwhisper_WhisperContext_initContextWithInputStream( return reinterpret_cast(context); } -JNIEXPORT jboolean JNICALL -Java_com_rnwhisper_WhisperContext_vadSimple( - JNIEnv *env, - jobject thiz, - jfloatArray audio_data, - jint audio_data_len, - jfloat vad_thold, - jfloat vad_freq_thold -) { - UNUSED(thiz); - - std::vector samples(audio_data_len); - jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr); - for (int i = 0; i < audio_data_len; i++) { - samples[i] = audio_data_arr[i]; - } - bool is_speech = rnwhisper::vad_simple(samples, WHISPER_SAMPLE_RATE, 1000, vad_thold, vad_freq_thold, false); - env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT); - return is_speech; -} struct whisper_full_params createFullParams(JNIEnv *env, jobject options) { struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); @@ -272,7 +252,7 @@ struct callback_context { }; JNIEXPORT jint JNICALL -Java_com_rnwhisper_WhisperContext_fullTranscribe( +Java_com_rnwhisper_WhisperContext_fullWithNewJob( JNIEnv *env, jobject thiz, jint job_id, @@ -333,7 +313,78 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( return code; } -// TODO: full for realtimeTranscribe with job_id (need create job first) +JNIEXPORT void JNICALL +Java_com_rnwhisper_WhisperContext_createRealtimeTranscribeJob( + JNIEnv *env, + jobject thiz, + jint job_id, + jlong context_ptr, + jobject options +) { + whisper_full_params params = createFullParams(env, options); + rnwhisper::job* job = rnwhisper::job_new(job_id, params); + rnwhisper::vad_params vad; + vad.use_vad = readablemap::getBool(env, options, "useVad", false); + vad.vad_ms = readablemap::getInt(env, options, "vadMs", 2000); + vad.vad_thold = readablemap::getFloat(env, options, "vadThold", 0.6f); + vad.freq_thold = readablemap::getFloat(env, options, "vadFreqThold", 100.0f); + job->set_vad_params(vad); +} + +JNIEXPORT void JNICALL +Java_com_rnwhisper_WhisperContext_removeRealtimeTranscribeJob( + JNIEnv *env, + jobject thiz, + jint job_id, + jlong context_ptr +) { + UNUSED(env); + UNUSED(thiz); + UNUSED(context_ptr); + rnwhisper::job_remove(job_id); +} + +JNIEXPORT jboolean JNICALL +Java_com_rnwhisper_WhisperContext_vadSimple( + JNIEnv *env, + jobject thiz, + jint job_id, + jshortArray pcm, + jint n_samples, + jint n +) { + UNUSED(thiz); + + jshort *pcm_arr = env->GetShortArrayElements(pcm, nullptr); + rnwhisper::job* job = rnwhisper::job_get(job_id); + bool is_speech = job->vad_simple(pcm_arr, n_samples, n); + env->ReleaseShortArrayElements(pcm, pcm_arr, JNI_ABORT); + return is_speech; +} + +JNIEXPORT jint JNICALL +Java_com_rnwhisper_WhisperContext_fullWithJob( + JNIEnv *env, + jobject thiz, + jint job_id, + jlong context_ptr, + jfloatArray audio_data, // TODO: move audio slice to C++ + jint audio_data_len +) { + UNUSED(thiz); + struct whisper_context *context = reinterpret_cast(context_ptr); + jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr); + + rnwhisper::job* job = rnwhisper::job_get(job_id); + + int code = whisper_full(context, job->params, audio_data_arr, audio_data_len); + if (code == 0) { + // whisper_print_timings(context); + } + env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT); + if (job->is_aborted()) code = -999; + return code; +} JNIEXPORT void JNICALL Java_com_rnwhisper_WhisperContext_abortTranscribe( diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index ef8177b..2bfc591 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -9,8 +9,8 @@ namespace rnwhisper { struct vad_params { bool use_vad = false; - float vad_thold = 0.1; - float freq_thold = 0.1; + float vad_thold = 0.6f; + float freq_thold = 100.0f; int vad_ms = 2000; int last_ms = 1000; bool verbose = false; diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index ea5974e..05ab7f9 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -355,13 +355,12 @@ - (OSStatus)transcribeRealtime:(int)jobId [self prepareRealtime:options]; self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]); - rnwhisper::vad_params vad = { + self->recordState.job->set_vad_params({ .use_vad = options[@"useVad"] != nil ? [options[@"useVad"] boolValue] : false, .vad_ms = options[@"vadMs"] != nil ? [options[@"vadMs"] intValue] : 2000, .vad_thold = options[@"vadThold"] != nil ? [options[@"vadThold"] floatValue] : 0.6f, .freq_thold = options[@"vadFreqThold"] != nil ? [options[@"vadFreqThold"] floatValue] : 100.0f - }; - self->recordState.job->set_vad_params(vad); + }); OSStatus status = AudioQueueNewInput( &self->recordState.dataFormat,