diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index 22d046f..d9bce87 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -282,6 +282,26 @@ private void emitTranscribeEvent(final String eventName, final WritableMap paylo eventEmitter.emit(eventName, event); } + private void emitProgress(int progress) { + WritableMap event = Arguments.createMap(); + event.putInt("contextId", WhisperContext.this.id); + event.putInt("jobId", jobId); + event.putInt("progress", progress); + eventEmitter.emit("@RNWhisper_onTranscribeProgress", event); + } + + private static class ProgressCallback { + WhisperContext context; + + public ProgressCallback(WhisperContext context) { + this.context = context; + } + + void onProgress(int progress) { + context.emitProgress(progress); + } + } + public WritableMap transcribeInputStream(int jobId, InputStream inputStream, ReadableMap options) throws IOException, Exception { this.jobId = jobId; isTranscribing = true; @@ -334,7 +354,9 @@ private int full(int jobId, ReadableMap options, float[] audioData, int audioDat // jstring language, options.hasKey("language") ? options.getString("language") : "auto", // jstring prompt - options.hasKey("prompt") ? options.getString("prompt") : null + options.hasKey("prompt") ? options.getString("prompt") : null, + // ProgressCallback progressCallback + options.hasKey("onProgress") && options.getBoolean("onProgress") ? new ProgressCallback(this) : null ); } @@ -469,6 +491,7 @@ private static String cpuInfo() { } } + protected static native long initContext(String modelPath); protected static native long initContextWithAsset(AssetManager assetManager, String modelPath); protected static native long initContextWithInputStream(PushbackInputStream inputStream); @@ -491,7 +514,8 @@ protected static native int fullTranscribe( boolean speed_up, boolean translate, String language, - String prompt + String prompt, + ProgressCallback progressCallback ); protected static native void abortTranscribe(int jobId); protected static native void abortAllTranscribe(); diff --git a/android/src/main/jni/whisper/jni.cpp b/android/src/main/jni/whisper/jni.cpp index c0de842..3b97417 100644 --- a/android/src/main/jni/whisper/jni.cpp +++ b/android/src/main/jni/whisper/jni.cpp @@ -184,6 +184,11 @@ Java_com_rnwhisper_WhisperContext_initContextWithInputStream( return reinterpret_cast(context); } +struct progress_callback_context { + JNIEnv *env; + jobject progress_callback_instance; +}; + JNIEXPORT jint JNICALL Java_com_rnwhisper_WhisperContext_fullTranscribe( JNIEnv *env, @@ -206,7 +211,8 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( jboolean speed_up, jboolean translate, jstring language, - jstring prompt + jstring prompt, + jobject progress_callback_instance ) { UNUSED(thiz); struct whisper_context *context = reinterpret_cast(context_ptr); @@ -274,6 +280,21 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( }; params.encoder_begin_callback_user_data = rn_whisper_assign_abort_map(job_id); + if (progress_callback_instance != nullptr) { + params.progress_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) { + progress_callback_context *cb_ctx = (progress_callback_context *)user_data; + JNIEnv *env = cb_ctx->env; + jobject progress_callback_instance = cb_ctx->progress_callback_instance; + jclass progress_callback_class = env->GetObjectClass(progress_callback_instance); + jmethodID onProgress = env->GetMethodID(progress_callback_class, "onProgress", "(I)V"); + env->CallVoidMethod(progress_callback_instance, onProgress, progress); + }; + progress_callback_context *cb_ctx = new progress_callback_context; + cb_ctx->env = env; + cb_ctx->progress_callback_instance = env->NewGlobalRef(progress_callback_instance); + params.progress_callback_user_data = cb_ctx; + } + LOGI("About to reset timings"); whisper_reset_timings(context); diff --git a/cpp/rn-whisper.cpp b/cpp/rn-whisper.cpp index aeff36a..fcd3133 100644 --- a/cpp/rn-whisper.cpp +++ b/cpp/rn-whisper.cpp @@ -25,6 +25,13 @@ void rn_whisper_abort_transcribe(int job_id) { } } +bool rn_whisper_transcribe_is_aborted(int job_id) { + if (abort_map.find(job_id) != abort_map.end()) { + return abort_map[job_id]; + } + return false; +} + void rn_whisper_abort_all_transcribe() { for (auto it = abort_map.begin(); it != abort_map.end(); ++it) { it->second = true; diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index 70680d5..4fd2c1b 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -8,6 +8,7 @@ extern "C" { bool* rn_whisper_assign_abort_map(int job_id); void rn_whisper_remove_abort_map(int job_id); void rn_whisper_abort_transcribe(int job_id); +bool rn_whisper_transcribe_is_aborted(int job_id); void rn_whisper_abort_all_transcribe(); #ifdef __cplusplus diff --git a/example/src/App.js b/example/src/App.js index 2adabf0..9a11e04 100644 --- a/example/src/App.js +++ b/example/src/App.js @@ -207,15 +207,20 @@ export default function App() { log('Start transcribing...') const startTime = Date.now() const { - // stop, + stop, promise, } = whisperContext.transcribe(sampleFile, { language: 'en', maxLen: 1, tokenTimestamps: true, + onProgress: cur => { + log(`Transcribing progress: ${cur}%`) + } }) + setStopTranscribe({ stop }) const { result, segments } = await promise const endTime = Date.now() + setStopTranscribe(null) setTranscibeResult( `Transcribed result: ${result}\n` + `Transcribed in ${endTime - startTime}ms in ${mode} mode` + diff --git a/ios/RNWhisper.mm b/ios/RNWhisper.mm index 4fb4798..aa34524 100644 --- a/ios/RNWhisper.mm +++ b/ios/RNWhisper.mm @@ -80,6 +80,14 @@ - (NSDictionary *)constantsToExport resolve([NSNumber numberWithInt:contextId]); } +- (NSArray *)supportedEvents { + return@[ + @"@RNWhisper_onTranscribeProgress", + @"@RNWhisper_onRealtimeTranscribe", + @"@RNWhisper_onRealtimeTranscribeEnd", + ]; +} + RCT_REMAP_METHOD(transcribeFile, withContextId:(int)contextId withJobId:(int)jobId @@ -114,21 +122,34 @@ - (NSDictionary *)constantsToExport reject(@"whisper_error", @"Invalid file", nil); return; } - int code = [context transcribeFile:jobId audioData:waveFile audioDataCount:count options:options]; - if (code != 0) { + dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ + int code = [context transcribeFile:jobId + audioData:waveFile + audioDataCount:count + options:options + onProgress: ^(int progress) { + if (rn_whisper_transcribe_is_aborted(jobId)) { + return; + } + dispatch_async(dispatch_get_main_queue(), ^{ + [self sendEventWithName:@"@RNWhisper_onTranscribeProgress" + body:@{ + @"contextId": [NSNumber numberWithInt:contextId], + @"jobId": [NSNumber numberWithInt:jobId], + @"progress": [NSNumber numberWithInt:progress] + } + ]; + }); + } + ]; + if (code != 0) { + free(waveFile); + reject(@"whisper_cpp_error", [NSString stringWithFormat:@"Failed to transcribe the file. Code: %d", code], nil); + return; + } free(waveFile); - reject(@"whisper_cpp_error", [NSString stringWithFormat:@"Failed to transcribe the file. Code: %d", code], nil); - return; - } - free(waveFile); - resolve([context getTextSegments]); -} - -- (NSArray *)supportedEvents { - return@[ - @"@RNWhisper_onRealtimeTranscribe", - @"@RNWhisper_onRealtimeTranscribeEnd", - ]; + resolve([context getTextSegments]); + }); } RCT_REMAP_METHOD(startRealtimeTranscribe, @@ -176,12 +197,20 @@ - (NSArray *)supportedEvents { } reject(@"whisper_error", [NSString stringWithFormat:@"Failed to start realtime transcribe. Status: %d", status], nil); } + RCT_REMAP_METHOD(abortTranscribe, withContextId:(int)contextId - withJobId:(int)jobId) + withJobId:(int)jobId + withResolver:(RCTPromiseResolveBlock)resolve + withRejecter:(RCTPromiseRejectBlock)reject) { RNWhisperContext *context = contexts[[NSNumber numberWithInt:contextId]]; + if (context == nil) { + reject(@"whisper_error", @"Context not found", nil); + return; + } [context stopTranscribe:jobId]; + resolve(nil); } RCT_REMAP_METHOD(releaseContext, diff --git a/ios/RNWhisperContext.h b/ios/RNWhisperContext.h index 7119fe5..9930eb5 100644 --- a/ios/RNWhisperContext.h +++ b/ios/RNWhisperContext.h @@ -48,7 +48,8 @@ typedef struct { - (int)transcribeFile:(int)jobId audioData:(float *)audioData audioDataCount:(int)audioDataCount - options:(NSDictionary *)options; + options:(NSDictionary *)options + onProgress:(void (^)(int))onProgress; - (void)stopTranscribe:(int)jobId; - (bool)isCapturing; - (bool)isTranscribing; diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index c11c267..5f4249a 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -263,10 +263,11 @@ - (int)transcribeFile:(int)jobId audioData:(float *)audioData audioDataCount:(int)audioDataCount options:(NSDictionary *)options + onProgress:(void (^)(int))onProgress { self->recordState.isTranscribing = true; self->recordState.jobId = jobId; - int code = [self fullTranscribe:jobId audioData:audioData audioDataCount:audioDataCount options:options]; + int code = [self fullTranscribeWithProgress:onProgress jobId:jobId audioData:audioData audioDataCount:audioDataCount options:options]; self->recordState.jobId = -1; self->recordState.isTranscribing = false; return code; @@ -297,7 +298,7 @@ - (void)stopCurrentTranscribe { [self stopTranscribe:self->recordState.jobId]; } -- (int)fullTranscribe:(int)jobId audioData:(float *)audioData audioDataCount:(int)audioDataCount options:(NSDictionary *)options { +- (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId { struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); const int n_threads = options[@"maxThreads"] != nil ? @@ -362,6 +363,39 @@ - (int)fullTranscribe:(int)jobId audioData:(float *)audioData audioDataCount:(in }; params.encoder_begin_callback_user_data = rn_whisper_assign_abort_map(jobId); + return params; +} + +- (int)fullTranscribeWithProgress:(void (^)(int))onProgress + jobId:(int)jobId + audioData:(float *)audioData + audioDataCount:(int)audioDataCount + options:(NSDictionary *)options +{ + struct whisper_full_params params = [self getParams:options jobId:jobId]; + if (options[@"onProgress"] && [options[@"onProgress"] boolValue]) { + params.progress_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) { + void (^onProgress)(int) = (__bridge void (^)(int))user_data; + onProgress(progress); + }; + params.progress_callback_user_data = (__bridge void *)(onProgress); + } + whisper_reset_timings(self->ctx); + + int code = whisper_full(self->ctx, params, audioData, audioDataCount); + rn_whisper_remove_abort_map(jobId); + // if (code == 0) { + // whisper_print_timings(self->ctx); + // } + return code; +} + +- (int)fullTranscribe:(int)jobId + audioData:(float *)audioData + audioDataCount:(int)audioDataCount + options:(NSDictionary *)options +{ + struct whisper_full_params params = [self getParams:options jobId:jobId]; whisper_reset_timings(self->ctx); int code = whisper_full(self->ctx, params, audioData, audioDataCount); diff --git a/src/NativeRNWhisper.ts b/src/NativeRNWhisper.ts index 1a4a699..db51d35 100644 --- a/src/NativeRNWhisper.ts +++ b/src/NativeRNWhisper.ts @@ -31,6 +31,8 @@ export type TranscribeOptions = { speedUp?: boolean, /** Initial Prompt */ prompt?: string, + /** Register onProgress event for transcribe file */ + onProgress?: boolean } export type TranscribeResult = { diff --git a/src/index.ts b/src/index.ts index 34e3b10..88e2453 100644 --- a/src/index.ts +++ b/src/index.ts @@ -24,9 +24,26 @@ if (Platform.OS === 'android') { export type { TranscribeOptions, TranscribeResult } + +const EVENT_ON_TRANSCRIBE_PROGRESS = '@RNWhisper_onTranscribeProgress' + const EVENT_ON_REALTIME_TRANSCRIBE = '@RNWhisper_onRealtimeTranscribe' const EVENT_ON_REALTIME_TRANSCRIBE_END = '@RNWhisper_onRealtimeTranscribeEnd' +export type TranscribeFileOptions = TranscribeOptions & { + /** + * Progress callback, the progress is between 0 and 100 + */ + onProgress?: (progress: number) => void +} + +export type TranscribeProgressNativeEvent = { + contextId: number + jobId: number + progress: number +} + +// NOTE: codegen missing TSIntersectionType support so we dont put it into the native spec export type TranscribeRealtimeOptions = TranscribeOptions & { /** * Realtime record max duration in seconds. @@ -91,7 +108,7 @@ export class WhisperContext { /** Transcribe audio file */ transcribe( filePath: string | number, - options: TranscribeOptions = {}, + options: TranscribeFileOptions = {}, ): { /** Stop the transcribe */ stop: () => void @@ -113,9 +130,34 @@ export class WhisperContext { } if (path.startsWith('file://')) path = path.slice(7) const jobId: number = Math.floor(Math.random() * 10000) + + const { onProgress, ...rest } = options + let progressListener: any + if (onProgress) { + progressListener = EventEmitter.addListener( + EVENT_ON_TRANSCRIBE_PROGRESS, + (evt: TranscribeProgressNativeEvent) => { + const { contextId, progress } = evt + if (contextId !== this.id || evt.jobId !== jobId) return + onProgress(progress) + }, + ) + } return { - stop: () => RNWhisper.abortTranscribe(this.id, jobId), - promise: RNWhisper.transcribeFile(this.id, jobId, path, options), + stop: async () => { + await RNWhisper.abortTranscribe(this.id, jobId) + progressListener?.remove() + }, + promise: RNWhisper.transcribeFile(this.id, jobId, path, { + ...rest, + onProgress: !!onProgress + }).then((result) => { + progressListener?.remove() + return result + }).catch((e) => { + progressListener?.remove() + throw e + }), } }