diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index cd8889b..7440b31 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -346,7 +346,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe payload.putInt("sliceIndex", transcribeSliceIndex); if (code == 0) { - payload.putMap("data", getTextSegments()); + payload.putMap("data", getTextSegments(0, getTextSegmentCount(context))); } else { payload.putString("error", "Transcribe failed with code " + code); } @@ -406,16 +406,41 @@ private void emitProgress(int progress) { eventEmitter.emit("@RNWhisper_onTranscribeProgress", event); } - private static class ProgressCallback { + private void emitNewSegments(WritableMap result) { + WritableMap event = Arguments.createMap(); + event.putInt("contextId", WhisperContext.this.id); + event.putInt("jobId", jobId); + event.putMap("result", result); + eventEmitter.emit("@RNWhisper_onTranscribeNewSegments", event); + } + + private static class Callback { WhisperContext context; + boolean emitProgressNeeded = false; + boolean emitNewSegmentsNeeded = false; + int totalNNew = 0; - public ProgressCallback(WhisperContext context) { + public Callback(WhisperContext context, boolean emitProgressNeeded, boolean emitNewSegmentsNeeded) { this.context = context; + this.emitProgressNeeded = emitProgressNeeded; + this.emitNewSegmentsNeeded = emitNewSegmentsNeeded; } void onProgress(int progress) { + if (!emitProgressNeeded) return; context.emitProgress(progress); } + + void onNewSegments(int nNew) { + Log.d(NAME, "onNewSegments: " + nNew); + totalNNew += nNew; + if (!emitNewSegmentsNeeded) return; + + WritableMap result = context.getTextSegments(totalNNew - nNew, totalNNew); + result.putInt("nNew", nNew); + result.putInt("totalNNew", totalNNew); + context.emitNewSegments(result); + } } public WritableMap transcribeInputStream(int jobId, InputStream inputStream, ReadableMap options) throws IOException, Exception { @@ -433,12 +458,14 @@ public WritableMap transcribeInputStream(int jobId, InputStream inputStream, Rea if (code != 0) { throw new Exception("Failed to transcribe the file. Code: " + code); } - WritableMap result = getTextSegments(); + WritableMap result = getTextSegments(0, getTextSegmentCount(context)); result.putBoolean("isAborted", isStoppedByAction); return result; } private int full(int jobId, ReadableMap options, float[] audioData, int audioDataLen) { + boolean hasProgressCallback = options.hasKey("onProgress") && options.getBoolean("onProgress"); + boolean hasNewSegmentsCallback = options.hasKey("onNewSegments") && options.getBoolean("onNewSegments"); return fullTranscribe( jobId, context, @@ -478,13 +505,12 @@ private int full(int jobId, ReadableMap options, float[] audioData, int audioDat options.hasKey("language") ? options.getString("language") : "auto", // jstring prompt options.hasKey("prompt") ? options.getString("prompt") : null, - // ProgressCallback progressCallback - options.hasKey("onProgress") && options.getBoolean("onProgress") ? new ProgressCallback(this) : null + // Callback callback + hasProgressCallback || hasNewSegmentsCallback ? new Callback(this, hasProgressCallback, hasNewSegmentsCallback) : null ); } - private WritableMap getTextSegments() { - Integer count = getTextSegmentCount(context); + private WritableMap getTextSegments(int start, int count) { StringBuilder builder = new StringBuilder(); WritableMap data = Arguments.createMap(); @@ -647,7 +673,7 @@ protected static native int fullTranscribe( boolean translate, String language, String prompt, - ProgressCallback progressCallback + Callback Callback ); protected static native void abortTranscribe(int jobId); protected static native void abortAllTranscribe(); diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 20e6f02..bece224 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -206,9 +206,9 @@ Java_com_rnwhisper_WhisperContext_vadSimple( return is_speech; } -struct progress_callback_context { +struct callback_context { JNIEnv *env; - jobject progress_callback_instance; + jobject callback_instance; }; JNIEXPORT jint JNICALL @@ -234,7 +234,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( jboolean translate, jstring language, jstring prompt, - jobject progress_callback_instance + jobject callback_instance ) { UNUSED(thiz); struct whisper_context *context = reinterpret_cast(context_ptr); @@ -302,19 +302,30 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( }; params.encoder_begin_callback_user_data = rn_whisper_assign_abort_map(job_id); - if (progress_callback_instance != nullptr) { + if (callback_instance != nullptr) { + callback_context *cb_ctx = new callback_context; + cb_ctx->env = env; + cb_ctx->callback_instance = env->NewGlobalRef(callback_instance); + params.progress_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) { - progress_callback_context *cb_ctx = (progress_callback_context *)user_data; + callback_context *cb_ctx = (callback_context *)user_data; JNIEnv *env = cb_ctx->env; - jobject progress_callback_instance = cb_ctx->progress_callback_instance; - jclass progress_callback_class = env->GetObjectClass(progress_callback_instance); - jmethodID onProgress = env->GetMethodID(progress_callback_class, "onProgress", "(I)V"); - env->CallVoidMethod(progress_callback_instance, onProgress, progress); + jobject callback_instance = cb_ctx->callback_instance; + jclass callback_class = env->GetObjectClass(callback_instance); + jmethodID onProgress = env->GetMethodID(callback_class, "onProgress", "(I)V"); + env->CallVoidMethod(callback_instance, onProgress, progress); }; - progress_callback_context *cb_ctx = new progress_callback_context; - cb_ctx->env = env; - cb_ctx->progress_callback_instance = env->NewGlobalRef(progress_callback_instance); params.progress_callback_user_data = cb_ctx; + + params.new_segment_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int n_new, void * user_data) { + callback_context *cb_ctx = (callback_context *)user_data; + JNIEnv *env = cb_ctx->env; + jobject callback_instance = cb_ctx->callback_instance; + jclass callback_class = env->GetObjectClass(callback_instance); + jmethodID onNewSegments = env->GetMethodID(callback_class, "onNewSegments", "(I)V"); + env->CallVoidMethod(callback_instance, onNewSegments, n_new); + }; + params.new_segment_callback_user_data = cb_ctx; } LOGI("About to reset timings");