diff --git a/android/src/main/CMakeLists.txt b/android/src/main/CMakeLists.txt index febe580..d7583aa 100644 --- a/android/src/main/CMakeLists.txt +++ b/android/src/main/CMakeLists.txt @@ -12,6 +12,7 @@ set( ${RNWHISPER_LIB_DIR}/ggml-backend.c ${RNWHISPER_LIB_DIR}/ggml-quants.c ${RNWHISPER_LIB_DIR}/whisper.cpp + ${RNWHISPER_LIB_DIR}/rn-audioutils.cpp ${RNWHISPER_LIB_DIR}/rn-whisper.cpp ${CMAKE_SOURCE_DIR}/jni.cpp ) @@ -33,6 +34,10 @@ function(build_library target_name) target_compile_options(${target_name} PRIVATE -mfpu=neon-vfpv4) endif () + if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") + target_compile_options(${target_name} PRIVATE -DRNWHISPER_ANDROID_ENABLE_LOGGING) + endif () + # NOTE: If you want to debug the native code, you can uncomment if and endif # if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug") diff --git a/android/src/main/java/com/rnwhisper/AudioUtils.java b/android/src/main/java/com/rnwhisper/AudioUtils.java index 4498a79..b6c614d 100644 --- a/android/src/main/java/com/rnwhisper/AudioUtils.java +++ b/android/src/main/java/com/rnwhisper/AudioUtils.java @@ -2,14 +2,10 @@ import android.util.Log; -import java.util.ArrayList; -import java.lang.StringBuilder; import java.io.IOException; import java.io.FileReader; import java.io.ByteArrayOutputStream; import java.io.File; -import java.io.FileOutputStream; -import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; @@ -19,82 +15,6 @@ public class AudioUtils { private static final String NAME = "RNWhisperAudioUtils"; - private static final int SAMPLE_RATE = 16000; - - private static byte[] shortToByte(short[] shortInts) { - int j = 0; - int length = shortInts.length; - byte[] byteData = new byte[length * 2]; - for (int i = 0; i < length; i++) { - byteData[j++] = (byte) (shortInts[i] >>> 8); - byteData[j++] = (byte) (shortInts[i] >>> 0); - } - return byteData; - } - - public static byte[] concatShortBuffers(ArrayList buffers) { - int totalLength = 0; - for (int i = 0; i < buffers.size(); i++) { - totalLength += buffers.get(i).length; - } - byte[] result = new byte[totalLength * 2]; - int offset = 0; - for (int i = 0; i < buffers.size(); i++) { - byte[] bytes = shortToByte(buffers.get(i)); - System.arraycopy(bytes, 0, result, offset, bytes.length); - offset += bytes.length; - } - - return result; - } - - private static byte[] removeTrailingZeros(byte[] audioData) { - int i = audioData.length - 1; - while (i >= 0 && audioData[i] == 0) { - --i; - } - byte[] newData = new byte[i + 1]; - System.arraycopy(audioData, 0, newData, 0, i + 1); - return newData; - } - - public static void saveWavFile(byte[] rawData, String audioOutputFile) throws IOException { - Log.d(NAME, "call saveWavFile"); - rawData = removeTrailingZeros(rawData); - DataOutputStream output = null; - try { - output = new DataOutputStream(new FileOutputStream(audioOutputFile)); - // WAVE header - // see http://ccrma.stanford.edu/courses/422/projects/WaveFormat/ - output.writeBytes("RIFF"); // chunk id - output.writeInt(Integer.reverseBytes(36 + rawData.length)); // chunk size - output.writeBytes("WAVE"); // format - output.writeBytes("fmt "); // subchunk 1 id - output.writeInt(Integer.reverseBytes(16)); // subchunk 1 size - output.writeShort(Short.reverseBytes((short) 1)); // audio format (1 = PCM) - output.writeShort(Short.reverseBytes((short) 1)); // number of channels - output.writeInt(Integer.reverseBytes(SAMPLE_RATE)); // sample rate - output.writeInt(Integer.reverseBytes(SAMPLE_RATE * 2)); // byte rate - output.writeShort(Short.reverseBytes((short) 2)); // block align - output.writeShort(Short.reverseBytes((short) 16)); // bits per sample - output.writeBytes("data"); // subchunk 2 id - output.writeInt(Integer.reverseBytes(rawData.length)); // subchunk 2 size - // Audio data (conversion big endian -> little endian) - short[] shorts = new short[rawData.length / 2]; - ByteBuffer.wrap(rawData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(shorts); - ByteBuffer bytes = ByteBuffer.allocate(shorts.length * 2); - for (short s : shorts) { - bytes.putShort(s); - } - Log.d(NAME, "writing audio file: " + audioOutputFile); - output.write(bytes.array()); - } finally { - if (output != null) { - output.close(); - } - } - } - public static float[] decodeWaveFile(InputStream inputStream) throws IOException { ByteArrayOutputStream baos = new ByteArrayOutputStream(); byte[] buffer = new byte[1024]; diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index bf1c7ee..256a148 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -42,7 +42,6 @@ public class WhisperContext { private AudioRecord recorder = null; private int bufferSize; private int nSamplesTranscribing = 0; - private ArrayList shortBufferSlices; // Remember number of samples in each slice private ArrayList sliceNSamples; // Current buffer slice index @@ -66,7 +65,6 @@ public WhisperContext(int id, ReactApplicationContext reactContext, long context } private void rewind() { - shortBufferSlices = null; sliceNSamples = null; sliceIndex = 0; transcribeSliceIndex = 0; @@ -79,41 +77,14 @@ private void rewind() { fullHandler = null; } - private boolean vad(ReadableMap options, short[] shortBuffer, int nSamples, int n) { - boolean isSpeech = true; - if (!isTranscribing && options.hasKey("useVad") && options.getBoolean("useVad")) { - int vadMs = options.hasKey("vadMs") ? options.getInt("vadMs") : 2000; - if (vadMs < 2000) vadMs = 2000; - int sampleSize = (int) (SAMPLE_RATE * vadMs / 1000); - if (nSamples + n > sampleSize) { - int start = nSamples + n - sampleSize; - float[] audioData = new float[sampleSize]; - for (int i = 0; i < sampleSize; i++) { - audioData[i] = shortBuffer[i + start] / 32768.0f; - } - float vadThold = options.hasKey("vadThold") ? (float) options.getDouble("vadThold") : 0.6f; - float vadFreqThold = options.hasKey("vadFreqThold") ? (float) options.getDouble("vadFreqThold") : 0.6f; - isSpeech = vadSimple(audioData, sampleSize, vadThold, vadFreqThold); - } else { - isSpeech = false; - } - } - return isSpeech; + private boolean vad(int sliceIndex, int nSamples, int n) { + if (isTranscribing) return true; + return vadSimple(jobId, sliceIndex, nSamples, n); } - private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) { - String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null; - if (audioOutputPath != null) { - // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage - Log.d(NAME, "Begin saving wav file to " + audioOutputPath); - try { - AudioUtils.saveWavFile(AudioUtils.concatShortBuffers(shortBufferSlices), audioOutputPath); - } catch (IOException e) { - Log.e(NAME, "Error saving wav file: " + e.getMessage()); - } - } - + private void finishRealtimeTranscribe(WritableMap result) { emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap()); + finishRealtimeTranscribeJob(jobId, context, sliceNSamples.stream().mapToInt(i -> i).toArray()); } public int startRealtimeTranscribe(int jobId, ReadableMap options) { @@ -135,16 +106,12 @@ public int startRealtimeTranscribe(int jobId, ReadableMap options) { int realtimeAudioSec = options.hasKey("realtimeAudioSec") ? options.getInt("realtimeAudioSec") : 0; final int audioSec = realtimeAudioSec > 0 ? realtimeAudioSec : DEFAULT_MAX_AUDIO_SEC; - int realtimeAudioSliceSec = options.hasKey("realtimeAudioSliceSec") ? options.getInt("realtimeAudioSliceSec") : 0; final int audioSliceSec = realtimeAudioSliceSec > 0 && realtimeAudioSliceSec < audioSec ? realtimeAudioSliceSec : audioSec; - isUseSlices = audioSliceSec < audioSec; - String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null; + createRealtimeTranscribeJob(jobId, context, options); - shortBufferSlices = new ArrayList(); - shortBufferSlices.add(new short[audioSliceSec * SAMPLE_RATE]); sliceNSamples = new ArrayList(); sliceNSamples.add(0); @@ -175,37 +142,29 @@ public void run() { nSamples == nSamplesTranscribing && sliceIndex == transcribeSliceIndex ) { - finishRealtimeTranscribe(options, Arguments.createMap()); + finishRealtimeTranscribe(Arguments.createMap()); } else if (!isTranscribing) { - short[] shortBuffer = shortBufferSlices.get(sliceIndex); - boolean isSpeech = vad(options, shortBuffer, nSamples, 0); - if (!isSpeech) { - finishRealtimeTranscribe(options, Arguments.createMap()); + if (!vad(sliceIndex, nSamples, 0)) { + finishRealtimeTranscribe(Arguments.createMap()); break; } isTranscribing = true; - fullTranscribeSamples(options, true); + fullTranscribeSamples(true); } break; } // Append to buffer - short[] shortBuffer = shortBufferSlices.get(sliceIndex); if (nSamples + n > audioSliceSec * SAMPLE_RATE) { Log.d(NAME, "next slice"); sliceIndex++; nSamples = 0; - shortBuffer = new short[audioSliceSec * SAMPLE_RATE]; - shortBufferSlices.add(shortBuffer); sliceNSamples.add(0); } + putPcmData(jobId, buffer, sliceIndex, nSamples, n); - for (int i = 0; i < n; i++) { - shortBuffer[nSamples + i] = buffer[i]; - } - - boolean isSpeech = vad(options, shortBuffer, nSamples, n); + boolean isSpeech = vad(sliceIndex, nSamples, n); nSamples += n; sliceNSamples.set(sliceIndex, nSamples); @@ -217,7 +176,7 @@ public void run() { fullHandler = new Thread(new Runnable() { @Override public void run() { - fullTranscribeSamples(options, false); + fullTranscribeSamples(false); } }); fullHandler.start(); @@ -228,7 +187,7 @@ public void run() { } if (!isTranscribing) { - finishRealtimeTranscribe(options, Arguments.createMap()); + finishRealtimeTranscribe(Arguments.createMap()); } if (fullHandler != null) { fullHandler.join(); // Wait for full transcribe to finish @@ -246,26 +205,16 @@ public void run() { return state; } - private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingCheck) { + private void fullTranscribeSamples(boolean skipCapturingCheck) { int nSamplesOfIndex = sliceNSamples.get(transcribeSliceIndex); if (!isCapturing && !skipCapturingCheck) return; - short[] shortBuffer = shortBufferSlices.get(transcribeSliceIndex); - int nSamples = sliceNSamples.get(transcribeSliceIndex); - nSamplesTranscribing = nSamplesOfIndex; - - // convert I16 to F32 - float[] nSamplesBuffer32 = new float[nSamplesTranscribing]; - for (int i = 0; i < nSamplesTranscribing; i++) { - nSamplesBuffer32[i] = shortBuffer[i] / 32768.0f; - } - Log.d(NAME, "Start transcribing realtime: " + nSamplesTranscribing); int timeStart = (int) System.currentTimeMillis(); - int code = full(jobId, options, nSamplesBuffer32, nSamplesTranscribing); + int code = fullWithJob(jobId, context, transcribeSliceIndex, nSamplesTranscribing); int timeEnd = (int) System.currentTimeMillis(); int timeRecording = (int) (nSamplesTranscribing / SAMPLE_RATE * 1000); @@ -302,7 +251,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe if (isStopped && !continueNeeded) { payload.putBoolean("isCapturing", false); payload.putBoolean("isStoppedByAction", isStoppedByAction); - finishRealtimeTranscribe(options, payload); + finishRealtimeTranscribe(payload); } else if (code == 0) { payload.putBoolean("isCapturing", true); emitTranscribeEvent("@RNWhisper_onRealtimeTranscribe", payload); @@ -313,7 +262,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe if (continueNeeded) { // If no more capturing, continue transcribing until all slices are transcribed - fullTranscribeSamples(options, true); + fullTranscribeSamples(true); } else if (isStopped) { // No next, cleanup rewind(); @@ -383,32 +332,30 @@ public WritableMap transcribeInputStream(int jobId, InputStream inputStream, Rea this.jobId = jobId; isTranscribing = true; float[] audioData = AudioUtils.decodeWaveFile(inputStream); - int code = full(jobId, options, audioData, audioData.length); - isTranscribing = false; - this.jobId = -1; - if (code != 0 && code != 999) { - throw new Exception("Failed to transcribe the file. Code: " + code); - } - WritableMap result = getTextSegments(0, getTextSegmentCount(context)); - result.putBoolean("isAborted", isStoppedByAction); - return result; - } - private int full(int jobId, ReadableMap options, float[] audioData, int audioDataLen) { boolean hasProgressCallback = options.hasKey("onProgress") && options.getBoolean("onProgress"); boolean hasNewSegmentsCallback = options.hasKey("onNewSegments") && options.getBoolean("onNewSegments"); - return fullTranscribe( + int code = fullWithNewJob( jobId, context, // float[] audio_data, audioData, // jint audio_data_len, - audioDataLen, + audioData.length, // ReadableMap options, options, // Callback callback hasProgressCallback || hasNewSegmentsCallback ? new Callback(this, hasProgressCallback, hasNewSegmentsCallback) : null ); + + isTranscribing = false; + this.jobId = -1; + if (code != 0 && code != 999) { + throw new Exception("Failed to transcribe the file. Code: " + code); + } + WritableMap result = getTextSegments(0, getTextSegmentCount(context)); + result.putBoolean("isAborted", isStoppedByAction); + return result; } private WritableMap getTextSegments(int start, int count) { @@ -527,12 +474,13 @@ private static String cpuInfo() { } } - + // JNI methods protected static native long initContext(String modelPath); protected static native long initContextWithAsset(AssetManager assetManager, String modelPath); protected static native long initContextWithInputStream(PushbackInputStream inputStream); - protected static native boolean vadSimple(float[] audio_data, int audio_data_len, float vad_thold, float vad_freq_thold); - protected static native int fullTranscribe( + protected static native void freeContext(long contextPtr); + + protected static native int fullWithNewJob( int job_id, long context, float[] audio_data, @@ -546,5 +494,19 @@ protected static native int fullTranscribe( protected static native String getTextSegment(long context, int index); protected static native int getTextSegmentT0(long context, int index); protected static native int getTextSegmentT1(long context, int index); - protected static native void freeContext(long contextPtr); + + protected static native void createRealtimeTranscribeJob( + int job_id, + long context, + ReadableMap options + ); + protected static native void finishRealtimeTranscribeJob(int job_id, long context, int[] sliceNSamples); + protected static native boolean vadSimple(int job_id, int slice_index, int n_samples, int n); + protected static native void putPcmData(int job_id, short[] buffer, int slice_index, int n_samples, int n); + protected static native int fullWithJob( + int job_id, + long context, + int slice_index, + int n_samples + ); } diff --git a/android/src/main/jni-utils.h b/android/src/main/jni-utils.h index 419ce34..f4cf1a9 100644 --- a/android/src/main/jni-utils.h +++ b/android/src/main/jni-utils.h @@ -4,7 +4,7 @@ namespace readablemap { -jboolean hasKey(JNIEnv *env, jobject readableMap, const char *key) { +bool hasKey(JNIEnv *env, jobject readableMap, const char *key) { jclass mapClass = env->GetObjectClass(readableMap); jmethodID hasKeyMethod = env->GetMethodID(mapClass, "hasKey", "(Ljava/lang/String;)Z"); jstring jKey = env->NewStringUTF(key); @@ -13,7 +13,7 @@ jboolean hasKey(JNIEnv *env, jobject readableMap, const char *key) { return result; } -jint getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) { +int getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) { if (!hasKey(env, readableMap, key)) { return defaultValue; } @@ -25,7 +25,7 @@ jint getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue return result; } -jboolean getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) { +bool getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) { if (!hasKey(env, readableMap, key)) { return defaultValue; } @@ -37,7 +37,7 @@ jboolean getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean def return result; } -jlong getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) { +long getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) { if (!hasKey(env, readableMap, key)) { return defaultValue; } @@ -49,7 +49,7 @@ jlong getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultVa return result; } -jfloat getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) { +float getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) { if (!hasKey(env, readableMap, key)) { return defaultValue; } diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 7222c49..360eb3d 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -191,49 +191,8 @@ Java_com_rnwhisper_WhisperContext_initContextWithInputStream( return reinterpret_cast(context); } -JNIEXPORT jboolean JNICALL -Java_com_rnwhisper_WhisperContext_vadSimple( - JNIEnv *env, - jobject thiz, - jfloatArray audio_data, - jint audio_data_len, - jfloat vad_thold, - jfloat vad_freq_thold -) { - UNUSED(thiz); - - std::vector samples(audio_data_len); - jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr); - for (int i = 0; i < audio_data_len; i++) { - samples[i] = audio_data_arr[i]; - } - bool is_speech = rn_whisper_vad_simple(samples, WHISPER_SAMPLE_RATE, 1000, vad_thold, vad_freq_thold, false); - env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT); - return is_speech; -} - -struct callback_context { - JNIEnv *env; - jobject callback_instance; -}; - -JNIEXPORT jint JNICALL -Java_com_rnwhisper_WhisperContext_fullTranscribe( - JNIEnv *env, - jobject thiz, - jint job_id, - jlong context_ptr, - jfloatArray audio_data, - jint audio_data_len, - jobject transcribe_params, - jobject callback_instance -) { - UNUSED(thiz); - struct whisper_context *context = reinterpret_cast(context_ptr); - jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr); - - LOGI("About to create params"); +struct whisper_full_params createFullParams(JNIEnv *env, jobject options) { struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); params.print_realtime = false; @@ -244,53 +203,72 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( int max_threads = std::thread::hardware_concurrency(); // Use 2 threads by default on 4-core devices, 4 threads on more cores int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads); - int n_threads = readablemap::getInt(env, transcribe_params, "maxThreads", default_n_threads); + int n_threads = readablemap::getInt(env, options, "maxThreads", default_n_threads); params.n_threads = n_threads > 0 ? n_threads : default_n_threads; - params.translate = readablemap::getBool(env, transcribe_params, "translate", false); - params.speed_up = readablemap::getBool(env, transcribe_params, "speedUp", false); - params.token_timestamps = readablemap::getBool(env, transcribe_params, "tokenTimestamps", false); + params.translate = readablemap::getBool(env, options, "translate", false); + params.speed_up = readablemap::getBool(env, options, "speedUp", false); + params.token_timestamps = readablemap::getBool(env, options, "tokenTimestamps", false); params.offset_ms = 0; params.no_context = true; params.single_segment = false; - int beam_size = readablemap::getInt(env, transcribe_params, "beamSize", -1); + int beam_size = readablemap::getInt(env, options, "beamSize", -1); if (beam_size > -1) { params.strategy = WHISPER_SAMPLING_BEAM_SEARCH; params.beam_search.beam_size = beam_size; } - int best_of = readablemap::getInt(env, transcribe_params, "bestOf", -1); + int best_of = readablemap::getInt(env, options, "bestOf", -1); if (best_of > -1) params.greedy.best_of = best_of; - int max_len = readablemap::getInt(env, transcribe_params, "maxLen", -1); + int max_len = readablemap::getInt(env, options, "maxLen", -1); if (max_len > -1) params.max_len = max_len; - int max_context = readablemap::getInt(env, transcribe_params, "maxContext", -1); + int max_context = readablemap::getInt(env, options, "maxContext", -1); if (max_context > -1) params.n_max_text_ctx = max_context; - int offset = readablemap::getInt(env, transcribe_params, "offset", -1); + int offset = readablemap::getInt(env, options, "offset", -1); if (offset > -1) params.offset_ms = offset; - int duration = readablemap::getInt(env, transcribe_params, "duration", -1); + int duration = readablemap::getInt(env, options, "duration", -1); if (duration > -1) params.duration_ms = duration; - int word_thold = readablemap::getInt(env, transcribe_params, "wordThold", -1); + int word_thold = readablemap::getInt(env, options, "wordThold", -1); if (word_thold > -1) params.thold_pt = word_thold; - float temperature = readablemap::getFloat(env, transcribe_params, "temperature", -1); + float temperature = readablemap::getFloat(env, options, "temperature", -1); if (temperature > -1) params.temperature = temperature; - float temperature_inc = readablemap::getFloat(env, transcribe_params, "temperatureInc", -1); + float temperature_inc = readablemap::getFloat(env, options, "temperatureInc", -1); if (temperature_inc > -1) params.temperature_inc = temperature_inc; - jstring prompt = readablemap::getString(env, transcribe_params, "prompt", nullptr); - if (prompt != nullptr) params.initial_prompt = env->GetStringUTFChars(prompt, nullptr); - jstring language = readablemap::getString(env, transcribe_params, "language", nullptr); - if (language != nullptr) params.language = env->GetStringUTFChars(language, nullptr); - - // abort handlers - bool* abort_ptr = rn_whisper_assign_abort_map(job_id); - params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - bool is_aborted = *(bool*)user_data; - return !is_aborted; - }; - params.encoder_begin_callback_user_data = abort_ptr; - params.abort_callback = [](void * user_data) { - bool is_aborted = *(bool*)user_data; - return is_aborted; - }; - params.abort_callback_user_data = abort_ptr; + jstring prompt = readablemap::getString(env, options, "prompt", nullptr); + if (prompt != nullptr) { + params.initial_prompt = env->GetStringUTFChars(prompt, nullptr); + env->DeleteLocalRef(prompt); + } + jstring language = readablemap::getString(env, options, "language", nullptr); + if (language != nullptr) { + params.language = env->GetStringUTFChars(language, nullptr); + env->DeleteLocalRef(language); + } + return params; +} + +struct callback_context { + JNIEnv *env; + jobject callback_instance; +}; + +JNIEXPORT jint JNICALL +Java_com_rnwhisper_WhisperContext_fullWithNewJob( + JNIEnv *env, + jobject thiz, + jint job_id, + jlong context_ptr, + jfloatArray audio_data, + jint audio_data_len, + jobject options, + jobject callback_instance +) { + UNUSED(thiz); + struct whisper_context *context = reinterpret_cast(context_ptr); + jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr); + + LOGI("About to create params"); + + whisper_full_params params = createFullParams(env, options); if (callback_instance != nullptr) { callback_context *cb_ctx = new callback_context; @@ -318,6 +296,8 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( params.new_segment_callback_user_data = cb_ctx; } + rnwhisper::job* job = rnwhisper::job_new(job_id, params); + LOGI("About to reset timings"); whisper_reset_timings(context); @@ -327,12 +307,122 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( // whisper_print_timings(context); } env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT); - if (language != nullptr) env->ReleaseStringUTFChars(language, params.language); - if (prompt != nullptr) env->ReleaseStringUTFChars(prompt, params.initial_prompt); - if (rn_whisper_transcribe_is_aborted(job_id)) { - code = -999; + + if (job->is_aborted()) code = -999; + rnwhisper::job_remove(job_id); + return code; +} + +JNIEXPORT void JNICALL +Java_com_rnwhisper_WhisperContext_createRealtimeTranscribeJob( + JNIEnv *env, + jobject thiz, + jint job_id, + jlong context_ptr, + jobject options +) { + whisper_full_params params = createFullParams(env, options); + rnwhisper::job* job = rnwhisper::job_new(job_id, params); + rnwhisper::vad_params vad; + vad.use_vad = readablemap::getBool(env, options, "useVad", false); + vad.vad_ms = readablemap::getInt(env, options, "vadMs", 2000); + vad.vad_thold = readablemap::getFloat(env, options, "vadThold", 0.6f); + vad.freq_thold = readablemap::getFloat(env, options, "vadFreqThold", 100.0f); + + jstring audio_output_path = readablemap::getString(env, options, "audioOutputPath", nullptr); + const char* audio_output_path_str = nullptr; + if (audio_output_path != nullptr) { + audio_output_path_str = env->GetStringUTFChars(audio_output_path, nullptr); + env->DeleteLocalRef(audio_output_path); + } + job->set_realtime_params( + vad, + readablemap::getInt(env, options, "realtimeAudioSec", 0), + readablemap::getInt(env, options, "realtimeAudioSliceSec", 0), + audio_output_path_str + ); +} + +JNIEXPORT void JNICALL +Java_com_rnwhisper_WhisperContext_finishRealtimeTranscribeJob( + JNIEnv *env, + jobject thiz, + jint job_id, + jlong context_ptr, + jintArray slice_n_samples +) { + UNUSED(env); + UNUSED(thiz); + UNUSED(context_ptr); + + rnwhisper::job *job = rnwhisper::job_get(job_id); + if (job->audio_output_path != nullptr) { + RNWHISPER_LOG_INFO("job->params.language: %s\n", job->params.language); + std::vector slice_n_samples_vec; + jint *slice_n_samples_arr = env->GetIntArrayElements(slice_n_samples, nullptr); + slice_n_samples_vec = std::vector(slice_n_samples_arr, slice_n_samples_arr + env->GetArrayLength(slice_n_samples)); + env->ReleaseIntArrayElements(slice_n_samples, slice_n_samples_arr, JNI_ABORT); + + // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage + rnaudioutils::save_wav_file( + rnaudioutils::concat_short_buffers(job->pcm_slices, slice_n_samples_vec), + job->audio_output_path + ); + } + rnwhisper::job_remove(job_id); +} + +JNIEXPORT jboolean JNICALL +Java_com_rnwhisper_WhisperContext_vadSimple( + JNIEnv *env, + jobject thiz, + jint job_id, + jint slice_index, + jint n_samples, + jint n +) { + UNUSED(thiz); + rnwhisper::job* job = rnwhisper::job_get(job_id); + return job->vad_simple(slice_index, n_samples, n); +} + +JNIEXPORT void JNICALL +Java_com_rnwhisper_WhisperContext_putPcmData( + JNIEnv *env, + jobject thiz, + jint job_id, + jshortArray pcm, + jint slice_index, + jint n_samples, + jint n +) { + UNUSED(thiz); + rnwhisper::job* job = rnwhisper::job_get(job_id); + jshort *pcm_arr = env->GetShortArrayElements(pcm, nullptr); + job->put_pcm_data(pcm_arr, slice_index, n_samples, n); + env->ReleaseShortArrayElements(pcm, pcm_arr, JNI_ABORT); +} + +JNIEXPORT jint JNICALL +Java_com_rnwhisper_WhisperContext_fullWithJob( + JNIEnv *env, + jobject thiz, + jint job_id, + jlong context_ptr, + jint slice_index, + jint n_samples +) { + UNUSED(thiz); + struct whisper_context *context = reinterpret_cast(context_ptr); + + rnwhisper::job* job = rnwhisper::job_get(job_id); + float* pcmf32 = job->pcm_slice_to_f32(slice_index, n_samples); + int code = whisper_full(context, job->params, pcmf32, n_samples); + free(pcmf32); + if (code == 0) { + // whisper_print_timings(context); } - rn_whisper_remove_abort_map(job_id); + if (job->is_aborted()) code = -999; return code; } @@ -343,7 +433,8 @@ Java_com_rnwhisper_WhisperContext_abortTranscribe( jint job_id ) { UNUSED(thiz); - rn_whisper_abort_transcribe(job_id); + rnwhisper::job *job = rnwhisper::job_get(job_id); + if (job) job->abort(); } JNIEXPORT void JNICALL @@ -352,7 +443,7 @@ Java_com_rnwhisper_WhisperContext_abortAllTranscribe( jobject thiz ) { UNUSED(thiz); - rn_whisper_abort_all_transcribe(); + rnwhisper::job_abort_all(); } JNIEXPORT jint JNICALL diff --git a/cpp/README.md b/cpp/README.md index c947f95..c0efae8 100644 --- a/cpp/README.md +++ b/cpp/README.md @@ -1,4 +1,4 @@ # Note -- Only `rn-whisper.h` / `rn-whisper.cpp` are the specific files for this project, others are sync from [whisper.cpp](https://github.com/ggerganov/whisper.cpp). +- Only `rn-*` are the specific files for this project, others are sync from [whisper.cpp](https://github.com/ggerganov/whisper.cpp). - We can update the native source by using the [bootstrap](../scripts/bootstrap.sh) script. diff --git a/cpp/rn-audioutils.cpp b/cpp/rn-audioutils.cpp new file mode 100644 index 0000000..292a704 --- /dev/null +++ b/cpp/rn-audioutils.cpp @@ -0,0 +1,68 @@ +#include "rn-audioutils.h" +#include "rn-whisper-log.h" + +namespace rnaudioutils { + +std::vector concat_short_buffers(const std::vector& buffers, const std::vector& slice_n_samples) { + std::vector output_data; + + for (size_t i = 0; i < buffers.size(); i++) { + int size = slice_n_samples[i]; // Number of shorts + short* slice = buffers[i]; + + // Copy each short as two bytes + for (int j = 0; j < size; j++) { + output_data.push_back(static_cast(slice[j] & 0xFF)); // Lower byte + output_data.push_back(static_cast((slice[j] >> 8) & 0xFF)); // Higher byte + } + } + + return output_data; +} + +std::vector remove_trailing_zeros(const std::vector& audio_data) { + auto last = std::find_if(audio_data.rbegin(), audio_data.rend(), [](uint8_t byte) { return byte != 0; }); + return std::vector(audio_data.begin(), last.base()); +} + +void save_wav_file(const std::vector& raw, const std::string& file) { + std::vector data = remove_trailing_zeros(raw); + + std::ofstream output(file, std::ios::binary); + + if (!output.is_open()) { + RNWHISPER_LOG_ERROR("Failed to open file for writing: %s\n", file.c_str()); + return; + } + + // WAVE header + output.write("RIFF", 4); + int32_t chunk_size = 36 + static_cast(data.size()); + output.write(reinterpret_cast(&chunk_size), sizeof(chunk_size)); + output.write("WAVE", 4); + output.write("fmt ", 4); + int32_t sub_chunk_size = 16; + output.write(reinterpret_cast(&sub_chunk_size), sizeof(sub_chunk_size)); + short audio_format = 1; + output.write(reinterpret_cast(&audio_format), sizeof(audio_format)); + short num_channels = 1; + output.write(reinterpret_cast(&num_channels), sizeof(num_channels)); + int32_t sample_rate = WHISPER_SAMPLE_RATE; + output.write(reinterpret_cast(&sample_rate), sizeof(sample_rate)); + int32_t byte_rate = WHISPER_SAMPLE_RATE * 2; + output.write(reinterpret_cast(&byte_rate), sizeof(byte_rate)); + short block_align = 2; + output.write(reinterpret_cast(&block_align), sizeof(block_align)); + short bits_per_sample = 16; + output.write(reinterpret_cast(&bits_per_sample), sizeof(bits_per_sample)); + output.write("data", 4); + int32_t sub_chunk2_size = static_cast(data.size()); + output.write(reinterpret_cast(&sub_chunk2_size), sizeof(sub_chunk2_size)); + output.write(reinterpret_cast(data.data()), data.size()); + + output.close(); + + RNWHISPER_LOG_INFO("Saved audio file: %s\n", file.c_str()); +} + +} // namespace rnaudioutils diff --git a/cpp/rn-audioutils.h b/cpp/rn-audioutils.h new file mode 100644 index 0000000..9e49976 --- /dev/null +++ b/cpp/rn-audioutils.h @@ -0,0 +1,14 @@ +#include +#include +#include +#include +#include +#include +#include "whisper.h" + +namespace rnaudioutils { + +std::vector concat_short_buffers(const std::vector& buffers, const std::vector& slice_n_samples); +void save_wav_file(const std::vector& raw, const std::string& file); + +} // namespace rnaudioutils diff --git a/cpp/rn-whisper-log.h b/cpp/rn-whisper-log.h new file mode 100644 index 0000000..61858f2 --- /dev/null +++ b/cpp/rn-whisper-log.h @@ -0,0 +1,11 @@ +#if defined(__ANDROID__) && defined(RNWHISPER_ANDROID_ENABLE_LOGGING) +#include +#define RNWHISPER_ANDROID_TAG "RNWHISPER_LOG_ANDROID" +#define RNWHISPER_LOG_INFO(...) __android_log_print(ANDROID_LOG_INFO , RNWHISPER_ANDROID_TAG, __VA_ARGS__) +#define RNWHISPER_LOG_WARN(...) __android_log_print(ANDROID_LOG_WARN , RNWHISPER_ANDROID_TAG, __VA_ARGS__) +#define RNWHISPER_LOG_ERROR(...) __android_log_print(ANDROID_LOG_ERROR, RNWHISPER_ANDROID_TAG, __VA_ARGS__) +#else +#define RNWHISPER_LOG_INFO(...) fprintf(stderr, __VA_ARGS__) +#define RNWHISPER_LOG_WARN(...) fprintf(stderr, __VA_ARGS__) +#define RNWHISPER_LOG_ERROR(...) fprintf(stderr, __VA_ARGS__) +#endif // __ANDROID__ diff --git a/cpp/rn-whisper.cpp b/cpp/rn-whisper.cpp index c27a491..31c549b 100644 --- a/cpp/rn-whisper.cpp +++ b/cpp/rn-whisper.cpp @@ -2,41 +2,11 @@ #include #include #include -#include "whisper.h" +#include "rn-whisper.h" -extern "C" { +#define DEFAULT_MAX_AUDIO_SEC 30; -std::unordered_map abort_map; - -bool* rn_whisper_assign_abort_map(int job_id) { - abort_map[job_id] = false; - return &abort_map[job_id]; -} - -void rn_whisper_remove_abort_map(int job_id) { - if (abort_map.find(job_id) != abort_map.end()) { - abort_map.erase(job_id); - } -} - -void rn_whisper_abort_transcribe(int job_id) { - if (abort_map.find(job_id) != abort_map.end()) { - abort_map[job_id] = true; - } -} - -bool rn_whisper_transcribe_is_aborted(int job_id) { - if (abort_map.find(job_id) != abort_map.end()) { - return abort_map[job_id]; - } - return false; -} - -void rn_whisper_abort_all_transcribe() { - for (auto it = abort_map.begin(); it != abort_map.end(); ++it) { - it->second = true; - } -} +namespace rnwhisper { void high_pass_filter(std::vector & data, float cutoff, float sample_rate) { const float rc = 1.0f / (2.0f * M_PI * cutoff); @@ -51,7 +21,7 @@ void high_pass_filter(std::vector & data, float cutoff, float sample_rate } } -bool rn_whisper_vad_simple(std::vector & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) { +bool vad_simple_impl(std::vector & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) { const int n_samples = pcmf32.size(); const int n_samples_last = (sample_rate * last_ms) / 1000; @@ -79,7 +49,7 @@ bool rn_whisper_vad_simple(std::vector & pcmf32, int sample_rate, int las energy_last /= n_samples_last; if (verbose) { - fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold); + RNWHISPER_LOG_INFO("%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold); } if (energy_last > vad_thold*energy_all) { @@ -89,4 +59,116 @@ bool rn_whisper_vad_simple(std::vector & pcmf32, int sample_rate, int las return true; } +void job::set_realtime_params( + vad_params params, + int sec, + int slice_sec, + const char* output_path +) { + vad = params; + if (vad.vad_ms < 2000) vad.vad_ms = 2000; + audio_sec = sec > 0 ? sec : DEFAULT_MAX_AUDIO_SEC; + audio_slice_sec = slice_sec > 0 && slice_sec < audio_sec ? slice_sec : audio_sec; + audio_output_path = output_path; +} + +bool job::vad_simple(int slice_index, int n_samples, int n) { + if (!vad.use_vad) return true; + + short* pcm = pcm_slices[slice_index]; + int sample_size = (int) (WHISPER_SAMPLE_RATE * vad.vad_ms / 1000); + if (n_samples + n > sample_size) { + int start = n_samples + n - sample_size; + std::vector pcmf32(sample_size); + for (int i = 0; i < sample_size; i++) { + pcmf32[i] = (float)pcm[i + start] / 32768.0f; + } + return vad_simple_impl(pcmf32, WHISPER_SAMPLE_RATE, vad.last_ms, vad.vad_thold, vad.freq_thold, vad.verbose); + } + return false; +} + +void job::put_pcm_data(short* data, int slice_index, int n_samples, int n) { + if (pcm_slices.size() == slice_index) { + int n_slices = (int) (WHISPER_SAMPLE_RATE * audio_slice_sec); + pcm_slices.push_back(new short[n_slices]); + } + short* pcm = pcm_slices[slice_index]; + for (int i = 0; i < n; i++) { + pcm[i + n_samples] = data[i]; + } +} + +float* job::pcm_slice_to_f32(int slice_index, int size) { + if (pcm_slices.size() > slice_index) { + float* pcmf32 = new float[size]; + for (int i = 0; i < size; i++) { + pcmf32[i] = (float)pcm_slices[slice_index][i] / 32768.0f; + } + return pcmf32; + } + return nullptr; +} + +bool job::is_aborted() { + return aborted; +} + +void job::abort() { + aborted = true; +} + +job::~job() { + RNWHISPER_LOG_INFO("rnwhisper::job::%s: job_id: %d\n", __func__, job_id); + + for (size_t i = 0; i < pcm_slices.size(); i++) { + delete[] pcm_slices[i]; + } + pcm_slices.clear(); +} + +std::unordered_map job_map; + +void job_abort_all() { + for (auto it = job_map.begin(); it != job_map.end(); ++it) { + it->second->abort(); + } +} + +job* job_new(int job_id, struct whisper_full_params params) { + job* ctx = new job(); + ctx->job_id = job_id; + ctx->params = params; + + job_map[job_id] = ctx; + + // Abort handler + params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { + job *j = (job*)user_data; + return !j->is_aborted(); + }; + params.encoder_begin_callback_user_data = job_map[job_id]; + params.abort_callback = [](void * user_data) { + job *j = (job*)user_data; + return j->is_aborted(); + }; + params.abort_callback_user_data = job_map[job_id]; + + return job_map[job_id]; +} + +job* job_get(int job_id) { + if (job_map.find(job_id) != job_map.end()) { + return job_map[job_id]; + } + return nullptr; +} + +void job_remove(int job_id) { + if (job_map.find(job_id) != job_map.end()) { + delete job_map[job_id]; + } + job_map.erase(job_id); +} + } diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index 4f65158..5daa90c 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -1,17 +1,49 @@ +#ifndef RNWHISPER_H +#define RNWHISPER_H -#ifdef __cplusplus #include -#include -extern "C" { -#endif - -bool* rn_whisper_assign_abort_map(int job_id); -void rn_whisper_remove_abort_map(int job_id); -void rn_whisper_abort_transcribe(int job_id); -bool rn_whisper_transcribe_is_aborted(int job_id); -void rn_whisper_abort_all_transcribe(); -bool rn_whisper_vad_simple(std::vector & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose); - -#ifdef __cplusplus -} -#endif +#include +#include "whisper.h" +#include "rn-whisper-log.h" +#include "rn-audioutils.h" + +namespace rnwhisper { + +struct vad_params { + bool use_vad = false; + float vad_thold = 0.6f; + float freq_thold = 100.0f; + int vad_ms = 2000; + int last_ms = 1000; + bool verbose = false; +}; + +struct job { + int job_id; + bool aborted = false; + whisper_full_params params; + + ~job(); + bool is_aborted(); + void abort(); + + // Realtime transcription only: + vad_params vad; + int audio_sec = 0; + int audio_slice_sec = 0; + const char* audio_output_path = nullptr; + std::vector pcm_slices; + void set_realtime_params(vad_params vad, int sec, int slice_sec, const char* output_path); + bool vad_simple(int slice_index, int n_samples, int n); + void put_pcm_data(short* pcm, int slice_index, int n_samples, int n); + float* pcm_slice_to_f32(int slice_index, int size); +}; + +void job_abort_all(); +job* job_new(int job_id, struct whisper_full_params params); +void job_remove(int job_id); +job* job_get(int job_id); + +} // namespace rnwhisper + +#endif // RNWHISPER_H \ No newline at end of file diff --git a/example/src/App.tsx b/example/src/App.tsx index a9224bf..df09f8d 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -225,12 +225,13 @@ export default function App() { log('Start transcribing...') const startTime = Date.now() const { stop, promise } = whisperContext.transcribe(sampleFile, { - language: 'en', maxLen: 1, tokenTimestamps: true, onProgress: (cur) => { log(`Transcribing progress: ${cur}%`) }, + language: 'en', + // prompt: 'HELLO WORLD', // onNewSegments: (segments) => { // console.log('New segments:', segments) // }, diff --git a/ios/RNWhisper.mm b/ios/RNWhisper.mm index 6aec9c7..f48a34b 100644 --- a/ios/RNWhisper.mm +++ b/ios/RNWhisper.mm @@ -142,9 +142,9 @@ - (NSArray *)supportedEvents { audioDataCount:count options:options onProgress: ^(int progress) { - if (rn_whisper_transcribe_is_aborted(jobId)) { - return; - } + rnwhisper::job* job = rnwhisper::job_get(jobId); + if (job && job->is_aborted()) return; + dispatch_async(dispatch_get_main_queue(), ^{ [self sendEventWithName:@"@RNWhisper_onTranscribeProgress" body:@{ @@ -156,9 +156,9 @@ - (NSArray *)supportedEvents { }); } onNewSegments: ^(NSDictionary *result) { - if (rn_whisper_transcribe_is_aborted(jobId)) { - return; - } + rnwhisper::job* job = rnwhisper::job_get(jobId); + if (job && job->is_aborted()) return; + dispatch_async(dispatch_get_main_queue(), ^{ [self sendEventWithName:@"@RNWhisper_onTranscribeNewSegments" body:@{ @@ -279,7 +279,7 @@ - (void)invalidate { [context invalidate]; } - rn_whisper_abort_all_transcribe(); // graceful abort + rnwhisper::job_abort_all(); // graceful abort [contexts removeAllObjects]; contexts = nil; diff --git a/ios/RNWhisperAudioUtils.h b/ios/RNWhisperAudioUtils.h index 628fa4f..a37581d 100644 --- a/ios/RNWhisperAudioUtils.h +++ b/ios/RNWhisperAudioUtils.h @@ -2,8 +2,6 @@ @interface RNWhisperAudioUtils : NSObject -+ (NSData *)concatShortBuffers:(NSMutableArray *)buffers sliceNSamples:(NSMutableArray *)sliceNSamples; -+ (void)saveWavFile:(NSData *)rawData audioOutputFile:(NSString *)audioOutputFile; + (float *)decodeWaveFile:(NSString*)filePath count:(int *)count; @end diff --git a/ios/RNWhisperAudioUtils.m b/ios/RNWhisperAudioUtils.m index a9ed994..334740f 100644 --- a/ios/RNWhisperAudioUtils.m +++ b/ios/RNWhisperAudioUtils.m @@ -3,62 +3,6 @@ @implementation RNWhisperAudioUtils -+ (NSData *)concatShortBuffers:(NSMutableArray *)buffers sliceNSamples:(NSMutableArray *)sliceNSamples { - NSMutableData *outputData = [NSMutableData data]; - for (int i = 0; i < buffers.count; i++) { - int size = [sliceNSamples objectAtIndex:i].intValue; - NSValue *buffer = [buffers objectAtIndex:i]; - short *bufferPtr = buffer.pointerValue; - [outputData appendBytes:bufferPtr length:size * sizeof(short)]; - } - return outputData; -} - -+ (void)saveWavFile:(NSData *)rawData audioOutputFile:(NSString *)audioOutputFile { - NSMutableData *outputData = [NSMutableData data]; - - // WAVE header - [outputData appendData:[@"RIFF" dataUsingEncoding:NSUTF8StringEncoding]]; // chunk id - int chunkSize = CFSwapInt32HostToLittle(36 + rawData.length); - [outputData appendBytes:&chunkSize length:sizeof(chunkSize)]; - [outputData appendData:[@"WAVE" dataUsingEncoding:NSUTF8StringEncoding]]; // format - [outputData appendData:[@"fmt " dataUsingEncoding:NSUTF8StringEncoding]]; // subchunk 1 id - - int subchunk1Size = CFSwapInt32HostToLittle(16); - [outputData appendBytes:&subchunk1Size length:sizeof(subchunk1Size)]; - - short audioFormat = CFSwapInt16HostToLittle(1); // PCM - [outputData appendBytes:&audioFormat length:sizeof(audioFormat)]; - - short numChannels = CFSwapInt16HostToLittle(1); // mono - [outputData appendBytes:&numChannels length:sizeof(numChannels)]; - - int sampleRate = CFSwapInt32HostToLittle(WHISPER_SAMPLE_RATE); - [outputData appendBytes:&sampleRate length:sizeof(sampleRate)]; - - // (bitDepth * sampleRate * channels) >> 3 - int byteRate = CFSwapInt32HostToLittle(WHISPER_SAMPLE_RATE * 1 * 16 / 8); - [outputData appendBytes:&byteRate length:sizeof(byteRate)]; - - // (bitDepth * channels) >> 3 - short blockAlign = CFSwapInt16HostToLittle(16 / 8); - [outputData appendBytes:&blockAlign length:sizeof(blockAlign)]; - - // bitDepth - short bitsPerSample = CFSwapInt16HostToLittle(16); - [outputData appendBytes:&bitsPerSample length:sizeof(bitsPerSample)]; - - [outputData appendData:[@"data" dataUsingEncoding:NSUTF8StringEncoding]]; // subchunk 2 id - int subchunk2Size = CFSwapInt32HostToLittle((int)rawData.length); - [outputData appendBytes:&subchunk2Size length:sizeof(subchunk2Size)]; - - // Audio data - [outputData appendData:rawData]; - - // Save to file - [outputData writeToFile:audioOutputFile atomically:YES]; -} - + (float *)decodeWaveFile:(NSString*)filePath count:(int *)count { NSURL *url = [NSURL fileURLWithPath:filePath]; NSData *fileData = [NSData dataWithContentsOfURL:url]; diff --git a/ios/RNWhisperContext.h b/ios/RNWhisperContext.h index 4d6d4ad..a029dfd 100644 --- a/ios/RNWhisperContext.h +++ b/ios/RNWhisperContext.h @@ -11,29 +11,21 @@ typedef struct { __unsafe_unretained id mSelf; - - int jobId; NSDictionary* options; + struct rnwhisper::job * job; + bool isTranscribing; bool isRealtime; bool isCapturing; bool isStoppedByAction; - int maxAudioSec; int nSamplesTranscribing; - NSMutableArray *shortBufferSlices; - NSMutableArray *sliceNSamples; + std::vector sliceNSamples; bool isUseSlices; int sliceIndex; int transcribeSliceIndex; - int audioSliceSec; NSString* audioOutputPath; - bool useVad; - int vadMs; - float vadThold; - float vadFreqThold; - AudioQueueRef queue; AudioStreamBasicDescription dataFormat; AudioQueueBufferRef buffers[NUM_BUFFERS]; diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index cd4e9dd..d7ab52e 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -1,5 +1,4 @@ #import "RNWhisperContext.h" -#import "RNWhisperAudioUtils.h" #import #include @@ -95,7 +94,7 @@ - (dispatch_queue_t)getDispatchQueue { return self->dQueue; } -- (void)prepareRealtime:(NSDictionary *)options { +- (void)prepareRealtime:(int)jobId options:(NSDictionary *)options { self->recordState.options = options; self->recordState.dataFormat.mSampleRate = WHISPER_SAMPLE_RATE; // 16000 @@ -108,74 +107,38 @@ - (void)prepareRealtime:(NSDictionary *)options { self->recordState.dataFormat.mReserved = 0; self->recordState.dataFormat.mFormatFlags = kLinearPCMFormatFlagIsSignedInteger; - int maxAudioSecOpt = options[@"realtimeAudioSec"] != nil ? [options[@"realtimeAudioSec"] intValue] : 0; - int maxAudioSec = maxAudioSecOpt > 0 ? maxAudioSecOpt : DEFAULT_MAX_AUDIO_SEC; - self->recordState.maxAudioSec = maxAudioSec; - - int realtimeAudioSliceSec = options[@"realtimeAudioSliceSec"] != nil ? [options[@"realtimeAudioSliceSec"] intValue] : 0; - int audioSliceSec = realtimeAudioSliceSec > 0 && realtimeAudioSliceSec < maxAudioSec ? realtimeAudioSliceSec : maxAudioSec; - - self->recordState.audioOutputPath = options[@"audioOutputPath"]; - - self->recordState.useVad = options[@"useVad"] != nil ? [options[@"useVad"] boolValue] : false; - self->recordState.vadMs = options[@"vadMs"] != nil ? [options[@"vadMs"] intValue] : 2000; - if (self->recordState.vadMs < 2000) self->recordState.vadMs = 2000; - - self->recordState.vadThold = options[@"vadThold"] != nil ? [options[@"vadThold"] floatValue] : 0.6f; - self->recordState.vadFreqThold = options[@"vadFreqThold"] != nil ? [options[@"vadFreqThold"] floatValue] : 100.0f; - - self->recordState.audioSliceSec = audioSliceSec; - self->recordState.isUseSlices = audioSliceSec < maxAudioSec; + self->recordState.isRealtime = true; + self->recordState.isTranscribing = false; + self->recordState.isCapturing = false; + self->recordState.isStoppedByAction = false; self->recordState.sliceIndex = 0; self->recordState.transcribeSliceIndex = 0; self->recordState.nSamplesTranscribing = 0; - [self freeBufferIfNeeded]; - self->recordState.shortBufferSlices = [NSMutableArray new]; - - int16_t *audioBufferI16 = (int16_t *) malloc(audioSliceSec * WHISPER_SAMPLE_RATE * sizeof(int16_t)); - [self->recordState.shortBufferSlices addObject:[NSValue valueWithPointer:audioBufferI16]]; - - self->recordState.sliceNSamples = [NSMutableArray new]; - [self->recordState.sliceNSamples addObject:[NSNumber numberWithInt:0]]; - - self->recordState.isRealtime = true; - self->recordState.isTranscribing = false; - self->recordState.isCapturing = false; - self->recordState.isStoppedByAction = false; + self->recordState.sliceNSamples.push_back(0); + + self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]); + self->recordState.job->set_realtime_params( + { + .use_vad = options[@"useVad"] != nil ? [options[@"useVad"] boolValue] : false, + .vad_ms = options[@"vadMs"] != nil ? [options[@"vadMs"] intValue] : 2000, + .vad_thold = options[@"vadThold"] != nil ? [options[@"vadThold"] floatValue] : 0.6f, + .freq_thold = options[@"vadFreqThold"] != nil ? [options[@"vadFreqThold"] floatValue] : 100.0f + }, + options[@"realtimeAudioSec"] != nil ? [options[@"realtimeAudioSec"] intValue] : 0, + options[@"realtimeAudioSliceSec"] != nil ? [options[@"realtimeAudioSliceSec"] intValue] : 0, + options[@"audioOutputPath"] != nil ? [options[@"audioOutputPath"] UTF8String] : nullptr + ); + self->recordState.isUseSlices = self->recordState.job->audio_slice_sec < self->recordState.job->audio_sec; self->recordState.mSelf = self; } -- (void)freeBufferIfNeeded { - if (self->recordState.shortBufferSlices != nil) { - for (int i = 0; i < [self->recordState.shortBufferSlices count]; i++) { - int16_t *audioBufferI16 = (int16_t *) [self->recordState.shortBufferSlices[i] pointerValue]; - free(audioBufferI16); - } - self->recordState.shortBufferSlices = nil; - } -} - -bool vad(RNWhisperContextRecordState *state, int16_t* audioBufferI16, int nSamples, int n) +bool vad(RNWhisperContextRecordState *state, int sliceIndex, int nSamples, int n) { - bool isSpeech = true; - if (!state->isTranscribing && state->useVad) { - int sampleSize = (int) (WHISPER_SAMPLE_RATE * state->vadMs / 1000); - if (nSamples + n > sampleSize) { - int start = nSamples + n - sampleSize; - std::vector audioBufferF32Vec(sampleSize); - for (int i = 0; i < sampleSize; i++) { - audioBufferF32Vec[i] = (float)audioBufferI16[i + start] / 32768.0f; - } - isSpeech = rn_whisper_vad_simple(audioBufferF32Vec, WHISPER_SAMPLE_RATE, 1000, state->vadThold, state->vadFreqThold, false); - NSLog(@"[RNWhisper] VAD result: %d", isSpeech); - } else { - isSpeech = false; - } - } - return isSpeech; + if (state->isTranscribing) return true; + return state->job->vad_simple(sliceIndex, nSamples, n); } void AudioInputCallback(void * inUserData, @@ -196,15 +159,15 @@ void AudioInputCallback(void * inUserData, } int totalNSamples = 0; - for (int i = 0; i < [state->sliceNSamples count]; i++) { - totalNSamples += [[state->sliceNSamples objectAtIndex:i] intValue]; + for (int i = 0; i < state->sliceNSamples.size(); i++) { + totalNSamples += state->sliceNSamples[i]; } const int n = inBuffer->mAudioDataByteSize / 2; - int nSamples = [state->sliceNSamples[state->sliceIndex] intValue]; + int nSamples = state->sliceNSamples[state->sliceIndex]; - if (totalNSamples + n > state->maxAudioSec * WHISPER_SAMPLE_RATE) { + if (totalNSamples + n > state->job->audio_sec * WHISPER_SAMPLE_RATE) { NSLog(@"[RNWhisper] Audio buffer is full, stop capturing"); state->isCapturing = false; [state->mSelf stopAudio]; @@ -218,8 +181,7 @@ void AudioInputCallback(void * inUserData, !state->isTranscribing && nSamples != state->nSamplesTranscribing ) { - int16_t* audioBufferI16 = (int16_t*) [state->shortBufferSlices[state->sliceIndex] pointerValue]; - if (!vad(state, audioBufferI16, nSamples, 0)) { + if (!vad(state, state->sliceIndex, nSamples, 0)) { [state->mSelf finishRealtimeTranscribe:state result:@{}]; return; } @@ -231,27 +193,20 @@ void AudioInputCallback(void * inUserData, return; } - int audioSliceSec = state->audioSliceSec; - if (nSamples + n > audioSliceSec * WHISPER_SAMPLE_RATE) { + if (nSamples + n > state->job->audio_slice_sec * WHISPER_SAMPLE_RATE) { // next slice state->sliceIndex++; nSamples = 0; - int16_t* audioBufferI16 = (int16_t*) malloc(audioSliceSec * WHISPER_SAMPLE_RATE * sizeof(int16_t)); - [state->shortBufferSlices addObject:[NSValue valueWithPointer:audioBufferI16]]; - [state->sliceNSamples addObject:[NSNumber numberWithInt:0]]; + state->sliceNSamples.push_back(0); } - // Append to buffer NSLog(@"[RNWhisper] Slice %d has %d samples", state->sliceIndex, nSamples); - int16_t* audioBufferI16 = (int16_t*) [state->shortBufferSlices[state->sliceIndex] pointerValue]; - for (int i = 0; i < n; i++) { - audioBufferI16[nSamples + i] = ((short*)inBuffer->mAudioData)[i]; - } + state->job->put_pcm_data((short*) inBuffer->mAudioData, state->sliceIndex, nSamples, n); - bool isSpeech = vad(state, audioBufferI16, nSamples, n); + bool isSpeech = vad(state, state->sliceIndex, nSamples, n); nSamples += n; - state->sliceNSamples[state->sliceIndex] = [NSNumber numberWithInt:nSamples]; + state->sliceNSamples[state->sliceIndex] = nSamples; AudioQueueEnqueueBuffer(state->queue, inBuffer, 0, NULL); @@ -267,32 +222,27 @@ void AudioInputCallback(void * inUserData, - (void)finishRealtimeTranscribe:(RNWhisperContextRecordState*) state result:(NSDictionary*)result { // Save wav if needed - if (state->audioOutputPath != nil) { + if (state->job->audio_output_path != nullptr) { // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage - [RNWhisperAudioUtils - saveWavFile:[RNWhisperAudioUtils concatShortBuffers:state->shortBufferSlices - sliceNSamples:state->sliceNSamples] - audioOutputFile:state->audioOutputPath - ]; + rnaudioutils::save_wav_file( + rnaudioutils::concat_short_buffers(state->job->pcm_slices, state->sliceNSamples), + state->job->audio_output_path + ); } - state->transcribeHandler(state->jobId, @"end", result); + state->transcribeHandler(state->job->job_id, @"end", result); + rnwhisper::job_remove(state->job->job_id); } - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state { - int nSamplesOfIndex = [[state->sliceNSamples objectAtIndex:state->transcribeSliceIndex] intValue]; + int nSamplesOfIndex = state->sliceNSamples[state->transcribeSliceIndex]; state->nSamplesTranscribing = nSamplesOfIndex; NSLog(@"[RNWhisper] Transcribing %d samples", state->nSamplesTranscribing); - int16_t* audioBufferI16 = (int16_t*) [state->shortBufferSlices[state->transcribeSliceIndex] pointerValue]; - float* audioBufferF32 = (float*) malloc(state->nSamplesTranscribing * sizeof(float)); - // convert I16 to F32 - for (int i = 0; i < state->nSamplesTranscribing; i++) { - audioBufferF32[i] = (float)audioBufferI16[i] / 32768.0f; - } + float* pcmf32 = state->job->pcm_slice_to_f32(state->transcribeSliceIndex, state->nSamplesTranscribing); + CFTimeInterval timeStart = CACurrentMediaTime(); - struct whisper_full_params params = [state->mSelf getParams:state->options jobId:state->jobId]; - int code = [state->mSelf fullTranscribe:state->jobId params:params audioData:audioBufferF32 audioDataCount:state->nSamplesTranscribing]; - free(audioBufferF32); + int code = [state->mSelf fullTranscribe:state->job audioData:pcmf32 audioDataCount:state->nSamplesTranscribing]; + free(pcmf32); CFTimeInterval timeEnd = CACurrentMediaTime(); const float timeRecording = (float) state->nSamplesTranscribing / (float) state->dataFormat.mSampleRate; @@ -312,7 +262,7 @@ - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state { result[@"error"] = [NSString stringWithFormat:@"Transcribe failed with code %d", code]; } - nSamplesOfIndex = [[state->sliceNSamples objectAtIndex:state->transcribeSliceIndex] intValue]; + nSamplesOfIndex = state->sliceNSamples[state->transcribeSliceIndex]; bool isStopped = state->isStoppedByAction || ( !state->isCapturing && @@ -340,10 +290,10 @@ - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state { [state->mSelf finishRealtimeTranscribe:state result:result]; } else if (code == 0) { result[@"isCapturing"] = @(true); - state->transcribeHandler(state->jobId, @"transcribe", result); + state->transcribeHandler(state->job->job_id, @"transcribe", result); } else { result[@"isCapturing"] = @(true); - state->transcribeHandler(state->jobId, @"transcribe", result); + state->transcribeHandler(state->job->job_id, @"transcribe", result); } if (continueNeeded) { @@ -371,8 +321,7 @@ - (OSStatus)transcribeRealtime:(int)jobId onTranscribe:(void (^)(int, NSString *, NSDictionary *))onTranscribe { self->recordState.transcribeHandler = onTranscribe; - self->recordState.jobId = jobId; - [self prepareRealtime:options]; + [self prepareRealtime:jobId options:options]; OSStatus status = AudioQueueNewInput( &self->recordState.dataFormat, @@ -413,9 +362,9 @@ - (void)transcribeFile:(int)jobId dispatch_async(dQueue, ^{ self->recordState.isStoppedByAction = false; self->recordState.isTranscribing = true; - self->recordState.jobId = jobId; - whisper_full_params params = [self getParams:options jobId:jobId]; + whisper_full_params params = [self createParams:options jobId:jobId]; + if (options[@"onProgress"] && [options[@"onProgress"] boolValue]) { params.progress_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) { void (^onProgress)(int) = (__bridge void (^)(int))user_data; @@ -460,8 +409,10 @@ - (void)transcribeFile:(int)jobId }; params.new_segment_callback_user_data = &user_data; } - int code = [self fullTranscribe:jobId params:params audioData:audioData audioDataCount:audioDataCount]; - self->recordState.jobId = -1; + + rnwhisper::job* job = rnwhisper::job_new(jobId, params);; + int code = [self fullTranscribe:job audioData:audioData audioDataCount:audioDataCount]; + rnwhisper::job_remove(jobId); self->recordState.isTranscribing = false; onEnd(code); }); @@ -476,7 +427,7 @@ - (void)stopAudio { } - (void)stopTranscribe:(int)jobId { - rn_whisper_abort_transcribe(jobId); + if (self->recordState.job) self->recordState.job->abort(); if (self->recordState.isRealtime && self->recordState.isCapturing) { [self stopAudio]; if (!self->recordState.isTranscribing) { @@ -490,13 +441,11 @@ - (void)stopTranscribe:(int)jobId { } - (void)stopCurrentTranscribe { - if (!self->recordState.jobId) { - return; - } - [self stopTranscribe:self->recordState.jobId]; + if (self->recordState.job == nullptr) return; + [self stopTranscribe:self->recordState.job->job_id]; } -- (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId { +- (struct whisper_full_params)createParams:(NSDictionary *)options jobId:(int)jobId { struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); const int n_threads = options[@"maxThreads"] != nil ? @@ -534,7 +483,6 @@ - (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId if (options[@"maxContext"] != nil) { params.n_max_text_ctx = [options[@"maxContext"] intValue]; } - if (options[@"offset"] != nil) { params.offset_ms = [options[@"offset"] intValue]; } @@ -550,39 +498,20 @@ - (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId if (options[@"temperatureInc"] != nil) { params.temperature_inc = [options[@"temperature_inc"] floatValue]; } - if (options[@"prompt"] != nil) { params.initial_prompt = [options[@"prompt"] UTF8String]; } - // abort handler - bool *abort_ptr = rn_whisper_assign_abort_map(jobId); - params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - bool is_aborted = *(bool*)user_data; - return !is_aborted; - }; - params.encoder_begin_callback_user_data = abort_ptr; - params.abort_callback = [](void * user_data) { - bool is_aborted = *(bool*)user_data; - return is_aborted; - }; - params.abort_callback_user_data = abort_ptr; - return params; } -- (int)fullTranscribe:(int)jobId - params:(struct whisper_full_params)params +- (int)fullTranscribe:(rnwhisper::job *)job audioData:(float *)audioData audioDataCount:(int)audioDataCount { whisper_reset_timings(self->ctx); - - int code = whisper_full(self->ctx, params, audioData, audioDataCount); - if (rn_whisper_transcribe_is_aborted(jobId)) { - code = -999; - } - rn_whisper_remove_abort_map(jobId); + int code = whisper_full(self->ctx, job->params, audioData, audioDataCount); + if (job && job->is_aborted()) code = -999; // if (code == 0) { // whisper_print_timings(self->ctx); // } @@ -616,7 +545,6 @@ - (NSMutableDictionary *)getTextSegments { - (void)invalidate { [self stopCurrentTranscribe]; whisper_free(self->ctx); - [self freeBufferIfNeeded]; } @end