Skip to content

Commit

Permalink
feat(android): implement onNewSegments event
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Sep 30, 2023
1 parent 5432799 commit fd141d0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 21 deletions.
44 changes: 35 additions & 9 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe
payload.putInt("sliceIndex", transcribeSliceIndex);

if (code == 0) {
payload.putMap("data", getTextSegments());
payload.putMap("data", getTextSegments(0, getTextSegmentCount(context)));
} else {
payload.putString("error", "Transcribe failed with code " + code);
}
Expand Down Expand Up @@ -406,16 +406,41 @@ private void emitProgress(int progress) {
eventEmitter.emit("@RNWhisper_onTranscribeProgress", event);
}

private static class ProgressCallback {
private void emitNewSegments(WritableMap result) {
WritableMap event = Arguments.createMap();
event.putInt("contextId", WhisperContext.this.id);
event.putInt("jobId", jobId);
event.putMap("result", result);
eventEmitter.emit("@RNWhisper_onTranscribeNewSegments", event);
}

private static class Callback {
WhisperContext context;
boolean emitProgressNeeded = false;
boolean emitNewSegmentsNeeded = false;
int totalNNew = 0;

public ProgressCallback(WhisperContext context) {
public Callback(WhisperContext context, boolean emitProgressNeeded, boolean emitNewSegmentsNeeded) {
this.context = context;
this.emitProgressNeeded = emitProgressNeeded;
this.emitNewSegmentsNeeded = emitNewSegmentsNeeded;
}

void onProgress(int progress) {
if (!emitProgressNeeded) return;
context.emitProgress(progress);
}

void onNewSegments(int nNew) {
Log.d(NAME, "onNewSegments: " + nNew);
totalNNew += nNew;
if (!emitNewSegmentsNeeded) return;

WritableMap result = context.getTextSegments(totalNNew - nNew, totalNNew);
result.putInt("nNew", nNew);
result.putInt("totalNNew", totalNNew);
context.emitNewSegments(result);
}
}

public WritableMap transcribeInputStream(int jobId, InputStream inputStream, ReadableMap options) throws IOException, Exception {
Expand All @@ -433,12 +458,14 @@ public WritableMap transcribeInputStream(int jobId, InputStream inputStream, Rea
if (code != 0) {
throw new Exception("Failed to transcribe the file. Code: " + code);
}
WritableMap result = getTextSegments();
WritableMap result = getTextSegments(0, getTextSegmentCount(context));
result.putBoolean("isAborted", isStoppedByAction);
return result;
}

private int full(int jobId, ReadableMap options, float[] audioData, int audioDataLen) {
boolean hasProgressCallback = options.hasKey("onProgress") && options.getBoolean("onProgress");
boolean hasNewSegmentsCallback = options.hasKey("onNewSegments") && options.getBoolean("onNewSegments");
return fullTranscribe(
jobId,
context,
Expand Down Expand Up @@ -478,13 +505,12 @@ private int full(int jobId, ReadableMap options, float[] audioData, int audioDat
options.hasKey("language") ? options.getString("language") : "auto",
// jstring prompt
options.hasKey("prompt") ? options.getString("prompt") : null,
// ProgressCallback progressCallback
options.hasKey("onProgress") && options.getBoolean("onProgress") ? new ProgressCallback(this) : null
// Callback callback
hasProgressCallback || hasNewSegmentsCallback ? new Callback(this, hasProgressCallback, hasNewSegmentsCallback) : null
);
}

private WritableMap getTextSegments() {
Integer count = getTextSegmentCount(context);
private WritableMap getTextSegments(int start, int count) {
StringBuilder builder = new StringBuilder();

WritableMap data = Arguments.createMap();
Expand Down Expand Up @@ -647,7 +673,7 @@ protected static native int fullTranscribe(
boolean translate,
String language,
String prompt,
ProgressCallback progressCallback
Callback Callback
);
protected static native void abortTranscribe(int jobId);
protected static native void abortAllTranscribe();
Expand Down
35 changes: 23 additions & 12 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ Java_com_rnwhisper_WhisperContext_vadSimple(
return is_speech;
}

struct progress_callback_context {
struct callback_context {
JNIEnv *env;
jobject progress_callback_instance;
jobject callback_instance;
};

JNIEXPORT jint JNICALL
Expand All @@ -234,7 +234,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
jboolean translate,
jstring language,
jstring prompt,
jobject progress_callback_instance
jobject callback_instance
) {
UNUSED(thiz);
struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
Expand Down Expand Up @@ -302,19 +302,30 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
};
params.encoder_begin_callback_user_data = rn_whisper_assign_abort_map(job_id);

if (progress_callback_instance != nullptr) {
if (callback_instance != nullptr) {
callback_context *cb_ctx = new callback_context;
cb_ctx->env = env;
cb_ctx->callback_instance = env->NewGlobalRef(callback_instance);

params.progress_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) {
progress_callback_context *cb_ctx = (progress_callback_context *)user_data;
callback_context *cb_ctx = (callback_context *)user_data;
JNIEnv *env = cb_ctx->env;
jobject progress_callback_instance = cb_ctx->progress_callback_instance;
jclass progress_callback_class = env->GetObjectClass(progress_callback_instance);
jmethodID onProgress = env->GetMethodID(progress_callback_class, "onProgress", "(I)V");
env->CallVoidMethod(progress_callback_instance, onProgress, progress);
jobject callback_instance = cb_ctx->callback_instance;
jclass callback_class = env->GetObjectClass(callback_instance);
jmethodID onProgress = env->GetMethodID(callback_class, "onProgress", "(I)V");
env->CallVoidMethod(callback_instance, onProgress, progress);
};
progress_callback_context *cb_ctx = new progress_callback_context;
cb_ctx->env = env;
cb_ctx->progress_callback_instance = env->NewGlobalRef(progress_callback_instance);
params.progress_callback_user_data = cb_ctx;

params.new_segment_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int n_new, void * user_data) {
callback_context *cb_ctx = (callback_context *)user_data;
JNIEnv *env = cb_ctx->env;
jobject callback_instance = cb_ctx->callback_instance;
jclass callback_class = env->GetObjectClass(callback_instance);
jmethodID onNewSegments = env->GetMethodID(callback_class, "onNewSegments", "(I)V");
env->CallVoidMethod(callback_instance, onNewSegments, n_new);
};
params.new_segment_callback_user_data = cb_ctx;
}

LOGI("About to reset timings");
Expand Down

0 comments on commit fd141d0

Please sign in to comment.