diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 63bb82d..27115d8 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -212,28 +212,7 @@ Java_com_rnwhisper_WhisperContext_vadSimple( return is_speech; } -struct callback_context { - JNIEnv *env; - jobject callback_instance; -}; - -JNIEXPORT jint JNICALL -Java_com_rnwhisper_WhisperContext_fullTranscribe( - JNIEnv *env, - jobject thiz, - jint job_id, - jlong context_ptr, - jfloatArray audio_data, - jint audio_data_len, - jobject transcribe_params, - jobject callback_instance -) { - UNUSED(thiz); - struct whisper_context *context = reinterpret_cast(context_ptr); - jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr); - - LOGI("About to create params"); - +struct whisper_full_params createFullParams(JNIEnv *env, jobject options) { struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); params.print_realtime = false; @@ -244,40 +223,72 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( int max_threads = std::thread::hardware_concurrency(); // Use 2 threads by default on 4-core devices, 4 threads on more cores int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads); - int n_threads = readablemap::getInt(env, transcribe_params, "maxThreads", default_n_threads); + int n_threads = readablemap::getInt(env, options, "maxThreads", default_n_threads); params.n_threads = n_threads > 0 ? n_threads : default_n_threads; - params.translate = readablemap::getBool(env, transcribe_params, "translate", false); - params.speed_up = readablemap::getBool(env, transcribe_params, "speedUp", false); - params.token_timestamps = readablemap::getBool(env, transcribe_params, "tokenTimestamps", false); + params.translate = readablemap::getBool(env, options, "translate", false); + params.speed_up = readablemap::getBool(env, options, "speedUp", false); + params.token_timestamps = readablemap::getBool(env, options, "tokenTimestamps", false); params.offset_ms = 0; params.no_context = true; params.single_segment = false; - int beam_size = readablemap::getInt(env, transcribe_params, "beamSize", -1); + int beam_size = readablemap::getInt(env, options, "beamSize", -1); if (beam_size > -1) { params.strategy = WHISPER_SAMPLING_BEAM_SEARCH; params.beam_search.beam_size = beam_size; } - int best_of = readablemap::getInt(env, transcribe_params, "bestOf", -1); + int best_of = readablemap::getInt(env, options, "bestOf", -1); if (best_of > -1) params.greedy.best_of = best_of; - int max_len = readablemap::getInt(env, transcribe_params, "maxLen", -1); + int max_len = readablemap::getInt(env, options, "maxLen", -1); if (max_len > -1) params.max_len = max_len; - int max_context = readablemap::getInt(env, transcribe_params, "maxContext", -1); + int max_context = readablemap::getInt(env, options, "maxContext", -1); if (max_context > -1) params.n_max_text_ctx = max_context; - int offset = readablemap::getInt(env, transcribe_params, "offset", -1); + int offset = readablemap::getInt(env, options, "offset", -1); if (offset > -1) params.offset_ms = offset; - int duration = readablemap::getInt(env, transcribe_params, "duration", -1); + int duration = readablemap::getInt(env, options, "duration", -1); if (duration > -1) params.duration_ms = duration; - int word_thold = readablemap::getInt(env, transcribe_params, "wordThold", -1); + int word_thold = readablemap::getInt(env, options, "wordThold", -1); if (word_thold > -1) params.thold_pt = word_thold; - float temperature = readablemap::getFloat(env, transcribe_params, "temperature", -1); + float temperature = readablemap::getFloat(env, options, "temperature", -1); if (temperature > -1) params.temperature = temperature; - float temperature_inc = readablemap::getFloat(env, transcribe_params, "temperatureInc", -1); + float temperature_inc = readablemap::getFloat(env, options, "temperatureInc", -1); if (temperature_inc > -1) params.temperature_inc = temperature_inc; - jstring prompt = readablemap::getString(env, transcribe_params, "prompt", nullptr); - if (prompt != nullptr) params.initial_prompt = env->GetStringUTFChars(prompt, nullptr); - jstring language = readablemap::getString(env, transcribe_params, "language", nullptr); - if (language != nullptr) params.language = env->GetStringUTFChars(language, nullptr); + jstring prompt = readablemap::getString(env, options, "prompt", nullptr); + if (prompt != nullptr) { + params.initial_prompt = env->GetStringUTFChars(prompt, nullptr); + env->ReleaseStringUTFChars(prompt, params.initial_prompt); + } + jstring language = readablemap::getString(env, options, "language", nullptr); + if (language != nullptr) { + params.language = env->GetStringUTFChars(language, nullptr); + env->ReleaseStringUTFChars(language, params.language); + } + return params; +} + +struct callback_context { + JNIEnv *env; + jobject callback_instance; +}; + +JNIEXPORT jint JNICALL +Java_com_rnwhisper_WhisperContext_fullTranscribe( + JNIEnv *env, + jobject thiz, + jint job_id, + jlong context_ptr, + jfloatArray audio_data, + jint audio_data_len, + jobject options, + jobject callback_instance +) { + UNUSED(thiz); + struct whisper_context *context = reinterpret_cast(context_ptr); + jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr); + + LOGI("About to create params"); + + whisper_full_params params = createFullParams(env, options); if (callback_instance != nullptr) { callback_context *cb_ctx = new callback_context; @@ -305,7 +316,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( params.new_segment_callback_user_data = cb_ctx; } - rnwhisper::job job = rnwhisper::job_new(job_id, params); + rnwhisper::job* job = rnwhisper::job_new(job_id, params); LOGI("About to reset timings"); whisper_reset_timings(context); @@ -316,14 +327,14 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( // whisper_print_timings(context); } env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT); - if (language != nullptr) env->ReleaseStringUTFChars(language, params.language); - if (prompt != nullptr) env->ReleaseStringUTFChars(prompt, params.initial_prompt); - - if (job.is_aborted()) code = -999; + + if (job->is_aborted()) code = -999; rnwhisper::job_remove(job_id); return code; } +// TODO: full for realtimeTranscribe with job_id (need create job first) + JNIEXPORT void JNICALL Java_com_rnwhisper_WhisperContext_abortTranscribe( JNIEnv *env, diff --git a/cpp/rn-whisper.cpp b/cpp/rn-whisper.cpp index 9e41e4c..27e7759 100644 --- a/cpp/rn-whisper.cpp +++ b/cpp/rn-whisper.cpp @@ -26,7 +26,7 @@ void job_abort_all() { } } -job job_new(int job_id, struct whisper_full_params params) { +job* job_new(int job_id, struct whisper_full_params params) { job ctx; ctx.job_id = job_id; ctx.params = params; @@ -44,7 +44,7 @@ job job_new(int job_id, struct whisper_full_params params) { params.abort_callback_user_data = &ctx; job_map[job_id] = ctx; - return ctx; + return &job_map[job_id]; } void job_remove(int job_id) { diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index 855c397..bfa6c74 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -17,7 +17,7 @@ struct job { }; void job_abort_all(); -job job_new(int job_id, struct whisper_full_params params); +job* job_new(int job_id, struct whisper_full_params params); void job_remove(int job_id); job* job_get(int job_id); diff --git a/example/src/App.tsx b/example/src/App.tsx index a56b789..04e7e35 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -215,7 +215,8 @@ export default function App() { log('Start transcribing...') const startTime = Date.now() const { stop, promise } = whisperContext.transcribe(sampleFile, { - language: 'en', + language: 'zh', + prompt: 'HELLO WORLD', maxLen: 1, tokenTimestamps: true, onProgress: (cur) => { diff --git a/ios/RNWhisperContext.h b/ios/RNWhisperContext.h index 4d6d4ad..9e458eb 100644 --- a/ios/RNWhisperContext.h +++ b/ios/RNWhisperContext.h @@ -11,10 +11,10 @@ typedef struct { __unsafe_unretained id mSelf; - - int jobId; NSDictionary* options; + struct rnwhisper::job * job; + bool isTranscribing; bool isRealtime; bool isCapturing; diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index f8d912d..5805079 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -275,7 +275,8 @@ - (void)finishRealtimeTranscribe:(RNWhisperContextRecordState*) state result:(NS audioOutputFile:state->audioOutputPath ]; } - state->transcribeHandler(state->jobId, @"end", result); + state->transcribeHandler(state->job->job_id, @"end", result); + rnwhisper::job_remove(state->job->job_id); } - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state { @@ -290,8 +291,9 @@ - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state { audioBufferF32[i] = (float)audioBufferI16[i] / 32768.0f; } CFTimeInterval timeStart = CACurrentMediaTime(); - struct whisper_full_params params = [state->mSelf getParams:state->options jobId:state->jobId]; - int code = [state->mSelf fullTranscribe:state->jobId params:params audioData:audioBufferF32 audioDataCount:state->nSamplesTranscribing]; + + int code = [state->mSelf fullTranscribe:state->job audioData:audioBufferF32 audioDataCount:state->nSamplesTranscribing]; + free(audioBufferF32); CFTimeInterval timeEnd = CACurrentMediaTime(); const float timeRecording = (float) state->nSamplesTranscribing / (float) state->dataFormat.mSampleRate; @@ -340,10 +342,10 @@ - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state { [state->mSelf finishRealtimeTranscribe:state result:result]; } else if (code == 0) { result[@"isCapturing"] = @(true); - state->transcribeHandler(state->jobId, @"transcribe", result); + state->transcribeHandler(state->job->job_id, @"transcribe", result); } else { result[@"isCapturing"] = @(true); - state->transcribeHandler(state->jobId, @"transcribe", result); + state->transcribeHandler(state->job->job_id, @"transcribe", result); } if (continueNeeded) { @@ -371,8 +373,8 @@ - (OSStatus)transcribeRealtime:(int)jobId onTranscribe:(void (^)(int, NSString *, NSDictionary *))onTranscribe { self->recordState.transcribeHandler = onTranscribe; - self->recordState.jobId = jobId; [self prepareRealtime:options]; + self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]); OSStatus status = AudioQueueNewInput( &self->recordState.dataFormat, @@ -413,9 +415,9 @@ - (void)transcribeFile:(int)jobId dispatch_async(dQueue, ^{ self->recordState.isStoppedByAction = false; self->recordState.isTranscribing = true; - self->recordState.jobId = jobId; - whisper_full_params params = [self getParams:options jobId:jobId]; + whisper_full_params params = [self createParams:options jobId:jobId]; + if (options[@"onProgress"] && [options[@"onProgress"] boolValue]) { params.progress_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) { void (^onProgress)(int) = (__bridge void (^)(int))user_data; @@ -460,8 +462,10 @@ - (void)transcribeFile:(int)jobId }; params.new_segment_callback_user_data = &user_data; } - int code = [self fullTranscribe:jobId params:params audioData:audioData audioDataCount:audioDataCount]; - self->recordState.jobId = -1; + + rnwhisper::job* job = rnwhisper::job_new(jobId, params);; + int code = [self fullTranscribe:job audioData:audioData audioDataCount:audioDataCount]; + rnwhisper::job_remove(jobId); self->recordState.isTranscribing = false; onEnd(code); }); @@ -476,8 +480,7 @@ - (void)stopAudio { } - (void)stopTranscribe:(int)jobId { - rnwhisper::job *job = rnwhisper::job_get(jobId); - if (job) job->abort(); + if (self->recordState.job) self->recordState.job->abort(); if (self->recordState.isRealtime && self->recordState.isCapturing) { [self stopAudio]; if (!self->recordState.isTranscribing) { @@ -491,13 +494,11 @@ - (void)stopTranscribe:(int)jobId { } - (void)stopCurrentTranscribe { - if (!self->recordState.jobId) { - return; - } - [self stopTranscribe:self->recordState.jobId]; + if (self->recordState.job == nullptr) return; + [self stopTranscribe:self->recordState.job->job_id]; } -- (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId { +- (struct whisper_full_params)createParams:(NSDictionary *)options jobId:(int)jobId { struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); const int n_threads = options[@"maxThreads"] != nil ? @@ -554,22 +555,16 @@ - (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId params.initial_prompt = [options[@"prompt"] UTF8String]; } - rnwhisper::job_new(jobId, params); return params; } -- (int)fullTranscribe:(int)jobId - params:(struct whisper_full_params)params +- (int)fullTranscribe:(rnwhisper::job *)job audioData:(float *)audioData audioDataCount:(int)audioDataCount { whisper_reset_timings(self->ctx); - - int code = whisper_full(self->ctx, params, audioData, audioDataCount); - - rnwhisper::job* job = rnwhisper::job_get(jobId); + int code = whisper_full(self->ctx, job->params, audioData, audioDataCount); if (job && job->is_aborted()) code = -999; - rnwhisper::job_remove(jobId); // if (code == 0) { // whisper_print_timings(self->ctx); // }