Skip to content

Commit

Permalink
feat(android): refactor fullTranscribe jni method
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Dec 4, 2023
1 parent f1b290b commit c5beca9
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 104 deletions.
50 changes: 3 additions & 47 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -404,38 +404,8 @@ private int full(int jobId, ReadableMap options, float[] audioData, int audioDat
audioData,
// jint audio_data_len,
audioDataLen,
// jint n_threads,
options.hasKey("maxThreads") ? options.getInt("maxThreads") : -1,
// jint max_context,
options.hasKey("maxContext") ? options.getInt("maxContext") : -1,

// jint word_thold,
options.hasKey("wordThold") ? options.getInt("wordThold") : -1,
// jint max_len,
options.hasKey("maxLen") ? options.getInt("maxLen") : -1,
// jboolean token_timestamps,
options.hasKey("tokenTimestamps") ? options.getBoolean("tokenTimestamps") : false,

// jint offset,
options.hasKey("offset") ? options.getInt("offset") : -1,
// jint duration,
options.hasKey("duration") ? options.getInt("duration") : -1,
// jfloat temperature,
options.hasKey("temperature") ? (float) options.getDouble("temperature") : -1.0f,
// jfloat temperature_inc,
options.hasKey("temperatureInc") ? (float) options.getDouble("temperatureInc") : -1.0f,
// jint beam_size,
options.hasKey("beamSize") ? options.getInt("beamSize") : -1,
// jint best_of,
options.hasKey("bestOf") ? options.getInt("bestOf") : -1,
// jboolean speed_up,
options.hasKey("speedUp") ? options.getBoolean("speedUp") : false,
// jboolean translate,
options.hasKey("translate") ? options.getBoolean("translate") : false,
// jstring language,
options.hasKey("language") ? options.getString("language") : "auto",
// jstring prompt
options.hasKey("prompt") ? options.getString("prompt") : null,
// ReadableMap options,
options,
// Callback callback
hasProgressCallback || hasNewSegmentsCallback ? new Callback(this, hasProgressCallback, hasNewSegmentsCallback) : null
);
Expand Down Expand Up @@ -567,21 +537,7 @@ protected static native int fullTranscribe(
long context,
float[] audio_data,
int audio_data_len,
int n_threads,
int max_context,
int word_thold,
int max_len,
boolean token_timestamps,
int offset,
int duration,
float temperature,
float temperature_inc,
int beam_size,
int best_of,
boolean speed_up,
boolean translate,
String language,
String prompt,
ReadableMap options,
Callback Callback
);
protected static native void abortTranscribe(int jobId);
Expand Down
76 changes: 76 additions & 0 deletions android/src/main/jni-utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include <jni.h>

// ReadableMap utils

namespace readablemap {

jboolean hasKey(JNIEnv *env, jobject readableMap, const char *key) {
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID hasKeyMethod = env->GetMethodID(mapClass, "hasKey", "(Ljava/lang/String;)Z");
jstring jKey = env->NewStringUTF(key);
jboolean result = env->CallBooleanMethod(readableMap, hasKeyMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

jint getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getIntMethod = env->GetMethodID(mapClass, "getInt", "(Ljava/lang/String;)I");
jstring jKey = env->NewStringUTF(key);
jint result = env->CallIntMethod(readableMap, getIntMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

jboolean getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getBoolMethod = env->GetMethodID(mapClass, "getBoolean", "(Ljava/lang/String;)Z");
jstring jKey = env->NewStringUTF(key);
jboolean result = env->CallBooleanMethod(readableMap, getBoolMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

jlong getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getLongMethod = env->GetMethodID(mapClass, "getLong", "(Ljava/lang/String;)J");
jstring jKey = env->NewStringUTF(key);
jlong result = env->CallLongMethod(readableMap, getLongMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

jfloat getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getFloatMethod = env->GetMethodID(mapClass, "getDouble", "(Ljava/lang/String;)D");
jstring jKey = env->NewStringUTF(key);
jfloat result = env->CallDoubleMethod(readableMap, getFloatMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

jstring getString(JNIEnv *env, jobject readableMap, const char *key, jstring defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getStringMethod = env->GetMethodID(mapClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
jstring jKey = env->NewStringUTF(key);
jstring result = (jstring) env->CallObjectMethod(readableMap, getStringMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

}
93 changes: 36 additions & 57 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "whisper.h"
#include "rn-whisper.h"
#include "ggml.h"
#include "jni-utils.h"

#define UNUSED(x) (void)(x)
#define TAG "JNI"
Expand Down Expand Up @@ -219,82 +220,59 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
jlong context_ptr,
jfloatArray audio_data,
jint audio_data_len,
jint n_threads,
jint max_context,
int word_thold,
int max_len,
jboolean token_timestamps,
jint offset,
jint duration,
jfloat temperature,
jfloat temperature_inc,
jint beam_size,
jint best_of,
jboolean speed_up,
jboolean translate,
jstring language,
jstring prompt,
jobject transcribe_params,
jobject callback_instance
) {
UNUSED(thiz);
struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);

int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);

LOGI("About to create params");

struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);

if (beam_size > -1) {
params.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
params.beam_search.beam_size = beam_size;
}

params.print_realtime = false;
params.print_progress = false;
params.print_timestamps = false;
params.print_special = false;
params.translate = translate;
const char *language_chars = env->GetStringUTFChars(language, nullptr);
params.language = language_chars;

int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
int n_threads = readablemap::getInt(env, transcribe_params, "maxThreads", default_n_threads);
params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
params.speed_up = speed_up;
params.translate = readablemap::getBool(env, transcribe_params, "translate", false);
params.speed_up = readablemap::getBool(env, transcribe_params, "speedUp", false);
params.token_timestamps = readablemap::getBool(env, transcribe_params, "tokenTimestamps", false);
params.offset_ms = 0;
params.no_context = true;
params.single_segment = false;

if (max_len > -1) {
params.max_len = max_len;
}
params.token_timestamps = token_timestamps;

if (best_of > -1) {
params.greedy.best_of = best_of;
}
if (max_context > -1) {
params.n_max_text_ctx = max_context;
}
if (offset > -1) {
params.offset_ms = offset;
}
if (duration > -1) {
params.duration_ms = duration;
}
if (word_thold > -1) {
params.thold_pt = word_thold;
}
if (temperature > -1) {
params.temperature = temperature;
}
if (temperature_inc > -1) {
params.temperature_inc = temperature_inc;
}
if (prompt != nullptr) {
params.initial_prompt = env->GetStringUTFChars(prompt, nullptr);
int beam_size = readablemap::getInt(env, transcribe_params, "beamSize", -1);
if (beam_size > -1) {
params.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
params.beam_search.beam_size = beam_size;
}
int best_of = readablemap::getInt(env, transcribe_params, "bestOf", -1);
if (best_of > -1) params.greedy.best_of = best_of;
int max_len = readablemap::getInt(env, transcribe_params, "maxLen", -1);
if (max_len > -1) params.max_len = max_len;
int max_context = readablemap::getInt(env, transcribe_params, "maxContext", -1);
if (max_context > -1) params.n_max_text_ctx = max_context;
int offset = readablemap::getInt(env, transcribe_params, "offset", -1);
if (offset > -1) params.offset_ms = offset;
int duration = readablemap::getInt(env, transcribe_params, "duration", -1);
if (duration > -1) params.duration_ms = duration;
int word_thold = readablemap::getInt(env, transcribe_params, "wordThold", -1);
if (word_thold > -1) params.thold_pt = word_thold;
float temperature = readablemap::getFloat(env, transcribe_params, "temperature", -1);
if (temperature > -1) params.temperature = temperature;
float temperature_inc = readablemap::getFloat(env, transcribe_params, "temperatureInc", -1);
if (temperature_inc > -1) params.temperature_inc = temperature_inc;
jstring prompt = readablemap::getString(env, transcribe_params, "prompt", nullptr);
if (prompt != nullptr) params.initial_prompt = env->GetStringUTFChars(prompt, nullptr);
jstring language = readablemap::getString(env, transcribe_params, "language", nullptr);
if (language != nullptr) params.language = env->GetStringUTFChars(language, nullptr);

// abort handlers
bool* abort_ptr = rn_whisper_assign_abort_map(job_id);
Expand Down Expand Up @@ -344,7 +322,8 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
// whisper_print_timings(context);
}
env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
env->ReleaseStringUTFChars(language, language_chars);
if (language != nullptr) env->ReleaseStringUTFChars(language, params.language);
if (prompt != nullptr) env->ReleaseStringUTFChars(prompt, params.initial_prompt);
if (rn_whisper_transcribe_is_aborted(job_id)) {
code = -999;
}
Expand Down

0 comments on commit c5beca9

Please sign in to comment.