Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(android): refactor fullTranscribe JNI method #165

Merged
merged 1 commit into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 3 additions & 47 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -404,38 +404,8 @@ private int full(int jobId, ReadableMap options, float[] audioData, int audioDat
audioData,
// jint audio_data_len,
audioDataLen,
// jint n_threads,
options.hasKey("maxThreads") ? options.getInt("maxThreads") : -1,
// jint max_context,
options.hasKey("maxContext") ? options.getInt("maxContext") : -1,

// jint word_thold,
options.hasKey("wordThold") ? options.getInt("wordThold") : -1,
// jint max_len,
options.hasKey("maxLen") ? options.getInt("maxLen") : -1,
// jboolean token_timestamps,
options.hasKey("tokenTimestamps") ? options.getBoolean("tokenTimestamps") : false,

// jint offset,
options.hasKey("offset") ? options.getInt("offset") : -1,
// jint duration,
options.hasKey("duration") ? options.getInt("duration") : -1,
// jfloat temperature,
options.hasKey("temperature") ? (float) options.getDouble("temperature") : -1.0f,
// jfloat temperature_inc,
options.hasKey("temperatureInc") ? (float) options.getDouble("temperatureInc") : -1.0f,
// jint beam_size,
options.hasKey("beamSize") ? options.getInt("beamSize") : -1,
// jint best_of,
options.hasKey("bestOf") ? options.getInt("bestOf") : -1,
// jboolean speed_up,
options.hasKey("speedUp") ? options.getBoolean("speedUp") : false,
// jboolean translate,
options.hasKey("translate") ? options.getBoolean("translate") : false,
// jstring language,
options.hasKey("language") ? options.getString("language") : "auto",
// jstring prompt
options.hasKey("prompt") ? options.getString("prompt") : null,
// ReadableMap options,
options,
// Callback callback
hasProgressCallback || hasNewSegmentsCallback ? new Callback(this, hasProgressCallback, hasNewSegmentsCallback) : null
);
Expand Down Expand Up @@ -567,21 +537,7 @@ protected static native int fullTranscribe(
long context,
float[] audio_data,
int audio_data_len,
int n_threads,
int max_context,
int word_thold,
int max_len,
boolean token_timestamps,
int offset,
int duration,
float temperature,
float temperature_inc,
int beam_size,
int best_of,
boolean speed_up,
boolean translate,
String language,
String prompt,
ReadableMap options,
Callback Callback
);
protected static native void abortTranscribe(int jobId);
Expand Down
76 changes: 76 additions & 0 deletions android/src/main/jni-utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include <jni.h>

// ReadableMap utils

namespace readablemap {

jboolean hasKey(JNIEnv *env, jobject readableMap, const char *key) {
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID hasKeyMethod = env->GetMethodID(mapClass, "hasKey", "(Ljava/lang/String;)Z");
jstring jKey = env->NewStringUTF(key);
jboolean result = env->CallBooleanMethod(readableMap, hasKeyMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

jint getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getIntMethod = env->GetMethodID(mapClass, "getInt", "(Ljava/lang/String;)I");
jstring jKey = env->NewStringUTF(key);
jint result = env->CallIntMethod(readableMap, getIntMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

jboolean getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getBoolMethod = env->GetMethodID(mapClass, "getBoolean", "(Ljava/lang/String;)Z");
jstring jKey = env->NewStringUTF(key);
jboolean result = env->CallBooleanMethod(readableMap, getBoolMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

jlong getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getLongMethod = env->GetMethodID(mapClass, "getLong", "(Ljava/lang/String;)J");
jstring jKey = env->NewStringUTF(key);
jlong result = env->CallLongMethod(readableMap, getLongMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

jfloat getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getFloatMethod = env->GetMethodID(mapClass, "getDouble", "(Ljava/lang/String;)D");
jstring jKey = env->NewStringUTF(key);
jfloat result = env->CallDoubleMethod(readableMap, getFloatMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

jstring getString(JNIEnv *env, jobject readableMap, const char *key, jstring defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getStringMethod = env->GetMethodID(mapClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
jstring jKey = env->NewStringUTF(key);
jstring result = (jstring) env->CallObjectMethod(readableMap, getStringMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

}
93 changes: 36 additions & 57 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "whisper.h"
#include "rn-whisper.h"
#include "ggml.h"
#include "jni-utils.h"

#define UNUSED(x) (void)(x)
#define TAG "JNI"
Expand Down Expand Up @@ -219,82 +220,59 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
jlong context_ptr,
jfloatArray audio_data,
jint audio_data_len,
jint n_threads,
jint max_context,
int word_thold,
int max_len,
jboolean token_timestamps,
jint offset,
jint duration,
jfloat temperature,
jfloat temperature_inc,
jint beam_size,
jint best_of,
jboolean speed_up,
jboolean translate,
jstring language,
jstring prompt,
jobject transcribe_params,
jobject callback_instance
) {
UNUSED(thiz);
struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);

int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);

LOGI("About to create params");

struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);

if (beam_size > -1) {
params.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
params.beam_search.beam_size = beam_size;
}

params.print_realtime = false;
params.print_progress = false;
params.print_timestamps = false;
params.print_special = false;
params.translate = translate;
const char *language_chars = env->GetStringUTFChars(language, nullptr);
params.language = language_chars;

int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
int n_threads = readablemap::getInt(env, transcribe_params, "maxThreads", default_n_threads);
params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
params.speed_up = speed_up;
params.translate = readablemap::getBool(env, transcribe_params, "translate", false);
params.speed_up = readablemap::getBool(env, transcribe_params, "speedUp", false);
params.token_timestamps = readablemap::getBool(env, transcribe_params, "tokenTimestamps", false);
params.offset_ms = 0;
params.no_context = true;
params.single_segment = false;

if (max_len > -1) {
params.max_len = max_len;
}
params.token_timestamps = token_timestamps;

if (best_of > -1) {
params.greedy.best_of = best_of;
}
if (max_context > -1) {
params.n_max_text_ctx = max_context;
}
if (offset > -1) {
params.offset_ms = offset;
}
if (duration > -1) {
params.duration_ms = duration;
}
if (word_thold > -1) {
params.thold_pt = word_thold;
}
if (temperature > -1) {
params.temperature = temperature;
}
if (temperature_inc > -1) {
params.temperature_inc = temperature_inc;
}
if (prompt != nullptr) {
params.initial_prompt = env->GetStringUTFChars(prompt, nullptr);
int beam_size = readablemap::getInt(env, transcribe_params, "beamSize", -1);
if (beam_size > -1) {
params.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
params.beam_search.beam_size = beam_size;
}
int best_of = readablemap::getInt(env, transcribe_params, "bestOf", -1);
if (best_of > -1) params.greedy.best_of = best_of;
int max_len = readablemap::getInt(env, transcribe_params, "maxLen", -1);
if (max_len > -1) params.max_len = max_len;
int max_context = readablemap::getInt(env, transcribe_params, "maxContext", -1);
if (max_context > -1) params.n_max_text_ctx = max_context;
int offset = readablemap::getInt(env, transcribe_params, "offset", -1);
if (offset > -1) params.offset_ms = offset;
int duration = readablemap::getInt(env, transcribe_params, "duration", -1);
if (duration > -1) params.duration_ms = duration;
int word_thold = readablemap::getInt(env, transcribe_params, "wordThold", -1);
if (word_thold > -1) params.thold_pt = word_thold;
float temperature = readablemap::getFloat(env, transcribe_params, "temperature", -1);
if (temperature > -1) params.temperature = temperature;
float temperature_inc = readablemap::getFloat(env, transcribe_params, "temperatureInc", -1);
if (temperature_inc > -1) params.temperature_inc = temperature_inc;
jstring prompt = readablemap::getString(env, transcribe_params, "prompt", nullptr);
if (prompt != nullptr) params.initial_prompt = env->GetStringUTFChars(prompt, nullptr);
jstring language = readablemap::getString(env, transcribe_params, "language", nullptr);
if (language != nullptr) params.language = env->GetStringUTFChars(language, nullptr);

// abort handlers
bool* abort_ptr = rn_whisper_assign_abort_map(job_id);
Expand Down Expand Up @@ -344,7 +322,8 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
// whisper_print_timings(context);
}
env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
env->ReleaseStringUTFChars(language, language_chars);
if (language != nullptr) env->ReleaseStringUTFChars(language, params.language);
if (prompt != nullptr) env->ReleaseStringUTFChars(prompt, params.initial_prompt);
if (rn_whisper_transcribe_is_aborted(job_id)) {
code = -999;
}
Expand Down
Loading