From 5439c7772e718dcd9d9c8bbe210e46c6d4e21d2e Mon Sep 17 00:00:00 2001 From: Te993 <3923106166@qq.com> Date: Wed, 4 Dec 2024 01:10:58 +0800 Subject: [PATCH 1/4] init android audio --- android/llama/src/main/cpp/CMakeLists.txt | 87 +++++++++++++------- android/llama/src/main/cpp/audio-android.cpp | 56 +++++++++++++ 2 files changed, 114 insertions(+), 29 deletions(-) create mode 100644 android/llama/src/main/cpp/audio-android.cpp diff --git a/android/llama/src/main/cpp/CMakeLists.txt b/android/llama/src/main/cpp/CMakeLists.txt index 00798324..30c4fe00 100644 --- a/android/llama/src/main/cpp/CMakeLists.txt +++ b/android/llama/src/main/cpp/CMakeLists.txt @@ -1,54 +1,71 @@ -# Sets the minimum CMake version required for this project. +# 设置 CMake 最小版本和构建类型 cmake_minimum_required(VERSION 3.22.1) -set(CMAKE_BUILD_TYPE Release) -# Declares the project name +set(CMAKE_BUILD_TYPE Release) + project("llama-android") -# Enable FetchContent module include(FetchContent) +set(SOURCE_BASE_DIR /Users/liute/Desktop/nexa-repo/llama.cpp) + FetchContent_Declare( json GIT_REPOSITORY https://github.com/nlohmann/json - GIT_TAG v3.11.3 + GIT_TAG v3.11.3 ) -FetchContent_MakeAvailable(json) -# Declare llama.cpp repository +##### from local ##### FetchContent_Declare( llama - GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git - GIT_TAG master - # SOURCE_SUBDIR llama.cpp_74d73dc + SOURCE_DIR ${SOURCE_BASE_DIR} ) - -# Declare llama.cpp repository FetchContent_Declare( llava - GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git - GIT_TAG master - SOURCE_SUBDIR examples/llava + SOURCE_DIR ${SOURCE_BASE_DIR}/examples/llava ) - FetchContent_Declare( omni_vlm - GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git - GIT_TAG master - SOURCE_SUBDIR examples/omni-vlm + SOURCE_DIR ${SOURCE_BASE_DIR}/examples/omni-vlm +) +FetchContent_Declare( + omni_audio + SOURCE_DIR ${SOURCE_BASE_DIR}/examples/nexa-omni-audio ) -# Make the content available -FetchContent_MakeAvailable(llama llava omni_vlm) +##### from remote ##### +# +#FetchContent_Declare( +# llama +# GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git +# GIT_TAG master +#) +# +#FetchContent_Declare( +# llava +# GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git +# GIT_TAG master +# SOURCE_SUBDIR examples/llava +#) +#FetchContent_Declare( +# omni_vlm +# GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git +# GIT_TAG master +# SOURCE_SUBDIR examples/omni-vlm +#) +#FetchContent_Declare( +# omni_audio +# GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git +# GIT_TAG T/dev +# SOURCE_SUBDIR examples/nexa-omni-audio +#) + +FetchContent_MakeAvailable(json llama llava omni_vlm omni_audio) -# Create the main library add_library(${CMAKE_PROJECT_NAME} SHARED llama-android.cpp common.cpp llava-android.cpp ) - - -# Link the required libraries target_link_libraries(${CMAKE_PROJECT_NAME} nlohmann_json llama @@ -59,14 +76,11 @@ target_link_libraries(${CMAKE_PROJECT_NAME} ) +##### vision ##### add_library(omni-android SHARED - llama-android.cpp common.cpp omni-android.cpp ) - - -# Link the required libraries target_link_libraries(omni-android nlohmann_json llama @@ -76,3 +90,18 @@ target_link_libraries(omni-android omni_vlm ) + +##### audio ##### +add_library(audio-android SHARED + audio-android.cpp + common.cpp +) +target_link_libraries(audio-android + nlohmann_json + llama + common + omni_audio + android + log +) + diff --git a/android/llama/src/main/cpp/audio-android.cpp b/android/llama/src/main/cpp/audio-android.cpp new file mode 100644 index 00000000..9dd6197b --- /dev/null +++ b/android/llama/src/main/cpp/audio-android.cpp @@ -0,0 +1,56 @@ +#include +#include +#include +#include +#include +#include +#include "omni.cpp" +#include +#include +#include +#include +#include + +#define TAG "audio-android.cpp" +#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) +#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) + +extern bool is_valid_utf8(const char* str); + +extern std::string jstring2str(JNIEnv* env, jstring jstr); + + +// 用于捕获输出的函数 +void redirect_output_to_logcat(const char* tag, int fd) { + char buffer[1024]; + while (true) { + ssize_t count = read(fd, buffer, sizeof(buffer) - 1); + if (count <= 0) break; + buffer[count] = '\0'; + __android_log_print(ANDROID_LOG_DEBUG, tag, "%s", buffer); + } +} + +// 初始化重定向 +void setup_redirect_stdout_stderr() { + int stdout_pipe[2]; + int stderr_pipe[2]; + + pipe(stdout_pipe); + pipe(stderr_pipe); + + // 重定向 stdout + dup2(stdout_pipe[1], STDOUT_FILENO); + close(stdout_pipe[1]); + std::thread(redirect_output_to_logcat, "STDOUT", stdout_pipe[0]).detach(); + + // 重定向 stderr + dup2(stderr_pipe[1], STDERR_FILENO); + close(stderr_pipe[1]); + std::thread(redirect_output_to_logcat, "STDERR", stderr_pipe[0]).detach(); +} + +JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) { + setup_redirect_stdout_stderr(); + return JNI_VERSION_1_6; +} From 6ab4c4f9574ede83674cb6c5b5964e826d7028c8 Mon Sep 17 00:00:00 2001 From: Te993 <3923106166@qq.com> Date: Wed, 4 Dec 2024 15:36:17 +0800 Subject: [PATCH 2/4] update --- android/llama/src/main/cpp/CMakeLists.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/android/llama/src/main/cpp/CMakeLists.txt b/android/llama/src/main/cpp/CMakeLists.txt index 30c4fe00..d90e1184 100644 --- a/android/llama/src/main/cpp/CMakeLists.txt +++ b/android/llama/src/main/cpp/CMakeLists.txt @@ -1,4 +1,3 @@ -# 设置 CMake 最小版本和构建类型 cmake_minimum_required(VERSION 3.22.1) set(CMAKE_BUILD_TYPE Release) @@ -6,14 +5,13 @@ project("llama-android") include(FetchContent) -set(SOURCE_BASE_DIR /Users/liute/Desktop/nexa-repo/llama.cpp) - FetchContent_Declare( json GIT_REPOSITORY https://github.com/nlohmann/json GIT_TAG v3.11.3 ) +set(SOURCE_BASE_DIR ~/Desktop/nexa-repo/llama.cpp) ##### from local ##### FetchContent_Declare( llama From 7b6c8ed2e13cc8ef79e5cd36b79b1837a45abcd5 Mon Sep 17 00:00:00 2001 From: Te993 <3923106166@qq.com> Date: Wed, 4 Dec 2024 17:09:41 +0800 Subject: [PATCH 3/4] update --- android/llama/src/main/cpp/audio-android.cpp | 21 ++ .../main/java/com/nexa/NexaAudioInference.kt | 204 ++++++++++++++++++ 2 files changed, 225 insertions(+) create mode 100644 android/llama/src/main/java/com/nexa/NexaAudioInference.kt diff --git a/android/llama/src/main/cpp/audio-android.cpp b/android/llama/src/main/cpp/audio-android.cpp index 9dd6197b..260cfb35 100644 --- a/android/llama/src/main/cpp/audio-android.cpp +++ b/android/llama/src/main/cpp/audio-android.cpp @@ -54,3 +54,24 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) { setup_redirect_stdout_stderr(); return JNI_VERSION_1_6; } + + +extern "C" JNIEXPORT jlong JNICALL +Java_com_nexa_NexaAudioInference_init_1params(JNIEnv *env, jobject /* this */) { + const char* argv = "-t 1"; + char* nc_argv = const_cast(argv); + omni_context_params* ctx_params = new omni_context_params(); + omni_context_params_parse(argc, argv, ctx_params) + + return reinterpret_cast(ctx_params); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_com_nexa_NexaAudioInference_init_1params(JNIEnv *env, jobject /* this */) { + const char* argv = "-t 1"; + char* nc_argv = const_cast(argv); + omni_context_params* ctx_params = new omni_context_params(); + omni_context_params_parse(argc, argv, ctx_params) + + return reinterpret_cast(ctx_params); +} \ No newline at end of file diff --git a/android/llama/src/main/java/com/nexa/NexaAudioInference.kt b/android/llama/src/main/java/com/nexa/NexaAudioInference.kt new file mode 100644 index 00000000..8d27c81a --- /dev/null +++ b/android/llama/src/main/java/com/nexa/NexaAudioInference.kt @@ -0,0 +1,204 @@ +package com.nexa +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOn + +class NexaAudioInference( + private val modelPath: String, + private val projectorPath: String, + private var imagePath: String, + private var stopWords: List = emptyList(), + private var temperature: Float = 0.8f, + private var maxNewTokens: Int = 64, + private var topK: Int = 40, + private var topP: Float = 0.95f +) { + init { + System.loadLibrary("llama-android") + } + + private var paramsPointer: Long = 0 + private var modelPointer: Long = 0 + private var llavaCtxPointer: Long = 0 + private var embedImagePointer: Long = 0 + private var samplerPointer: Long = 0 + private var nPastPointer: Long = 0 + private var generatedTokenNum: Int = 0 + private var generatedText: String = "" + private var isModelLoaded: Boolean = false + private var cachedTokenPointer: Long = 0 + + private external fun init_params(modelPath: String, mmprojPath: String): Long + + private external fun update_params(params: Long, temperature: Float, topK: Int, topP: Float) + + private external fun load_model(params: Long): Long + + private external fun free_model(model: Long) + + private external fun llava_init_context(params: Long, model: Long): Long + + private external fun llava_ctx_free(ctx: Long) + + private external fun load_image(ctx: Long, params: Long, imagepath: String): Long + + private external fun llava_image_embed_free(llava_image_embed: Long) + + private external fun llava_eval(ctx: Long, params: Long, llava_image_embed: Long, prompt: String): Long + + private external fun llava_sampler_init(ctx: Long, params: Long): Long + + private external fun llava_sample(ctx: Long, sampler: Long, n_past: Long, cached_tokens: Long): String + + private external fun cached_token_init(): Long + + private external fun cached_token_free(cached_tokens: Long) + + private external fun llava_sample_free(sampler: Long) + + @Synchronized + fun loadModel() { + if(isModelLoaded){ + throw RuntimeException("Model is already loaded.") + } + try { + paramsPointer = init_params(modelPath, mmprojPath = projectorPath) + modelPointer = load_model(paramsPointer) + isModelLoaded = true + } catch (e: Exception) { + println(e) + } catch (e: UnsatisfiedLinkError) { + throw RuntimeException("Native method not found: ${e.message}") + } + } + + fun dispose() { + if(paramsPointer!=0L){ + paramsPointer = 0; + } + if (modelPointer != 0L) { + free_model(modelPointer) + modelPointer = 0; + } + } + + private fun updateParams( + stopWords: List? = null, + temperature: Float? = null, + maxNewTokens: Int? = null, + topK: Int? = null, + topP: Float? = null + ) { + if(stopWords != null){ + this.stopWords = stopWords + } + if (temperature != null) { + this.temperature = temperature + } + if (maxNewTokens != null) { + this.maxNewTokens = maxNewTokens + } + if (topK != null) { + this.topK = topK; + } + if (topP != null) { + this.topP = topP + } + + if(paramsPointer != 0L) { + update_params(paramsPointer, this.temperature, this.topK, this.topP) + } + } + + private fun shouldStop(): Boolean { + if(this.generatedTokenNum >= this.maxNewTokens){ + return true + } + + return stopWords.any { generatedText.contains(it, ignoreCase = true) } + } + + private fun resetGeneration() { + generatedTokenNum = 0 + generatedText = "" + } + + @Synchronized + fun createCompletionStream( + prompt: String, + imagePath: String? = null, + stopWords: List? = null, + temperature: Float? = null, + maxNewTokens: Int? = null, + topK: Int? = null, + topP: Float? = null + ): Flow = flow { + if(!isModelLoaded){ + throw RuntimeException("Model is not loaded.") + } + + // Reset generation state at the start + resetGeneration() + updateParams(stopWords, temperature, maxNewTokens, topK, topP) + +// val thread = Thread { + + val imagePathToUse = imagePath ?: this@NexaAudioInference.imagePath + llavaCtxPointer = llava_init_context(paramsPointer, modelPointer) + embedImagePointer = load_image(llavaCtxPointer, paramsPointer, imagePathToUse) + nPastPointer = llava_eval(llavaCtxPointer, paramsPointer, embedImagePointer, prompt) + samplerPointer = llava_sampler_init(llavaCtxPointer, paramsPointer) + cachedTokenPointer = cached_token_init() + + try { + while (true) { + val sampledText = llava_sample(llavaCtxPointer, samplerPointer, nPastPointer, cachedTokenPointer) + generatedTokenNum += 1 + generatedText += sampledText + if(shouldStop()){ + break + } + emit(sampledText) + print(sampledText) + } + } finally { + // Clean up resources and reset generation state + cleanupResources() + resetGeneration() + } + + println("This is a new thread!") + // Your thread logic here +// } +// thread.start() + }.flowOn(Dispatchers.IO) + + private fun cleanupResources() { + if(cachedTokenPointer != 0L){ + cached_token_free(cachedTokenPointer) + cachedTokenPointer = 0 + } + + if (samplerPointer != 0L) { + llava_sample_free(samplerPointer) + samplerPointer = 0 + } + + if (embedImagePointer != 0L) { + try { + llava_image_embed_free(embedImagePointer) + embedImagePointer = 0 + } catch (e: Exception) { + println(e) + } catch (e: Error) { + throw RuntimeException("Native method not found: ${e.message}") + } + } + + if (llavaCtxPointer != 0L) { + llava_ctx_free(llavaCtxPointer) + llavaCtxPointer = 0 + } + } +} From 53967ca713805a8d3da1f6154894786e01a97614 Mon Sep 17 00:00:00 2001 From: Te993 <3923106166@qq.com> Date: Thu, 5 Dec 2024 11:57:59 +0800 Subject: [PATCH 4/4] add NexaAudioInference --- android/gradle.properties | 2 +- android/llama/src/main/cpp/CMakeLists.txt | 72 ++++---- android/llama/src/main/cpp/audio-android.cpp | 128 ++++++++++++-- .../main/java/com/nexa/NexaAudioInference.kt | 166 ++++-------------- 4 files changed, 187 insertions(+), 181 deletions(-) diff --git a/android/gradle.properties b/android/gradle.properties index 99d0efd3..9e3fc850 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -6,7 +6,7 @@ # http://www.gradle.org/docs/current/userguide/build_environment.html # Specifies the JVM arguments used for the daemon process. # The setting is particularly useful for tweaking memory settings. -org.gradle.jvmargs=-Xms1024m -Xmx4g -XX:MaxDirectMemorySize=3g -Dsun.nio.MaxDirectMemorySize=3g -Dfile.encoding=UTF-8 -Dorg.gradle.parallel=true -Dorg.gradle.workers.max=4 +org.gradle.jvmargs=-Xms2048m -Xmx5g -XX:MaxDirectMemorySize=4g -Dsun.nio.MaxDirectMemorySize=4g -Dfile.encoding=UTF-8 -Dorg.gradle.parallel=true -Dorg.gradle.workers.max=4 # When configured, Gradle will run in incubating parallel mode. # This option should only be used with decoupled projects. More details, visit # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects diff --git a/android/llama/src/main/cpp/CMakeLists.txt b/android/llama/src/main/cpp/CMakeLists.txt index d90e1184..0278d9d2 100644 --- a/android/llama/src/main/cpp/CMakeLists.txt +++ b/android/llama/src/main/cpp/CMakeLists.txt @@ -11,52 +11,52 @@ FetchContent_Declare( GIT_TAG v3.11.3 ) -set(SOURCE_BASE_DIR ~/Desktop/nexa-repo/llama.cpp) +# set(SOURCE_BASE_DIR /nexa-ai/llama.cpp) ##### from local ##### +# FetchContent_Declare( +# llama +# SOURCE_DIR ${SOURCE_BASE_DIR} +# ) +# FetchContent_Declare( +# llava +# SOURCE_DIR ${SOURCE_BASE_DIR}/examples/llava +# ) +# FetchContent_Declare( +# omni_vlm +# SOURCE_DIR ${SOURCE_BASE_DIR}/examples/omni-vlm +# ) +# FetchContent_Declare( +# omni_audio +# SOURCE_DIR ${SOURCE_BASE_DIR}/examples/nexa-omni-audio +# ) + +##### from remote ##### + FetchContent_Declare( - llama - SOURCE_DIR ${SOURCE_BASE_DIR} + llama + GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git + GIT_TAG master ) + FetchContent_Declare( - llava - SOURCE_DIR ${SOURCE_BASE_DIR}/examples/llava + llava + GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git + GIT_TAG master + SOURCE_SUBDIR examples/llava ) FetchContent_Declare( - omni_vlm - SOURCE_DIR ${SOURCE_BASE_DIR}/examples/omni-vlm + omni_vlm + GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git + GIT_TAG master + SOURCE_SUBDIR examples/omni-vlm ) FetchContent_Declare( - omni_audio - SOURCE_DIR ${SOURCE_BASE_DIR}/examples/nexa-omni-audio + omni_audio + GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git + GIT_TAG T/dev + SOURCE_SUBDIR examples/nexa-omni-audio ) -##### from remote ##### -# -#FetchContent_Declare( -# llama -# GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git -# GIT_TAG master -#) -# -#FetchContent_Declare( -# llava -# GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git -# GIT_TAG master -# SOURCE_SUBDIR examples/llava -#) -#FetchContent_Declare( -# omni_vlm -# GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git -# GIT_TAG master -# SOURCE_SUBDIR examples/omni-vlm -#) -#FetchContent_Declare( -# omni_audio -# GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git -# GIT_TAG T/dev -# SOURCE_SUBDIR examples/nexa-omni-audio -#) - FetchContent_MakeAvailable(json llama llava omni_vlm omni_audio) add_library(${CMAKE_PROJECT_NAME} SHARED diff --git a/android/llama/src/main/cpp/audio-android.cpp b/android/llama/src/main/cpp/audio-android.cpp index 260cfb35..307766d7 100644 --- a/android/llama/src/main/cpp/audio-android.cpp +++ b/android/llama/src/main/cpp/audio-android.cpp @@ -19,8 +19,6 @@ extern bool is_valid_utf8(const char* str); extern std::string jstring2str(JNIEnv* env, jstring jstr); - -// 用于捕获输出的函数 void redirect_output_to_logcat(const char* tag, int fd) { char buffer[1024]; while (true) { @@ -31,7 +29,6 @@ void redirect_output_to_logcat(const char* tag, int fd) { } } -// 初始化重定向 void setup_redirect_stdout_stderr() { int stdout_pipe[2]; int stderr_pipe[2]; @@ -55,23 +52,122 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) { return JNI_VERSION_1_6; } - extern "C" JNIEXPORT jlong JNICALL -Java_com_nexa_NexaAudioInference_init_1params(JNIEnv *env, jobject /* this */) { - const char* argv = "-t 1"; - char* nc_argv = const_cast(argv); - omni_context_params* ctx_params = new omni_context_params(); - omni_context_params_parse(argc, argv, ctx_params) +Java_com_nexa_NexaAudioInference_init_1ctx_1params(JNIEnv *env, jobject /* this */, jstring jmodel, jstring jprojector, jstring jaudio) { + const char* model = env->GetStringUTFChars(jmodel, nullptr); + const char* projector = env->GetStringUTFChars(jprojector, nullptr); + const char* audio = env->GetStringUTFChars(jaudio, nullptr); + const char* argv[] = {"-t", "1"}; + int argc = 1; + omni_context_params* ctx_params = new omni_context_params(omni_context_default_params()); + omni_context_params_parse(argc, const_cast(argv), *ctx_params); + ctx_params->model = model; + ctx_params->mmproj = projector; + ctx_params->file = audio; return reinterpret_cast(ctx_params); } extern "C" JNIEXPORT jlong JNICALL -Java_com_nexa_NexaAudioInference_init_1params(JNIEnv *env, jobject /* this */) { - const char* argv = "-t 1"; - char* nc_argv = const_cast(argv); - omni_context_params* ctx_params = new omni_context_params(); - omni_context_params_parse(argc, argv, ctx_params) +Java_com_nexa_NexaAudioInference_init_1ctx(JNIEnv *env, jobject /* this */, jlong jctx_params) { + auto* ctx_params = reinterpret_cast(jctx_params); + std::cout << ctx_params->n_gpu_layers << std::endl; + std::cout << ctx_params->model << std::endl; + std::cout << ctx_params->mmproj << std::endl; + std::cout << ctx_params->file << std::endl; + omni_context *ctx_omni = omni_init_context(*ctx_params); + return reinterpret_cast(ctx_omni); +} - return reinterpret_cast(ctx_params); -} \ No newline at end of file +extern "C" JNIEXPORT void JNICALL +Java_com_nexa_NexaAudioInference_free_1ctx(JNIEnv *env, jobject /* this */, jlong jctx_omni) { + auto* ctx_omni = reinterpret_cast(jctx_omni); + omni_free(ctx_omni); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_com_nexa_NexaAudioInference_init_1npast(JNIEnv *env, jobject /* this */) { + int* n_past = new int(0); + return reinterpret_cast(n_past); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_com_nexa_NexaAudioInference_init_1params(JNIEnv *env, jobject /* this */, jlong jctx_params) { + auto* ctx_params = reinterpret_cast(jctx_params); + + if (ctx_params == nullptr) { + std::cerr << "Error: jctx_params is null!" << std::endl; + return 0; // Return 0 (null) if the context parameter is invalid. + } + + // Step 2: Call the function to extract omni_params from ctx_params. + omni_params extracted_params; + try { + extracted_params = get_omni_params_from_context_params(*ctx_params); + } catch (const std::exception& e) { + std::cerr << "Error in get_omni_params_from_context_params: " << e.what() << std::endl; + return 0; // Return 0 (null) if an exception is thrown during the extraction. + } + + // Step 3: Allocate memory for omni_params and ensure it's successful. + omni_params* all_params = nullptr; + try { + all_params = new omni_params(extracted_params); + } catch (const std::bad_alloc& e) { + std::cerr << "Error: Failed to allocate memory for omni_params: " << e.what() << std::endl; + return 0; // Return 0 (null) if memory allocation fails. + } + + std::cout << " fname_inp size: " << all_params->whisper.fname_inp.size() << std::endl; + + // Step 4: Return the pointer to the newly allocated omni_params object. + std::cout << "all_params address: " << all_params << std::endl; + return reinterpret_cast(all_params); +} + + +//val sampler = init_sampler(allParamsPointer, ctxParamsPointer, prompt, audiuo_path, npastPointer) +extern "C" JNIEXPORT jlong JNICALL +Java_com_nexa_NexaAudioInference_init_1sampler(JNIEnv *env, jobject /* this */, jlong jctx_omni, jlong jctx_params, jstring jprompt, jstring jaudio_path, jlong jnpast) { + auto* n_past = reinterpret_cast(jnpast); + if (n_past == nullptr) { + std::cout << "n_past is null!" << std::endl; + } + const char* prompt = env->GetStringUTFChars(jprompt, nullptr); + auto* all_params = reinterpret_cast(jctx_params); + auto* ctx_omni = reinterpret_cast(jctx_omni); + + ggml_tensor *audio_embed = omni_process_audio(ctx_omni, *all_params); + std::string system_prompt, user_prompt; + system_prompt = "user\nAudio 1: <|audio_bos|>"; + user_prompt = "<|audio_eos|>\n" + std::string(prompt) + "\nmodel\n"; + + eval_string(ctx_omni->ctx_llama, system_prompt.c_str(), all_params->gpt.n_batch, n_past, true); + omni_eval_audio_embed(ctx_omni->ctx_llama, audio_embed, all_params->gpt.n_batch, n_past); + eval_string(ctx_omni->ctx_llama, user_prompt.c_str(), all_params->gpt.n_batch, n_past, false); + + struct common_sampler * ctx_sampling = common_sampler_init(ctx_omni->model, all_params->gpt.sparams); + + return reinterpret_cast(ctx_sampling); +} + + +extern "C" JNIEXPORT jstring JNICALL +Java_com_nexa_NexaAudioInference_sampler(JNIEnv *env, jobject /* this */, jlong jctx_omni, jlong jsampler, jlong jnpast) { + auto* ctx_omni = reinterpret_cast(jctx_omni); + auto* sampler = reinterpret_cast(jsampler); + auto* n_past = reinterpret_cast(jnpast); + + const char * tmp = sample(sampler, ctx_omni->ctx_llama, n_past); + + jstring new_token = nullptr; + new_token = env->NewStringUTF(tmp); + return new_token; +} + + +extern "C" JNIEXPORT jstring JNICALL +Java_com_nexa_NexaAudioInference_free_1sampler(JNIEnv *env, jobject /* this */, jlong jsampler) { + auto* sampler = reinterpret_cast(jsampler); + common_sampler_free(sampler); +} diff --git a/android/llama/src/main/java/com/nexa/NexaAudioInference.kt b/android/llama/src/main/java/com/nexa/NexaAudioInference.kt index 8d27c81a..5d0a68c9 100644 --- a/android/llama/src/main/java/com/nexa/NexaAudioInference.kt +++ b/android/llama/src/main/java/com/nexa/NexaAudioInference.kt @@ -15,47 +15,23 @@ class NexaAudioInference( private var topP: Float = 0.95f ) { init { - System.loadLibrary("llama-android") + System.loadLibrary("audio-android") } - private var paramsPointer: Long = 0 - private var modelPointer: Long = 0 - private var llavaCtxPointer: Long = 0 - private var embedImagePointer: Long = 0 - private var samplerPointer: Long = 0 - private var nPastPointer: Long = 0 + private var ctxParamsPointer: Long = 0 + private var ctxPointer: Long = 0 private var generatedTokenNum: Int = 0 private var generatedText: String = "" private var isModelLoaded: Boolean = false - private var cachedTokenPointer: Long = 0 - private external fun init_params(modelPath: String, mmprojPath: String): Long - - private external fun update_params(params: Long, temperature: Float, topK: Int, topP: Float) - - private external fun load_model(params: Long): Long - - private external fun free_model(model: Long) - - private external fun llava_init_context(params: Long, model: Long): Long - - private external fun llava_ctx_free(ctx: Long) - - private external fun load_image(ctx: Long, params: Long, imagepath: String): Long - - private external fun llava_image_embed_free(llava_image_embed: Long) - - private external fun llava_eval(ctx: Long, params: Long, llava_image_embed: Long, prompt: String): Long - - private external fun llava_sampler_init(ctx: Long, params: Long): Long - - private external fun llava_sample(ctx: Long, sampler: Long, n_past: Long, cached_tokens: Long): String - - private external fun cached_token_init(): Long - - private external fun cached_token_free(cached_tokens: Long) - - private external fun llava_sample_free(sampler: Long) + private external fun init_ctx_params( model: String, project: String, audio_path:String): Long + private external fun init_ctx(ctxParamsPointer: Long): Long + private external fun free_ctx(ctxPointer: Long) + private external fun init_npast():Long + private external fun init_params(ctxParamsPointer: Long): Long + private external fun init_sampler(ctxPointer:Long, omniParamsPointer: Long, prompt: String, audioPath: String, npastPointer: Long): Long + private external fun sampler(ctxOmniPointer :Long , samplerPointer: Long, npastPointer: Long): String + private external fun free_sampler(samplerPointer: Long) @Synchronized fun loadModel() { @@ -63,51 +39,26 @@ class NexaAudioInference( throw RuntimeException("Model is already loaded.") } try { - paramsPointer = init_params(modelPath, mmprojPath = projectorPath) - modelPointer = load_model(paramsPointer) + val audiuo_path = "/storage/emulated/0/Android/data/ai.nexa.app_java/files/jfk.wav" + ctxParamsPointer = init_ctx_params(modelPath, projectorPath, audiuo_path) + ctxPointer = init_ctx(ctxParamsPointer) isModelLoaded = true } catch (e: Exception) { println(e) } catch (e: UnsatisfiedLinkError) { throw RuntimeException("Native method not found: ${e.message}") + }catch (e:Error){ + println(e) } } fun dispose() { - if(paramsPointer!=0L){ - paramsPointer = 0; + if(ctxParamsPointer!=0L){ + ctxParamsPointer = 0; } - if (modelPointer != 0L) { - free_model(modelPointer) - modelPointer = 0; - } - } - - private fun updateParams( - stopWords: List? = null, - temperature: Float? = null, - maxNewTokens: Int? = null, - topK: Int? = null, - topP: Float? = null - ) { - if(stopWords != null){ - this.stopWords = stopWords - } - if (temperature != null) { - this.temperature = temperature - } - if (maxNewTokens != null) { - this.maxNewTokens = maxNewTokens - } - if (topK != null) { - this.topK = topK; - } - if (topP != null) { - this.topP = topP - } - - if(paramsPointer != 0L) { - update_params(paramsPointer, this.temperature, this.topK, this.topP) + if (ctxPointer != 0L) { + free_ctx(ctxPointer) + ctxPointer = 0; } } @@ -137,68 +88,27 @@ class NexaAudioInference( if(!isModelLoaded){ throw RuntimeException("Model is not loaded.") } - - // Reset generation state at the start resetGeneration() - updateParams(stopWords, temperature, maxNewTokens, topK, topP) - -// val thread = Thread { + val imagePathToUse = imagePath ?: this@NexaAudioInference.imagePath - val imagePathToUse = imagePath ?: this@NexaAudioInference.imagePath - llavaCtxPointer = llava_init_context(paramsPointer, modelPointer) - embedImagePointer = load_image(llavaCtxPointer, paramsPointer, imagePathToUse) - nPastPointer = llava_eval(llavaCtxPointer, paramsPointer, embedImagePointer, prompt) - samplerPointer = llava_sampler_init(llavaCtxPointer, paramsPointer) - cachedTokenPointer = cached_token_init() + val audiuo_path = "/storage/emulated/0/Android/data/ai.nexa.app_java/files/jfk.wav" + val npastPointer = init_npast() + val allParamsPointer = init_params(ctxParamsPointer) + val sampler = init_sampler(ctxPointer, allParamsPointer, prompt, audiuo_path, npastPointer) - try { - while (true) { - val sampledText = llava_sample(llavaCtxPointer, samplerPointer, nPastPointer, cachedTokenPointer) - generatedTokenNum += 1 - generatedText += sampledText - if(shouldStop()){ - break - } - emit(sampledText) - print(sampledText) + try { + while (true) { + val sampledText = sampler(ctxPointer, sampler, npastPointer) + generatedTokenNum += 1 + generatedText += sampledText + if(shouldStop()){ + break } - } finally { - // Clean up resources and reset generation state - cleanupResources() - resetGeneration() + emit(sampledText) } - - println("This is a new thread!") - // Your thread logic here -// } -// thread.start() - }.flowOn(Dispatchers.IO) - - private fun cleanupResources() { - if(cachedTokenPointer != 0L){ - cached_token_free(cachedTokenPointer) - cachedTokenPointer = 0 - } - - if (samplerPointer != 0L) { - llava_sample_free(samplerPointer) - samplerPointer = 0 - } - - if (embedImagePointer != 0L) { - try { - llava_image_embed_free(embedImagePointer) - embedImagePointer = 0 - } catch (e: Exception) { - println(e) - } catch (e: Error) { - throw RuntimeException("Native method not found: ${e.message}") - } - } - - if (llavaCtxPointer != 0L) { - llava_ctx_free(llavaCtxPointer) - llavaCtxPointer = 0 + } finally { + resetGeneration() + free_sampler(sampler) } - } + }.flowOn(Dispatchers.IO) }