From 5d45f03a3204dfd0c2b2deb1a2b05ceea4a90989 Mon Sep 17 00:00:00 2001 From: Jhen Date: Tue, 5 Dec 2023 10:14:40 +0800 Subject: [PATCH 01/19] feat(cpp): create rnwhisper_job struct --- cpp/README.md | 2 +- cpp/rn-whisper.cpp | 49 ++++++++++++++++++++++++++-------------------- cpp/rn-whisper.h | 36 ++++++++++++++++++++++------------ 3 files changed, 52 insertions(+), 35 deletions(-) diff --git a/cpp/README.md b/cpp/README.md index c947f95..a221f9b 100644 --- a/cpp/README.md +++ b/cpp/README.md @@ -1,4 +1,4 @@ # Note -- Only `rn-whisper.h` / `rn-whisper.cpp` are the specific files for this project, others are sync from [whisper.cpp](https://github.com/ggerganov/whisper.cpp). +- Only `rn-whisper.hpp` are the specific files for this project, others are sync from [whisper.cpp](https://github.com/ggerganov/whisper.cpp). - We can update the native source by using the [bootstrap](../scripts/bootstrap.sh) script. diff --git a/cpp/rn-whisper.cpp b/cpp/rn-whisper.cpp index c27a491..5cd31a8 100644 --- a/cpp/rn-whisper.cpp +++ b/cpp/rn-whisper.cpp @@ -3,39 +3,46 @@ #include #include #include "whisper.h" +#include "rn-whisper.h" -extern "C" { +namespace rnwhisper { -std::unordered_map abort_map; +job::~job() { + fprintf(stderr, "%s: job_id: %d\n", __func__, job_id); +} -bool* rn_whisper_assign_abort_map(int job_id) { - abort_map[job_id] = false; - return &abort_map[job_id]; +bool job::is_aborted() { + return aborted; } -void rn_whisper_remove_abort_map(int job_id) { - if (abort_map.find(job_id) != abort_map.end()) { - abort_map.erase(job_id); - } +void job::abort() { + aborted = true; } -void rn_whisper_abort_transcribe(int job_id) { - if (abort_map.find(job_id) != abort_map.end()) { - abort_map[job_id] = true; +std::unordered_map job_map; + +void job_abort_all() { + for (auto it = job_map.begin(); it != job_map.end(); ++it) { + it->second.abort(); } } -bool rn_whisper_transcribe_is_aborted(int job_id) { - if (abort_map.find(job_id) != abort_map.end()) { - return abort_map[job_id]; - } - return false; +job job_new(int job_id) { + job ctx; + ctx.job_id = job_id; + job_map[job_id] = ctx; + return ctx; +} + +void job_remove(int job_id) { + job_map.erase(job_id); } -void rn_whisper_abort_all_transcribe() { - for (auto it = abort_map.begin(); it != abort_map.end(); ++it) { - it->second = true; +job* job_get(int job_id) { + if (job_map.find(job_id) != job_map.end()) { + return &job_map[job_id]; } + return nullptr; } void high_pass_filter(std::vector & data, float cutoff, float sample_rate) { @@ -51,7 +58,7 @@ void high_pass_filter(std::vector & data, float cutoff, float sample_rate } } -bool rn_whisper_vad_simple(std::vector & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) { +bool vad_simple(std::vector & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) { const int n_samples = pcmf32.size(); const int n_samples_last = (sample_rate * last_ms) / 1000; diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index 4f65158..1d45829 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -1,17 +1,27 @@ +#ifndef RNWHISPER_H +#define RNWHISPER_H -#ifdef __cplusplus #include -#include -extern "C" { -#endif +#include -bool* rn_whisper_assign_abort_map(int job_id); -void rn_whisper_remove_abort_map(int job_id); -void rn_whisper_abort_transcribe(int job_id); -bool rn_whisper_transcribe_is_aborted(int job_id); -void rn_whisper_abort_all_transcribe(); -bool rn_whisper_vad_simple(std::vector & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose); +namespace rnwhisper { -#ifdef __cplusplus -} -#endif +struct job { + int job_id; + bool aborted = false; + ~job(); + bool is_aborted(); + void abort(); +}; + +void job_abort_all(); +job job_new(int job_id); +void job_remove(int job_id); +job* job_get(int job_id); + +void high_pass_filter(std::vector & data, float cutoff, float sample_rate); +bool vad_simple(std::vector & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose); + +} // namespace rnwhisper + +#endif // RNWHISPER_H \ No newline at end of file From 73ae08c27bb51ec29e53bd710f15ac2b7d6368f5 Mon Sep 17 00:00:00 2001 From: Jhen Date: Tue, 5 Dec 2023 11:28:28 +0800 Subject: [PATCH 02/19] feat(ios): update rn-whisper api --- ios/RNWhisper.mm | 14 +++++++------- ios/RNWhisperContext.mm | 27 ++++++++++++++------------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/ios/RNWhisper.mm b/ios/RNWhisper.mm index 6aec9c7..f48a34b 100644 --- a/ios/RNWhisper.mm +++ b/ios/RNWhisper.mm @@ -142,9 +142,9 @@ - (NSArray *)supportedEvents { audioDataCount:count options:options onProgress: ^(int progress) { - if (rn_whisper_transcribe_is_aborted(jobId)) { - return; - } + rnwhisper::job* job = rnwhisper::job_get(jobId); + if (job && job->is_aborted()) return; + dispatch_async(dispatch_get_main_queue(), ^{ [self sendEventWithName:@"@RNWhisper_onTranscribeProgress" body:@{ @@ -156,9 +156,9 @@ - (NSArray *)supportedEvents { }); } onNewSegments: ^(NSDictionary *result) { - if (rn_whisper_transcribe_is_aborted(jobId)) { - return; - } + rnwhisper::job* job = rnwhisper::job_get(jobId); + if (job && job->is_aborted()) return; + dispatch_async(dispatch_get_main_queue(), ^{ [self sendEventWithName:@"@RNWhisper_onTranscribeNewSegments" body:@{ @@ -279,7 +279,7 @@ - (void)invalidate { [context invalidate]; } - rn_whisper_abort_all_transcribe(); // graceful abort + rnwhisper::job_abort_all(); // graceful abort [contexts removeAllObjects]; contexts = nil; diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index cd4e9dd..c4c1c62 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -169,7 +169,7 @@ bool vad(RNWhisperContextRecordState *state, int16_t* audioBufferI16, int nSampl for (int i = 0; i < sampleSize; i++) { audioBufferF32Vec[i] = (float)audioBufferI16[i + start] / 32768.0f; } - isSpeech = rn_whisper_vad_simple(audioBufferF32Vec, WHISPER_SAMPLE_RATE, 1000, state->vadThold, state->vadFreqThold, false); + isSpeech = rnwhisper::vad_simple(audioBufferF32Vec, WHISPER_SAMPLE_RATE, 1000, state->vadThold, state->vadFreqThold, false); NSLog(@"[RNWhisper] VAD result: %d", isSpeech); } else { isSpeech = false; @@ -476,7 +476,8 @@ - (void)stopAudio { } - (void)stopTranscribe:(int)jobId { - rn_whisper_abort_transcribe(jobId); + rnwhisper::job *job = rnwhisper::job_get(jobId); + if (job) job->abort(); if (self->recordState.isRealtime && self->recordState.isCapturing) { [self stopAudio]; if (!self->recordState.isTranscribing) { @@ -556,17 +557,17 @@ - (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId } // abort handler - bool *abort_ptr = rn_whisper_assign_abort_map(jobId); + rnwhisper::job job = rnwhisper::job_new(jobId); params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - bool is_aborted = *(bool*)user_data; - return !is_aborted; + rnwhisper::job job = *(rnwhisper::job*)user_data; + return !job.is_aborted(); }; - params.encoder_begin_callback_user_data = abort_ptr; + params.encoder_begin_callback_user_data = &job; params.abort_callback = [](void * user_data) { - bool is_aborted = *(bool*)user_data; - return is_aborted; + rnwhisper::job job = *(rnwhisper::job*)user_data; + return job.is_aborted(); }; - params.abort_callback_user_data = abort_ptr; + params.abort_callback_user_data = &job; return params; } @@ -579,10 +580,10 @@ - (int)fullTranscribe:(int)jobId whisper_reset_timings(self->ctx); int code = whisper_full(self->ctx, params, audioData, audioDataCount); - if (rn_whisper_transcribe_is_aborted(jobId)) { - code = -999; - } - rn_whisper_remove_abort_map(jobId); + + rnwhisper::job* job = rnwhisper::job_get(jobId); + if (job && job->is_aborted()) code = -999; + rnwhisper::job_remove(jobId); // if (code == 0) { // whisper_print_timings(self->ctx); // } From 8aebcfcb1adde58b432ccb4bee00d2bc2de9dd9d Mon Sep 17 00:00:00 2001 From: Jhen Date: Tue, 5 Dec 2023 11:35:11 +0800 Subject: [PATCH 03/19] feat(android): update rn-whisper api --- android/src/main/jni.cpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 7222c49..450ca34 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -207,7 +207,7 @@ Java_com_rnwhisper_WhisperContext_vadSimple( for (int i = 0; i < audio_data_len; i++) { samples[i] = audio_data_arr[i]; } - bool is_speech = rn_whisper_vad_simple(samples, WHISPER_SAMPLE_RATE, 1000, vad_thold, vad_freq_thold, false); + bool is_speech = rnwhisper::vad_simple(samples, WHISPER_SAMPLE_RATE, 1000, vad_thold, vad_freq_thold, false); env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT); return is_speech; } @@ -280,17 +280,17 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( if (language != nullptr) params.language = env->GetStringUTFChars(language, nullptr); // abort handlers - bool* abort_ptr = rn_whisper_assign_abort_map(job_id); + rnwhisper::job job = rnwhisper::job_new(job_id); params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - bool is_aborted = *(bool*)user_data; - return !is_aborted; + rnwhisper::job job = *(rnwhisper::job*)user_data; + return !job.is_aborted(); }; - params.encoder_begin_callback_user_data = abort_ptr; + params.encoder_begin_callback_user_data = &job; params.abort_callback = [](void * user_data) { - bool is_aborted = *(bool*)user_data; - return is_aborted; + rnwhisper::job job = *(rnwhisper::job*)user_data; + return job.is_aborted(); }; - params.abort_callback_user_data = abort_ptr; + params.abort_callback_user_data = &job; if (callback_instance != nullptr) { callback_context *cb_ctx = new callback_context; @@ -329,10 +329,9 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT); if (language != nullptr) env->ReleaseStringUTFChars(language, params.language); if (prompt != nullptr) env->ReleaseStringUTFChars(prompt, params.initial_prompt); - if (rn_whisper_transcribe_is_aborted(job_id)) { - code = -999; - } - rn_whisper_remove_abort_map(job_id); + + if (job.is_aborted()) code = -999; + rnwhisper::job_remove(job_id); return code; } @@ -343,7 +342,8 @@ Java_com_rnwhisper_WhisperContext_abortTranscribe( jint job_id ) { UNUSED(thiz); - rn_whisper_abort_transcribe(job_id); + rnwhisper::job *job = rnwhisper::job_get(job_id); + if (job) job->abort(); } JNIEXPORT void JNICALL @@ -352,7 +352,7 @@ Java_com_rnwhisper_WhisperContext_abortAllTranscribe( jobject thiz ) { UNUSED(thiz); - rn_whisper_abort_all_transcribe(); + rnwhisper::job_abort_all(); } JNIEXPORT jint JNICALL From 36244f41fb59276f4123a44f34fc95b447814e17 Mon Sep 17 00:00:00 2001 From: Jhen Date: Tue, 5 Dec 2023 12:07:35 +0800 Subject: [PATCH 04/19] fix: user_data should not deref --- android/src/main/jni.cpp | 8 ++++---- ios/RNWhisperContext.mm | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 450ca34..81b82e2 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -282,13 +282,13 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( // abort handlers rnwhisper::job job = rnwhisper::job_new(job_id); params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - rnwhisper::job job = *(rnwhisper::job*)user_data; - return !job.is_aborted(); + rnwhisper::job *job = (rnwhisper::job*)user_data; + return !job->is_aborted(); }; params.encoder_begin_callback_user_data = &job; params.abort_callback = [](void * user_data) { - rnwhisper::job job = *(rnwhisper::job*)user_data; - return job.is_aborted(); + rnwhisper::job *job = (rnwhisper::job*)user_data; + return job->is_aborted(); }; params.abort_callback_user_data = &job; diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index c4c1c62..8e0606f 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -559,13 +559,13 @@ - (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId // abort handler rnwhisper::job job = rnwhisper::job_new(jobId); params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - rnwhisper::job job = *(rnwhisper::job*)user_data; - return !job.is_aborted(); + rnwhisper::job *job = (rnwhisper::job*)user_data; + return !job->is_aborted(); }; params.encoder_begin_callback_user_data = &job; params.abort_callback = [](void * user_data) { - rnwhisper::job job = *(rnwhisper::job*)user_data; - return job.is_aborted(); + rnwhisper::job *job = (rnwhisper::job*)user_data; + return job->is_aborted(); }; params.abort_callback_user_data = &job; From 5c48c743c80efafa52d560ab0af4247341f3e161 Mon Sep 17 00:00:00 2001 From: Jhen Date: Tue, 5 Dec 2023 12:14:31 +0800 Subject: [PATCH 05/19] chore: revert unnecessary change --- cpp/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/README.md b/cpp/README.md index a221f9b..c947f95 100644 --- a/cpp/README.md +++ b/cpp/README.md @@ -1,4 +1,4 @@ # Note -- Only `rn-whisper.hpp` are the specific files for this project, others are sync from [whisper.cpp](https://github.com/ggerganov/whisper.cpp). +- Only `rn-whisper.h` / `rn-whisper.cpp` are the specific files for this project, others are sync from [whisper.cpp](https://github.com/ggerganov/whisper.cpp). - We can update the native source by using the [bootstrap](../scripts/bootstrap.sh) script. From 8ba887dc2ddcfa75830dbfbe9627aeb769b20949 Mon Sep 17 00:00:00 2001 From: Jhen Date: Wed, 6 Dec 2023 12:10:31 +0800 Subject: [PATCH 06/19] feat(cpp): move abort handler --- android/src/main/jni.cpp | 15 ++------------- cpp/rn-whisper.cpp | 17 +++++++++++++++-- cpp/rn-whisper.h | 4 +++- ios/RNWhisperContext.mm | 16 +--------------- 4 files changed, 21 insertions(+), 31 deletions(-) diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 81b82e2..63bb82d 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -279,19 +279,6 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( jstring language = readablemap::getString(env, transcribe_params, "language", nullptr); if (language != nullptr) params.language = env->GetStringUTFChars(language, nullptr); - // abort handlers - rnwhisper::job job = rnwhisper::job_new(job_id); - params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - rnwhisper::job *job = (rnwhisper::job*)user_data; - return !job->is_aborted(); - }; - params.encoder_begin_callback_user_data = &job; - params.abort_callback = [](void * user_data) { - rnwhisper::job *job = (rnwhisper::job*)user_data; - return job->is_aborted(); - }; - params.abort_callback_user_data = &job; - if (callback_instance != nullptr) { callback_context *cb_ctx = new callback_context; cb_ctx->env = env; @@ -318,6 +305,8 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( params.new_segment_callback_user_data = cb_ctx; } + rnwhisper::job job = rnwhisper::job_new(job_id, params); + LOGI("About to reset timings"); whisper_reset_timings(context); diff --git a/cpp/rn-whisper.cpp b/cpp/rn-whisper.cpp index 5cd31a8..9e41e4c 100644 --- a/cpp/rn-whisper.cpp +++ b/cpp/rn-whisper.cpp @@ -2,7 +2,6 @@ #include #include #include -#include "whisper.h" #include "rn-whisper.h" namespace rnwhisper { @@ -27,9 +26,23 @@ void job_abort_all() { } } -job job_new(int job_id) { +job job_new(int job_id, struct whisper_full_params params) { job ctx; ctx.job_id = job_id; + ctx.params = params; + + // Abort handler + params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { + job *j = (job*)user_data; + return !j->is_aborted(); + }; + params.encoder_begin_callback_user_data = &ctx; + params.abort_callback = [](void * user_data) { + job *j = (job*)user_data; + return j->is_aborted(); + }; + params.abort_callback_user_data = &ctx; + job_map[job_id] = ctx; return ctx; } diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index 1d45829..855c397 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -3,11 +3,13 @@ #include #include +#include "whisper.h" namespace rnwhisper { struct job { int job_id; + whisper_full_params params; bool aborted = false; ~job(); bool is_aborted(); @@ -15,7 +17,7 @@ struct job { }; void job_abort_all(); -job job_new(int job_id); +job job_new(int job_id, struct whisper_full_params params); void job_remove(int job_id); job* job_get(int job_id); diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index 8e0606f..f8d912d 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -535,7 +535,6 @@ - (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId if (options[@"maxContext"] != nil) { params.n_max_text_ctx = [options[@"maxContext"] intValue]; } - if (options[@"offset"] != nil) { params.offset_ms = [options[@"offset"] intValue]; } @@ -551,24 +550,11 @@ - (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId if (options[@"temperatureInc"] != nil) { params.temperature_inc = [options[@"temperature_inc"] floatValue]; } - if (options[@"prompt"] != nil) { params.initial_prompt = [options[@"prompt"] UTF8String]; } - // abort handler - rnwhisper::job job = rnwhisper::job_new(jobId); - params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - rnwhisper::job *job = (rnwhisper::job*)user_data; - return !job->is_aborted(); - }; - params.encoder_begin_callback_user_data = &job; - params.abort_callback = [](void * user_data) { - rnwhisper::job *job = (rnwhisper::job*)user_data; - return job->is_aborted(); - }; - params.abort_callback_user_data = &job; - + rnwhisper::job_new(jobId, params); return params; } From 6fb88ed19aca28be0c6269fb976e5d35d4d85197 Mon Sep 17 00:00:00 2001 From: Jhen Date: Wed, 6 Dec 2023 13:38:53 +0800 Subject: [PATCH 07/19] feat(ios): store job in RNWhisperContext --- android/src/main/jni.cpp | 99 ++++++++++++++++++++++------------------ cpp/rn-whisper.cpp | 4 +- cpp/rn-whisper.h | 2 +- example/src/App.tsx | 3 +- ios/RNWhisperContext.h | 4 +- ios/RNWhisperContext.mm | 45 ++++++++---------- 6 files changed, 82 insertions(+), 75 deletions(-) diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 63bb82d..27115d8 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -212,28 +212,7 @@ Java_com_rnwhisper_WhisperContext_vadSimple( return is_speech; } -struct callback_context { - JNIEnv *env; - jobject callback_instance; -}; - -JNIEXPORT jint JNICALL -Java_com_rnwhisper_WhisperContext_fullTranscribe( - JNIEnv *env, - jobject thiz, - jint job_id, - jlong context_ptr, - jfloatArray audio_data, - jint audio_data_len, - jobject transcribe_params, - jobject callback_instance -) { - UNUSED(thiz); - struct whisper_context *context = reinterpret_cast(context_ptr); - jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr); - - LOGI("About to create params"); - +struct whisper_full_params createFullParams(JNIEnv *env, jobject options) { struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); params.print_realtime = false; @@ -244,40 +223,72 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( int max_threads = std::thread::hardware_concurrency(); // Use 2 threads by default on 4-core devices, 4 threads on more cores int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads); - int n_threads = readablemap::getInt(env, transcribe_params, "maxThreads", default_n_threads); + int n_threads = readablemap::getInt(env, options, "maxThreads", default_n_threads); params.n_threads = n_threads > 0 ? n_threads : default_n_threads; - params.translate = readablemap::getBool(env, transcribe_params, "translate", false); - params.speed_up = readablemap::getBool(env, transcribe_params, "speedUp", false); - params.token_timestamps = readablemap::getBool(env, transcribe_params, "tokenTimestamps", false); + params.translate = readablemap::getBool(env, options, "translate", false); + params.speed_up = readablemap::getBool(env, options, "speedUp", false); + params.token_timestamps = readablemap::getBool(env, options, "tokenTimestamps", false); params.offset_ms = 0; params.no_context = true; params.single_segment = false; - int beam_size = readablemap::getInt(env, transcribe_params, "beamSize", -1); + int beam_size = readablemap::getInt(env, options, "beamSize", -1); if (beam_size > -1) { params.strategy = WHISPER_SAMPLING_BEAM_SEARCH; params.beam_search.beam_size = beam_size; } - int best_of = readablemap::getInt(env, transcribe_params, "bestOf", -1); + int best_of = readablemap::getInt(env, options, "bestOf", -1); if (best_of > -1) params.greedy.best_of = best_of; - int max_len = readablemap::getInt(env, transcribe_params, "maxLen", -1); + int max_len = readablemap::getInt(env, options, "maxLen", -1); if (max_len > -1) params.max_len = max_len; - int max_context = readablemap::getInt(env, transcribe_params, "maxContext", -1); + int max_context = readablemap::getInt(env, options, "maxContext", -1); if (max_context > -1) params.n_max_text_ctx = max_context; - int offset = readablemap::getInt(env, transcribe_params, "offset", -1); + int offset = readablemap::getInt(env, options, "offset", -1); if (offset > -1) params.offset_ms = offset; - int duration = readablemap::getInt(env, transcribe_params, "duration", -1); + int duration = readablemap::getInt(env, options, "duration", -1); if (duration > -1) params.duration_ms = duration; - int word_thold = readablemap::getInt(env, transcribe_params, "wordThold", -1); + int word_thold = readablemap::getInt(env, options, "wordThold", -1); if (word_thold > -1) params.thold_pt = word_thold; - float temperature = readablemap::getFloat(env, transcribe_params, "temperature", -1); + float temperature = readablemap::getFloat(env, options, "temperature", -1); if (temperature > -1) params.temperature = temperature; - float temperature_inc = readablemap::getFloat(env, transcribe_params, "temperatureInc", -1); + float temperature_inc = readablemap::getFloat(env, options, "temperatureInc", -1); if (temperature_inc > -1) params.temperature_inc = temperature_inc; - jstring prompt = readablemap::getString(env, transcribe_params, "prompt", nullptr); - if (prompt != nullptr) params.initial_prompt = env->GetStringUTFChars(prompt, nullptr); - jstring language = readablemap::getString(env, transcribe_params, "language", nullptr); - if (language != nullptr) params.language = env->GetStringUTFChars(language, nullptr); + jstring prompt = readablemap::getString(env, options, "prompt", nullptr); + if (prompt != nullptr) { + params.initial_prompt = env->GetStringUTFChars(prompt, nullptr); + env->ReleaseStringUTFChars(prompt, params.initial_prompt); + } + jstring language = readablemap::getString(env, options, "language", nullptr); + if (language != nullptr) { + params.language = env->GetStringUTFChars(language, nullptr); + env->ReleaseStringUTFChars(language, params.language); + } + return params; +} + +struct callback_context { + JNIEnv *env; + jobject callback_instance; +}; + +JNIEXPORT jint JNICALL +Java_com_rnwhisper_WhisperContext_fullTranscribe( + JNIEnv *env, + jobject thiz, + jint job_id, + jlong context_ptr, + jfloatArray audio_data, + jint audio_data_len, + jobject options, + jobject callback_instance +) { + UNUSED(thiz); + struct whisper_context *context = reinterpret_cast(context_ptr); + jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr); + + LOGI("About to create params"); + + whisper_full_params params = createFullParams(env, options); if (callback_instance != nullptr) { callback_context *cb_ctx = new callback_context; @@ -305,7 +316,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( params.new_segment_callback_user_data = cb_ctx; } - rnwhisper::job job = rnwhisper::job_new(job_id, params); + rnwhisper::job* job = rnwhisper::job_new(job_id, params); LOGI("About to reset timings"); whisper_reset_timings(context); @@ -316,14 +327,14 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( // whisper_print_timings(context); } env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT); - if (language != nullptr) env->ReleaseStringUTFChars(language, params.language); - if (prompt != nullptr) env->ReleaseStringUTFChars(prompt, params.initial_prompt); - - if (job.is_aborted()) code = -999; + + if (job->is_aborted()) code = -999; rnwhisper::job_remove(job_id); return code; } +// TODO: full for realtimeTranscribe with job_id (need create job first) + JNIEXPORT void JNICALL Java_com_rnwhisper_WhisperContext_abortTranscribe( JNIEnv *env, diff --git a/cpp/rn-whisper.cpp b/cpp/rn-whisper.cpp index 9e41e4c..27e7759 100644 --- a/cpp/rn-whisper.cpp +++ b/cpp/rn-whisper.cpp @@ -26,7 +26,7 @@ void job_abort_all() { } } -job job_new(int job_id, struct whisper_full_params params) { +job* job_new(int job_id, struct whisper_full_params params) { job ctx; ctx.job_id = job_id; ctx.params = params; @@ -44,7 +44,7 @@ job job_new(int job_id, struct whisper_full_params params) { params.abort_callback_user_data = &ctx; job_map[job_id] = ctx; - return ctx; + return &job_map[job_id]; } void job_remove(int job_id) { diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index 855c397..bfa6c74 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -17,7 +17,7 @@ struct job { }; void job_abort_all(); -job job_new(int job_id, struct whisper_full_params params); +job* job_new(int job_id, struct whisper_full_params params); void job_remove(int job_id); job* job_get(int job_id); diff --git a/example/src/App.tsx b/example/src/App.tsx index a56b789..04e7e35 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -215,7 +215,8 @@ export default function App() { log('Start transcribing...') const startTime = Date.now() const { stop, promise } = whisperContext.transcribe(sampleFile, { - language: 'en', + language: 'zh', + prompt: 'HELLO WORLD', maxLen: 1, tokenTimestamps: true, onProgress: (cur) => { diff --git a/ios/RNWhisperContext.h b/ios/RNWhisperContext.h index 4d6d4ad..9e458eb 100644 --- a/ios/RNWhisperContext.h +++ b/ios/RNWhisperContext.h @@ -11,10 +11,10 @@ typedef struct { __unsafe_unretained id mSelf; - - int jobId; NSDictionary* options; + struct rnwhisper::job * job; + bool isTranscribing; bool isRealtime; bool isCapturing; diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index f8d912d..5805079 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -275,7 +275,8 @@ - (void)finishRealtimeTranscribe:(RNWhisperContextRecordState*) state result:(NS audioOutputFile:state->audioOutputPath ]; } - state->transcribeHandler(state->jobId, @"end", result); + state->transcribeHandler(state->job->job_id, @"end", result); + rnwhisper::job_remove(state->job->job_id); } - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state { @@ -290,8 +291,9 @@ - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state { audioBufferF32[i] = (float)audioBufferI16[i] / 32768.0f; } CFTimeInterval timeStart = CACurrentMediaTime(); - struct whisper_full_params params = [state->mSelf getParams:state->options jobId:state->jobId]; - int code = [state->mSelf fullTranscribe:state->jobId params:params audioData:audioBufferF32 audioDataCount:state->nSamplesTranscribing]; + + int code = [state->mSelf fullTranscribe:state->job audioData:audioBufferF32 audioDataCount:state->nSamplesTranscribing]; + free(audioBufferF32); CFTimeInterval timeEnd = CACurrentMediaTime(); const float timeRecording = (float) state->nSamplesTranscribing / (float) state->dataFormat.mSampleRate; @@ -340,10 +342,10 @@ - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state { [state->mSelf finishRealtimeTranscribe:state result:result]; } else if (code == 0) { result[@"isCapturing"] = @(true); - state->transcribeHandler(state->jobId, @"transcribe", result); + state->transcribeHandler(state->job->job_id, @"transcribe", result); } else { result[@"isCapturing"] = @(true); - state->transcribeHandler(state->jobId, @"transcribe", result); + state->transcribeHandler(state->job->job_id, @"transcribe", result); } if (continueNeeded) { @@ -371,8 +373,8 @@ - (OSStatus)transcribeRealtime:(int)jobId onTranscribe:(void (^)(int, NSString *, NSDictionary *))onTranscribe { self->recordState.transcribeHandler = onTranscribe; - self->recordState.jobId = jobId; [self prepareRealtime:options]; + self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]); OSStatus status = AudioQueueNewInput( &self->recordState.dataFormat, @@ -413,9 +415,9 @@ - (void)transcribeFile:(int)jobId dispatch_async(dQueue, ^{ self->recordState.isStoppedByAction = false; self->recordState.isTranscribing = true; - self->recordState.jobId = jobId; - whisper_full_params params = [self getParams:options jobId:jobId]; + whisper_full_params params = [self createParams:options jobId:jobId]; + if (options[@"onProgress"] && [options[@"onProgress"] boolValue]) { params.progress_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) { void (^onProgress)(int) = (__bridge void (^)(int))user_data; @@ -460,8 +462,10 @@ - (void)transcribeFile:(int)jobId }; params.new_segment_callback_user_data = &user_data; } - int code = [self fullTranscribe:jobId params:params audioData:audioData audioDataCount:audioDataCount]; - self->recordState.jobId = -1; + + rnwhisper::job* job = rnwhisper::job_new(jobId, params);; + int code = [self fullTranscribe:job audioData:audioData audioDataCount:audioDataCount]; + rnwhisper::job_remove(jobId); self->recordState.isTranscribing = false; onEnd(code); }); @@ -476,8 +480,7 @@ - (void)stopAudio { } - (void)stopTranscribe:(int)jobId { - rnwhisper::job *job = rnwhisper::job_get(jobId); - if (job) job->abort(); + if (self->recordState.job) self->recordState.job->abort(); if (self->recordState.isRealtime && self->recordState.isCapturing) { [self stopAudio]; if (!self->recordState.isTranscribing) { @@ -491,13 +494,11 @@ - (void)stopTranscribe:(int)jobId { } - (void)stopCurrentTranscribe { - if (!self->recordState.jobId) { - return; - } - [self stopTranscribe:self->recordState.jobId]; + if (self->recordState.job == nullptr) return; + [self stopTranscribe:self->recordState.job->job_id]; } -- (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId { +- (struct whisper_full_params)createParams:(NSDictionary *)options jobId:(int)jobId { struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); const int n_threads = options[@"maxThreads"] != nil ? @@ -554,22 +555,16 @@ - (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId params.initial_prompt = [options[@"prompt"] UTF8String]; } - rnwhisper::job_new(jobId, params); return params; } -- (int)fullTranscribe:(int)jobId - params:(struct whisper_full_params)params +- (int)fullTranscribe:(rnwhisper::job *)job audioData:(float *)audioData audioDataCount:(int)audioDataCount { whisper_reset_timings(self->ctx); - - int code = whisper_full(self->ctx, params, audioData, audioDataCount); - - rnwhisper::job* job = rnwhisper::job_get(jobId); + int code = whisper_full(self->ctx, job->params, audioData, audioDataCount); if (job && job->is_aborted()) code = -999; - rnwhisper::job_remove(jobId); // if (code == 0) { // whisper_print_timings(self->ctx); // } From 50f87139fc62653763487b55bad9ed2f4bdad3fc Mon Sep 17 00:00:00 2001 From: Jhen Date: Thu, 7 Dec 2023 18:16:04 +0800 Subject: [PATCH 08/19] feat(ios): move vad params --- cpp/rn-whisper.cpp | 126 +++++++++++++++++++++++----------------- cpp/rn-whisper.h | 18 ++++-- ios/RNWhisperContext.h | 5 -- ios/RNWhisperContext.mm | 35 ++++------- 4 files changed, 98 insertions(+), 86 deletions(-) diff --git a/cpp/rn-whisper.cpp b/cpp/rn-whisper.cpp index 27e7759..d60c65d 100644 --- a/cpp/rn-whisper.cpp +++ b/cpp/rn-whisper.cpp @@ -6,58 +6,6 @@ namespace rnwhisper { -job::~job() { - fprintf(stderr, "%s: job_id: %d\n", __func__, job_id); -} - -bool job::is_aborted() { - return aborted; -} - -void job::abort() { - aborted = true; -} - -std::unordered_map job_map; - -void job_abort_all() { - for (auto it = job_map.begin(); it != job_map.end(); ++it) { - it->second.abort(); - } -} - -job* job_new(int job_id, struct whisper_full_params params) { - job ctx; - ctx.job_id = job_id; - ctx.params = params; - - // Abort handler - params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - job *j = (job*)user_data; - return !j->is_aborted(); - }; - params.encoder_begin_callback_user_data = &ctx; - params.abort_callback = [](void * user_data) { - job *j = (job*)user_data; - return j->is_aborted(); - }; - params.abort_callback_user_data = &ctx; - - job_map[job_id] = ctx; - return &job_map[job_id]; -} - -void job_remove(int job_id) { - job_map.erase(job_id); -} - -job* job_get(int job_id) { - if (job_map.find(job_id) != job_map.end()) { - return &job_map[job_id]; - } - return nullptr; -} - void high_pass_filter(std::vector & data, float cutoff, float sample_rate) { const float rc = 1.0f / (2.0f * M_PI * cutoff); const float dt = 1.0f / sample_rate; @@ -71,7 +19,7 @@ void high_pass_filter(std::vector & data, float cutoff, float sample_rate } } -bool vad_simple(std::vector & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) { +bool vad_simple_impl(std::vector & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) { const int n_samples = pcmf32.size(); const int n_samples_last = (sample_rate * last_ms) / 1000; @@ -109,4 +57,76 @@ bool vad_simple(std::vector & pcmf32, int sample_rate, int last_ms, float return true; } +job::~job() { + fprintf(stderr, "%s: job_id: %d\n", __func__, job_id); +} + +void job::set_vad_params(vad_params params) { + vad = params; + if (vad.vad_ms < 2000) vad.vad_ms = 2000; +} + +bool job::vad_simple(short* pcm, int n_samples, int n) { + if (!vad.use_vad) return true; + + int sample_size = (int) (WHISPER_SAMPLE_RATE * vad.vad_ms / 1000); + if (n_samples + n > sample_size) { + int start = n_samples + n - sample_size; + std::vector pcmf32(sample_size); + for (int i = 0; i < sample_size; i++) { + pcmf32[i] = (float)pcm[i + start] / 32768.0f; + } + return vad_simple_impl(pcmf32, WHISPER_SAMPLE_RATE, vad.last_ms, vad.vad_thold, vad.freq_thold, vad.verbose); + } + return false; +} + +bool job::is_aborted() { + return aborted; +} + +void job::abort() { + aborted = true; +} + +std::unordered_map job_map; + +void job_abort_all() { + for (auto it = job_map.begin(); it != job_map.end(); ++it) { + it->second.abort(); + } +} + +job* job_new(int job_id, struct whisper_full_params params) { + job ctx; + ctx.job_id = job_id; + ctx.params = params; + + // Abort handler + params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { + job *j = (job*)user_data; + return !j->is_aborted(); + }; + params.encoder_begin_callback_user_data = &ctx; + params.abort_callback = [](void * user_data) { + job *j = (job*)user_data; + return j->is_aborted(); + }; + params.abort_callback_user_data = &ctx; + + job_map[job_id] = ctx; + return &job_map[job_id]; +} + +job* job_get(int job_id) { + if (job_map.find(job_id) != job_map.end()) { + return &job_map[job_id]; + } + return nullptr; +} + +void job_remove(int job_id) { + job_map.erase(job_id); +} + } diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index bfa6c74..ef8177b 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -7,11 +7,24 @@ namespace rnwhisper { +struct vad_params { + bool use_vad = false; + float vad_thold = 0.1; + float freq_thold = 0.1; + int vad_ms = 2000; + int last_ms = 1000; + bool verbose = false; +}; + struct job { int job_id; - whisper_full_params params; bool aborted = false; + whisper_full_params params; + vad_params vad; // Realtime transcription only + ~job(); + void set_vad_params(vad_params vad); + bool vad_simple(short* pcm, int n_samples, int n); bool is_aborted(); void abort(); }; @@ -21,9 +34,6 @@ job* job_new(int job_id, struct whisper_full_params params); void job_remove(int job_id); job* job_get(int job_id); -void high_pass_filter(std::vector & data, float cutoff, float sample_rate); -bool vad_simple(std::vector & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose); - } // namespace rnwhisper #endif // RNWHISPER_H \ No newline at end of file diff --git a/ios/RNWhisperContext.h b/ios/RNWhisperContext.h index 9e458eb..19b2400 100644 --- a/ios/RNWhisperContext.h +++ b/ios/RNWhisperContext.h @@ -29,11 +29,6 @@ typedef struct { int audioSliceSec; NSString* audioOutputPath; - bool useVad; - int vadMs; - float vadThold; - float vadFreqThold; - AudioQueueRef queue; AudioStreamBasicDescription dataFormat; AudioQueueBufferRef buffers[NUM_BUFFERS]; diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index 5805079..ea5974e 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -117,13 +117,6 @@ - (void)prepareRealtime:(NSDictionary *)options { self->recordState.audioOutputPath = options[@"audioOutputPath"]; - self->recordState.useVad = options[@"useVad"] != nil ? [options[@"useVad"] boolValue] : false; - self->recordState.vadMs = options[@"vadMs"] != nil ? [options[@"vadMs"] intValue] : 2000; - if (self->recordState.vadMs < 2000) self->recordState.vadMs = 2000; - - self->recordState.vadThold = options[@"vadThold"] != nil ? [options[@"vadThold"] floatValue] : 0.6f; - self->recordState.vadFreqThold = options[@"vadFreqThold"] != nil ? [options[@"vadFreqThold"] floatValue] : 100.0f; - self->recordState.audioSliceSec = audioSliceSec; self->recordState.isUseSlices = audioSliceSec < maxAudioSec; @@ -158,24 +151,10 @@ - (void)freeBufferIfNeeded { } } -bool vad(RNWhisperContextRecordState *state, int16_t* audioBufferI16, int nSamples, int n) +bool vad(RNWhisperContextRecordState *state, short* pcm, int nSamples, int n) { - bool isSpeech = true; - if (!state->isTranscribing && state->useVad) { - int sampleSize = (int) (WHISPER_SAMPLE_RATE * state->vadMs / 1000); - if (nSamples + n > sampleSize) { - int start = nSamples + n - sampleSize; - std::vector audioBufferF32Vec(sampleSize); - for (int i = 0; i < sampleSize; i++) { - audioBufferF32Vec[i] = (float)audioBufferI16[i + start] / 32768.0f; - } - isSpeech = rnwhisper::vad_simple(audioBufferF32Vec, WHISPER_SAMPLE_RATE, 1000, state->vadThold, state->vadFreqThold, false); - NSLog(@"[RNWhisper] VAD result: %d", isSpeech); - } else { - isSpeech = false; - } - } - return isSpeech; + if (state->isTranscribing) return true; + return state->job->vad_simple(pcm, nSamples, n); } void AudioInputCallback(void * inUserData, @@ -376,6 +355,14 @@ - (OSStatus)transcribeRealtime:(int)jobId [self prepareRealtime:options]; self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]); + rnwhisper::vad_params vad = { + .use_vad = options[@"useVad"] != nil ? [options[@"useVad"] boolValue] : false, + .vad_ms = options[@"vadMs"] != nil ? [options[@"vadMs"] intValue] : 2000, + .vad_thold = options[@"vadThold"] != nil ? [options[@"vadThold"] floatValue] : 0.6f, + .freq_thold = options[@"vadFreqThold"] != nil ? [options[@"vadFreqThold"] floatValue] : 100.0f + }; + self->recordState.job->set_vad_params(vad); + OSStatus status = AudioQueueNewInput( &self->recordState.dataFormat, AudioInputCallback, From 2350ede563cf2ad199600132ae0bb78fb8323bee Mon Sep 17 00:00:00 2001 From: Jhen Date: Fri, 8 Dec 2023 12:39:19 +0800 Subject: [PATCH 09/19] feat(android): create createRealtimeTranscribeJob and update vadSimple jni methods --- .../java/com/rnwhisper/WhisperContext.java | 66 ++++++------- android/src/main/jni-utils.h | 10 +- android/src/main/jni.cpp | 95 ++++++++++++++----- cpp/rn-whisper.h | 4 +- ios/RNWhisperContext.mm | 5 +- 5 files changed, 110 insertions(+), 70 deletions(-) diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index bf1c7ee..06a60f9 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -80,25 +80,8 @@ private void rewind() { } private boolean vad(ReadableMap options, short[] shortBuffer, int nSamples, int n) { - boolean isSpeech = true; - if (!isTranscribing && options.hasKey("useVad") && options.getBoolean("useVad")) { - int vadMs = options.hasKey("vadMs") ? options.getInt("vadMs") : 2000; - if (vadMs < 2000) vadMs = 2000; - int sampleSize = (int) (SAMPLE_RATE * vadMs / 1000); - if (nSamples + n > sampleSize) { - int start = nSamples + n - sampleSize; - float[] audioData = new float[sampleSize]; - for (int i = 0; i < sampleSize; i++) { - audioData[i] = shortBuffer[i + start] / 32768.0f; - } - float vadThold = options.hasKey("vadThold") ? (float) options.getDouble("vadThold") : 0.6f; - float vadFreqThold = options.hasKey("vadFreqThold") ? (float) options.getDouble("vadFreqThold") : 0.6f; - isSpeech = vadSimple(audioData, sampleSize, vadThold, vadFreqThold); - } else { - isSpeech = false; - } - } - return isSpeech; + if (isTranscribing) return true; + return vadSimple(jobId, shortBuffer, nSamples, n); } private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) { @@ -112,8 +95,8 @@ private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) { Log.e(NAME, "Error saving wav file: " + e.getMessage()); } } - emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap()); + removeRealtimeTranscribeJob(jobId, context); } public int startRealtimeTranscribe(int jobId, ReadableMap options) { @@ -132,6 +115,7 @@ public int startRealtimeTranscribe(int jobId, ReadableMap options) { rewind(); this.jobId = jobId; + createRealtimeTranscribeJob(jobId, context, options); int realtimeAudioSec = options.hasKey("realtimeAudioSec") ? options.getInt("realtimeAudioSec") : 0; final int audioSec = realtimeAudioSec > 0 ? realtimeAudioSec : DEFAULT_MAX_AUDIO_SEC; @@ -178,8 +162,7 @@ public void run() { finishRealtimeTranscribe(options, Arguments.createMap()); } else if (!isTranscribing) { short[] shortBuffer = shortBufferSlices.get(sliceIndex); - boolean isSpeech = vad(options, shortBuffer, nSamples, 0); - if (!isSpeech) { + if (!vad(options, shortBuffer, nSamples, 0)) { finishRealtimeTranscribe(options, Arguments.createMap()); break; } @@ -265,7 +248,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe Log.d(NAME, "Start transcribing realtime: " + nSamplesTranscribing); int timeStart = (int) System.currentTimeMillis(); - int code = full(jobId, options, nSamplesBuffer32, nSamplesTranscribing); + int code = fullWithJob(jobId, context, nSamplesBuffer32, nSamplesTranscribing); int timeEnd = (int) System.currentTimeMillis(); int timeRecording = (int) (nSamplesTranscribing / SAMPLE_RATE * 1000); @@ -383,32 +366,30 @@ public WritableMap transcribeInputStream(int jobId, InputStream inputStream, Rea this.jobId = jobId; isTranscribing = true; float[] audioData = AudioUtils.decodeWaveFile(inputStream); - int code = full(jobId, options, audioData, audioData.length); - isTranscribing = false; - this.jobId = -1; - if (code != 0 && code != 999) { - throw new Exception("Failed to transcribe the file. Code: " + code); - } - WritableMap result = getTextSegments(0, getTextSegmentCount(context)); - result.putBoolean("isAborted", isStoppedByAction); - return result; - } - private int full(int jobId, ReadableMap options, float[] audioData, int audioDataLen) { boolean hasProgressCallback = options.hasKey("onProgress") && options.getBoolean("onProgress"); boolean hasNewSegmentsCallback = options.hasKey("onNewSegments") && options.getBoolean("onNewSegments"); - return fullTranscribe( + int code = fullWithNewJob( jobId, context, // float[] audio_data, audioData, // jint audio_data_len, - audioDataLen, + audioData.length, // ReadableMap options, options, // Callback callback hasProgressCallback || hasNewSegmentsCallback ? new Callback(this, hasProgressCallback, hasNewSegmentsCallback) : null ); + + isTranscribing = false; + this.jobId = -1; + if (code != 0 && code != 999) { + throw new Exception("Failed to transcribe the file. Code: " + code); + } + WritableMap result = getTextSegments(0, getTextSegmentCount(context)); + result.putBoolean("isAborted", isStoppedByAction); + return result; } private WritableMap getTextSegments(int start, int count) { @@ -531,8 +512,8 @@ private static String cpuInfo() { protected static native long initContext(String modelPath); protected static native long initContextWithAsset(AssetManager assetManager, String modelPath); protected static native long initContextWithInputStream(PushbackInputStream inputStream); - protected static native boolean vadSimple(float[] audio_data, int audio_data_len, float vad_thold, float vad_freq_thold); - protected static native int fullTranscribe( + + protected static native int fullWithNewJob( int job_id, long context, float[] audio_data, @@ -540,6 +521,15 @@ protected static native int fullTranscribe( ReadableMap options, Callback Callback ); + protected static native void createRealtimeTranscribeJob(int job_id, long context, ReadableMap options); + protected static native void removeRealtimeTranscribeJob(int job_id, long context); + protected static native boolean vadSimple(int job_id, short[] pcm, int n_samples, int n); + protected static native int fullWithJob( + int job_id, + long context, + float[] audio_data, + int audio_data_len + ); protected static native void abortTranscribe(int jobId); protected static native void abortAllTranscribe(); protected static native int getTextSegmentCount(long context); diff --git a/android/src/main/jni-utils.h b/android/src/main/jni-utils.h index 419ce34..f4cf1a9 100644 --- a/android/src/main/jni-utils.h +++ b/android/src/main/jni-utils.h @@ -4,7 +4,7 @@ namespace readablemap { -jboolean hasKey(JNIEnv *env, jobject readableMap, const char *key) { +bool hasKey(JNIEnv *env, jobject readableMap, const char *key) { jclass mapClass = env->GetObjectClass(readableMap); jmethodID hasKeyMethod = env->GetMethodID(mapClass, "hasKey", "(Ljava/lang/String;)Z"); jstring jKey = env->NewStringUTF(key); @@ -13,7 +13,7 @@ jboolean hasKey(JNIEnv *env, jobject readableMap, const char *key) { return result; } -jint getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) { +int getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) { if (!hasKey(env, readableMap, key)) { return defaultValue; } @@ -25,7 +25,7 @@ jint getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue return result; } -jboolean getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) { +bool getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) { if (!hasKey(env, readableMap, key)) { return defaultValue; } @@ -37,7 +37,7 @@ jboolean getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean def return result; } -jlong getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) { +long getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) { if (!hasKey(env, readableMap, key)) { return defaultValue; } @@ -49,7 +49,7 @@ jlong getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultVa return result; } -jfloat getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) { +float getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) { if (!hasKey(env, readableMap, key)) { return defaultValue; } diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 27115d8..5b7ddb6 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -191,26 +191,6 @@ Java_com_rnwhisper_WhisperContext_initContextWithInputStream( return reinterpret_cast(context); } -JNIEXPORT jboolean JNICALL -Java_com_rnwhisper_WhisperContext_vadSimple( - JNIEnv *env, - jobject thiz, - jfloatArray audio_data, - jint audio_data_len, - jfloat vad_thold, - jfloat vad_freq_thold -) { - UNUSED(thiz); - - std::vector samples(audio_data_len); - jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr); - for (int i = 0; i < audio_data_len; i++) { - samples[i] = audio_data_arr[i]; - } - bool is_speech = rnwhisper::vad_simple(samples, WHISPER_SAMPLE_RATE, 1000, vad_thold, vad_freq_thold, false); - env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT); - return is_speech; -} struct whisper_full_params createFullParams(JNIEnv *env, jobject options) { struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); @@ -272,7 +252,7 @@ struct callback_context { }; JNIEXPORT jint JNICALL -Java_com_rnwhisper_WhisperContext_fullTranscribe( +Java_com_rnwhisper_WhisperContext_fullWithNewJob( JNIEnv *env, jobject thiz, jint job_id, @@ -333,7 +313,78 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( return code; } -// TODO: full for realtimeTranscribe with job_id (need create job first) +JNIEXPORT void JNICALL +Java_com_rnwhisper_WhisperContext_createRealtimeTranscribeJob( + JNIEnv *env, + jobject thiz, + jint job_id, + jlong context_ptr, + jobject options +) { + whisper_full_params params = createFullParams(env, options); + rnwhisper::job* job = rnwhisper::job_new(job_id, params); + rnwhisper::vad_params vad; + vad.use_vad = readablemap::getBool(env, options, "useVad", false); + vad.vad_ms = readablemap::getInt(env, options, "vadMs", 2000); + vad.vad_thold = readablemap::getFloat(env, options, "vadThold", 0.6f); + vad.freq_thold = readablemap::getFloat(env, options, "vadFreqThold", 100.0f); + job->set_vad_params(vad); +} + +JNIEXPORT void JNICALL +Java_com_rnwhisper_WhisperContext_removeRealtimeTranscribeJob( + JNIEnv *env, + jobject thiz, + jint job_id, + jlong context_ptr +) { + UNUSED(env); + UNUSED(thiz); + UNUSED(context_ptr); + rnwhisper::job_remove(job_id); +} + +JNIEXPORT jboolean JNICALL +Java_com_rnwhisper_WhisperContext_vadSimple( + JNIEnv *env, + jobject thiz, + jint job_id, + jshortArray pcm, + jint n_samples, + jint n +) { + UNUSED(thiz); + + jshort *pcm_arr = env->GetShortArrayElements(pcm, nullptr); + rnwhisper::job* job = rnwhisper::job_get(job_id); + bool is_speech = job->vad_simple(pcm_arr, n_samples, n); + env->ReleaseShortArrayElements(pcm, pcm_arr, JNI_ABORT); + return is_speech; +} + +JNIEXPORT jint JNICALL +Java_com_rnwhisper_WhisperContext_fullWithJob( + JNIEnv *env, + jobject thiz, + jint job_id, + jlong context_ptr, + jfloatArray audio_data, // TODO: move audio slice to C++ + jint audio_data_len +) { + UNUSED(thiz); + struct whisper_context *context = reinterpret_cast(context_ptr); + jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr); + + rnwhisper::job* job = rnwhisper::job_get(job_id); + + int code = whisper_full(context, job->params, audio_data_arr, audio_data_len); + if (code == 0) { + // whisper_print_timings(context); + } + env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT); + if (job->is_aborted()) code = -999; + return code; +} JNIEXPORT void JNICALL Java_com_rnwhisper_WhisperContext_abortTranscribe( diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index ef8177b..2bfc591 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -9,8 +9,8 @@ namespace rnwhisper { struct vad_params { bool use_vad = false; - float vad_thold = 0.1; - float freq_thold = 0.1; + float vad_thold = 0.6f; + float freq_thold = 100.0f; int vad_ms = 2000; int last_ms = 1000; bool verbose = false; diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index ea5974e..05ab7f9 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -355,13 +355,12 @@ - (OSStatus)transcribeRealtime:(int)jobId [self prepareRealtime:options]; self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]); - rnwhisper::vad_params vad = { + self->recordState.job->set_vad_params({ .use_vad = options[@"useVad"] != nil ? [options[@"useVad"] boolValue] : false, .vad_ms = options[@"vadMs"] != nil ? [options[@"vadMs"] intValue] : 2000, .vad_thold = options[@"vadThold"] != nil ? [options[@"vadThold"] floatValue] : 0.6f, .freq_thold = options[@"vadFreqThold"] != nil ? [options[@"vadFreqThold"] floatValue] : 100.0f - }; - self->recordState.job->set_vad_params(vad); + }); OSStatus status = AudioQueueNewInput( &self->recordState.dataFormat, From 54fea101197fb69ecf364af44f4458b5b4737d6d Mon Sep 17 00:00:00 2001 From: Jhen Date: Fri, 8 Dec 2023 14:29:37 +0800 Subject: [PATCH 10/19] feat(cpp): move audio slices --- .../java/com/rnwhisper/WhisperContext.java | 59 ++++------ android/src/main/jni.cpp | 40 +++++-- cpp/rn-whisper.cpp | 38 ++++++- cpp/rn-whisper.h | 16 ++- ios/RNWhisperContext.h | 3 - ios/RNWhisperContext.mm | 104 ++++++------------ 6 files changed, 135 insertions(+), 125 deletions(-) diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index 06a60f9..e7b56b4 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -42,7 +42,6 @@ public class WhisperContext { private AudioRecord recorder = null; private int bufferSize; private int nSamplesTranscribing = 0; - private ArrayList shortBufferSlices; // Remember number of samples in each slice private ArrayList sliceNSamples; // Current buffer slice index @@ -66,7 +65,6 @@ public WhisperContext(int id, ReactApplicationContext reactContext, long context } private void rewind() { - shortBufferSlices = null; sliceNSamples = null; sliceIndex = 0; transcribeSliceIndex = 0; @@ -79,9 +77,9 @@ private void rewind() { fullHandler = null; } - private boolean vad(ReadableMap options, short[] shortBuffer, int nSamples, int n) { + private boolean vad(ReadableMap options, int sliceIndex, int nSamples, int n) { if (isTranscribing) return true; - return vadSimple(jobId, shortBuffer, nSamples, n); + return vadSimple(jobId, sliceIndex, nSamples, n); } private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) { @@ -89,11 +87,12 @@ private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) { if (audioOutputPath != null) { // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage Log.d(NAME, "Begin saving wav file to " + audioOutputPath); - try { - AudioUtils.saveWavFile(AudioUtils.concatShortBuffers(shortBufferSlices), audioOutputPath); - } catch (IOException e) { - Log.e(NAME, "Error saving wav file: " + e.getMessage()); - } + // try { + // // TODO: cpp audio utils + // AudioUtils.saveWavFile(AudioUtils.concatShortBuffers(shortBufferSlices), audioOutputPath); + // } catch (IOException e) { + // Log.e(NAME, "Error saving wav file: " + e.getMessage()); + // } } emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap()); removeRealtimeTranscribeJob(jobId, context); @@ -115,20 +114,17 @@ public int startRealtimeTranscribe(int jobId, ReadableMap options) { rewind(); this.jobId = jobId; - createRealtimeTranscribeJob(jobId, context, options); int realtimeAudioSec = options.hasKey("realtimeAudioSec") ? options.getInt("realtimeAudioSec") : 0; final int audioSec = realtimeAudioSec > 0 ? realtimeAudioSec : DEFAULT_MAX_AUDIO_SEC; - int realtimeAudioSliceSec = options.hasKey("realtimeAudioSliceSec") ? options.getInt("realtimeAudioSliceSec") : 0; final int audioSliceSec = realtimeAudioSliceSec > 0 && realtimeAudioSliceSec < audioSec ? realtimeAudioSliceSec : audioSec; - isUseSlices = audioSliceSec < audioSec; + createRealtimeTranscribeJob(jobId, context, options); + String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null; - shortBufferSlices = new ArrayList(); - shortBufferSlices.add(new short[audioSliceSec * SAMPLE_RATE]); sliceNSamples = new ArrayList(); sliceNSamples.add(0); @@ -161,8 +157,7 @@ public void run() { ) { finishRealtimeTranscribe(options, Arguments.createMap()); } else if (!isTranscribing) { - short[] shortBuffer = shortBufferSlices.get(sliceIndex); - if (!vad(options, shortBuffer, nSamples, 0)) { + if (!vad(options, sliceIndex, nSamples, 0)) { finishRealtimeTranscribe(options, Arguments.createMap()); break; } @@ -173,22 +168,16 @@ public void run() { } // Append to buffer - short[] shortBuffer = shortBufferSlices.get(sliceIndex); if (nSamples + n > audioSliceSec * SAMPLE_RATE) { Log.d(NAME, "next slice"); sliceIndex++; nSamples = 0; - shortBuffer = new short[audioSliceSec * SAMPLE_RATE]; - shortBufferSlices.add(shortBuffer); sliceNSamples.add(0); } + putPcmData(buffer, sliceIndex, nSamples, n); - for (int i = 0; i < n; i++) { - shortBuffer[nSamples + i] = buffer[i]; - } - - boolean isSpeech = vad(options, shortBuffer, nSamples, n); + boolean isSpeech = vad(options, sliceIndex, nSamples, n); nSamples += n; sliceNSamples.set(sliceIndex, nSamples); @@ -234,21 +223,14 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe if (!isCapturing && !skipCapturingCheck) return; - short[] shortBuffer = shortBufferSlices.get(transcribeSliceIndex); int nSamples = sliceNSamples.get(transcribeSliceIndex); nSamplesTranscribing = nSamplesOfIndex; - // convert I16 to F32 - float[] nSamplesBuffer32 = new float[nSamplesTranscribing]; - for (int i = 0; i < nSamplesTranscribing; i++) { - nSamplesBuffer32[i] = shortBuffer[i] / 32768.0f; - } - Log.d(NAME, "Start transcribing realtime: " + nSamplesTranscribing); int timeStart = (int) System.currentTimeMillis(); - int code = fullWithJob(jobId, context, nSamplesBuffer32, nSamplesTranscribing); + int code = fullWithJob(jobId, context, transcribeSliceIndex, nSamplesTranscribing); int timeEnd = (int) System.currentTimeMillis(); int timeRecording = (int) (nSamplesTranscribing / SAMPLE_RATE * 1000); @@ -521,14 +503,19 @@ protected static native int fullWithNewJob( ReadableMap options, Callback Callback ); - protected static native void createRealtimeTranscribeJob(int job_id, long context, ReadableMap options); + protected static native void createRealtimeTranscribeJob( + int job_id, + long context, + ReadableMap options + ); protected static native void removeRealtimeTranscribeJob(int job_id, long context); - protected static native boolean vadSimple(int job_id, short[] pcm, int n_samples, int n); + protected static native boolean vadSimple(int job_id, int slice_index, int n_samples, int n); + protected static native void putPcmData(short[] buffer, int slice_index, int n_samples, int n); protected static native int fullWithJob( int job_id, long context, - float[] audio_data, - int audio_data_len + int slice_index, + int n_samples ); protected static native void abortTranscribe(int jobId); protected static native void abortAllTranscribe(); diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 5b7ddb6..fe337a3 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -328,7 +328,11 @@ Java_com_rnwhisper_WhisperContext_createRealtimeTranscribeJob( vad.vad_ms = readablemap::getInt(env, options, "vadMs", 2000); vad.vad_thold = readablemap::getFloat(env, options, "vadThold", 0.6f); vad.freq_thold = readablemap::getFloat(env, options, "vadFreqThold", 100.0f); - job->set_vad_params(vad); + job->set_realtime_params( + vad, + readablemap::getInt(env, options, "realtimeAudioSec", 0), + readablemap::getInt(env, options, "realtimeAudioSliceSec", 0) + ); } JNIEXPORT void JNICALL @@ -341,6 +345,8 @@ Java_com_rnwhisper_WhisperContext_removeRealtimeTranscribeJob( UNUSED(env); UNUSED(thiz); UNUSED(context_ptr); + rnwhisper::job *job = rnwhisper::job_get(job_id); + job->free_pcm_slices(); rnwhisper::job_remove(job_id); } @@ -349,17 +355,30 @@ Java_com_rnwhisper_WhisperContext_vadSimple( JNIEnv *env, jobject thiz, jint job_id, - jshortArray pcm, + jint slice_index, jint n_samples, jint n ) { UNUSED(thiz); + rnwhisper::job* job = rnwhisper::job_get(job_id); + return job->vad_simple(slice_index, n_samples, n); +} - jshort *pcm_arr = env->GetShortArrayElements(pcm, nullptr); +JNIEXPORT void JNICALL +Java_com_rnwhisper_WhisperContext_putPcmData( + JNIEnv *env, + jobject thiz, + jint job_id, + jshortArray pcm, + jint slice_index, + jint n_samples, + jint n +) { + UNUSED(thiz); rnwhisper::job* job = rnwhisper::job_get(job_id); - bool is_speech = job->vad_simple(pcm_arr, n_samples, n); + jshort *pcm_arr = env->GetShortArrayElements(pcm, nullptr); + job->put_pcm_data(pcm_arr, slice_index, n_samples, n); env->ReleaseShortArrayElements(pcm, pcm_arr, JNI_ABORT); - return is_speech; } JNIEXPORT jint JNICALL @@ -368,20 +387,19 @@ Java_com_rnwhisper_WhisperContext_fullWithJob( jobject thiz, jint job_id, jlong context_ptr, - jfloatArray audio_data, // TODO: move audio slice to C++ - jint audio_data_len + jint slice_index, + jint n_samples ) { UNUSED(thiz); struct whisper_context *context = reinterpret_cast(context_ptr); - jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr); rnwhisper::job* job = rnwhisper::job_get(job_id); - - int code = whisper_full(context, job->params, audio_data_arr, audio_data_len); + float* pcmf32 = job->pcm_slice_to_f32(slice_index, n_samples); + int code = whisper_full(context, job->params, pcmf32, n_samples); + free(pcmf32); if (code == 0) { // whisper_print_timings(context); } - env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT); if (job->is_aborted()) code = -999; return code; } diff --git a/cpp/rn-whisper.cpp b/cpp/rn-whisper.cpp index d60c65d..3493c18 100644 --- a/cpp/rn-whisper.cpp +++ b/cpp/rn-whisper.cpp @@ -4,6 +4,8 @@ #include #include "rn-whisper.h" +#define DEFAULT_MAX_AUDIO_SEC 30; + namespace rnwhisper { void high_pass_filter(std::vector & data, float cutoff, float sample_rate) { @@ -61,14 +63,17 @@ job::~job() { fprintf(stderr, "%s: job_id: %d\n", __func__, job_id); } -void job::set_vad_params(vad_params params) { +void job::set_realtime_params(vad_params params, int sec, int slice_sec) { vad = params; if (vad.vad_ms < 2000) vad.vad_ms = 2000; + audio_sec = sec > 0 ? sec : DEFAULT_MAX_AUDIO_SEC; + audio_slice_sec = slice_sec > 0 && slice_sec < audio_sec ? slice_sec : audio_sec; } -bool job::vad_simple(short* pcm, int n_samples, int n) { +bool job::vad_simple(int slice_index, int n_samples, int n) { if (!vad.use_vad) return true; + short* pcm = pcm_slices[slice_index]; int sample_size = (int) (WHISPER_SAMPLE_RATE * vad.vad_ms / 1000); if (n_samples + n > sample_size) { int start = n_samples + n - sample_size; @@ -81,6 +86,35 @@ bool job::vad_simple(short* pcm, int n_samples, int n) { return false; } +void job::put_pcm_data(short* data, int slice_index, int n_samples, int n) { + if (pcm_slices.size() == slice_index) { + int n_slices = (int) (WHISPER_SAMPLE_RATE * audio_slice_sec); + pcm_slices.push_back(new short[n_slices]); + } + short* pcm = pcm_slices[slice_index]; + for (int i = 0; i < n; i++) { + pcm[i + n_samples] = data[i]; + } +} + +float* job::pcm_slice_to_f32(int slice_index, int size) { + if (pcm_slices.size() > slice_index) { + float* pcmf32 = new float[size]; + for (int i = 0; i < size; i++) { + pcmf32[i] = (float)pcm_slices[slice_index][i] / 32768.0f; + } + return pcmf32; + } + return nullptr; +} + +void job::free_pcm_slices() { + for (size_t i = 0; i < pcm_slices.size(); i++) { + delete[] pcm_slices[i]; + } + pcm_slices.clear(); +} + bool job::is_aborted() { return aborted; } diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index 2bfc591..149c7c7 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -20,13 +20,21 @@ struct job { int job_id; bool aborted = false; whisper_full_params params; - vad_params vad; // Realtime transcription only - + ~job(); - void set_vad_params(vad_params vad); - bool vad_simple(short* pcm, int n_samples, int n); bool is_aborted(); void abort(); + + // Realtime transcription only: + vad_params vad; + int audio_sec = 0; + int audio_slice_sec = 0; + std::vector pcm_slices; + void set_realtime_params(vad_params vad, int audio_sec, int audio_slice_sec); + bool vad_simple(int slice_index, int n_samples, int n); + void put_pcm_data(short* pcm, int slice_index, int n_samples, int n); + float* pcm_slice_to_f32(int slice_index, int size); + void free_pcm_slices(); }; void job_abort_all(); diff --git a/ios/RNWhisperContext.h b/ios/RNWhisperContext.h index 19b2400..e16d188 100644 --- a/ios/RNWhisperContext.h +++ b/ios/RNWhisperContext.h @@ -19,14 +19,11 @@ typedef struct { bool isRealtime; bool isCapturing; bool isStoppedByAction; - int maxAudioSec; int nSamplesTranscribing; - NSMutableArray *shortBufferSlices; NSMutableArray *sliceNSamples; bool isUseSlices; int sliceIndex; int transcribeSliceIndex; - int audioSliceSec; NSString* audioOutputPath; AudioQueueRef queue; diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index 05ab7f9..ba095fc 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -95,7 +95,7 @@ - (dispatch_queue_t)getDispatchQueue { return self->dQueue; } -- (void)prepareRealtime:(NSDictionary *)options { +- (void)prepareRealtime:(int)jobId options:(NSDictionary *)options { self->recordState.options = options; self->recordState.dataFormat.mSampleRate = WHISPER_SAMPLE_RATE; // 16000 @@ -108,53 +108,40 @@ - (void)prepareRealtime:(NSDictionary *)options { self->recordState.dataFormat.mReserved = 0; self->recordState.dataFormat.mFormatFlags = kLinearPCMFormatFlagIsSignedInteger; - int maxAudioSecOpt = options[@"realtimeAudioSec"] != nil ? [options[@"realtimeAudioSec"] intValue] : 0; - int maxAudioSec = maxAudioSecOpt > 0 ? maxAudioSecOpt : DEFAULT_MAX_AUDIO_SEC; - self->recordState.maxAudioSec = maxAudioSec; - - int realtimeAudioSliceSec = options[@"realtimeAudioSliceSec"] != nil ? [options[@"realtimeAudioSliceSec"] intValue] : 0; - int audioSliceSec = realtimeAudioSliceSec > 0 && realtimeAudioSliceSec < maxAudioSec ? realtimeAudioSliceSec : maxAudioSec; + self->recordState.isRealtime = true; + self->recordState.isTranscribing = false; + self->recordState.isCapturing = false; + self->recordState.isStoppedByAction = false; self->recordState.audioOutputPath = options[@"audioOutputPath"]; - self->recordState.audioSliceSec = audioSliceSec; - self->recordState.isUseSlices = audioSliceSec < maxAudioSec; - self->recordState.sliceIndex = 0; self->recordState.transcribeSliceIndex = 0; self->recordState.nSamplesTranscribing = 0; - [self freeBufferIfNeeded]; - self->recordState.shortBufferSlices = [NSMutableArray new]; - - int16_t *audioBufferI16 = (int16_t *) malloc(audioSliceSec * WHISPER_SAMPLE_RATE * sizeof(int16_t)); - [self->recordState.shortBufferSlices addObject:[NSValue valueWithPointer:audioBufferI16]]; - self->recordState.sliceNSamples = [NSMutableArray new]; [self->recordState.sliceNSamples addObject:[NSNumber numberWithInt:0]]; - self->recordState.isRealtime = true; - self->recordState.isTranscribing = false; - self->recordState.isCapturing = false; - self->recordState.isStoppedByAction = false; + self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]); + self->recordState.job->set_realtime_params( + { + .use_vad = options[@"useVad"] != nil ? [options[@"useVad"] boolValue] : false, + .vad_ms = options[@"vadMs"] != nil ? [options[@"vadMs"] intValue] : 2000, + .vad_thold = options[@"vadThold"] != nil ? [options[@"vadThold"] floatValue] : 0.6f, + .freq_thold = options[@"vadFreqThold"] != nil ? [options[@"vadFreqThold"] floatValue] : 100.0f + }, + options[@"realtimeAudioSec"] != nil ? [options[@"realtimeAudioSec"] intValue] : 0, + options[@"realtimeAudioSliceSec"] != nil ? [options[@"realtimeAudioSliceSec"] intValue] : 0 + ); + self->recordState.isUseSlices = self->recordState.job->audio_slice_sec < self->recordState.job->audio_sec; self->recordState.mSelf = self; } -- (void)freeBufferIfNeeded { - if (self->recordState.shortBufferSlices != nil) { - for (int i = 0; i < [self->recordState.shortBufferSlices count]; i++) { - int16_t *audioBufferI16 = (int16_t *) [self->recordState.shortBufferSlices[i] pointerValue]; - free(audioBufferI16); - } - self->recordState.shortBufferSlices = nil; - } -} - -bool vad(RNWhisperContextRecordState *state, short* pcm, int nSamples, int n) +bool vad(RNWhisperContextRecordState *state, int sliceIndex, int nSamples, int n) { if (state->isTranscribing) return true; - return state->job->vad_simple(pcm, nSamples, n); + return state->job->vad_simple(sliceIndex, nSamples, n); } void AudioInputCallback(void * inUserData, @@ -183,7 +170,7 @@ void AudioInputCallback(void * inUserData, int nSamples = [state->sliceNSamples[state->sliceIndex] intValue]; - if (totalNSamples + n > state->maxAudioSec * WHISPER_SAMPLE_RATE) { + if (totalNSamples + n > state->job->audio_sec * WHISPER_SAMPLE_RATE) { NSLog(@"[RNWhisper] Audio buffer is full, stop capturing"); state->isCapturing = false; [state->mSelf stopAudio]; @@ -197,8 +184,7 @@ void AudioInputCallback(void * inUserData, !state->isTranscribing && nSamples != state->nSamplesTranscribing ) { - int16_t* audioBufferI16 = (int16_t*) [state->shortBufferSlices[state->sliceIndex] pointerValue]; - if (!vad(state, audioBufferI16, nSamples, 0)) { + if (!vad(state, state->sliceIndex, nSamples, 0)) { [state->mSelf finishRealtimeTranscribe:state result:@{}]; return; } @@ -210,25 +196,19 @@ void AudioInputCallback(void * inUserData, return; } - int audioSliceSec = state->audioSliceSec; - if (nSamples + n > audioSliceSec * WHISPER_SAMPLE_RATE) { + if (nSamples + n > state->job->audio_slice_sec * WHISPER_SAMPLE_RATE) { // next slice state->sliceIndex++; nSamples = 0; - int16_t* audioBufferI16 = (int16_t*) malloc(audioSliceSec * WHISPER_SAMPLE_RATE * sizeof(int16_t)); - [state->shortBufferSlices addObject:[NSValue valueWithPointer:audioBufferI16]]; [state->sliceNSamples addObject:[NSNumber numberWithInt:0]]; } // Append to buffer NSLog(@"[RNWhisper] Slice %d has %d samples", state->sliceIndex, nSamples); - int16_t* audioBufferI16 = (int16_t*) [state->shortBufferSlices[state->sliceIndex] pointerValue]; - for (int i = 0; i < n; i++) { - audioBufferI16[nSamples + i] = ((short*)inBuffer->mAudioData)[i]; - } + state->job->put_pcm_data((short*) inBuffer->mAudioData, state->sliceIndex, nSamples, n); - bool isSpeech = vad(state, audioBufferI16, nSamples, n); + bool isSpeech = vad(state, state->sliceIndex, nSamples, n); nSamples += n; state->sliceNSamples[state->sliceIndex] = [NSNumber numberWithInt:nSamples]; @@ -248,13 +228,14 @@ - (void)finishRealtimeTranscribe:(RNWhisperContextRecordState*) state result:(NS // Save wav if needed if (state->audioOutputPath != nil) { // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage - [RNWhisperAudioUtils - saveWavFile:[RNWhisperAudioUtils concatShortBuffers:state->shortBufferSlices - sliceNSamples:state->sliceNSamples] - audioOutputFile:state->audioOutputPath - ]; + // [RNWhisperAudioUtils + // saveWavFile:[RNWhisperAudioUtils concatShortBuffers:state->shortBufferSlices + // sliceNSamples:state->sliceNSamples] + // audioOutputFile:state->audioOutputPath + // ]; } state->transcribeHandler(state->job->job_id, @"end", result); + state->job->free_pcm_slices(); rnwhisper::job_remove(state->job->job_id); } @@ -263,17 +244,11 @@ - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state { state->nSamplesTranscribing = nSamplesOfIndex; NSLog(@"[RNWhisper] Transcribing %d samples", state->nSamplesTranscribing); - int16_t* audioBufferI16 = (int16_t*) [state->shortBufferSlices[state->transcribeSliceIndex] pointerValue]; - float* audioBufferF32 = (float*) malloc(state->nSamplesTranscribing * sizeof(float)); - // convert I16 to F32 - for (int i = 0; i < state->nSamplesTranscribing; i++) { - audioBufferF32[i] = (float)audioBufferI16[i] / 32768.0f; - } - CFTimeInterval timeStart = CACurrentMediaTime(); - - int code = [state->mSelf fullTranscribe:state->job audioData:audioBufferF32 audioDataCount:state->nSamplesTranscribing]; + float* pcmf32 = state->job->pcm_slice_to_f32(state->transcribeSliceIndex, state->nSamplesTranscribing); - free(audioBufferF32); + CFTimeInterval timeStart = CACurrentMediaTime(); + int code = [state->mSelf fullTranscribe:state->job audioData:pcmf32 audioDataCount:state->nSamplesTranscribing]; + free(pcmf32); CFTimeInterval timeEnd = CACurrentMediaTime(); const float timeRecording = (float) state->nSamplesTranscribing / (float) state->dataFormat.mSampleRate; @@ -352,15 +327,7 @@ - (OSStatus)transcribeRealtime:(int)jobId onTranscribe:(void (^)(int, NSString *, NSDictionary *))onTranscribe { self->recordState.transcribeHandler = onTranscribe; - [self prepareRealtime:options]; - self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]); - - self->recordState.job->set_vad_params({ - .use_vad = options[@"useVad"] != nil ? [options[@"useVad"] boolValue] : false, - .vad_ms = options[@"vadMs"] != nil ? [options[@"vadMs"] intValue] : 2000, - .vad_thold = options[@"vadThold"] != nil ? [options[@"vadThold"] floatValue] : 0.6f, - .freq_thold = options[@"vadFreqThold"] != nil ? [options[@"vadFreqThold"] floatValue] : 100.0f - }); + [self prepareRealtime:jobId options:options]; OSStatus status = AudioQueueNewInput( &self->recordState.dataFormat, @@ -584,7 +551,6 @@ - (NSMutableDictionary *)getTextSegments { - (void)invalidate { [self stopCurrentTranscribe]; whisper_free(self->ctx); - [self freeBufferIfNeeded]; } @end From 6f956864537d5753dea78e6a0286fbe7bd19c57c Mon Sep 17 00:00:00 2001 From: Jhen Date: Fri, 8 Dec 2023 17:56:18 +0800 Subject: [PATCH 11/19] feat(cpp): move audio utils & save audio --- android/src/main/CMakeLists.txt | 1 + .../main/java/com/rnwhisper/AudioUtils.java | 80 ------------------- .../java/com/rnwhisper/WhisperContext.java | 17 +--- android/src/main/jni.cpp | 27 ++++++- cpp/rn-audioutils.cpp | 65 +++++++++++++++ cpp/rn-audioutils.h | 14 ++++ cpp/rn-whisper.cpp | 8 +- cpp/rn-whisper.h | 4 +- ios/RNWhisperAudioUtils.h | 2 - ios/RNWhisperAudioUtils.m | 56 ------------- ios/RNWhisperContext.h | 2 +- ios/RNWhisperContext.mm | 35 ++++---- 12 files changed, 133 insertions(+), 178 deletions(-) create mode 100644 cpp/rn-audioutils.cpp create mode 100644 cpp/rn-audioutils.h diff --git a/android/src/main/CMakeLists.txt b/android/src/main/CMakeLists.txt index febe580..0e26f28 100644 --- a/android/src/main/CMakeLists.txt +++ b/android/src/main/CMakeLists.txt @@ -12,6 +12,7 @@ set( ${RNWHISPER_LIB_DIR}/ggml-backend.c ${RNWHISPER_LIB_DIR}/ggml-quants.c ${RNWHISPER_LIB_DIR}/whisper.cpp + ${RNWHISPER_LIB_DIR}/rn-audioutils.cpp ${RNWHISPER_LIB_DIR}/rn-whisper.cpp ${CMAKE_SOURCE_DIR}/jni.cpp ) diff --git a/android/src/main/java/com/rnwhisper/AudioUtils.java b/android/src/main/java/com/rnwhisper/AudioUtils.java index 4498a79..b6c614d 100644 --- a/android/src/main/java/com/rnwhisper/AudioUtils.java +++ b/android/src/main/java/com/rnwhisper/AudioUtils.java @@ -2,14 +2,10 @@ import android.util.Log; -import java.util.ArrayList; -import java.lang.StringBuilder; import java.io.IOException; import java.io.FileReader; import java.io.ByteArrayOutputStream; import java.io.File; -import java.io.FileOutputStream; -import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; @@ -19,82 +15,6 @@ public class AudioUtils { private static final String NAME = "RNWhisperAudioUtils"; - private static final int SAMPLE_RATE = 16000; - - private static byte[] shortToByte(short[] shortInts) { - int j = 0; - int length = shortInts.length; - byte[] byteData = new byte[length * 2]; - for (int i = 0; i < length; i++) { - byteData[j++] = (byte) (shortInts[i] >>> 8); - byteData[j++] = (byte) (shortInts[i] >>> 0); - } - return byteData; - } - - public static byte[] concatShortBuffers(ArrayList buffers) { - int totalLength = 0; - for (int i = 0; i < buffers.size(); i++) { - totalLength += buffers.get(i).length; - } - byte[] result = new byte[totalLength * 2]; - int offset = 0; - for (int i = 0; i < buffers.size(); i++) { - byte[] bytes = shortToByte(buffers.get(i)); - System.arraycopy(bytes, 0, result, offset, bytes.length); - offset += bytes.length; - } - - return result; - } - - private static byte[] removeTrailingZeros(byte[] audioData) { - int i = audioData.length - 1; - while (i >= 0 && audioData[i] == 0) { - --i; - } - byte[] newData = new byte[i + 1]; - System.arraycopy(audioData, 0, newData, 0, i + 1); - return newData; - } - - public static void saveWavFile(byte[] rawData, String audioOutputFile) throws IOException { - Log.d(NAME, "call saveWavFile"); - rawData = removeTrailingZeros(rawData); - DataOutputStream output = null; - try { - output = new DataOutputStream(new FileOutputStream(audioOutputFile)); - // WAVE header - // see http://ccrma.stanford.edu/courses/422/projects/WaveFormat/ - output.writeBytes("RIFF"); // chunk id - output.writeInt(Integer.reverseBytes(36 + rawData.length)); // chunk size - output.writeBytes("WAVE"); // format - output.writeBytes("fmt "); // subchunk 1 id - output.writeInt(Integer.reverseBytes(16)); // subchunk 1 size - output.writeShort(Short.reverseBytes((short) 1)); // audio format (1 = PCM) - output.writeShort(Short.reverseBytes((short) 1)); // number of channels - output.writeInt(Integer.reverseBytes(SAMPLE_RATE)); // sample rate - output.writeInt(Integer.reverseBytes(SAMPLE_RATE * 2)); // byte rate - output.writeShort(Short.reverseBytes((short) 2)); // block align - output.writeShort(Short.reverseBytes((short) 16)); // bits per sample - output.writeBytes("data"); // subchunk 2 id - output.writeInt(Integer.reverseBytes(rawData.length)); // subchunk 2 size - // Audio data (conversion big endian -> little endian) - short[] shorts = new short[rawData.length / 2]; - ByteBuffer.wrap(rawData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(shorts); - ByteBuffer bytes = ByteBuffer.allocate(shorts.length * 2); - for (short s : shorts) { - bytes.putShort(s); - } - Log.d(NAME, "writing audio file: " + audioOutputFile); - output.write(bytes.array()); - } finally { - if (output != null) { - output.close(); - } - } - } - public static float[] decodeWaveFile(InputStream inputStream) throws IOException { ByteArrayOutputStream baos = new ByteArrayOutputStream(); byte[] buffer = new byte[1024]; diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index e7b56b4..fdf9265 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -83,19 +83,8 @@ private boolean vad(ReadableMap options, int sliceIndex, int nSamples, int n) { } private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) { - String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null; - if (audioOutputPath != null) { - // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage - Log.d(NAME, "Begin saving wav file to " + audioOutputPath); - // try { - // // TODO: cpp audio utils - // AudioUtils.saveWavFile(AudioUtils.concatShortBuffers(shortBufferSlices), audioOutputPath); - // } catch (IOException e) { - // Log.e(NAME, "Error saving wav file: " + e.getMessage()); - // } - } emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap()); - removeRealtimeTranscribeJob(jobId, context); + finishRealtimeTranscribeJob(jobId, context, sliceNSamples.stream().mapToInt(i -> i).toArray()); } public int startRealtimeTranscribe(int jobId, ReadableMap options) { @@ -123,8 +112,6 @@ public int startRealtimeTranscribe(int jobId, ReadableMap options) { createRealtimeTranscribeJob(jobId, context, options); - String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null; - sliceNSamples = new ArrayList(); sliceNSamples.add(0); @@ -508,7 +495,7 @@ protected static native void createRealtimeTranscribeJob( long context, ReadableMap options ); - protected static native void removeRealtimeTranscribeJob(int job_id, long context); + protected static native void finishRealtimeTranscribeJob(int job_id, long context, int[] sliceNSamples); protected static native boolean vadSimple(int job_id, int slice_index, int n_samples, int n); protected static native void putPcmData(short[] buffer, int slice_index, int n_samples, int n); protected static native int fullWithJob( diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index fe337a3..b2fe8aa 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -328,24 +328,45 @@ Java_com_rnwhisper_WhisperContext_createRealtimeTranscribeJob( vad.vad_ms = readablemap::getInt(env, options, "vadMs", 2000); vad.vad_thold = readablemap::getFloat(env, options, "vadThold", 0.6f); vad.freq_thold = readablemap::getFloat(env, options, "vadFreqThold", 100.0f); + + jstring audio_output_path = readablemap::getString(env, options, "audioOutputPath", nullptr); + std::string *audio_output_path_str = nullptr; + if (audio_output_path != nullptr) { + audio_output_path_str = new std::string(env->GetStringUTFChars(audio_output_path, nullptr)); + env->ReleaseStringUTFChars(audio_output_path, audio_output_path_str->c_str()); + } job->set_realtime_params( vad, readablemap::getInt(env, options, "realtimeAudioSec", 0), - readablemap::getInt(env, options, "realtimeAudioSliceSec", 0) + readablemap::getInt(env, options, "realtimeAudioSliceSec", 0), + audio_output_path_str ); } JNIEXPORT void JNICALL -Java_com_rnwhisper_WhisperContext_removeRealtimeTranscribeJob( +Java_com_rnwhisper_WhisperContext_finishRealtimeTranscribeJob( JNIEnv *env, jobject thiz, jint job_id, - jlong context_ptr + jlong context_ptr, + jintArray slice_n_samples ) { UNUSED(env); UNUSED(thiz); UNUSED(context_ptr); + rnwhisper::job *job = rnwhisper::job_get(job_id); + if (job->audio_output_path != nullptr) { + std::vector slice_n_samples_vec; + jint *slice_n_samples_arr = env->GetIntArrayElements(slice_n_samples, nullptr); + slice_n_samples_vec = std::vector(slice_n_samples_arr, slice_n_samples_arr + env->GetArrayLength(slice_n_samples)); + env->ReleaseIntArrayElements(slice_n_samples, slice_n_samples_arr, JNI_ABORT); + + rnaudioutils::save_wav_file( + rnaudioutils::concat_short_buffers(job->pcm_slices, slice_n_samples_vec), + *job->audio_output_path + ); + } job->free_pcm_slices(); rnwhisper::job_remove(job_id); } diff --git a/cpp/rn-audioutils.cpp b/cpp/rn-audioutils.cpp new file mode 100644 index 0000000..e2b554e --- /dev/null +++ b/cpp/rn-audioutils.cpp @@ -0,0 +1,65 @@ +#include "rn-audioutils.h" + +namespace rnaudioutils { + +std::vector concat_short_buffers(const std::vector& buffers, const std::vector& slice_n_samples) { + std::vector output_data; + + for (size_t i = 0; i < buffers.size(); i++) { + int size = slice_n_samples[i]; // Number of shorts + short* slice = buffers[i]; + + // Copy each short as two bytes + for (int j = 0; j < size; j++) { + output_data.push_back(static_cast(slice[j] & 0xFF)); // Lower byte + output_data.push_back(static_cast((slice[j] >> 8) & 0xFF)); // Higher byte + } + } + + return output_data; +} + +std::vector remove_trailing_zeros(const std::vector& audio_data) { + auto last = std::find_if(audio_data.rbegin(), audio_data.rend(), [](uint8_t byte) { return byte != 0; }); + return std::vector(audio_data.begin(), last.base()); +} + +void save_wav_file(const std::vector& raw, const std::string& file) { + std::vector data = remove_trailing_zeros(raw); + + std::ofstream output(file, std::ios::binary); + + if (!output.is_open()) { + std::cerr << "Failed to open file for writing: " << file << std::endl; + return; + } + + // WAVE header + output.write("RIFF", 4); + int32_t chunk_size = 36 + static_cast(data.size()); + output.write(reinterpret_cast(&chunk_size), sizeof(chunk_size)); + output.write("WAVE", 4); + output.write("fmt ", 4); + int32_t sub_chunk_size = 16; + output.write(reinterpret_cast(&sub_chunk_size), sizeof(sub_chunk_size)); + short audio_format = 1; + output.write(reinterpret_cast(&audio_format), sizeof(audio_format)); + short num_channels = 1; + output.write(reinterpret_cast(&num_channels), sizeof(num_channels)); + int32_t sample_rate = WHISPER_SAMPLE_RATE; + output.write(reinterpret_cast(&sample_rate), sizeof(sample_rate)); + int32_t byte_rate = WHISPER_SAMPLE_RATE * 2; + output.write(reinterpret_cast(&byte_rate), sizeof(byte_rate)); + short block_align = 2; + output.write(reinterpret_cast(&block_align), sizeof(block_align)); + short bits_per_sample = 16; + output.write(reinterpret_cast(&bits_per_sample), sizeof(bits_per_sample)); + output.write("data", 4); + int32_t sub_chunk2_size = static_cast(data.size()); + output.write(reinterpret_cast(&sub_chunk2_size), sizeof(sub_chunk2_size)); + output.write(reinterpret_cast(data.data()), data.size()); + + output.close(); +} + +} // namespace rnaudioutils diff --git a/cpp/rn-audioutils.h b/cpp/rn-audioutils.h new file mode 100644 index 0000000..9e49976 --- /dev/null +++ b/cpp/rn-audioutils.h @@ -0,0 +1,14 @@ +#include +#include +#include +#include +#include +#include +#include "whisper.h" + +namespace rnaudioutils { + +std::vector concat_short_buffers(const std::vector& buffers, const std::vector& slice_n_samples); +void save_wav_file(const std::vector& raw, const std::string& file); + +} // namespace rnaudioutils diff --git a/cpp/rn-whisper.cpp b/cpp/rn-whisper.cpp index 3493c18..4357c69 100644 --- a/cpp/rn-whisper.cpp +++ b/cpp/rn-whisper.cpp @@ -63,11 +63,17 @@ job::~job() { fprintf(stderr, "%s: job_id: %d\n", __func__, job_id); } -void job::set_realtime_params(vad_params params, int sec, int slice_sec) { +void job::set_realtime_params( + vad_params params, + int sec, + int slice_sec, + std::string* output_path +) { vad = params; if (vad.vad_ms < 2000) vad.vad_ms = 2000; audio_sec = sec > 0 ? sec : DEFAULT_MAX_AUDIO_SEC; audio_slice_sec = slice_sec > 0 && slice_sec < audio_sec ? slice_sec : audio_sec; + audio_output_path = output_path; } bool job::vad_simple(int slice_index, int n_samples, int n) { diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index 149c7c7..9a2bc40 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -4,6 +4,7 @@ #include #include #include "whisper.h" +#include "rn-audioutils.h" namespace rnwhisper { @@ -29,8 +30,9 @@ struct job { vad_params vad; int audio_sec = 0; int audio_slice_sec = 0; + std::string* audio_output_path = nullptr; std::vector pcm_slices; - void set_realtime_params(vad_params vad, int audio_sec, int audio_slice_sec); + void set_realtime_params(vad_params vad, int audio_sec, int audio_slice_sec, std::string* audio_output_path); bool vad_simple(int slice_index, int n_samples, int n); void put_pcm_data(short* pcm, int slice_index, int n_samples, int n); float* pcm_slice_to_f32(int slice_index, int size); diff --git a/ios/RNWhisperAudioUtils.h b/ios/RNWhisperAudioUtils.h index 628fa4f..a37581d 100644 --- a/ios/RNWhisperAudioUtils.h +++ b/ios/RNWhisperAudioUtils.h @@ -2,8 +2,6 @@ @interface RNWhisperAudioUtils : NSObject -+ (NSData *)concatShortBuffers:(NSMutableArray *)buffers sliceNSamples:(NSMutableArray *)sliceNSamples; -+ (void)saveWavFile:(NSData *)rawData audioOutputFile:(NSString *)audioOutputFile; + (float *)decodeWaveFile:(NSString*)filePath count:(int *)count; @end diff --git a/ios/RNWhisperAudioUtils.m b/ios/RNWhisperAudioUtils.m index a9ed994..334740f 100644 --- a/ios/RNWhisperAudioUtils.m +++ b/ios/RNWhisperAudioUtils.m @@ -3,62 +3,6 @@ @implementation RNWhisperAudioUtils -+ (NSData *)concatShortBuffers:(NSMutableArray *)buffers sliceNSamples:(NSMutableArray *)sliceNSamples { - NSMutableData *outputData = [NSMutableData data]; - for (int i = 0; i < buffers.count; i++) { - int size = [sliceNSamples objectAtIndex:i].intValue; - NSValue *buffer = [buffers objectAtIndex:i]; - short *bufferPtr = buffer.pointerValue; - [outputData appendBytes:bufferPtr length:size * sizeof(short)]; - } - return outputData; -} - -+ (void)saveWavFile:(NSData *)rawData audioOutputFile:(NSString *)audioOutputFile { - NSMutableData *outputData = [NSMutableData data]; - - // WAVE header - [outputData appendData:[@"RIFF" dataUsingEncoding:NSUTF8StringEncoding]]; // chunk id - int chunkSize = CFSwapInt32HostToLittle(36 + rawData.length); - [outputData appendBytes:&chunkSize length:sizeof(chunkSize)]; - [outputData appendData:[@"WAVE" dataUsingEncoding:NSUTF8StringEncoding]]; // format - [outputData appendData:[@"fmt " dataUsingEncoding:NSUTF8StringEncoding]]; // subchunk 1 id - - int subchunk1Size = CFSwapInt32HostToLittle(16); - [outputData appendBytes:&subchunk1Size length:sizeof(subchunk1Size)]; - - short audioFormat = CFSwapInt16HostToLittle(1); // PCM - [outputData appendBytes:&audioFormat length:sizeof(audioFormat)]; - - short numChannels = CFSwapInt16HostToLittle(1); // mono - [outputData appendBytes:&numChannels length:sizeof(numChannels)]; - - int sampleRate = CFSwapInt32HostToLittle(WHISPER_SAMPLE_RATE); - [outputData appendBytes:&sampleRate length:sizeof(sampleRate)]; - - // (bitDepth * sampleRate * channels) >> 3 - int byteRate = CFSwapInt32HostToLittle(WHISPER_SAMPLE_RATE * 1 * 16 / 8); - [outputData appendBytes:&byteRate length:sizeof(byteRate)]; - - // (bitDepth * channels) >> 3 - short blockAlign = CFSwapInt16HostToLittle(16 / 8); - [outputData appendBytes:&blockAlign length:sizeof(blockAlign)]; - - // bitDepth - short bitsPerSample = CFSwapInt16HostToLittle(16); - [outputData appendBytes:&bitsPerSample length:sizeof(bitsPerSample)]; - - [outputData appendData:[@"data" dataUsingEncoding:NSUTF8StringEncoding]]; // subchunk 2 id - int subchunk2Size = CFSwapInt32HostToLittle((int)rawData.length); - [outputData appendBytes:&subchunk2Size length:sizeof(subchunk2Size)]; - - // Audio data - [outputData appendData:rawData]; - - // Save to file - [outputData writeToFile:audioOutputFile atomically:YES]; -} - + (float *)decodeWaveFile:(NSString*)filePath count:(int *)count { NSURL *url = [NSURL fileURLWithPath:filePath]; NSData *fileData = [NSData dataWithContentsOfURL:url]; diff --git a/ios/RNWhisperContext.h b/ios/RNWhisperContext.h index e16d188..a029dfd 100644 --- a/ios/RNWhisperContext.h +++ b/ios/RNWhisperContext.h @@ -20,7 +20,7 @@ typedef struct { bool isCapturing; bool isStoppedByAction; int nSamplesTranscribing; - NSMutableArray *sliceNSamples; + std::vector sliceNSamples; bool isUseSlices; int sliceIndex; int transcribeSliceIndex; diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index ba095fc..65206e4 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -1,5 +1,4 @@ #import "RNWhisperContext.h" -#import "RNWhisperAudioUtils.h" #import #include @@ -113,16 +112,14 @@ - (void)prepareRealtime:(int)jobId options:(NSDictionary *)options { self->recordState.isCapturing = false; self->recordState.isStoppedByAction = false; - self->recordState.audioOutputPath = options[@"audioOutputPath"]; - self->recordState.sliceIndex = 0; self->recordState.transcribeSliceIndex = 0; self->recordState.nSamplesTranscribing = 0; - self->recordState.sliceNSamples = [NSMutableArray new]; - [self->recordState.sliceNSamples addObject:[NSNumber numberWithInt:0]]; + self->recordState.sliceNSamples.push_back(0); self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]); + std::string audio_output_path = options[@"audioOutputPath"] != nil ? [options[@"audioOutputPath"] UTF8String] : ""; self->recordState.job->set_realtime_params( { .use_vad = options[@"useVad"] != nil ? [options[@"useVad"] boolValue] : false, @@ -131,7 +128,8 @@ - (void)prepareRealtime:(int)jobId options:(NSDictionary *)options { .freq_thold = options[@"vadFreqThold"] != nil ? [options[@"vadFreqThold"] floatValue] : 100.0f }, options[@"realtimeAudioSec"] != nil ? [options[@"realtimeAudioSec"] intValue] : 0, - options[@"realtimeAudioSliceSec"] != nil ? [options[@"realtimeAudioSliceSec"] intValue] : 0 + options[@"realtimeAudioSliceSec"] != nil ? [options[@"realtimeAudioSliceSec"] intValue] : 0, + options[@"audioOutputPath"] != nil ? &audio_output_path : nullptr ); self->recordState.isUseSlices = self->recordState.job->audio_slice_sec < self->recordState.job->audio_sec; @@ -162,13 +160,13 @@ void AudioInputCallback(void * inUserData, } int totalNSamples = 0; - for (int i = 0; i < [state->sliceNSamples count]; i++) { - totalNSamples += [[state->sliceNSamples objectAtIndex:i] intValue]; + for (int i = 0; i < state->sliceNSamples.size(); i++) { + totalNSamples += state->sliceNSamples[i]; } const int n = inBuffer->mAudioDataByteSize / 2; - int nSamples = [state->sliceNSamples[state->sliceIndex] intValue]; + int nSamples = state->sliceNSamples[state->sliceIndex]; if (totalNSamples + n > state->job->audio_sec * WHISPER_SAMPLE_RATE) { NSLog(@"[RNWhisper] Audio buffer is full, stop capturing"); @@ -200,7 +198,7 @@ void AudioInputCallback(void * inUserData, // next slice state->sliceIndex++; nSamples = 0; - [state->sliceNSamples addObject:[NSNumber numberWithInt:0]]; + state->sliceNSamples.push_back(0); } // Append to buffer @@ -210,7 +208,7 @@ void AudioInputCallback(void * inUserData, bool isSpeech = vad(state, state->sliceIndex, nSamples, n); nSamples += n; - state->sliceNSamples[state->sliceIndex] = [NSNumber numberWithInt:nSamples]; + state->sliceNSamples[state->sliceIndex] = nSamples; AudioQueueEnqueueBuffer(state->queue, inBuffer, 0, NULL); @@ -226,13 +224,12 @@ void AudioInputCallback(void * inUserData, - (void)finishRealtimeTranscribe:(RNWhisperContextRecordState*) state result:(NSDictionary*)result { // Save wav if needed - if (state->audioOutputPath != nil) { + if (state->job->audio_output_path != nullptr) { // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage - // [RNWhisperAudioUtils - // saveWavFile:[RNWhisperAudioUtils concatShortBuffers:state->shortBufferSlices - // sliceNSamples:state->sliceNSamples] - // audioOutputFile:state->audioOutputPath - // ]; + rnaudioutils::save_wav_file( + rnaudioutils::concat_short_buffers(state->job->pcm_slices, state->sliceNSamples), + *state->job->audio_output_path + ); } state->transcribeHandler(state->job->job_id, @"end", result); state->job->free_pcm_slices(); @@ -240,7 +237,7 @@ - (void)finishRealtimeTranscribe:(RNWhisperContextRecordState*) state result:(NS } - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state { - int nSamplesOfIndex = [[state->sliceNSamples objectAtIndex:state->transcribeSliceIndex] intValue]; + int nSamplesOfIndex = state->sliceNSamples[state->transcribeSliceIndex]; state->nSamplesTranscribing = nSamplesOfIndex; NSLog(@"[RNWhisper] Transcribing %d samples", state->nSamplesTranscribing); @@ -268,7 +265,7 @@ - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state { result[@"error"] = [NSString stringWithFormat:@"Transcribe failed with code %d", code]; } - nSamplesOfIndex = [[state->sliceNSamples objectAtIndex:state->transcribeSliceIndex] intValue]; + nSamplesOfIndex = state->sliceNSamples[state->transcribeSliceIndex]; bool isStopped = state->isStoppedByAction || ( !state->isCapturing && From ced37c636b936a5b52739380b9e47c945eff116e Mon Sep 17 00:00:00 2001 From: Jhen Date: Fri, 8 Dec 2023 18:04:35 +0800 Subject: [PATCH 12/19] feat(docs): update --- cpp/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/README.md b/cpp/README.md index c947f95..c0efae8 100644 --- a/cpp/README.md +++ b/cpp/README.md @@ -1,4 +1,4 @@ # Note -- Only `rn-whisper.h` / `rn-whisper.cpp` are the specific files for this project, others are sync from [whisper.cpp](https://github.com/ggerganov/whisper.cpp). +- Only `rn-*` are the specific files for this project, others are sync from [whisper.cpp](https://github.com/ggerganov/whisper.cpp). - We can update the native source by using the [bootstrap](../scripts/bootstrap.sh) script. From fa97384b7c04d7b5b6a1f42b55ee52184bda0351 Mon Sep 17 00:00:00 2001 From: Jhen Date: Fri, 8 Dec 2023 18:04:51 +0800 Subject: [PATCH 13/19] feat(android): cleanup unnecessary arguments --- .../java/com/rnwhisper/WhisperContext.java | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index fdf9265..e757835 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -77,12 +77,12 @@ private void rewind() { fullHandler = null; } - private boolean vad(ReadableMap options, int sliceIndex, int nSamples, int n) { + private boolean vad(int sliceIndex, int nSamples, int n) { if (isTranscribing) return true; return vadSimple(jobId, sliceIndex, nSamples, n); } - private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) { + private void finishRealtimeTranscribe(WritableMap result) { emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap()); finishRealtimeTranscribeJob(jobId, context, sliceNSamples.stream().mapToInt(i -> i).toArray()); } @@ -142,14 +142,14 @@ public void run() { nSamples == nSamplesTranscribing && sliceIndex == transcribeSliceIndex ) { - finishRealtimeTranscribe(options, Arguments.createMap()); + finishRealtimeTranscribe(Arguments.createMap()); } else if (!isTranscribing) { - if (!vad(options, sliceIndex, nSamples, 0)) { - finishRealtimeTranscribe(options, Arguments.createMap()); + if (!vad(sliceIndex, nSamples, 0)) { + finishRealtimeTranscribe(Arguments.createMap()); break; } isTranscribing = true; - fullTranscribeSamples(options, true); + fullTranscribeSamples(true); } break; } @@ -164,7 +164,7 @@ public void run() { } putPcmData(buffer, sliceIndex, nSamples, n); - boolean isSpeech = vad(options, sliceIndex, nSamples, n); + boolean isSpeech = vad(sliceIndex, nSamples, n); nSamples += n; sliceNSamples.set(sliceIndex, nSamples); @@ -176,7 +176,7 @@ public void run() { fullHandler = new Thread(new Runnable() { @Override public void run() { - fullTranscribeSamples(options, false); + fullTranscribeSamples(false); } }); fullHandler.start(); @@ -187,7 +187,7 @@ public void run() { } if (!isTranscribing) { - finishRealtimeTranscribe(options, Arguments.createMap()); + finishRealtimeTranscribe(Arguments.createMap()); } if (fullHandler != null) { fullHandler.join(); // Wait for full transcribe to finish @@ -205,15 +205,12 @@ public void run() { return state; } - private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingCheck) { + private void fullTranscribeSamples(boolean skipCapturingCheck) { int nSamplesOfIndex = sliceNSamples.get(transcribeSliceIndex); if (!isCapturing && !skipCapturingCheck) return; - int nSamples = sliceNSamples.get(transcribeSliceIndex); - nSamplesTranscribing = nSamplesOfIndex; - Log.d(NAME, "Start transcribing realtime: " + nSamplesTranscribing); int timeStart = (int) System.currentTimeMillis(); @@ -254,7 +251,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe if (isStopped && !continueNeeded) { payload.putBoolean("isCapturing", false); payload.putBoolean("isStoppedByAction", isStoppedByAction); - finishRealtimeTranscribe(options, payload); + finishRealtimeTranscribe(payload); } else if (code == 0) { payload.putBoolean("isCapturing", true); emitTranscribeEvent("@RNWhisper_onRealtimeTranscribe", payload); @@ -265,7 +262,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe if (continueNeeded) { // If no more capturing, continue transcribing until all slices are transcribed - fullTranscribeSamples(options, true); + fullTranscribeSamples(true); } else if (isStopped) { // No next, cleanup rewind(); @@ -477,10 +474,11 @@ private static String cpuInfo() { } } - + // JNI methods protected static native long initContext(String modelPath); protected static native long initContextWithAsset(AssetManager assetManager, String modelPath); protected static native long initContextWithInputStream(PushbackInputStream inputStream); + protected static native void freeContext(long contextPtr); protected static native int fullWithNewJob( int job_id, @@ -490,6 +488,13 @@ protected static native int fullWithNewJob( ReadableMap options, Callback Callback ); + protected static native void abortTranscribe(int jobId); + protected static native void abortAllTranscribe(); + protected static native int getTextSegmentCount(long context); + protected static native String getTextSegment(long context, int index); + protected static native int getTextSegmentT0(long context, int index); + protected static native int getTextSegmentT1(long context, int index); + protected static native void createRealtimeTranscribeJob( int job_id, long context, @@ -504,11 +509,4 @@ protected static native int fullWithJob( int slice_index, int n_samples ); - protected static native void abortTranscribe(int jobId); - protected static native void abortAllTranscribe(); - protected static native int getTextSegmentCount(long context); - protected static native String getTextSegment(long context, int index); - protected static native int getTextSegmentT0(long context, int index); - protected static native int getTextSegmentT1(long context, int index); - protected static native void freeContext(long contextPtr); } From 7edc11f46b299e5edebd1a6348fcf9b74182c3b8 Mon Sep 17 00:00:00 2001 From: Jhen Date: Fri, 8 Dec 2023 18:45:53 +0800 Subject: [PATCH 14/19] feat(android): keep todo --- android/src/main/jni.cpp | 1 + ios/RNWhisperContext.mm | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index b2fe8aa..a114936 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -362,6 +362,7 @@ Java_com_rnwhisper_WhisperContext_finishRealtimeTranscribeJob( slice_n_samples_vec = std::vector(slice_n_samples_arr, slice_n_samples_arr + env->GetArrayLength(slice_n_samples)); env->ReleaseIntArrayElements(slice_n_samples, slice_n_samples_arr, JNI_ABORT); + // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage rnaudioutils::save_wav_file( rnaudioutils::concat_short_buffers(job->pcm_slices, slice_n_samples_vec), *job->audio_output_path diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index 65206e4..8243438 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -201,7 +201,6 @@ void AudioInputCallback(void * inUserData, state->sliceNSamples.push_back(0); } - // Append to buffer NSLog(@"[RNWhisper] Slice %d has %d samples", state->sliceIndex, nSamples); state->job->put_pcm_data((short*) inBuffer->mAudioData, state->sliceIndex, nSamples, n); From 759bcd5ab09c97df517cc04b478ffeb699f4e113 Mon Sep 17 00:00:00 2001 From: Jhen Date: Sat, 9 Dec 2023 08:28:21 +0800 Subject: [PATCH 15/19] fix(cpp): store job pointer instead --- android/src/main/jni.cpp | 1 - cpp/rn-whisper.cpp | 44 +++++++++++++++++++++------------------- cpp/rn-whisper.h | 1 - ios/RNWhisperContext.mm | 1 - 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index a114936..665d17e 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -368,7 +368,6 @@ Java_com_rnwhisper_WhisperContext_finishRealtimeTranscribeJob( *job->audio_output_path ); } - job->free_pcm_slices(); rnwhisper::job_remove(job_id); } diff --git a/cpp/rn-whisper.cpp b/cpp/rn-whisper.cpp index 4357c69..8c0cc54 100644 --- a/cpp/rn-whisper.cpp +++ b/cpp/rn-whisper.cpp @@ -59,10 +59,6 @@ bool vad_simple_impl(std::vector & pcmf32, int sample_rate, int last_ms, return true; } -job::~job() { - fprintf(stderr, "%s: job_id: %d\n", __func__, job_id); -} - void job::set_realtime_params( vad_params params, int sec, @@ -114,13 +110,6 @@ float* job::pcm_slice_to_f32(int slice_index, int size) { return nullptr; } -void job::free_pcm_slices() { - for (size_t i = 0; i < pcm_slices.size(); i++) { - delete[] pcm_slices[i]; - } - pcm_slices.clear(); -} - bool job::is_aborted() { return aborted; } @@ -129,43 +118,56 @@ void job::abort() { aborted = true; } -std::unordered_map job_map; +job::~job() { + fprintf(stderr, "%s: job_id: %d\n", __func__, job_id); + + for (size_t i = 0; i < pcm_slices.size(); i++) { + delete[] pcm_slices[i]; + } + pcm_slices.clear(); +} + +std::unordered_map job_map; void job_abort_all() { for (auto it = job_map.begin(); it != job_map.end(); ++it) { - it->second.abort(); + it->second->abort(); } } job* job_new(int job_id, struct whisper_full_params params) { - job ctx; - ctx.job_id = job_id; - ctx.params = params; + job* ctx = new job(); + ctx->job_id = job_id; + ctx->params = params; + + job_map[job_id] = ctx; // Abort handler params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { job *j = (job*)user_data; return !j->is_aborted(); }; - params.encoder_begin_callback_user_data = &ctx; + params.encoder_begin_callback_user_data = job_map[job_id]; params.abort_callback = [](void * user_data) { job *j = (job*)user_data; return j->is_aborted(); }; - params.abort_callback_user_data = &ctx; + params.abort_callback_user_data = job_map[job_id]; - job_map[job_id] = ctx; - return &job_map[job_id]; + return job_map[job_id]; } job* job_get(int job_id) { if (job_map.find(job_id) != job_map.end()) { - return &job_map[job_id]; + return job_map[job_id]; } return nullptr; } void job_remove(int job_id) { + if (job_map.find(job_id) != job_map.end()) { + delete job_map[job_id]; + } job_map.erase(job_id); } diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index 9a2bc40..306278a 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -36,7 +36,6 @@ struct job { bool vad_simple(int slice_index, int n_samples, int n); void put_pcm_data(short* pcm, int slice_index, int n_samples, int n); float* pcm_slice_to_f32(int slice_index, int size); - void free_pcm_slices(); }; void job_abort_all(); diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index 8243438..6e4e269 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -231,7 +231,6 @@ - (void)finishRealtimeTranscribe:(RNWhisperContextRecordState*) state result:(NS ); } state->transcribeHandler(state->job->job_id, @"end", result); - state->job->free_pcm_slices(); rnwhisper::job_remove(state->job->job_id); } From b3aacbbeddc1fed26e82f2bcd001f53e0b681be0 Mon Sep 17 00:00:00 2001 From: Jhen Date: Sat, 9 Dec 2023 08:37:17 +0800 Subject: [PATCH 16/19] feat(cpp): add custom log for easy debug in android --- android/src/main/CMakeLists.txt | 4 ++++ cpp/rn-audioutils.cpp | 3 ++- cpp/rn-whisper-log.h | 11 +++++++++++ cpp/rn-whisper.cpp | 4 ++-- cpp/rn-whisper.h | 1 + 5 files changed, 20 insertions(+), 3 deletions(-) create mode 100644 cpp/rn-whisper-log.h diff --git a/android/src/main/CMakeLists.txt b/android/src/main/CMakeLists.txt index 0e26f28..d7583aa 100644 --- a/android/src/main/CMakeLists.txt +++ b/android/src/main/CMakeLists.txt @@ -34,6 +34,10 @@ function(build_library target_name) target_compile_options(${target_name} PRIVATE -mfpu=neon-vfpv4) endif () + if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") + target_compile_options(${target_name} PRIVATE -DRNWHISPER_ANDROID_ENABLE_LOGGING) + endif () + # NOTE: If you want to debug the native code, you can uncomment if and endif # if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug") diff --git a/cpp/rn-audioutils.cpp b/cpp/rn-audioutils.cpp index e2b554e..ae7d3e4 100644 --- a/cpp/rn-audioutils.cpp +++ b/cpp/rn-audioutils.cpp @@ -1,4 +1,5 @@ #include "rn-audioutils.h" +#include "rn-whisper-log.h" namespace rnaudioutils { @@ -30,7 +31,7 @@ void save_wav_file(const std::vector& raw, const std::string& file) { std::ofstream output(file, std::ios::binary); if (!output.is_open()) { - std::cerr << "Failed to open file for writing: " << file << std::endl; + RNWHISPER_LOG_ERROR("Failed to open file for writing: %s\n", file.c_str()); return; } diff --git a/cpp/rn-whisper-log.h b/cpp/rn-whisper-log.h new file mode 100644 index 0000000..f05e758 --- /dev/null +++ b/cpp/rn-whisper-log.h @@ -0,0 +1,11 @@ +#if defined(__ANDROID__) && defined(RNWHISPER_ANDROID_ENABLE_LOGGING) +#include +#define RNWHISPER_ANDROID_TAG "RNWHISPER_LOG_ANDROID" +#define RNWHISPER_LOG_INFO(...) __android_log_print(ANDROID_LOG_INFO , WHISPER_ANDROID_TAG, __VA_ARGS__) +#define RNWHISPER_LOG_WARN(...) __android_log_print(ANDROID_LOG_WARN , WHISPER_ANDROID_TAG, __VA_ARGS__) +#define RNWHISPER_LOG_ERROR(...) __android_log_print(ANDROID_LOG_ERROR, WHISPER_ANDROID_TAG, __VA_ARGS__) +#else +#define RNWHISPER_LOG_INFO(...) fprintf(stderr, __VA_ARGS__) +#define RNWHISPER_LOG_WARN(...) fprintf(stderr, __VA_ARGS__) +#define RNWHISPER_LOG_ERROR(...) fprintf(stderr, __VA_ARGS__) +#endif // __ANDROID__ diff --git a/cpp/rn-whisper.cpp b/cpp/rn-whisper.cpp index 8c0cc54..e1855c8 100644 --- a/cpp/rn-whisper.cpp +++ b/cpp/rn-whisper.cpp @@ -49,7 +49,7 @@ bool vad_simple_impl(std::vector & pcmf32, int sample_rate, int last_ms, energy_last /= n_samples_last; if (verbose) { - fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold); + RNWHISPER_LOG_INFO("%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold); } if (energy_last > vad_thold*energy_all) { @@ -119,7 +119,7 @@ void job::abort() { } job::~job() { - fprintf(stderr, "%s: job_id: %d\n", __func__, job_id); + RNWHISPER_LOG_INFO("rnwhisper::job::%s: job_id: %d\n", __func__, job_id); for (size_t i = 0; i < pcm_slices.size(); i++) { delete[] pcm_slices[i]; diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index 306278a..7516a19 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -4,6 +4,7 @@ #include #include #include "whisper.h" +#include "rn-whisper-log.h" #include "rn-audioutils.h" namespace rnwhisper { From 3a78b10adcedbcb00987d727c631cda4067cc291 Mon Sep 17 00:00:00 2001 From: Jhen Date: Sat, 9 Dec 2023 09:22:53 +0800 Subject: [PATCH 17/19] fix(android): build --- android/src/main/java/com/rnwhisper/WhisperContext.java | 4 ++-- cpp/rn-whisper-log.h | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index e757835..256a148 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -162,7 +162,7 @@ public void run() { nSamples = 0; sliceNSamples.add(0); } - putPcmData(buffer, sliceIndex, nSamples, n); + putPcmData(jobId, buffer, sliceIndex, nSamples, n); boolean isSpeech = vad(sliceIndex, nSamples, n); @@ -502,7 +502,7 @@ protected static native void createRealtimeTranscribeJob( ); protected static native void finishRealtimeTranscribeJob(int job_id, long context, int[] sliceNSamples); protected static native boolean vadSimple(int job_id, int slice_index, int n_samples, int n); - protected static native void putPcmData(short[] buffer, int slice_index, int n_samples, int n); + protected static native void putPcmData(int job_id, short[] buffer, int slice_index, int n_samples, int n); protected static native int fullWithJob( int job_id, long context, diff --git a/cpp/rn-whisper-log.h b/cpp/rn-whisper-log.h index f05e758..61858f2 100644 --- a/cpp/rn-whisper-log.h +++ b/cpp/rn-whisper-log.h @@ -1,9 +1,9 @@ #if defined(__ANDROID__) && defined(RNWHISPER_ANDROID_ENABLE_LOGGING) #include #define RNWHISPER_ANDROID_TAG "RNWHISPER_LOG_ANDROID" -#define RNWHISPER_LOG_INFO(...) __android_log_print(ANDROID_LOG_INFO , WHISPER_ANDROID_TAG, __VA_ARGS__) -#define RNWHISPER_LOG_WARN(...) __android_log_print(ANDROID_LOG_WARN , WHISPER_ANDROID_TAG, __VA_ARGS__) -#define RNWHISPER_LOG_ERROR(...) __android_log_print(ANDROID_LOG_ERROR, WHISPER_ANDROID_TAG, __VA_ARGS__) +#define RNWHISPER_LOG_INFO(...) __android_log_print(ANDROID_LOG_INFO , RNWHISPER_ANDROID_TAG, __VA_ARGS__) +#define RNWHISPER_LOG_WARN(...) __android_log_print(ANDROID_LOG_WARN , RNWHISPER_ANDROID_TAG, __VA_ARGS__) +#define RNWHISPER_LOG_ERROR(...) __android_log_print(ANDROID_LOG_ERROR, RNWHISPER_ANDROID_TAG, __VA_ARGS__) #else #define RNWHISPER_LOG_INFO(...) fprintf(stderr, __VA_ARGS__) #define RNWHISPER_LOG_WARN(...) fprintf(stderr, __VA_ARGS__) From 31e09aa9f1ae306e3e0eab3c99d8fb0191443b18 Mon Sep 17 00:00:00 2001 From: Jhen Date: Sat, 9 Dec 2023 10:29:24 +0800 Subject: [PATCH 18/19] fix(cpp): str params should not be released early --- android/src/main/jni.cpp | 13 +++++++------ cpp/rn-audioutils.cpp | 2 ++ cpp/rn-whisper.cpp | 2 +- cpp/rn-whisper.h | 4 ++-- ios/RNWhisperContext.mm | 5 ++--- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 665d17e..360eb3d 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -236,12 +236,12 @@ struct whisper_full_params createFullParams(JNIEnv *env, jobject options) { jstring prompt = readablemap::getString(env, options, "prompt", nullptr); if (prompt != nullptr) { params.initial_prompt = env->GetStringUTFChars(prompt, nullptr); - env->ReleaseStringUTFChars(prompt, params.initial_prompt); + env->DeleteLocalRef(prompt); } jstring language = readablemap::getString(env, options, "language", nullptr); if (language != nullptr) { params.language = env->GetStringUTFChars(language, nullptr); - env->ReleaseStringUTFChars(language, params.language); + env->DeleteLocalRef(language); } return params; } @@ -330,10 +330,10 @@ Java_com_rnwhisper_WhisperContext_createRealtimeTranscribeJob( vad.freq_thold = readablemap::getFloat(env, options, "vadFreqThold", 100.0f); jstring audio_output_path = readablemap::getString(env, options, "audioOutputPath", nullptr); - std::string *audio_output_path_str = nullptr; + const char* audio_output_path_str = nullptr; if (audio_output_path != nullptr) { - audio_output_path_str = new std::string(env->GetStringUTFChars(audio_output_path, nullptr)); - env->ReleaseStringUTFChars(audio_output_path, audio_output_path_str->c_str()); + audio_output_path_str = env->GetStringUTFChars(audio_output_path, nullptr); + env->DeleteLocalRef(audio_output_path); } job->set_realtime_params( vad, @@ -357,6 +357,7 @@ Java_com_rnwhisper_WhisperContext_finishRealtimeTranscribeJob( rnwhisper::job *job = rnwhisper::job_get(job_id); if (job->audio_output_path != nullptr) { + RNWHISPER_LOG_INFO("job->params.language: %s\n", job->params.language); std::vector slice_n_samples_vec; jint *slice_n_samples_arr = env->GetIntArrayElements(slice_n_samples, nullptr); slice_n_samples_vec = std::vector(slice_n_samples_arr, slice_n_samples_arr + env->GetArrayLength(slice_n_samples)); @@ -365,7 +366,7 @@ Java_com_rnwhisper_WhisperContext_finishRealtimeTranscribeJob( // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage rnaudioutils::save_wav_file( rnaudioutils::concat_short_buffers(job->pcm_slices, slice_n_samples_vec), - *job->audio_output_path + job->audio_output_path ); } rnwhisper::job_remove(job_id); diff --git a/cpp/rn-audioutils.cpp b/cpp/rn-audioutils.cpp index ae7d3e4..292a704 100644 --- a/cpp/rn-audioutils.cpp +++ b/cpp/rn-audioutils.cpp @@ -61,6 +61,8 @@ void save_wav_file(const std::vector& raw, const std::string& file) { output.write(reinterpret_cast(data.data()), data.size()); output.close(); + + RNWHISPER_LOG_INFO("Saved audio file: %s\n", file.c_str()); } } // namespace rnaudioutils diff --git a/cpp/rn-whisper.cpp b/cpp/rn-whisper.cpp index e1855c8..31c549b 100644 --- a/cpp/rn-whisper.cpp +++ b/cpp/rn-whisper.cpp @@ -63,7 +63,7 @@ void job::set_realtime_params( vad_params params, int sec, int slice_sec, - std::string* output_path + const char* output_path ) { vad = params; if (vad.vad_ms < 2000) vad.vad_ms = 2000; diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index 7516a19..5daa90c 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -31,9 +31,9 @@ struct job { vad_params vad; int audio_sec = 0; int audio_slice_sec = 0; - std::string* audio_output_path = nullptr; + const char* audio_output_path = nullptr; std::vector pcm_slices; - void set_realtime_params(vad_params vad, int audio_sec, int audio_slice_sec, std::string* audio_output_path); + void set_realtime_params(vad_params vad, int sec, int slice_sec, const char* output_path); bool vad_simple(int slice_index, int n_samples, int n); void put_pcm_data(short* pcm, int slice_index, int n_samples, int n); float* pcm_slice_to_f32(int slice_index, int size); diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index 6e4e269..d7ab52e 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -119,7 +119,6 @@ - (void)prepareRealtime:(int)jobId options:(NSDictionary *)options { self->recordState.sliceNSamples.push_back(0); self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]); - std::string audio_output_path = options[@"audioOutputPath"] != nil ? [options[@"audioOutputPath"] UTF8String] : ""; self->recordState.job->set_realtime_params( { .use_vad = options[@"useVad"] != nil ? [options[@"useVad"] boolValue] : false, @@ -129,7 +128,7 @@ - (void)prepareRealtime:(int)jobId options:(NSDictionary *)options { }, options[@"realtimeAudioSec"] != nil ? [options[@"realtimeAudioSec"] intValue] : 0, options[@"realtimeAudioSliceSec"] != nil ? [options[@"realtimeAudioSliceSec"] intValue] : 0, - options[@"audioOutputPath"] != nil ? &audio_output_path : nullptr + options[@"audioOutputPath"] != nil ? [options[@"audioOutputPath"] UTF8String] : nullptr ); self->recordState.isUseSlices = self->recordState.job->audio_slice_sec < self->recordState.job->audio_sec; @@ -227,7 +226,7 @@ - (void)finishRealtimeTranscribe:(RNWhisperContextRecordState*) state result:(NS // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage rnaudioutils::save_wav_file( rnaudioutils::concat_short_buffers(state->job->pcm_slices, state->sliceNSamples), - *state->job->audio_output_path + state->job->audio_output_path ); } state->transcribeHandler(state->job->job_id, @"end", result); From f07361801a1a654a1ed06d1f1e768d5bf37a869c Mon Sep 17 00:00:00 2001 From: Jhen Date: Sat, 9 Dec 2023 10:34:24 +0800 Subject: [PATCH 19/19] fix(example): revert some unnecessary change --- example/src/App.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/src/App.tsx b/example/src/App.tsx index 83de489..df09f8d 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -225,13 +225,13 @@ export default function App() { log('Start transcribing...') const startTime = Date.now() const { stop, promise } = whisperContext.transcribe(sampleFile, { - language: 'zh', - prompt: 'HELLO WORLD', maxLen: 1, tokenTimestamps: true, onProgress: (cur) => { log(`Transcribing progress: ${cur}%`) }, + language: 'en', + // prompt: 'HELLO WORLD', // onNewSegments: (segments) => { // console.log('New segments:', segments) // },