diff --git a/.github/workflows/build-aar.yml b/.github/workflows/build-aar.yml index 1fcd3453..2c5d5941 100644 --- a/.github/workflows/build-aar.yml +++ b/.github/workflows/build-aar.yml @@ -34,7 +34,7 @@ jobs: - name: Navigate to android Directory and Build AAR run: | echo "Navigating to the example directory..." - cd android/llama.android + cd android echo "Starting Gradle build process in $(pwd)..." ./gradlew assembleRelease --stacktrace --info shell: bash @@ -42,7 +42,7 @@ jobs: - name: Rename and upload AAR run: | echo "Navigating to the android directory to find AAR output..." - cd android/llama.android + cd android mkdir -p ../artifacts ls -ld ../artifacts || echo "Artifacts directory does not exist." AAR_PATH=$(find ./llama/build/outputs/aar -type f -name "*.aar" | head -n 1) diff --git a/android/llama.android/.gitignore b/android/.gitignore similarity index 98% rename from android/llama.android/.gitignore rename to android/.gitignore index 347e252e..ddca75e5 100644 --- a/android/llama.android/.gitignore +++ b/android/.gitignore @@ -2,6 +2,8 @@ .gradle/ build/ +.idea + # Local configuration file (sdk path, etc) local.properties diff --git a/android/llama.android/README.md b/android/README.md similarity index 90% rename from android/llama.android/README.md rename to android/README.md index aa91234c..fb3dd4d3 100644 --- a/android/llama.android/README.md +++ b/android/README.md @@ -51,4 +51,9 @@ Open the [android test project](./app-java) folder in Android Studio and run the ## Download Models -You can download models from the [Nexa AI ModelHub](https://nexa.ai/models). \ No newline at end of file +You can download models from the [Nexa AI ModelHub](https://nexa.ai/models). + +## How to estimate power usage + +- ```adb shell dumpsys batterystats --reset``` +- ```adb shell dumpsys batterystats > batterystats.txt``` \ No newline at end of file diff --git a/android/llama.android/app-java/.gitignore b/android/app-java/.gitignore similarity index 100% rename from android/llama.android/app-java/.gitignore rename to android/app-java/.gitignore diff --git a/android/llama.android/app-java/build.gradle b/android/app-java/build.gradle similarity index 100% rename from android/llama.android/app-java/build.gradle rename to android/app-java/build.gradle diff --git a/android/llama.android/app-java/proguard-rules.pro b/android/app-java/proguard-rules.pro similarity index 100% rename from android/llama.android/app-java/proguard-rules.pro rename to android/app-java/proguard-rules.pro diff --git a/android/llama.android/app-java/src/androidTest/java/ai/nexa/app_java/ExampleInstrumentedTest.java b/android/app-java/src/androidTest/java/ai/nexa/app_java/ExampleInstrumentedTest.java similarity index 100% rename from android/llama.android/app-java/src/androidTest/java/ai/nexa/app_java/ExampleInstrumentedTest.java rename to android/app-java/src/androidTest/java/ai/nexa/app_java/ExampleInstrumentedTest.java diff --git a/android/llama.android/app-java/src/main/AndroidManifest.xml b/android/app-java/src/main/AndroidManifest.xml similarity index 92% rename from android/llama.android/app-java/src/main/AndroidManifest.xml rename to android/app-java/src/main/AndroidManifest.xml index 8aaea0a2..ec4d330c 100644 --- a/android/llama.android/app-java/src/main/AndroidManifest.xml +++ b/android/app-java/src/main/AndroidManifest.xml @@ -1,6 +1,6 @@ - - + @@ -12,6 +12,9 @@ -#include -#include -#include -#include -#include -#include "llama.h" -#include "common.h" -#include "llava.h" - -// Write C++ code here. -// -// Do not forget to dynamically load the C++ library into your application. -// -// For instance, -// -// In MainActivity.java: -// static { -// System.loadLibrary("llama-android"); -// } -// -// Or, in MainActivity.kt: -// companion object { -// init { -// System.loadLibrary("llama-android") -// } -// } - -#define TAG "llama-android.cpp" -#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) -#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) - -jclass la_int_var; -jmethodID la_int_var_value; -jmethodID la_int_var_inc; - -std::string cached_token_chars; - -extern bool is_valid_utf8(const char* str); - -static void log_callback(ggml_log_level level, const char * fmt, void * data) { - if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data); - else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data); - else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data); - else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data); -} - -extern "C" -JNIEXPORT jlong JNICALL -Java_com_nexa_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) { - llama_model_params model_params = llama_model_default_params(); - - auto path_to_model = env->GetStringUTFChars(filename, 0); - LOGi("Loading model from %s", path_to_model); - - auto model = llama_load_model_from_file(path_to_model, model_params); - env->ReleaseStringUTFChars(filename, path_to_model); - - if (!model) { - LOGe("load_model() failed"); - env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed"); - return 0; - } - - return reinterpret_cast(model); -} - -extern "C" -JNIEXPORT void JNICALL -Java_com_nexa_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) { - llama_free_model(reinterpret_cast(model)); -} - -extern "C" -JNIEXPORT jlong JNICALL -Java_com_nexa_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) { - auto model = reinterpret_cast(jmodel); - - if (!model) { - LOGe("new_context(): model cannot be null"); - env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null"); - return 0; - } - - int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2)); - LOGi("Using %d threads", n_threads); - - llama_context_params ctx_params = llama_context_default_params(); - ctx_params.seed = 1234; - ctx_params.n_ctx = 2048; - ctx_params.n_threads = n_threads; - ctx_params.n_threads_batch = n_threads; - - llama_context * context = llama_new_context_with_model(model, ctx_params); - - if (!context) { - LOGe("llama_new_context_with_model() returned null)"); - env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), - "llama_new_context_with_model() returned null)"); - return 0; - } - - return reinterpret_cast(context); -} - -extern "C" -JNIEXPORT void JNICALL -Java_com_nexa_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) { - llama_free(reinterpret_cast(context)); -} - -extern "C" -JNIEXPORT void JNICALL -Java_com_nexa_LLamaAndroid_backend_1free(JNIEnv *, jobject) { - llama_backend_free(); -} - -extern "C" -JNIEXPORT void JNICALL -Java_com_nexa_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) { - llama_log_set(log_callback, NULL); -} - -extern "C" -JNIEXPORT jstring JNICALL -Java_com_nexa_LLamaAndroid_bench_1model( - JNIEnv *env, - jobject, - jlong context_pointer, - jlong model_pointer, - jlong batch_pointer, - jint pp, - jint tg, - jint pl, - jint nr - ) { - auto pp_avg = 0.0; - auto tg_avg = 0.0; - auto pp_std = 0.0; - auto tg_std = 0.0; - - const auto context = reinterpret_cast(context_pointer); - const auto model = reinterpret_cast(model_pointer); - const auto batch = reinterpret_cast(batch_pointer); - - const int n_ctx = llama_n_ctx(context); - - LOGi("n_ctx = %d", n_ctx); - - int i, j; - int nri; - for (nri = 0; nri < nr; nri++) { - LOGi("Benchmark prompt processing (pp)"); - - llama_batch_clear(*batch); - - const int n_tokens = pp; - for (i = 0; i < n_tokens; i++) { - llama_batch_add(*batch, 0, i, { 0 }, false); - } - - batch->logits[batch->n_tokens - 1] = true; - llama_kv_cache_clear(context); - - const auto t_pp_start = ggml_time_us(); - if (llama_decode(context, *batch) != 0) { - LOGi("llama_decode() failed during prompt processing"); - } - const auto t_pp_end = ggml_time_us(); - - // bench text generation - - LOGi("Benchmark text generation (tg)"); - - llama_kv_cache_clear(context); - const auto t_tg_start = ggml_time_us(); - for (i = 0; i < tg; i++) { - - llama_batch_clear(*batch); - for (j = 0; j < pl; j++) { - llama_batch_add(*batch, 0, i, { j }, true); - } - - LOGi("llama_decode() text generation: %d", i); - if (llama_decode(context, *batch) != 0) { - LOGi("llama_decode() failed during text generation"); - } - } - - const auto t_tg_end = ggml_time_us(); - - llama_kv_cache_clear(context); - - const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; - const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; - - const auto speed_pp = double(pp) / t_pp; - const auto speed_tg = double(pl * tg) / t_tg; - - pp_avg += speed_pp; - tg_avg += speed_tg; - - pp_std += speed_pp * speed_pp; - tg_std += speed_tg * speed_tg; - - LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg); - } - - pp_avg /= double(nr); - tg_avg /= double(nr); - - if (nr > 1) { - pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1)); - tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1)); - } else { - pp_std = 0; - tg_std = 0; - } - - char model_desc[128]; - llama_model_desc(model, model_desc, sizeof(model_desc)); - - const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0; - const auto model_n_params = double(llama_model_n_params(model)) / 1e9; - - const auto backend = "(Android)"; // TODO: What should this be? - - std::stringstream result; - result << std::setprecision(2); - result << "| model | size | params | backend | test | t/s |\n"; - result << "| --- | --- | --- | --- | --- | --- |\n"; - result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n"; - result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n"; - - return env->NewStringUTF(result.str().c_str()); -} - -extern "C" -JNIEXPORT void JNICALL -Java_com_nexa_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { - llama_batch_free(*reinterpret_cast(batch_pointer)); -} - -extern "C" -JNIEXPORT jlong JNICALL -Java_com_nexa_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { - - // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. - - llama_batch *batch = new llama_batch { - 0, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - 0, - 0, - 0, - }; - - if (embd) { - batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd); - } else { - batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); - } - - batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); - batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); - batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); - for (int i = 0; i < n_tokens; ++i) { - batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); - } - batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); - - return reinterpret_cast(batch); -} - -extern "C" -JNIEXPORT void JNICALL -Java_com_nexa_LLamaAndroid_backend_1init(JNIEnv *, jobject) { - llama_backend_init(); -} - -extern "C" -JNIEXPORT jstring JNICALL -Java_com_nexa_LLamaAndroid_system_1info(JNIEnv *env, jobject) { - return env->NewStringUTF(llama_print_system_info()); -} - -extern "C" -JNIEXPORT jint JNICALL -Java_com_nexa_LLamaAndroid_completion_1init( - JNIEnv *env, - jobject, - jlong context_pointer, - jlong batch_pointer, - jstring jtext, - jint n_len - ) { - - cached_token_chars.clear(); - - const auto text = env->GetStringUTFChars(jtext, 0); - const auto context = reinterpret_cast(context_pointer); - const auto batch = reinterpret_cast(batch_pointer); - - const auto tokens_list = llama_tokenize(context, text, 1); - - auto n_ctx = llama_n_ctx(context); - auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); - - LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req); - - if (n_kv_req > n_ctx) { - LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); - } - - for (auto id : tokens_list) { - LOGi("%s", llama_token_to_piece(context, id).c_str()); - } - - llama_batch_clear(*batch); - - // evaluate the initial prompt - for (auto i = 0; i < tokens_list.size(); i++) { - llama_batch_add(*batch, tokens_list[i], i, { 0 }, false); - } - - // llama_decode will output logits only for the last token of the prompt - batch->logits[batch->n_tokens - 1] = true; - - if (llama_decode(context, *batch) != 0) { - LOGe("llama_decode() failed"); - } - - env->ReleaseStringUTFChars(jtext, text); - - return batch->n_tokens; -} - -extern "C" -JNIEXPORT jstring JNICALL -Java_com_nexa_LLamaAndroid_completion_1loop( - JNIEnv * env, - jobject, - jlong context_pointer, - jlong batch_pointer, - jint n_len, - jobject intvar_ncur -) { - const auto context = reinterpret_cast(context_pointer); - const auto batch = reinterpret_cast(batch_pointer); - const auto model = llama_get_model(context); - - if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); - if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); - if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); - - auto n_vocab = llama_n_vocab(model); - auto logits = llama_get_logits_ith(context, batch->n_tokens - 1); - - std::vector candidates; - candidates.reserve(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - - // sample the most likely token - const auto new_token_id = llama_sample_token_greedy(context, &candidates_p); - - const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); - if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { - return nullptr; - } - - auto new_token_chars = llama_token_to_piece(context, new_token_id); - cached_token_chars += new_token_chars; - - jstring new_token = nullptr; - if (is_valid_utf8(cached_token_chars.c_str())) { - new_token = env->NewStringUTF(cached_token_chars.c_str()); - LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id); - cached_token_chars.clear(); - } else { - new_token = env->NewStringUTF(""); - } - - llama_batch_clear(*batch); - llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true); - - env->CallVoidMethod(intvar_ncur, la_int_var_inc); - - if (llama_decode(context, *batch) != 0) { - LOGe("llama_decode() returned null"); - } - - return new_token; -} - -extern "C" -JNIEXPORT void JNICALL -Java_com_nexa_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { - llama_kv_cache_clear(reinterpret_cast(context)); -} diff --git a/android/llama.android/llama/.gitignore b/android/llama/.gitignore similarity index 100% rename from android/llama.android/llama/.gitignore rename to android/llama/.gitignore diff --git a/android/llama.android/llama/build.gradle.kts b/android/llama/build.gradle.kts similarity index 100% rename from android/llama.android/llama/build.gradle.kts rename to android/llama/build.gradle.kts diff --git a/android/llama.android/llama/consumer-rules.pro b/android/llama/consumer-rules.pro similarity index 100% rename from android/llama.android/llama/consumer-rules.pro rename to android/llama/consumer-rules.pro diff --git a/android/llama.android/llama/proguard-rules.pro b/android/llama/proguard-rules.pro similarity index 100% rename from android/llama.android/llama/proguard-rules.pro rename to android/llama/proguard-rules.pro diff --git a/android/llama.android/llama/src/androidTest/java/android/llama/nexa/ExampleInstrumentedTest.kt b/android/llama/src/androidTest/java/android/llama/nexa/ExampleInstrumentedTest.kt similarity index 100% rename from android/llama.android/llama/src/androidTest/java/android/llama/nexa/ExampleInstrumentedTest.kt rename to android/llama/src/androidTest/java/android/llama/nexa/ExampleInstrumentedTest.kt diff --git a/android/llama.android/llama/src/main/AndroidManifest.xml b/android/llama/src/main/AndroidManifest.xml similarity index 100% rename from android/llama.android/llama/src/main/AndroidManifest.xml rename to android/llama/src/main/AndroidManifest.xml diff --git a/android/llama.android/llama/src/main/cpp/CMakeLists.txt b/android/llama/src/main/cpp/CMakeLists.txt similarity index 57% rename from android/llama.android/llama/src/main/cpp/CMakeLists.txt rename to android/llama/src/main/cpp/CMakeLists.txt index 78b4e9a1..00798324 100644 --- a/android/llama.android/llama/src/main/cpp/CMakeLists.txt +++ b/android/llama/src/main/cpp/CMakeLists.txt @@ -17,28 +17,34 @@ FetchContent_MakeAvailable(json) # Declare llama.cpp repository FetchContent_Declare( llama -# GIT_REPOSITORY https://github.com/ggerganov/llama.cpp - GIT_REPOSITORY https://github.com/NexaAI/llama.cpp + GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git GIT_TAG master + # SOURCE_SUBDIR llama.cpp_74d73dc ) -# Declare llava repository (if needed) +# Declare llama.cpp repository FetchContent_Declare( llava -# GIT_REPOSITORY https://github.com/ggerganov/llama.cpp - GIT_REPOSITORY https://github.com/NexaAI/llama.cpp + GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git GIT_TAG master SOURCE_SUBDIR examples/llava ) +FetchContent_Declare( + omni_vlm + GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git + GIT_TAG master + SOURCE_SUBDIR examples/omni-vlm +) + # Make the content available -FetchContent_MakeAvailable(llama llava) +FetchContent_MakeAvailable(llama llava omni_vlm) # Create the main library add_library(${CMAKE_PROJECT_NAME} SHARED llama-android.cpp - llava-android.cpp common.cpp + llava-android.cpp ) @@ -50,4 +56,23 @@ target_link_libraries(${CMAKE_PROJECT_NAME} android log llava -) \ No newline at end of file +) + + +add_library(omni-android SHARED + llama-android.cpp + common.cpp + omni-android.cpp +) + + +# Link the required libraries +target_link_libraries(omni-android + nlohmann_json + llama + common + android + log + omni_vlm +) + diff --git a/android/llama/src/main/cpp/common.cpp b/android/llama/src/main/cpp/common.cpp new file mode 100644 index 00000000..7152cd9a --- /dev/null +++ b/android/llama/src/main/cpp/common.cpp @@ -0,0 +1,51 @@ +#include +#include +#include +#include +#include +#include +std::string jstring2str(JNIEnv* env, jstring jstr) { + if (!jstr) { + return ""; + } + const char* str = env->GetStringUTFChars(jstr, nullptr); + if (!str) { + return ""; + } + std::string ret(str); + env->ReleaseStringUTFChars(jstr, str); + return ret; +} + +bool is_valid_utf8(const char * string) { + if (!string) { + return true; + } + + const unsigned char * bytes = (const unsigned char *)string; + int num; + + while (*bytes != 0x00) { + if ((*bytes & 0x80) == 0x00) { + num = 1; + } else if ((*bytes & 0xE0) == 0xC0) { + num = 2; + } else if ((*bytes & 0xF0) == 0xE0) { + num = 3; + } else if ((*bytes & 0xF8) == 0xF0) { + num = 4; + } else { + return false; + } + + bytes += 1; + for (int i = 1; i < num; ++i) { + if ((*bytes & 0xC0) != 0x80) { + return false; + } + bytes += 1; + } + } + + return true; +} diff --git a/android/llama/src/main/cpp/llama-android.cpp b/android/llama/src/main/cpp/llama-android.cpp new file mode 100644 index 00000000..2a96930d --- /dev/null +++ b/android/llama/src/main/cpp/llama-android.cpp @@ -0,0 +1,409 @@ +//#include +//#include +//#include +//#include +//#include +//#include +//#include "llama.h" +//#include "common.h" +//#include "llava.h" +// +//// Write C++ code here. +//// +//// Do not forget to dynamically load the C++ library into your application. +//// +//// For instance, +//// +//// In MainActivity.java: +//// static { +//// System.loadLibrary("llama-android"); +//// } +//// +//// Or, in MainActivity.kt: +//// companion object { +//// init { +//// System.loadLibrary("llama-android") +//// } +//// } +// +//#define TAG "llama-android.cpp" +//#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) +//#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) +// +//jclass la_int_var; +//jmethodID la_int_var_value; +//jmethodID la_int_var_inc; +// +//std::string cached_token_chars; +// +//extern bool is_valid_utf8(const char* str); +// +//static void log_callback(ggml_log_level level, const char * fmt, void * data) { +// if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data); +// else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data); +// else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data); +// else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data); +//} +// +//extern "C" +//JNIEXPORT jlong JNICALL +//Java_com_nexa_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) { +// llama_model_params model_params = llama_model_default_params(); +// +// auto path_to_model = env->GetStringUTFChars(filename, 0); +// LOGi("Loading model from %s", path_to_model); +// +// auto model = llama_load_model_from_file(path_to_model, model_params); +// env->ReleaseStringUTFChars(filename, path_to_model); +// +// if (!model) { +// LOGe("load_model() failed"); +// env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed"); +// return 0; +// } +// +// return reinterpret_cast(model); +//} +// +//extern "C" +//JNIEXPORT void JNICALL +//Java_com_nexa_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) { +// llama_free_model(reinterpret_cast(model)); +//} +// +//extern "C" +//JNIEXPORT jlong JNICALL +//Java_com_nexa_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) { +// auto model = reinterpret_cast(jmodel); +// +// if (!model) { +// LOGe("new_context(): model cannot be null"); +// env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null"); +// return 0; +// } +// +// int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2)); +// LOGi("Using %d threads", n_threads); +// +// llama_context_params ctx_params = llama_context_default_params(); +//// ctx_params.seed = 1234; +//// ctx_params.n_ctx = 2048; +// ctx_params.n_threads = n_threads; +// ctx_params.n_threads_batch = n_threads; +// +// llama_context * context = llama_new_context_with_model(model, ctx_params); +// +// if (!context) { +// LOGe("llama_new_context_with_model() returned null)"); +// env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), +// "llama_new_context_with_model() returned null)"); +// return 0; +// } +// +// return reinterpret_cast(context); +//} +// +//extern "C" +//JNIEXPORT void JNICALL +//Java_com_nexa_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) { +// llama_free(reinterpret_cast(context)); +//} +// +//extern "C" +//JNIEXPORT void JNICALL +//Java_com_nexa_LLamaAndroid_backend_1free(JNIEnv *, jobject) { +// llama_backend_free(); +//} +// +//extern "C" +//JNIEXPORT void JNICALL +//Java_com_nexa_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) { +// llama_log_set(log_callback, NULL); +//} +// +//extern "C" +//JNIEXPORT jstring JNICALL +//Java_com_nexa_LLamaAndroid_bench_1model( +// JNIEnv *env, +// jobject, +// jlong context_pointer, +// jlong model_pointer, +// jlong batch_pointer, +// jint pp, +// jint tg, +// jint pl, +// jint nr +// ) { +// auto pp_avg = 0.0; +// auto tg_avg = 0.0; +// auto pp_std = 0.0; +// auto tg_std = 0.0; +// +// const auto context = reinterpret_cast(context_pointer); +// const auto model = reinterpret_cast(model_pointer); +// const auto batch = reinterpret_cast(batch_pointer); +// +// const int n_ctx = llama_n_ctx(context); +// +// LOGi("n_ctx = %d", n_ctx); +// +// int i, j; +// int nri; +// for (nri = 0; nri < nr; nri++) { +// LOGi("Benchmark prompt processing (pp)"); +// +// llama_batch_clear(*batch); +// +// const int n_tokens = pp; +// for (i = 0; i < n_tokens; i++) { +// llama_batch_add(*batch, 0, i, { 0 }, false); +// } +// +// batch->logits[batch->n_tokens - 1] = true; +// llama_kv_cache_clear(context); +// +// const auto t_pp_start = ggml_time_us(); +// if (llama_decode(context, *batch) != 0) { +// LOGi("llama_decode() failed during prompt processing"); +// } +// const auto t_pp_end = ggml_time_us(); +// +// // bench text generation +// +// LOGi("Benchmark text generation (tg)"); +// +// llama_kv_cache_clear(context); +// const auto t_tg_start = ggml_time_us(); +// for (i = 0; i < tg; i++) { +// +// llama_batch_clear(*batch); +// for (j = 0; j < pl; j++) { +// llama_batch_add(*batch, 0, i, { j }, true); +// } +// +// LOGi("llama_decode() text generation: %d", i); +// if (llama_decode(context, *batch) != 0) { +// LOGi("llama_decode() failed during text generation"); +// } +// } +// +// const auto t_tg_end = ggml_time_us(); +// +// llama_kv_cache_clear(context); +// +// const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; +// const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; +// +// const auto speed_pp = double(pp) / t_pp; +// const auto speed_tg = double(pl * tg) / t_tg; +// +// pp_avg += speed_pp; +// tg_avg += speed_tg; +// +// pp_std += speed_pp * speed_pp; +// tg_std += speed_tg * speed_tg; +// +// LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg); +// } +// +// pp_avg /= double(nr); +// tg_avg /= double(nr); +// +// if (nr > 1) { +// pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1)); +// tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1)); +// } else { +// pp_std = 0; +// tg_std = 0; +// } +// +// char model_desc[128]; +// llama_model_desc(model, model_desc, sizeof(model_desc)); +// +// const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0; +// const auto model_n_params = double(llama_model_n_params(model)) / 1e9; +// +// const auto backend = "(Android)"; // TODO: What should this be? +// +// std::stringstream result; +// result << std::setprecision(2); +// result << "| model | size | params | backend | test | t/s |\n"; +// result << "| --- | --- | --- | --- | --- | --- |\n"; +// result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n"; +// result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n"; +// +// return env->NewStringUTF(result.str().c_str()); +//} +// +//extern "C" +//JNIEXPORT void JNICALL +//Java_com_nexa_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { +// llama_batch_free(*reinterpret_cast(batch_pointer)); +//} +// +//extern "C" +//JNIEXPORT jlong JNICALL +//Java_com_nexa_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { +// +// // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. +// +// llama_batch *batch = new llama_batch { +// 0, +// nullptr, +// nullptr, +// nullptr, +// nullptr, +// nullptr, +// nullptr, +// 0, +// 0, +// 0, +// }; +// +// if (embd) { +// batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd); +// } else { +// batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); +// } +// +// batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); +// batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); +// batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); +// for (int i = 0; i < n_tokens; ++i) { +// batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); +// } +// batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); +// +// return reinterpret_cast(batch); +//} +// +//extern "C" +//JNIEXPORT void JNICALL +//Java_com_nexa_LLamaAndroid_backend_1init(JNIEnv *, jobject) { +// llama_backend_init(); +//} +// +//extern "C" +//JNIEXPORT jstring JNICALL +//Java_com_nexa_LLamaAndroid_system_1info(JNIEnv *env, jobject) { +// return env->NewStringUTF(llama_print_system_info()); +//} +// +//extern "C" +//JNIEXPORT jint JNICALL +//Java_com_nexa_LLamaAndroid_completion_1init( +// JNIEnv *env, +// jobject, +// jlong context_pointer, +// jlong batch_pointer, +// jstring jtext, +// jint n_len +// ) { +// +// cached_token_chars.clear(); +// +// const auto text = env->GetStringUTFChars(jtext, 0); +// const auto context = reinterpret_cast(context_pointer); +// const auto batch = reinterpret_cast(batch_pointer); +// +// const auto tokens_list = llama_tokenize(context, text, 1); +// +// auto n_ctx = llama_n_ctx(context); +// auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); +// +// LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req); +// +// if (n_kv_req > n_ctx) { +// LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); +// } +// +// for (auto id : tokens_list) { +// LOGi("%s", llama_token_to_piece(context, id).c_str()); +// } +// +// llama_batch_clear(*batch); +// +// // evaluate the initial prompt +// for (auto i = 0; i < tokens_list.size(); i++) { +// llama_batch_add(*batch, tokens_list[i], i, { 0 }, false); +// } +// +// // llama_decode will output logits only for the last token of the prompt +// batch->logits[batch->n_tokens - 1] = true; +// +// if (llama_decode(context, *batch) != 0) { +// LOGe("llama_decode() failed"); +// } +// +// env->ReleaseStringUTFChars(jtext, text); +// +// return batch->n_tokens; +//} +// +//extern "C" +//JNIEXPORT jstring JNICALL +//Java_com_nexa_LLamaAndroid_completion_1loop( +// JNIEnv * env, +// jobject, +// jlong context_pointer, +// jlong batch_pointer, +// jint n_len, +// jobject intvar_ncur +//) { +// const auto context = reinterpret_cast(context_pointer); +// const auto batch = reinterpret_cast(batch_pointer); +// const auto model = llama_get_model(context); +// +// if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); +// if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); +// if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); +// +// auto n_vocab = llama_n_vocab(model); +// auto logits = llama_get_logits_ith(context, batch->n_tokens - 1); +// +// std::vector candidates; +// candidates.reserve(n_vocab); +// +// for (llama_token token_id = 0; token_id < n_vocab; token_id++) { +// candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); +// } +// +// llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; +// +// // sample the most likely token +// const auto new_token_id = llama_sample_token_greedy(context, &candidates_p); +// +// const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); +// if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { +// return nullptr; +// } +// +// auto new_token_chars = llama_token_to_piece(context, new_token_id); +// cached_token_chars += new_token_chars; +// +// jstring new_token = nullptr; +// if (is_valid_utf8(cached_token_chars.c_str())) { +// new_token = env->NewStringUTF(cached_token_chars.c_str()); +// LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id); +// cached_token_chars.clear(); +// } else { +// new_token = env->NewStringUTF(""); +// } +// +// llama_batch_clear(*batch); +// llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true); +// +// env->CallVoidMethod(intvar_ncur, la_int_var_inc); +// +// if (llama_decode(context, *batch) != 0) { +// LOGe("llama_decode() returned null"); +// } +// +// return new_token; +//} +// +//extern "C" +//JNIEXPORT void JNICALL +//Java_com_nexa_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { +// llama_kv_cache_clear(reinterpret_cast(context)); +//} diff --git a/android/llama.android/llama/src/main/cpp/llava-android.cpp b/android/llama/src/main/cpp/llava-android.cpp similarity index 86% rename from android/llama.android/llama/src/main/cpp/llava-android.cpp rename to android/llama/src/main/cpp/llava-android.cpp index f2ce542d..6a1cbff2 100644 --- a/android/llama.android/llama/src/main/cpp/llava-android.cpp +++ b/android/llama/src/main/cpp/llava-android.cpp @@ -5,9 +5,13 @@ #include #include #include "llama.h" -#include "common.h" +//#include "common.h" #include "llava-cli.cpp" #include +#include +#include +#include +#include #define TAG "llava-android.cpp" #define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) @@ -15,22 +19,45 @@ extern bool is_valid_utf8(const char* str); -std::string jstring2str(JNIEnv* env, jstring jstr) { - if (!jstr) { - return ""; - } - const char* str = env->GetStringUTFChars(jstr, nullptr); - if (!str) { - return ""; +extern std::string jstring2str(JNIEnv* env, jstring jstr); + + +// 用于捕获输出的函数 +void redirect_output_to_logcat(const char* tag, int fd) { + char buffer[1024]; + while (true) { + ssize_t count = read(fd, buffer, sizeof(buffer) - 1); + if (count <= 0) break; + buffer[count] = '\0'; + __android_log_print(ANDROID_LOG_DEBUG, tag, "%s", buffer); } - std::string ret(str); - env->ReleaseStringUTFChars(jstr, str); - return ret; } -#include -#include -#include +// 初始化重定向 +void setup_redirect_stdout_stderr() { + int stdout_pipe[2]; + int stderr_pipe[2]; + + pipe(stdout_pipe); + pipe(stderr_pipe); + + // 重定向 stdout + dup2(stdout_pipe[1], STDOUT_FILENO); + close(stdout_pipe[1]); + std::thread(redirect_output_to_logcat, "STDOUT", stdout_pipe[0]).detach(); + + // 重定向 stderr + dup2(stderr_pipe[1], STDERR_FILENO); + close(stderr_pipe[1]); + std::thread(redirect_output_to_logcat, "STDERR", stderr_pipe[0]).detach(); +} + +JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) { + setup_redirect_stdout_stderr(); + return JNI_VERSION_1_6; +} + + // Helper function to throw a Java exception from JNI void throwJavaException(JNIEnv* env, const char* className, const std::string& message) { @@ -65,10 +92,10 @@ Java_com_nexa_NexaVlmInference_init_1params(JNIEnv *env, jobject /* this */, jst const char* model_chars = env->GetStringUTFChars(jmodel, nullptr); const char* mmproj_chars = env->GetStringUTFChars(jmmproj, nullptr); - const char* argv = "omni-wrapper-py"; + const char* argv = "-t 1"; char* nc_argv = const_cast(argv); - gpt_params* params = new gpt_params(); - gpt_params_parse(1, &nc_argv, *params); + common_params* params = new common_params(); + common_params_parse(0, &nc_argv, *params, LLAMA_EXAMPLE_LLAVA, print_usage); params->model = std::string(model_chars); params->mmproj = std::string(mmproj_chars); @@ -94,7 +121,7 @@ Java_com_nexa_NexaVlmInference_init_1params(JNIEnv *env, jobject /* this */, jst extern "C" JNIEXPORT jlong JNICALL Java_com_nexa_NexaVlmInference_load_1model(JNIEnv *env, jobject /* this */, jlong jparams) { try { - const auto params = reinterpret_cast(jparams); + const auto params = reinterpret_cast(jparams); auto* model = llava_init(params); if (model == nullptr) { @@ -122,10 +149,12 @@ Java_com_nexa_NexaVlmInference_update_1params(JNIEnv *env, jobject /* this */, j int32_t top_k = (int32_t) jtopK; float top_p = (float) jtopP; float temp = (float) jtemp; - const auto params = reinterpret_cast(jparams); + const auto params = reinterpret_cast(jparams); params->sparams.top_k = top_k; params->sparams.top_p = top_p; params->sparams.temp = temp; + + } extern "C" JNIEXPORT void JNICALL @@ -139,7 +168,7 @@ Java_com_nexa_NexaVlmInference_free_1model(JNIEnv *env, jobject /* this */, jlon extern "C" JNIEXPORT jlong JNICALL Java_com_nexa_NexaVlmInference_llava_1init_1context(JNIEnv *env, jobject /* this */, jlong jparams, jlong jmodel) { try { - const auto params = reinterpret_cast(jparams); + const auto params = reinterpret_cast(jparams); const auto llava_model = reinterpret_cast(jmodel); auto* ctx_llava = llava_init_context(params, llava_model); @@ -217,7 +246,7 @@ Java_com_nexa_NexaVlmInference_llava_1image_1embed_1free(JNIEnv *env, jobject /* extern "C" JNIEXPORT jlong JNICALL Java_com_nexa_NexaVlmInference_load_1image(JNIEnv *env, jobject /* this */, jlong llava_ctx_pointer, jlong jparams, jstring imagePath) { try { - auto* params = reinterpret_cast(jparams); + auto* params = reinterpret_cast(jparams); auto* ctx_llava = reinterpret_cast(llava_ctx_pointer); std::string image_str = jstring2str(env, imagePath); @@ -245,7 +274,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_nexa_NexaVlmInference_llava_1eval(JNIEnv *env, jobject /* this */, jlong llava_ctx_pointer, jlong jparams, jlong llava_image_embed_pointer, jstring jprompt) { try { - auto* params = reinterpret_cast(jparams); + auto* params = reinterpret_cast(jparams); auto* image_embed = reinterpret_cast(llava_image_embed_pointer); auto* ctx_llava = reinterpret_cast(llava_ctx_pointer); @@ -282,9 +311,9 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_nexa_NexaVlmInference_llava_1sampler_1init(JNIEnv *env, jobject /* this */, jlong llava_ctx_pointer, jlong jparams) { try { - auto* params = reinterpret_cast(jparams); + auto* params = reinterpret_cast(jparams); auto* ctx_llava = reinterpret_cast(llava_ctx_pointer); - struct llama_sampling_context * smpl = llama_sampling_init(params->sparams); + struct common_sampler * smpl = common_sampler_init(ctx_llava->model,params->sparams); if (smpl == nullptr) { throwJavaException(env, "java/lang/RuntimeException", "Failed to initialize llava ctx"); @@ -308,7 +337,7 @@ Java_com_nexa_NexaVlmInference_llava_1sampler_1init(JNIEnv *env, jobject /* this extern "C" JNIEXPORT jstring JNICALL Java_com_nexa_NexaVlmInference_llava_1sample(JNIEnv *env, jobject /* this */, jlong llava_ctx_pointer, jlong sampler, jlong jnpast, jlong jcached_tokens) { - auto* smpl = reinterpret_cast(sampler); + auto* smpl = reinterpret_cast(sampler); auto* ctx_llava = reinterpret_cast(llava_ctx_pointer); auto* cached_tokens = reinterpret_cast(jcached_tokens); auto* n_past = reinterpret_cast(jnpast); @@ -327,8 +356,8 @@ Java_com_nexa_NexaVlmInference_llava_1sample(JNIEnv *env, jobject /* this */, jl extern "C" JNIEXPORT void JNICALL Java_com_nexa_NexaVlmInference_llava_1sample_1free(JNIEnv *env, jobject /* this */, jlong sampler) { - auto* smpl = reinterpret_cast(sampler); - llama_sampling_free(smpl); + auto* smpl = reinterpret_cast(sampler); + common_sampler_free(smpl); } diff --git a/android/llama/src/main/cpp/omni-android.cpp b/android/llama/src/main/cpp/omni-android.cpp new file mode 100644 index 00000000..d534fe45 --- /dev/null +++ b/android/llama/src/main/cpp/omni-android.cpp @@ -0,0 +1,139 @@ +#include +#include +#include +#include +#include +#include +#include "llama.h" +#include "omni-vlm-wrapper.cpp" +//#include "omni-vlm-cli.cpp" +#include +#include +#include +#include +#include + +#define TAG "llava-android.cpp" +#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) +#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) + +extern bool is_valid_utf8(const char* str); + +extern std::string jstring2str(JNIEnv* env, jstring jstr); + + +// 用于捕获输出的函数 +void redirect_output_to_logcat(const char* tag, int fd) { + char buffer[1024]; + while (true) { + ssize_t count = read(fd, buffer, sizeof(buffer) - 1); + if (count <= 0) break; + buffer[count] = '\0'; + __android_log_print(ANDROID_LOG_DEBUG, tag, "%s", buffer); + } +} + +// 初始化重定向 +void setup_redirect_stdout_stderr() { + int stdout_pipe[2]; + int stderr_pipe[2]; + + pipe(stdout_pipe); + pipe(stderr_pipe); + + // 重定向 stdout + dup2(stdout_pipe[1], STDOUT_FILENO); + close(stdout_pipe[1]); + std::thread(redirect_output_to_logcat, "STDOUT", stdout_pipe[0]).detach(); + + // 重定向 stderr + dup2(stderr_pipe[1], STDERR_FILENO); + close(stderr_pipe[1]); + std::thread(redirect_output_to_logcat, "STDERR", stderr_pipe[0]).detach(); +} + +JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) { + setup_redirect_stdout_stderr(); + return JNI_VERSION_1_6; +} + +extern "C" JNIEXPORT void JNICALL +Java_com_nexa_NexaOmniVlmInference_init(JNIEnv *env, jobject /* this */, jstring jmodel, jstring jmmproj, jstring jtype) { + const char* model_chars = env->GetStringUTFChars(jmodel, nullptr); + const char* mmproj_chars = env->GetStringUTFChars(jmmproj, nullptr); + const char* type = env->GetStringUTFChars(jtype, nullptr); + + omnivlm_init(model_chars, mmproj_chars, type); +} + + +extern "C" JNIEXPORT jlong JNICALL +Java_com_nexa_NexaOmniVlmInference_image_1embed(JNIEnv *env, jobject /* this */, jstring jprompt, jstring jimage) { + const char* prompt = env->GetStringUTFChars(jprompt, nullptr); + const char* imag_path = env->GetStringUTFChars(jimage, nullptr); + + ctx_omnivlm = omnivlm_init_context(¶ms, model); + std::string image = imag_path; + params.prompt = prompt; + params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n\n<|vision_start|><|image_pad|><|vision_end|>" + params.prompt + "<|im_end|>"; + auto * image_embed = load_image(ctx_omnivlm, ¶ms, image); + + return reinterpret_cast(image_embed); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_com_nexa_NexaOmniVlmInference_sampler_1init(JNIEnv *env, jobject /* this */, jstring jprompt, jstring jimage, jlong jnpast,jlong jimage_embed) { + auto* n_past = reinterpret_cast(jnpast); + + auto* image_embed = reinterpret_cast(jimage_embed); + + if (image_embed == nullptr) { + std::cout << "image_embed is null!" << std::endl; + } + + size_t image_pos = params.prompt.find("<|image_pad|>"); + std::string system_prompt, user_prompt; + + system_prompt = params.prompt.substr(0, image_pos); + user_prompt = params.prompt.substr(image_pos + std::string("<|image_pad|>").length()); + + params.sparams.top_k = 1; + params.sparams.top_p = 1.0f; + eval_string(ctx_omnivlm->ctx_llama, system_prompt.c_str(), params.n_batch, n_past, true); + omnivlm_eval_image_embed(ctx_omnivlm->ctx_llama, image_embed, params.n_batch, n_past); + eval_string(ctx_omnivlm->ctx_llama, user_prompt.c_str(), params.n_batch, n_past, false); + + struct common_sampler * smpl = common_sampler_init(ctx_omnivlm->model, params.sparams); + + return reinterpret_cast(smpl); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_com_nexa_NexaOmniVlmInference_npast_1init(JNIEnv *env, jobject /* this */) { + int* n_past = new int(0); + return reinterpret_cast(n_past); +} + + +extern "C" JNIEXPORT jstring JNICALL +Java_com_nexa_NexaOmniVlmInference_inference(JNIEnv *env, jobject /* this */, jlong jnpast, jlong jsampler) { + auto* n_past = reinterpret_cast(jnpast); + auto * sampler = reinterpret_cast(jsampler); + const char * tmp = sample(sampler, ctx_omnivlm->ctx_llama, n_past); + + jstring new_token = nullptr; + new_token = env->NewStringUTF(tmp); + return new_token; +} + + +extern "C" JNIEXPORT void JNICALL +Java_com_nexa_NexaOmniVlmInference_sampler_1free(JNIEnv *env, jobject /* this */, jlong jsampler) { + struct common_sampler * sampler = reinterpret_cast(jsampler); + common_sampler_free(sampler); +} + +extern "C" JNIEXPORT void JNICALL +Java_com_nexa_NexaOmniVlmInference_free(JNIEnv *env, jobject /* this */) { + omnivlm_free(); +} diff --git a/android/llama.android/llama/src/main/java/com/nexa/LLamaAndroid.kt b/android/llama/src/main/java/com/nexa/LLamaAndroid.kt similarity index 100% rename from android/llama.android/llama/src/main/java/com/nexa/LLamaAndroid.kt rename to android/llama/src/main/java/com/nexa/LLamaAndroid.kt diff --git a/android/llama/src/main/java/com/nexa/NexaOmniVlmInference.kt b/android/llama/src/main/java/com/nexa/NexaOmniVlmInference.kt new file mode 100644 index 00000000..86de9fd1 --- /dev/null +++ b/android/llama/src/main/java/com/nexa/NexaOmniVlmInference.kt @@ -0,0 +1,135 @@ +package com.nexa +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOn + +class NexaOmniVlmInference( + private val modelPath: String, + private val projectorPath: String, + private var imagePath: String, + private var stopWords: List = emptyList(), + private var temperature: Float = 0.8f, + private var maxNewTokens: Int = 64, + private var topK: Int = 40, + private var topP: Float = 0.95f +) { + init { + System.loadLibrary("omni-android") + } + private var embed_imgage_pointer: Long = 0 + private var sampler_pointer: Long = 0 + private var nPastPointer: Long = 0 + private var generatedTokenNum: Int = 0 + private var generatedText: String = "" + private var isModelLoaded: Boolean = false + private var type:String = "vlm-81-ocr" + + private external fun init(model: String, proj: String, type: String) + private external fun sampler_free(sampler:Long) + private external fun free() + + private external fun image_embed(prompt:String, image:String): Long + private external fun sampler_init( prompt: String, image: String, npast: Long, jimage_embed: Long): Long + private external fun inference( npast: Long, sampler:Long): String + private external fun npast_init():Long + + @Synchronized + fun loadModel() { + if(isModelLoaded){ + throw RuntimeException("Model is already loaded.") + } + try { + init(modelPath, projectorPath, type) + isModelLoaded = true + } catch (e: Exception) { + println(e) + } catch (e: UnsatisfiedLinkError) { + throw RuntimeException("Native method not found: ${e.message}") + } + } + + fun dispose() { + free() + } + + private fun updateParams( + stopWords: List? = null, + temperature: Float? = null, + maxNewTokens: Int? = null, + topK: Int? = null, + topP: Float? = null + ) { + if(stopWords != null){ + this.stopWords = stopWords + } + if (temperature != null) { + this.temperature = temperature + } + if (maxNewTokens != null) { + this.maxNewTokens = maxNewTokens + } + if (topK != null) { + this.topK = topK; + } + if (topP != null) { + this.topP = topP + } + + } + + private fun shouldStop(): Boolean { + if(this.generatedTokenNum >= this.maxNewTokens){ + return true + } + + return stopWords.any { generatedText.contains(it, ignoreCase = true) } + } + + private fun resetGeneration() { + generatedTokenNum = 0 + generatedText = "" + } + + @Synchronized + fun createCompletionStream( + prompt: String, + imagePath: String? = null, + stopWords: List? = null, + temperature: Float? = null, + maxNewTokens: Int? = null, + topK: Int? = null, + topP: Float? = null + ): Flow = flow { + if(!isModelLoaded){ + throw RuntimeException("Model is not loaded.") + } + + // Reset generation state at the start + resetGeneration() + updateParams(stopWords, temperature, maxNewTokens, topK, topP) + val imagePathToUse = imagePath ?: this@NexaOmniVlmInference.imagePath + nPastPointer = npast_init(); + embed_imgage_pointer = image_embed(prompt, imagePathToUse) + sampler_pointer = sampler_init(prompt, imagePathToUse, nPastPointer, embed_imgage_pointer) + + try { + while (true) { + val sampledText = inference(nPastPointer, sampler_pointer) + generatedTokenNum += 1 + generatedText += sampledText + if(shouldStop()){ + break + } + emit(sampledText) + } + } finally { + // Clean up resources and reset generation state + resetGeneration() + sampler_free(sampler_pointer) + } + + println("This is a new thread!") + // Your thread logic here + }.flowOn(Dispatchers.IO) +} diff --git a/android/llama.android/llama/src/main/java/com/nexa/NexaVlmInference.kt b/android/llama/src/main/java/com/nexa/NexaVlmInference.kt similarity index 81% rename from android/llama.android/llama/src/main/java/com/nexa/NexaVlmInference.kt rename to android/llama/src/main/java/com/nexa/NexaVlmInference.kt index 82a617de..9497f6ca 100644 --- a/android/llama.android/llama/src/main/java/com/nexa/NexaVlmInference.kt +++ b/android/llama/src/main/java/com/nexa/NexaVlmInference.kt @@ -142,28 +142,36 @@ class NexaVlmInference( resetGeneration() updateParams(stopWords, temperature, maxNewTokens, topK, topP) - val imagePathToUse = imagePath ?: this@NexaVlmInference.imagePath - llavaCtxPointer = llava_init_context(paramsPointer, modelPointer) - embedImagePointer = load_image(llavaCtxPointer, paramsPointer, imagePathToUse) - nPastPointer = llava_eval(llavaCtxPointer, paramsPointer, embedImagePointer, prompt) - samplerPointer = llava_sampler_init(llavaCtxPointer, paramsPointer) - cachedTokenPointer = cached_token_init() +// val thread = Thread { - try { - while (true) { - val sampledText = llava_sample(llavaCtxPointer, samplerPointer, nPastPointer, cachedTokenPointer) - generatedTokenNum += 1 - generatedText += sampledText - if(shouldStop()){ - break + val imagePathToUse = imagePath ?: this@NexaVlmInference.imagePath + llavaCtxPointer = llava_init_context(paramsPointer, modelPointer) + embedImagePointer = load_image(llavaCtxPointer, paramsPointer, imagePathToUse) + nPastPointer = llava_eval(llavaCtxPointer, paramsPointer, embedImagePointer, prompt) + samplerPointer = llava_sampler_init(llavaCtxPointer, paramsPointer) + cachedTokenPointer = cached_token_init() + + try { + while (true) { + val sampledText = llava_sample(llavaCtxPointer, samplerPointer, nPastPointer, cachedTokenPointer) + generatedTokenNum += 1 + generatedText += sampledText + if(shouldStop()){ + break + } + emit(sampledText) + print(sampledText) } - emit(sampledText) + } finally { + // Clean up resources and reset generation state + cleanupResources() + resetGeneration() } - } finally { - // Clean up resources and reset generation state - cleanupResources() - resetGeneration() - } + + println("This is a new thread!") + // Your thread logic here +// } +// thread.start() }.flowOn(Dispatchers.IO) private fun cleanupResources() { diff --git a/android/llama.android/llama/src/test/java/android/llama/nexa/ExampleUnitTest.kt b/android/llama/src/test/java/android/llama/nexa/ExampleUnitTest.kt similarity index 100% rename from android/llama.android/llama/src/test/java/android/llama/nexa/ExampleUnitTest.kt rename to android/llama/src/test/java/android/llama/nexa/ExampleUnitTest.kt diff --git a/android/llama.android/settings.gradle.kts b/android/settings.gradle.kts similarity index 100% rename from android/llama.android/settings.gradle.kts rename to android/settings.gradle.kts