diff --git a/android/src/main/CMakeLists.txt b/android/src/main/CMakeLists.txt index ed77fa8..640dc4c 100644 --- a/android/src/main/CMakeLists.txt +++ b/android/src/main/CMakeLists.txt @@ -33,6 +33,7 @@ set( ${RNLLAMA_LIB_DIR}/sgemm.cpp ${RNLLAMA_LIB_DIR}/ggml-aarch64.c ${RNLLAMA_LIB_DIR}/rn-llama.hpp + ${CMAKE_SOURCE_DIR}/jni-utils.h ${CMAKE_SOURCE_DIR}/jni.cpp ) diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index b369bb1..9c14a3a 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -70,6 +70,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa params.hasKey("lora") ? params.getString("lora") : "", // float lora_scaled, params.hasKey("lora_scaled") ? (float) params.getDouble("lora_scaled") : 1.0f, + // ReadableArray lora_adapters, + params.hasKey("lora_list") ? params.getArray("lora_list") : null, // float rope_freq_base, params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f, // float rope_freq_scale @@ -301,6 +303,22 @@ public String bench(int pp, int tg, int pl, int nr) { return bench(this.context, pp, tg, pl, nr); } + public int applyLoraAdapters(ReadableArray loraAdapters) { + int result = applyLoraAdapters(this.context, loraAdapters); + if (result != 0) { + throw new IllegalStateException("Failed to apply lora adapters"); + } + return result; + } + + public void removeLoraAdapters() { + removeLoraAdapters(this.context); + } + + public WritableArray getLoadedLoraAdapters() { + return getLoadedLoraAdapters(this.context); + } + public void release() { freeContext(context); } @@ -406,6 +424,7 @@ protected static native long initContext( boolean vocab_only, String lora, float lora_scaled, + ReadableArray lora_list, float rope_freq_base, float rope_freq_scale, int pooling_type, @@ -457,7 +476,7 @@ protected static native WritableMap doCompletion( double[][] logit_bias, float dry_multiplier, float dry_base, - int dry_allowed_length, + int dry_allowed_length, int dry_penalty_last_n, String[] dry_sequence_breakers, PartialCompletionCallback partial_completion_callback @@ -473,6 +492,9 @@ protected static native WritableMap embedding( int embd_normalize ); protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr); + protected static native int applyLoraAdapters(long contextPtr, ReadableArray loraAdapters); + protected static native void removeLoraAdapters(long contextPtr); + protected static native WritableArray getLoadedLoraAdapters(long contextPtr); protected static native void freeContext(long contextPtr); protected static native void logToAndroid(); } diff --git a/android/src/main/java/com/rnllama/RNLlama.java b/android/src/main/java/com/rnllama/RNLlama.java index aa19731..ab98a92 100644 --- a/android/src/main/java/com/rnllama/RNLlama.java +++ b/android/src/main/java/com/rnllama/RNLlama.java @@ -413,6 +413,104 @@ protected void onPostExecute(String result) { tasks.put(task, "bench-" + contextId); } + public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) { + final int contextId = (int) id; + AsyncTask task = new AsyncTask() { + private Exception exception; + + @Override + protected Void doInBackground(Void... voids) { + try { + LlamaContext context = contexts.get(contextId); + if (context == null) { + throw new Exception("Context not found"); + } + if (context.isPredicting()) { + throw new Exception("Context is busy"); + } + context.applyLoraAdapters(loraAdapters); + } catch (Exception e) { + exception = e; + } + return null; + } + + @Override + protected void onPostExecute(Void result) { + if (exception != null) { + promise.reject(exception); + return; + } + } + }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR); + tasks.put(task, "applyLoraAdapters-" + contextId); + } + + public void removeLoraAdapters(double id, final Promise promise) { + final int contextId = (int) id; + AsyncTask task = new AsyncTask() { + private Exception exception; + + @Override + protected Void doInBackground(Void... voids) { + try { + LlamaContext context = contexts.get(contextId); + if (context == null) { + throw new Exception("Context not found"); + } + if (context.isPredicting()) { + throw new Exception("Context is busy"); + } + context.removeLoraAdapters(); + } catch (Exception e) { + exception = e; + } + return null; + } + + @Override + protected void onPostExecute(Void result) { + if (exception != null) { + promise.reject(exception); + return; + } + promise.resolve(null); + } + }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR); + tasks.put(task, "removeLoraAdapters-" + contextId); + } + + public void getLoadedLoraAdapters(double id, final Promise promise) { + final int contextId = (int) id; + AsyncTask task = new AsyncTask() { + private Exception exception; + + @Override + protected ReadableArray doInBackground(Void... voids) { + try { + LlamaContext context = contexts.get(contextId); + if (context == null) { + throw new Exception("Context not found"); + } + return context.getLoadedLoraAdapters(); + } catch (Exception e) { + exception = e; + } + return null; + } + + @Override + protected void onPostExecute(ReadableArray result) { + if (exception != null) { + promise.reject(exception); + return; + } + promise.resolve(result); + } + }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR); + tasks.put(task, "getLoadedLoraAdapters-" + contextId); + } + public void releaseContext(double id, Promise promise) { final int contextId = (int) id; AsyncTask task = new AsyncTask() { diff --git a/android/src/main/jni-utils.h b/android/src/main/jni-utils.h new file mode 100644 index 0000000..39bde3e --- /dev/null +++ b/android/src/main/jni-utils.h @@ -0,0 +1,94 @@ +#include + +// ReadableMap utils + +namespace readablearray { + +int size(JNIEnv *env, jobject readableArray) { + jclass arrayClass = env->GetObjectClass(readableArray); + jmethodID sizeMethod = env->GetMethodID(arrayClass, "size", "()I"); + return env->CallIntMethod(readableArray, sizeMethod); +} + +jobject getMap(JNIEnv *env, jobject readableArray, int index) { + jclass arrayClass = env->GetObjectClass(readableArray); + jmethodID getMapMethod = env->GetMethodID(arrayClass, "getMap", "(I)Lcom/facebook/react/bridge/ReadableMap;"); + return env->CallObjectMethod(readableArray, getMapMethod, index); +} + +// Other methods not used yet + +} + +namespace readablemap { + +bool hasKey(JNIEnv *env, jobject readableMap, const char *key) { + jclass mapClass = env->GetObjectClass(readableMap); + jmethodID hasKeyMethod = env->GetMethodID(mapClass, "hasKey", "(Ljava/lang/String;)Z"); + jstring jKey = env->NewStringUTF(key); + jboolean result = env->CallBooleanMethod(readableMap, hasKeyMethod, jKey); + env->DeleteLocalRef(jKey); + return result; +} + +int getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) { + if (!hasKey(env, readableMap, key)) { + return defaultValue; + } + jclass mapClass = env->GetObjectClass(readableMap); + jmethodID getIntMethod = env->GetMethodID(mapClass, "getInt", "(Ljava/lang/String;)I"); + jstring jKey = env->NewStringUTF(key); + jint result = env->CallIntMethod(readableMap, getIntMethod, jKey); + env->DeleteLocalRef(jKey); + return result; +} + +bool getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) { + if (!hasKey(env, readableMap, key)) { + return defaultValue; + } + jclass mapClass = env->GetObjectClass(readableMap); + jmethodID getBoolMethod = env->GetMethodID(mapClass, "getBoolean", "(Ljava/lang/String;)Z"); + jstring jKey = env->NewStringUTF(key); + jboolean result = env->CallBooleanMethod(readableMap, getBoolMethod, jKey); + env->DeleteLocalRef(jKey); + return result; +} + +long getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) { + if (!hasKey(env, readableMap, key)) { + return defaultValue; + } + jclass mapClass = env->GetObjectClass(readableMap); + jmethodID getLongMethod = env->GetMethodID(mapClass, "getLong", "(Ljava/lang/String;)J"); + jstring jKey = env->NewStringUTF(key); + jlong result = env->CallLongMethod(readableMap, getLongMethod, jKey); + env->DeleteLocalRef(jKey); + return result; +} + +float getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) { + if (!hasKey(env, readableMap, key)) { + return defaultValue; + } + jclass mapClass = env->GetObjectClass(readableMap); + jmethodID getFloatMethod = env->GetMethodID(mapClass, "getDouble", "(Ljava/lang/String;)D"); + jstring jKey = env->NewStringUTF(key); + jfloat result = env->CallDoubleMethod(readableMap, getFloatMethod, jKey); + env->DeleteLocalRef(jKey); + return result; +} + +jstring getString(JNIEnv *env, jobject readableMap, const char *key, jstring defaultValue) { + if (!hasKey(env, readableMap, key)) { + return defaultValue; + } + jclass mapClass = env->GetObjectClass(readableMap); + jmethodID getStringMethod = env->GetMethodID(mapClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;"); + jstring jKey = env->NewStringUTF(key); + jstring result = (jstring) env->CallObjectMethod(readableMap, getStringMethod, jKey); + env->DeleteLocalRef(jKey); + return result; +} + +} diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 1ec3d19..d261657 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -12,6 +12,7 @@ #include "llama-impl.h" #include "ggml.h" #include "rn-llama.hpp" +#include "jni-utils.h" #define UNUSED(x) (void)(x) #define TAG "RNLLAMA_ANDROID_JNI" @@ -127,7 +128,7 @@ static inline void pushString(JNIEnv *env, jobject arr, const char *value) { // Method to push WritableMap into WritableArray static inline void pushMap(JNIEnv *env, jobject arr, jobject value) { jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray"); - jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/WritableMap;)V"); + jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/ReadableMap;)V"); env->CallVoidMethod(arr, pushMapMethod, value); } @@ -235,6 +236,7 @@ Java_com_rnllama_LlamaContext_initContext( jboolean vocab_only, jstring lora_str, jfloat lora_scaled, + jobject lora_list, jfloat rope_freq_base, jfloat rope_freq_scale, jint pooling_type, @@ -284,11 +286,6 @@ Java_com_rnllama_LlamaContext_initContext( defaultParams.use_mlock = use_mlock; defaultParams.use_mmap = use_mmap; - const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr); - if (lora_chars != nullptr && lora_chars[0] != '\0') { - defaultParams.lora_adapters.push_back({lora_chars, lora_scaled}); - } - defaultParams.rope_freq_base = rope_freq_base; defaultParams.rope_freq_scale = rope_freq_scale; @@ -322,20 +319,52 @@ Java_com_rnllama_LlamaContext_initContext( bool is_model_loaded = llama->loadModel(defaultParams); env->ReleaseStringUTFChars(model_path_str, model_path_chars); - env->ReleaseStringUTFChars(lora_str, lora_chars); env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars); env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars); LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false")); if (is_model_loaded) { - if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) { - LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported"); - llama_free(llama->ctx); - return -1; - } - context_map[(long) llama->ctx] = llama; + if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) { + LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported"); + llama_free(llama->ctx); + return -1; + } + context_map[(long) llama->ctx] = llama; } else { + llama_free(llama->ctx); + } + + std::vector lora_adapters; + const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr); + if (lora_chars != nullptr && lora_chars[0] != '\0') { + common_lora_adapter_info la; + la.path = lora_chars; + la.scale = lora_scaled; + lora_adapters.push_back(la); + } + + if (lora_list != nullptr) { + // lora_adapters: ReadableArray + int lora_list_size = readablearray::size(env, lora_list); + for (int i = 0; i < lora_list_size; i++) { + jobject lora_adapter = readablearray::getMap(env, lora_list, i); + jstring path = readablemap::getString(env, lora_adapter, "path", nullptr); + if (path != nullptr) { + const char *path_chars = env->GetStringUTFChars(path, nullptr); + common_lora_adapter_info la; + la.path = path_chars; + la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f); + lora_adapters.push_back(la); + env->ReleaseStringUTFChars(path, path_chars); + } + } + } + env->ReleaseStringUTFChars(lora_str, lora_chars); + int result = llama->applyLoraAdapters(lora_adapters); + if (result != 0) { + LOGI("[RNLlama] Failed to apply lora adapters"); llama_free(llama->ctx); + return -1; } return reinterpret_cast(llama->ctx); @@ -537,7 +566,7 @@ Java_com_rnllama_LlamaContext_doCompletion( jobjectArray logit_bias, jfloat dry_multiplier, jfloat dry_base, - jint dry_allowed_length, + jint dry_allowed_length, jint dry_penalty_last_n, jobjectArray dry_sequence_breakers, jobject partial_completion_callback @@ -876,6 +905,56 @@ Java_com_rnllama_LlamaContext_bench( return env->NewStringUTF(result.c_str()); } +JNIEXPORT jint JNICALL +Java_com_rnllama_LlamaContext_applyLoraAdapters( + JNIEnv *env, jobject thiz, jlong context_ptr, jobjectArray loraAdapters) { + UNUSED(thiz); + auto llama = context_map[(long) context_ptr]; + + // lora_adapters: ReadableArray + std::vector lora_adapters; + int lora_adapters_size = readablearray::size(env, loraAdapters); + for (int i = 0; i < lora_adapters_size; i++) { + jobject lora_adapter = readablearray::getMap(env, loraAdapters, i); + jstring path = readablemap::getString(env, lora_adapter, "path", nullptr); + if (path != nullptr) { + const char *path_chars = env->GetStringUTFChars(path, nullptr); + env->ReleaseStringUTFChars(path, path_chars); + float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f); + common_lora_adapter_info la; + la.path = path_chars; + la.scale = scaled; + lora_adapters.push_back(la); + } + } + return llama->applyLoraAdapters(lora_adapters); +} + +JNIEXPORT void JNICALL +Java_com_rnllama_LlamaContext_removeLoraAdapters( + JNIEnv *env, jobject thiz, jlong context_ptr) { + UNUSED(env); + UNUSED(thiz); + auto llama = context_map[(long) context_ptr]; + llama->removeLoraAdapters(); +} + +JNIEXPORT jobject JNICALL +Java_com_rnllama_LlamaContext_getLoadedLoraAdapters( + JNIEnv *env, jobject thiz, jlong context_ptr) { + UNUSED(thiz); + auto llama = context_map[(long) context_ptr]; + auto loaded_lora_adapters = llama->getLoadedLoraAdapters(); + auto result = createWritableArray(env); + for (common_lora_adapter_container &la : loaded_lora_adapters) { + auto map = createWriteableMap(env); + putString(env, map, "path", la.path.c_str()); + putDouble(env, map, "scaled", la.scale); + pushMap(env, result, map); + } + return result; +} + JNIEXPORT void JNICALL Java_com_rnllama_LlamaContext_freeContext( JNIEnv *env, jobject thiz, jlong context_ptr) { diff --git a/android/src/newarch/java/com/rnllama/RNLlamaModule.java b/android/src/newarch/java/com/rnllama/RNLlamaModule.java index 19077c8..0c154d4 100644 --- a/android/src/newarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/newarch/java/com/rnllama/RNLlamaModule.java @@ -92,6 +92,21 @@ public void bench(double id, final double pp, final double tg, final double pl, rnllama.bench(id, pp, tg, pl, nr, promise); } + @ReactMethod + public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) { + rnllama.applyLoraAdapters(id, loraAdapters, promise); + } + + @ReactMethod + public void removeLoraAdapters(double id, final Promise promise) { + rnllama.removeLoraAdapters(id, promise); + } + + @ReactMethod + public void getLoadedLoraAdapters(double id, final Promise promise) { + rnllama.getLoadedLoraAdapters(id, promise); + } + @ReactMethod public void releaseContext(double id, Promise promise) { rnllama.releaseContext(id, promise); diff --git a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java index a96bf3a..6da8e8f 100644 --- a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java @@ -93,6 +93,21 @@ public void bench(double id, final double pp, final double tg, final double pl, rnllama.bench(id, pp, tg, pl, nr, promise); } + @ReactMethod + public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) { + rnllama.applyLoraAdapters(id, loraAdapters); + } + + @ReactMethod + public void removeLoraAdapters(double id, final Promise promise) { + rnllama.removeLoraAdapters(id, promise); + } + + @ReactMethod + public void getLoadedLoraAdapters(double id, final Promise promise) { + rnllama.getLoadedLoraAdapters(id, promise); + } + @ReactMethod public void releaseContext(double id, Promise promise) { rnllama.releaseContext(id, promise); diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index aa6f9bd..c175744 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -229,6 +229,8 @@ struct llama_rn_context std::string stopping_word; bool incomplete = false; + std::vector lora_adapters; + ~llama_rn_context() { if (ctx) @@ -723,6 +725,35 @@ struct llama_rn_context std::to_string(tg_std) + std::string("]"); } + + int applyLoraAdapters(std::vector lora_adapters) { + this->lora_adapters.clear(); + auto containers = std::vector(); + for (auto & la : lora_adapters) { + common_lora_adapter_container loaded_la; + loaded_la.path = la.path; + loaded_la.scale = la.scale; + loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str()); + if (loaded_la.adapter == nullptr) { + LOG_ERROR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); + return -1; + } + + this->lora_adapters.push_back(loaded_la); + containers.push_back(loaded_la); + } + common_lora_adapters_apply(ctx, containers); + return 0; + } + + void removeLoraAdapters() { + this->lora_adapters.clear(); + common_lora_adapters_apply(ctx, this->lora_adapters); // apply empty list + } + + std::vector getLoadedLoraAdapters() { + return this->lora_adapters; + } }; } diff --git a/example/ios/.xcode.env.local b/example/ios/.xcode.env.local index 3f04e1e..1fa3749 100644 --- a/example/ios/.xcode.env.local +++ b/example/ios/.xcode.env.local @@ -1 +1 @@ -export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1731819230881-0.9134550796855203/node +export NODE_BINARY=/var/folders/g8/v75_3l3n23g909mshlzdj4wh0000gn/T/yarn--1731985865125-0.724061577974688/node diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock index fa10b28..6b776f8 100644 --- a/example/ios/Podfile.lock +++ b/example/ios/Podfile.lock @@ -8,7 +8,7 @@ PODS: - hermes-engine/Pre-built (= 0.72.3) - hermes-engine/Pre-built (0.72.3) - libevent (2.1.12) - - llama-rn (0.4.0): + - llama-rn (0.4.1): - RCT-Folly - RCTRequired - RCTTypeSafety @@ -1261,7 +1261,7 @@ SPEC CHECKSUMS: glog: 04b94705f318337d7ead9e6d17c019bd9b1f6b1b hermes-engine: 10fbd3f62405c41ea07e71973ea61e1878d07322 libevent: 4049cae6c81cdb3654a443be001fb9bdceff7913 - llama-rn: d935a3e23a8c1bb15ca58578af852c16d608bcaa + llama-rn: 763672c81a2903020663ad432f2051357e1f20ba RCT-Folly: 424b8c9a7a0b9ab2886ffe9c3b041ef628fd4fb1 RCTRequired: a2faf4bad4e438ca37b2040cb8f7799baa065c18 RCTTypeSafety: cb09f3e4747b6d18331a15eb05271de7441ca0b3 diff --git a/example/src/App.tsx b/example/src/App.tsx index 2cb4c59..3849753 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -95,7 +95,7 @@ export default function App() { const handleInitContext = async ( file: DocumentPickerResponse, - loraFile?: DocumentPickerResponse, + loraFile: DocumentPickerResponse | null, ) => { await handleReleaseContext() await getModelInfo(file.uri) @@ -108,7 +108,7 @@ export default function App() { n_gpu_layers: Platform.OS === 'ios' ? 99 : 0, // > 0: enable GPU // embedding: true, - lora: loraFile?.uri, + lora_list: loraFile ? [{ path: loraFile.uri, scaled: 1.0 }] : undefined, // Or lora: loraFile?.uri, }, (progress) => { setMessages((msgs) => { @@ -179,6 +179,15 @@ export default function App() { return file } + const pickLora = async () => { + let loraFile + const loraRes = await DocumentPicker.pick({ + type: Platform.OS === 'ios' ? 'public.data' : 'application/octet-stream', + }).catch((e) => console.log('No lora file picked, error: ', e.message)) + if (loraRes?.[0]) loraFile = await copyFileIfNeeded('lora', loraRes[0]) + return loraFile + } + const handlePickModel = async () => { const modelRes = await DocumentPicker.pick({ type: Platform.OS === 'ios' ? 'public.data' : 'application/octet-stream', @@ -186,12 +195,10 @@ export default function App() { if (!modelRes?.[0]) return const modelFile = await copyFileIfNeeded('model', modelRes?.[0]) - let loraFile + let loraFile: any = null // Example: Apply lora adapter (Currently only select one lora file) (Uncomment to use) - // const loraRes = await DocumentPicker.pick({ - // type: Platform.OS === 'ios' ? 'public.data' : 'application/octet-stream', - // }).catch(e => console.log('No lora file picked, error: ', e.message)) - // if (loraRes?.[0]) loraFile = await copyFileIfNeeded('lora', loraRes[0]) + // loraFile = await pickLora() + loraFile = null handleInitContext(modelFile, loraFile) } @@ -278,6 +285,31 @@ export default function App() { addSystemMessage(`Session load failed: ${e.message}`) }) return + case '/lora': + pickLora() + .then((loraFile) => { + if (loraFile) + context.applyLoraAdapters([{ path: loraFile.uri }]) + }) + .then(() => context.getLoadedLoraAdapters()) + .then((loraList) => + addSystemMessage( + `Loaded lora adapters: ${JSON.stringify(loraList)}`, + ), + ) + return + case '/remove-lora': + context.removeLoraAdapters().then(() => { + addSystemMessage('Lora adapters removed!') + }) + return + case '/lora-list': + context.getLoadedLoraAdapters().then((loraList) => { + addSystemMessage( + `Loaded lora adapters: ${JSON.stringify(loraList)}`, + ) + }) + return } } const textMessage: MessageType.Text = { @@ -417,7 +449,7 @@ export default function App() { dry_base: 1.75, dry_allowed_length: 2, dry_penalty_last_n: -1, - dry_sequence_breakers: ["\n", ":", "\"", "*"], + dry_sequence_breakers: ['\n', ':', '"', '*'], mirostat: 0, mirostat_tau: 5, mirostat_eta: 0.1, diff --git a/ios/RNLlama.mm b/ios/RNLlama.mm index 0f61f67..9c6b848 100644 --- a/ios/RNLlama.mm +++ b/ios/RNLlama.mm @@ -271,6 +271,53 @@ - (NSArray *)supportedEvents { } } +RCT_EXPORT_METHOD(applyLoraAdapters:(double)contextId + withLoraAdapters:(NSArray *)loraAdapters + withResolver:(RCTPromiseResolveBlock)resolve + withRejecter:(RCTPromiseRejectBlock)reject) +{ + RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]]; + if (context == nil) { + reject(@"llama_error", @"Context not found", nil); + return; + } + if ([context isPredicting]) { + reject(@"llama_error", @"Context is busy", nil); + return; + } + [context applyLoraAdapters:loraAdapters]; + resolve(nil); +} + +RCT_EXPORT_METHOD(removeLoraAdapters:(double)contextId + withResolver:(RCTPromiseResolveBlock)resolve + withRejecter:(RCTPromiseRejectBlock)reject) +{ + RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]]; + if (context == nil) { + reject(@"llama_error", @"Context not found", nil); + return; + } + if ([context isPredicting]) { + reject(@"llama_error", @"Context is busy", nil); + return; + } + [context removeLoraAdapters]; + resolve(nil); +} + +RCT_EXPORT_METHOD(getLoadedLoraAdapters:(double)contextId + withResolver:(RCTPromiseResolveBlock)resolve + withRejecter:(RCTPromiseRejectBlock)reject) +{ + RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]]; + if (context == nil) { + reject(@"llama_error", @"Context not found", nil); + return; + } + resolve([context getLoadedLoraAdapters]); +} + RCT_EXPORT_METHOD(releaseContext:(double)contextId withResolver:(RCTPromiseResolveBlock)resolve withRejecter:(RCTPromiseRejectBlock)reject) diff --git a/ios/RNLlamaContext.h b/ios/RNLlamaContext.h index 52c4e92..82bcccd 100644 --- a/ios/RNLlamaContext.h +++ b/ios/RNLlamaContext.h @@ -33,7 +33,9 @@ - (NSDictionary *)loadSession:(NSString *)path; - (int)saveSession:(NSString *)path size:(int)size; - (NSString *)bench:(int)pp tg:(int)tg pl:(int)pl nr:(int)nr; - +- (void)applyLoraAdapters:(NSArray *)loraAdapters; +- (void)removeLoraAdapters; +- (NSArray *)getLoadedLoraAdapters; - (void)invalidate; @end diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index c2cb593..8d36488 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -110,12 +110,6 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig } } - if (params[@"lora"]) { - float lora_scaled = 1.0f; - if (params[@"lora_scaled"]) lora_scaled = [params[@"lora_scaled"] floatValue]; - defaultParams.lora_adapters.push_back({[params[@"lora"] UTF8String], lora_scaled}); - } - if (params[@"rope_freq_base"]) defaultParams.rope_freq_base = [params[@"rope_freq_base"] floatValue]; if (params[@"rope_freq_scale"]) defaultParams.rope_freq_scale = [params[@"rope_freq_scale"] floatValue]; @@ -130,6 +124,7 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig const int defaultNThreads = nThreads == 4 ? 2 : MIN(4, maxThreads); defaultParams.cpuparams.n_threads = nThreads > 0 ? nThreads : defaultNThreads; + RNLlamaContext *context = [[RNLlamaContext alloc] init]; context->llama = new rnllama::llama_rn_context(); context->llama->is_load_interrupted = false; @@ -159,6 +154,34 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig @throw [NSException exceptionWithName:@"LlamaException" reason:@"Embedding is not supported in encoder-decoder models" userInfo:nil]; } + std::vector lora_adapters; + if (params[@"lora"]) { + common_lora_adapter_info la; + la.path = [params[@"lora"] UTF8String]; + la.scale = 1.0f; + if (params[@"lora_scaled"]) la.scale = [params[@"lora_scaled"] floatValue]; + lora_adapters.push_back(la); + } + if (params[@"lora_list"] && [params[@"lora_list"] isKindOfClass:[NSArray class]]) { + NSArray *lora_list = params[@"lora_list"]; + for (NSDictionary *lora_adapter in lora_list) { + NSString *path = lora_adapter[@"path"]; + if (!path) continue; + float scale = [lora_adapter[@"scaled"] floatValue]; + common_lora_adapter_info la; + la.path = [path UTF8String]; + la.scale = scale; + lora_adapters.push_back(la); + } + } + if (lora_adapters.size() > 0) { + int result = context->llama->applyLoraAdapters(lora_adapters); + if (result != 0) { + delete context->llama; + @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to apply lora adapters" userInfo:nil]; + } + } + context->is_metal_enabled = isMetalEnabled; context->reason_no_metal = reasonNoMetal; @@ -538,6 +561,36 @@ - (NSString *)bench:(int)pp tg:(int)tg pl:(int)pl nr:(int)nr { return [NSString stringWithUTF8String:llama->bench(pp, tg, pl, nr).c_str()]; } +- (void)applyLoraAdapters:(NSArray *)loraAdapters { + std::vector lora_adapters; + for (NSDictionary *loraAdapter in loraAdapters) { + common_lora_adapter_info la; + la.path = [loraAdapter[@"path"] UTF8String]; + la.scale = [loraAdapter[@"scaled"] doubleValue]; + lora_adapters.push_back(la); + } + int result = llama->applyLoraAdapters(lora_adapters); + if (result != 0) { + @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to apply lora adapters" userInfo:nil]; + } +} + +- (void)removeLoraAdapters { + llama->removeLoraAdapters(); +} + +- (NSArray *)getLoadedLoraAdapters { + std::vector loaded_lora_adapters = llama->getLoadedLoraAdapters(); + NSMutableArray *result = [[NSMutableArray alloc] init]; + for (common_lora_adapter_container &la : loaded_lora_adapters) { + [result addObject:@{ + @"path": [NSString stringWithUTF8String:la.path.c_str()], + @"scale": @(la.scale) + }]; + } + return result; +} + - (void)invalidate { delete llama; // llama_backend_free(); diff --git a/scripts/common.cpp.patch b/scripts/common.cpp.patch index 1f4c63c..4cc23b7 100644 --- a/scripts/common.cpp.patch +++ b/scripts/common.cpp.patch @@ -1,5 +1,5 @@ ---- common.cpp.orig 2024-11-17 12:52:58 -+++ common.cpp 2024-11-17 12:48:35 +--- common.cpp.orig 2024-11-21 10:21:53 ++++ common.cpp 2024-11-21 10:22:56 @@ -4,10 +4,6 @@ #include "common.h" diff --git a/scripts/common.h.patch b/scripts/common.h.patch index 9a2b365..26b4253 100644 --- a/scripts/common.h.patch +++ b/scripts/common.h.patch @@ -1,5 +1,5 @@ ---- common.h.orig 2024-11-17 11:56:40 -+++ common.h 2024-11-17 11:56:41 +--- common.h.orig 2024-11-21 10:21:53 ++++ common.h 2024-11-21 10:23:00 @@ -41,6 +41,17 @@ struct common_control_vector_load_info; diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index 5278f39..e69f37f 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -34,8 +34,18 @@ export type NativeContextParams = { use_mmap?: boolean vocab_only?: boolean - lora?: string // lora_adaptor + /** + * Single LoRA adapter path + */ + lora?: string + /** + * Single LoRA adapter scale + */ lora_scaled?: number + /** + * LoRA adapter list + */ + lora_list?: Array<{ path: string; scaled?: number }> rope_freq_base?: number rope_freq_scale?: number @@ -237,7 +247,10 @@ export interface Spec extends TurboModule { setContextLimit(limit: number): Promise modelInfo(path: string, skip?: string[]): Promise - initContext(contextId: number, params: NativeContextParams): Promise + initContext( + contextId: number, + params: NativeContextParams, + ): Promise getFormattedChat( contextId: number, @@ -273,6 +286,15 @@ export interface Spec extends TurboModule { nr: number, ): Promise + applyLoraAdapters( + contextId: number, + loraAdapters: Array<{ path: string; scaled?: number }>, + ): Promise + removeLoraAdapters(contextId: number): Promise + getLoadedLoraAdapters( + contextId: number, + ): Promise> + releaseContext(contextId: number): Promise releaseAllContexts(): Promise diff --git a/src/index.ts b/src/index.ts index d85d929..57fa208 100644 --- a/src/index.ts +++ b/src/index.ts @@ -14,7 +14,10 @@ import type { NativeCompletionTokenProbItem, NativeCompletionResultTimings, } from './NativeRNLlama' -import type { SchemaGrammarConverterPropOrder, SchemaGrammarConverterBuiltinRule } from './grammar' +import type { + SchemaGrammarConverterPropOrder, + SchemaGrammarConverterBuiltinRule, +} from './grammar' import { SchemaGrammarConverter, convertJsonSchemaToGrammar } from './grammar' import type { RNLlamaMessagePart, RNLlamaOAICompatibleMessage } from './chat' import { formatChat } from './chat' @@ -63,10 +66,26 @@ type TokenNativeEvent = { export type ContextParams = Omit< NativeContextParams, - 'cache_type_k' | 'cache_type_v' | 'pooling_type' + 'cache_type_k' | 'cache_type_v' | 'pooling_type' > & { - cache_type_k?: 'f16' | 'f32' | 'q8_0' | 'q4_0' | 'q4_1' | 'iq4_nl' | 'q5_0' | 'q5_1' - cache_type_v?: 'f16' | 'f32' | 'q8_0' | 'q4_0' | 'q4_1' | 'iq4_nl' | 'q5_0' | 'q5_1' + cache_type_k?: + | 'f16' + | 'f32' + | 'q8_0' + | 'q4_0' + | 'q4_1' + | 'iq4_nl' + | 'q5_0' + | 'q5_1' + cache_type_v?: + | 'f16' + | 'f32' + | 'q8_0' + | 'q4_0' + | 'q4_1' + | 'iq4_nl' + | 'q5_0' + | 'q5_1' pooling_type?: 'none' | 'mean' | 'cls' | 'last' | 'rank' } @@ -145,7 +164,10 @@ export class LlamaContext { let finalPrompt = params.prompt if (params.messages) { // messages always win - finalPrompt = await this.getFormattedChat(params.messages, params.chatTemplate) + finalPrompt = await this.getFormattedChat( + params.messages, + params.chatTemplate, + ) } let tokenListener: any = @@ -214,6 +236,28 @@ export class LlamaContext { } } + async applyLoraAdapters( + loraList: Array<{ path: string; scaled?: number }> + ): Promise { + let loraAdapters: Array<{ path: string; scaled?: number }> = [] + if (loraList) + loraAdapters = loraList.map((l) => ({ + path: l.path.replace(/file:\/\//, ''), + scaled: l.scaled, + })) + return RNLlama.applyLoraAdapters(this.id, loraAdapters) + } + + async removeLoraAdapters(): Promise { + return RNLlama.removeLoraAdapters(this.id) + } + + async getLoadedLoraAdapters(): Promise< + Array<{ path: string; scaled?: number }> + > { + return RNLlama.getLoadedLoraAdapters(this.id) + } + async release(): Promise { return RNLlama.releaseContext(this.id) } @@ -254,6 +298,7 @@ export async function initLlama( is_model_asset: isModelAsset, pooling_type: poolingType, lora, + lora_list: loraList, ...rest }: ContextParams, onProgress?: (progress: number) => void, @@ -264,6 +309,13 @@ export async function initLlama( let loraPath = lora if (loraPath?.startsWith('file://')) loraPath = loraPath.slice(7) + let loraAdapters: Array<{ path: string; scaled?: number }> = [] + if (loraList) + loraAdapters = loraList.map((l) => ({ + path: l.path.replace(/file:\/\//, ''), + scaled: l.scaled, + })) + const contextId = contextIdCounter + contextIdRandom() contextIdCounter += 1 @@ -289,6 +341,7 @@ export async function initLlama( use_progress_callback: !!onProgress, pooling_type: poolType, lora: loraPath, + lora_list: loraAdapters, ...rest, }).catch((err: any) => { removeProgressListener?.remove()