Skip to content

Commit

Permalink
feat: support multiple lora files & dynamic apply / remove lora (#92)
Browse files Browse the repository at this point in the history
* feat: support multi lora params

* feat: lora apply / remove / list for initialized context

* feat(ts): add methods

* fix(ts): lora list path

* fix: remove removePrevious

* fix: use llama->applyLoraAdapters on init

* feat(example): add lora comments

* fix(android): push map

* fix(example): getLoadedLoraAdapters usage

* fix(cpp): apply empty list instead of expose new fn

* fix(ios): removeLoraAdapters

* feat: check context is predicting
  • Loading branch information
jhen0409 authored Nov 21, 2024
1 parent f408643 commit 9b6ea9f
Show file tree
Hide file tree
Showing 18 changed files with 608 additions and 44 deletions.
1 change: 1 addition & 0 deletions android/src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ set(
${RNLLAMA_LIB_DIR}/sgemm.cpp
${RNLLAMA_LIB_DIR}/ggml-aarch64.c
${RNLLAMA_LIB_DIR}/rn-llama.hpp
${CMAKE_SOURCE_DIR}/jni-utils.h
${CMAKE_SOURCE_DIR}/jni.cpp
)

Expand Down
24 changes: 23 additions & 1 deletion android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
params.hasKey("lora") ? params.getString("lora") : "",
// float lora_scaled,
params.hasKey("lora_scaled") ? (float) params.getDouble("lora_scaled") : 1.0f,
// ReadableArray lora_adapters,
params.hasKey("lora_list") ? params.getArray("lora_list") : null,
// float rope_freq_base,
params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
// float rope_freq_scale
Expand Down Expand Up @@ -301,6 +303,22 @@ public String bench(int pp, int tg, int pl, int nr) {
return bench(this.context, pp, tg, pl, nr);
}

public int applyLoraAdapters(ReadableArray loraAdapters) {
int result = applyLoraAdapters(this.context, loraAdapters);
if (result != 0) {
throw new IllegalStateException("Failed to apply lora adapters");
}
return result;
}

public void removeLoraAdapters() {
removeLoraAdapters(this.context);
}

public WritableArray getLoadedLoraAdapters() {
return getLoadedLoraAdapters(this.context);
}

public void release() {
freeContext(context);
}
Expand Down Expand Up @@ -406,6 +424,7 @@ protected static native long initContext(
boolean vocab_only,
String lora,
float lora_scaled,
ReadableArray lora_list,
float rope_freq_base,
float rope_freq_scale,
int pooling_type,
Expand Down Expand Up @@ -457,7 +476,7 @@ protected static native WritableMap doCompletion(
double[][] logit_bias,
float dry_multiplier,
float dry_base,
int dry_allowed_length,
int dry_allowed_length,
int dry_penalty_last_n,
String[] dry_sequence_breakers,
PartialCompletionCallback partial_completion_callback
Expand All @@ -473,6 +492,9 @@ protected static native WritableMap embedding(
int embd_normalize
);
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
protected static native int applyLoraAdapters(long contextPtr, ReadableArray loraAdapters);
protected static native void removeLoraAdapters(long contextPtr);
protected static native WritableArray getLoadedLoraAdapters(long contextPtr);
protected static native void freeContext(long contextPtr);
protected static native void logToAndroid();
}
98 changes: 98 additions & 0 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,104 @@ protected void onPostExecute(String result) {
tasks.put(task, "bench-" + contextId);
}

public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Void>() {
private Exception exception;

@Override
protected Void doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
if (context.isPredicting()) {
throw new Exception("Context is busy");
}
context.applyLoraAdapters(loraAdapters);
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(Void result) {
if (exception != null) {
promise.reject(exception);
return;
}
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
tasks.put(task, "applyLoraAdapters-" + contextId);
}

public void removeLoraAdapters(double id, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Void>() {
private Exception exception;

@Override
protected Void doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
if (context.isPredicting()) {
throw new Exception("Context is busy");
}
context.removeLoraAdapters();
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(Void result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(null);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
tasks.put(task, "removeLoraAdapters-" + contextId);
}

public void getLoadedLoraAdapters(double id, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, ReadableArray>() {
private Exception exception;

@Override
protected ReadableArray doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
return context.getLoadedLoraAdapters();
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(ReadableArray result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
tasks.put(task, "getLoadedLoraAdapters-" + contextId);
}

public void releaseContext(double id, Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Void>() {
Expand Down
94 changes: 94 additions & 0 deletions android/src/main/jni-utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#include <jni.h>

// ReadableMap utils

namespace readablearray {

int size(JNIEnv *env, jobject readableArray) {
jclass arrayClass = env->GetObjectClass(readableArray);
jmethodID sizeMethod = env->GetMethodID(arrayClass, "size", "()I");
return env->CallIntMethod(readableArray, sizeMethod);
}

jobject getMap(JNIEnv *env, jobject readableArray, int index) {
jclass arrayClass = env->GetObjectClass(readableArray);
jmethodID getMapMethod = env->GetMethodID(arrayClass, "getMap", "(I)Lcom/facebook/react/bridge/ReadableMap;");
return env->CallObjectMethod(readableArray, getMapMethod, index);
}

// Other methods not used yet

}

namespace readablemap {

bool hasKey(JNIEnv *env, jobject readableMap, const char *key) {
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID hasKeyMethod = env->GetMethodID(mapClass, "hasKey", "(Ljava/lang/String;)Z");
jstring jKey = env->NewStringUTF(key);
jboolean result = env->CallBooleanMethod(readableMap, hasKeyMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

int getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getIntMethod = env->GetMethodID(mapClass, "getInt", "(Ljava/lang/String;)I");
jstring jKey = env->NewStringUTF(key);
jint result = env->CallIntMethod(readableMap, getIntMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

bool getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getBoolMethod = env->GetMethodID(mapClass, "getBoolean", "(Ljava/lang/String;)Z");
jstring jKey = env->NewStringUTF(key);
jboolean result = env->CallBooleanMethod(readableMap, getBoolMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

long getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getLongMethod = env->GetMethodID(mapClass, "getLong", "(Ljava/lang/String;)J");
jstring jKey = env->NewStringUTF(key);
jlong result = env->CallLongMethod(readableMap, getLongMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

float getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getFloatMethod = env->GetMethodID(mapClass, "getDouble", "(Ljava/lang/String;)D");
jstring jKey = env->NewStringUTF(key);
jfloat result = env->CallDoubleMethod(readableMap, getFloatMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

jstring getString(JNIEnv *env, jobject readableMap, const char *key, jstring defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getStringMethod = env->GetMethodID(mapClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
jstring jKey = env->NewStringUTF(key);
jstring result = (jstring) env->CallObjectMethod(readableMap, getStringMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

}
Loading

0 comments on commit 9b6ea9f

Please sign in to comment.