From 53967ca713805a8d3da1f6154894786e01a97614 Mon Sep 17 00:00:00 2001 From: Te993 <3923106166@qq.com> Date: Thu, 5 Dec 2024 11:57:59 +0800 Subject: [PATCH] add NexaAudioInference --- android/gradle.properties | 2 +- android/llama/src/main/cpp/CMakeLists.txt | 72 ++++---- android/llama/src/main/cpp/audio-android.cpp | 128 ++++++++++++-- .../main/java/com/nexa/NexaAudioInference.kt | 166 ++++-------------- 4 files changed, 187 insertions(+), 181 deletions(-) diff --git a/android/gradle.properties b/android/gradle.properties index 99d0efd3..9e3fc850 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -6,7 +6,7 @@ # http://www.gradle.org/docs/current/userguide/build_environment.html # Specifies the JVM arguments used for the daemon process. # The setting is particularly useful for tweaking memory settings. -org.gradle.jvmargs=-Xms1024m -Xmx4g -XX:MaxDirectMemorySize=3g -Dsun.nio.MaxDirectMemorySize=3g -Dfile.encoding=UTF-8 -Dorg.gradle.parallel=true -Dorg.gradle.workers.max=4 +org.gradle.jvmargs=-Xms2048m -Xmx5g -XX:MaxDirectMemorySize=4g -Dsun.nio.MaxDirectMemorySize=4g -Dfile.encoding=UTF-8 -Dorg.gradle.parallel=true -Dorg.gradle.workers.max=4 # When configured, Gradle will run in incubating parallel mode. # This option should only be used with decoupled projects. More details, visit # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects diff --git a/android/llama/src/main/cpp/CMakeLists.txt b/android/llama/src/main/cpp/CMakeLists.txt index d90e1184..0278d9d2 100644 --- a/android/llama/src/main/cpp/CMakeLists.txt +++ b/android/llama/src/main/cpp/CMakeLists.txt @@ -11,52 +11,52 @@ FetchContent_Declare( GIT_TAG v3.11.3 ) -set(SOURCE_BASE_DIR ~/Desktop/nexa-repo/llama.cpp) +# set(SOURCE_BASE_DIR /nexa-ai/llama.cpp) ##### from local ##### +# FetchContent_Declare( +# llama +# SOURCE_DIR ${SOURCE_BASE_DIR} +# ) +# FetchContent_Declare( +# llava +# SOURCE_DIR ${SOURCE_BASE_DIR}/examples/llava +# ) +# FetchContent_Declare( +# omni_vlm +# SOURCE_DIR ${SOURCE_BASE_DIR}/examples/omni-vlm +# ) +# FetchContent_Declare( +# omni_audio +# SOURCE_DIR ${SOURCE_BASE_DIR}/examples/nexa-omni-audio +# ) + +##### from remote ##### + FetchContent_Declare( - llama - SOURCE_DIR ${SOURCE_BASE_DIR} + llama + GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git + GIT_TAG master ) + FetchContent_Declare( - llava - SOURCE_DIR ${SOURCE_BASE_DIR}/examples/llava + llava + GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git + GIT_TAG master + SOURCE_SUBDIR examples/llava ) FetchContent_Declare( - omni_vlm - SOURCE_DIR ${SOURCE_BASE_DIR}/examples/omni-vlm + omni_vlm + GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git + GIT_TAG master + SOURCE_SUBDIR examples/omni-vlm ) FetchContent_Declare( - omni_audio - SOURCE_DIR ${SOURCE_BASE_DIR}/examples/nexa-omni-audio + omni_audio + GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git + GIT_TAG T/dev + SOURCE_SUBDIR examples/nexa-omni-audio ) -##### from remote ##### -# -#FetchContent_Declare( -# llama -# GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git -# GIT_TAG master -#) -# -#FetchContent_Declare( -# llava -# GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git -# GIT_TAG master -# SOURCE_SUBDIR examples/llava -#) -#FetchContent_Declare( -# omni_vlm -# GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git -# GIT_TAG master -# SOURCE_SUBDIR examples/omni-vlm -#) -#FetchContent_Declare( -# omni_audio -# GIT_REPOSITORY https://github.com/NexaAI/llama.cpp.git -# GIT_TAG T/dev -# SOURCE_SUBDIR examples/nexa-omni-audio -#) - FetchContent_MakeAvailable(json llama llava omni_vlm omni_audio) add_library(${CMAKE_PROJECT_NAME} SHARED diff --git a/android/llama/src/main/cpp/audio-android.cpp b/android/llama/src/main/cpp/audio-android.cpp index 260cfb35..307766d7 100644 --- a/android/llama/src/main/cpp/audio-android.cpp +++ b/android/llama/src/main/cpp/audio-android.cpp @@ -19,8 +19,6 @@ extern bool is_valid_utf8(const char* str); extern std::string jstring2str(JNIEnv* env, jstring jstr); - -// 用于捕获输出的函数 void redirect_output_to_logcat(const char* tag, int fd) { char buffer[1024]; while (true) { @@ -31,7 +29,6 @@ void redirect_output_to_logcat(const char* tag, int fd) { } } -// 初始化重定向 void setup_redirect_stdout_stderr() { int stdout_pipe[2]; int stderr_pipe[2]; @@ -55,23 +52,122 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) { return JNI_VERSION_1_6; } - extern "C" JNIEXPORT jlong JNICALL -Java_com_nexa_NexaAudioInference_init_1params(JNIEnv *env, jobject /* this */) { - const char* argv = "-t 1"; - char* nc_argv = const_cast(argv); - omni_context_params* ctx_params = new omni_context_params(); - omni_context_params_parse(argc, argv, ctx_params) +Java_com_nexa_NexaAudioInference_init_1ctx_1params(JNIEnv *env, jobject /* this */, jstring jmodel, jstring jprojector, jstring jaudio) { + const char* model = env->GetStringUTFChars(jmodel, nullptr); + const char* projector = env->GetStringUTFChars(jprojector, nullptr); + const char* audio = env->GetStringUTFChars(jaudio, nullptr); + const char* argv[] = {"-t", "1"}; + int argc = 1; + omni_context_params* ctx_params = new omni_context_params(omni_context_default_params()); + omni_context_params_parse(argc, const_cast(argv), *ctx_params); + ctx_params->model = model; + ctx_params->mmproj = projector; + ctx_params->file = audio; return reinterpret_cast(ctx_params); } extern "C" JNIEXPORT jlong JNICALL -Java_com_nexa_NexaAudioInference_init_1params(JNIEnv *env, jobject /* this */) { - const char* argv = "-t 1"; - char* nc_argv = const_cast(argv); - omni_context_params* ctx_params = new omni_context_params(); - omni_context_params_parse(argc, argv, ctx_params) +Java_com_nexa_NexaAudioInference_init_1ctx(JNIEnv *env, jobject /* this */, jlong jctx_params) { + auto* ctx_params = reinterpret_cast(jctx_params); + std::cout << ctx_params->n_gpu_layers << std::endl; + std::cout << ctx_params->model << std::endl; + std::cout << ctx_params->mmproj << std::endl; + std::cout << ctx_params->file << std::endl; + omni_context *ctx_omni = omni_init_context(*ctx_params); + return reinterpret_cast(ctx_omni); +} - return reinterpret_cast(ctx_params); -} \ No newline at end of file +extern "C" JNIEXPORT void JNICALL +Java_com_nexa_NexaAudioInference_free_1ctx(JNIEnv *env, jobject /* this */, jlong jctx_omni) { + auto* ctx_omni = reinterpret_cast(jctx_omni); + omni_free(ctx_omni); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_com_nexa_NexaAudioInference_init_1npast(JNIEnv *env, jobject /* this */) { + int* n_past = new int(0); + return reinterpret_cast(n_past); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_com_nexa_NexaAudioInference_init_1params(JNIEnv *env, jobject /* this */, jlong jctx_params) { + auto* ctx_params = reinterpret_cast(jctx_params); + + if (ctx_params == nullptr) { + std::cerr << "Error: jctx_params is null!" << std::endl; + return 0; // Return 0 (null) if the context parameter is invalid. + } + + // Step 2: Call the function to extract omni_params from ctx_params. + omni_params extracted_params; + try { + extracted_params = get_omni_params_from_context_params(*ctx_params); + } catch (const std::exception& e) { + std::cerr << "Error in get_omni_params_from_context_params: " << e.what() << std::endl; + return 0; // Return 0 (null) if an exception is thrown during the extraction. + } + + // Step 3: Allocate memory for omni_params and ensure it's successful. + omni_params* all_params = nullptr; + try { + all_params = new omni_params(extracted_params); + } catch (const std::bad_alloc& e) { + std::cerr << "Error: Failed to allocate memory for omni_params: " << e.what() << std::endl; + return 0; // Return 0 (null) if memory allocation fails. + } + + std::cout << " fname_inp size: " << all_params->whisper.fname_inp.size() << std::endl; + + // Step 4: Return the pointer to the newly allocated omni_params object. + std::cout << "all_params address: " << all_params << std::endl; + return reinterpret_cast(all_params); +} + + +//val sampler = init_sampler(allParamsPointer, ctxParamsPointer, prompt, audiuo_path, npastPointer) +extern "C" JNIEXPORT jlong JNICALL +Java_com_nexa_NexaAudioInference_init_1sampler(JNIEnv *env, jobject /* this */, jlong jctx_omni, jlong jctx_params, jstring jprompt, jstring jaudio_path, jlong jnpast) { + auto* n_past = reinterpret_cast(jnpast); + if (n_past == nullptr) { + std::cout << "n_past is null!" << std::endl; + } + const char* prompt = env->GetStringUTFChars(jprompt, nullptr); + auto* all_params = reinterpret_cast(jctx_params); + auto* ctx_omni = reinterpret_cast(jctx_omni); + + ggml_tensor *audio_embed = omni_process_audio(ctx_omni, *all_params); + std::string system_prompt, user_prompt; + system_prompt = "user\nAudio 1: <|audio_bos|>"; + user_prompt = "<|audio_eos|>\n" + std::string(prompt) + "\nmodel\n"; + + eval_string(ctx_omni->ctx_llama, system_prompt.c_str(), all_params->gpt.n_batch, n_past, true); + omni_eval_audio_embed(ctx_omni->ctx_llama, audio_embed, all_params->gpt.n_batch, n_past); + eval_string(ctx_omni->ctx_llama, user_prompt.c_str(), all_params->gpt.n_batch, n_past, false); + + struct common_sampler * ctx_sampling = common_sampler_init(ctx_omni->model, all_params->gpt.sparams); + + return reinterpret_cast(ctx_sampling); +} + + +extern "C" JNIEXPORT jstring JNICALL +Java_com_nexa_NexaAudioInference_sampler(JNIEnv *env, jobject /* this */, jlong jctx_omni, jlong jsampler, jlong jnpast) { + auto* ctx_omni = reinterpret_cast(jctx_omni); + auto* sampler = reinterpret_cast(jsampler); + auto* n_past = reinterpret_cast(jnpast); + + const char * tmp = sample(sampler, ctx_omni->ctx_llama, n_past); + + jstring new_token = nullptr; + new_token = env->NewStringUTF(tmp); + return new_token; +} + + +extern "C" JNIEXPORT jstring JNICALL +Java_com_nexa_NexaAudioInference_free_1sampler(JNIEnv *env, jobject /* this */, jlong jsampler) { + auto* sampler = reinterpret_cast(jsampler); + common_sampler_free(sampler); +} diff --git a/android/llama/src/main/java/com/nexa/NexaAudioInference.kt b/android/llama/src/main/java/com/nexa/NexaAudioInference.kt index 8d27c81a..5d0a68c9 100644 --- a/android/llama/src/main/java/com/nexa/NexaAudioInference.kt +++ b/android/llama/src/main/java/com/nexa/NexaAudioInference.kt @@ -15,47 +15,23 @@ class NexaAudioInference( private var topP: Float = 0.95f ) { init { - System.loadLibrary("llama-android") + System.loadLibrary("audio-android") } - private var paramsPointer: Long = 0 - private var modelPointer: Long = 0 - private var llavaCtxPointer: Long = 0 - private var embedImagePointer: Long = 0 - private var samplerPointer: Long = 0 - private var nPastPointer: Long = 0 + private var ctxParamsPointer: Long = 0 + private var ctxPointer: Long = 0 private var generatedTokenNum: Int = 0 private var generatedText: String = "" private var isModelLoaded: Boolean = false - private var cachedTokenPointer: Long = 0 - private external fun init_params(modelPath: String, mmprojPath: String): Long - - private external fun update_params(params: Long, temperature: Float, topK: Int, topP: Float) - - private external fun load_model(params: Long): Long - - private external fun free_model(model: Long) - - private external fun llava_init_context(params: Long, model: Long): Long - - private external fun llava_ctx_free(ctx: Long) - - private external fun load_image(ctx: Long, params: Long, imagepath: String): Long - - private external fun llava_image_embed_free(llava_image_embed: Long) - - private external fun llava_eval(ctx: Long, params: Long, llava_image_embed: Long, prompt: String): Long - - private external fun llava_sampler_init(ctx: Long, params: Long): Long - - private external fun llava_sample(ctx: Long, sampler: Long, n_past: Long, cached_tokens: Long): String - - private external fun cached_token_init(): Long - - private external fun cached_token_free(cached_tokens: Long) - - private external fun llava_sample_free(sampler: Long) + private external fun init_ctx_params( model: String, project: String, audio_path:String): Long + private external fun init_ctx(ctxParamsPointer: Long): Long + private external fun free_ctx(ctxPointer: Long) + private external fun init_npast():Long + private external fun init_params(ctxParamsPointer: Long): Long + private external fun init_sampler(ctxPointer:Long, omniParamsPointer: Long, prompt: String, audioPath: String, npastPointer: Long): Long + private external fun sampler(ctxOmniPointer :Long , samplerPointer: Long, npastPointer: Long): String + private external fun free_sampler(samplerPointer: Long) @Synchronized fun loadModel() { @@ -63,51 +39,26 @@ class NexaAudioInference( throw RuntimeException("Model is already loaded.") } try { - paramsPointer = init_params(modelPath, mmprojPath = projectorPath) - modelPointer = load_model(paramsPointer) + val audiuo_path = "/storage/emulated/0/Android/data/ai.nexa.app_java/files/jfk.wav" + ctxParamsPointer = init_ctx_params(modelPath, projectorPath, audiuo_path) + ctxPointer = init_ctx(ctxParamsPointer) isModelLoaded = true } catch (e: Exception) { println(e) } catch (e: UnsatisfiedLinkError) { throw RuntimeException("Native method not found: ${e.message}") + }catch (e:Error){ + println(e) } } fun dispose() { - if(paramsPointer!=0L){ - paramsPointer = 0; + if(ctxParamsPointer!=0L){ + ctxParamsPointer = 0; } - if (modelPointer != 0L) { - free_model(modelPointer) - modelPointer = 0; - } - } - - private fun updateParams( - stopWords: List? = null, - temperature: Float? = null, - maxNewTokens: Int? = null, - topK: Int? = null, - topP: Float? = null - ) { - if(stopWords != null){ - this.stopWords = stopWords - } - if (temperature != null) { - this.temperature = temperature - } - if (maxNewTokens != null) { - this.maxNewTokens = maxNewTokens - } - if (topK != null) { - this.topK = topK; - } - if (topP != null) { - this.topP = topP - } - - if(paramsPointer != 0L) { - update_params(paramsPointer, this.temperature, this.topK, this.topP) + if (ctxPointer != 0L) { + free_ctx(ctxPointer) + ctxPointer = 0; } } @@ -137,68 +88,27 @@ class NexaAudioInference( if(!isModelLoaded){ throw RuntimeException("Model is not loaded.") } - - // Reset generation state at the start resetGeneration() - updateParams(stopWords, temperature, maxNewTokens, topK, topP) - -// val thread = Thread { + val imagePathToUse = imagePath ?: this@NexaAudioInference.imagePath - val imagePathToUse = imagePath ?: this@NexaAudioInference.imagePath - llavaCtxPointer = llava_init_context(paramsPointer, modelPointer) - embedImagePointer = load_image(llavaCtxPointer, paramsPointer, imagePathToUse) - nPastPointer = llava_eval(llavaCtxPointer, paramsPointer, embedImagePointer, prompt) - samplerPointer = llava_sampler_init(llavaCtxPointer, paramsPointer) - cachedTokenPointer = cached_token_init() + val audiuo_path = "/storage/emulated/0/Android/data/ai.nexa.app_java/files/jfk.wav" + val npastPointer = init_npast() + val allParamsPointer = init_params(ctxParamsPointer) + val sampler = init_sampler(ctxPointer, allParamsPointer, prompt, audiuo_path, npastPointer) - try { - while (true) { - val sampledText = llava_sample(llavaCtxPointer, samplerPointer, nPastPointer, cachedTokenPointer) - generatedTokenNum += 1 - generatedText += sampledText - if(shouldStop()){ - break - } - emit(sampledText) - print(sampledText) + try { + while (true) { + val sampledText = sampler(ctxPointer, sampler, npastPointer) + generatedTokenNum += 1 + generatedText += sampledText + if(shouldStop()){ + break } - } finally { - // Clean up resources and reset generation state - cleanupResources() - resetGeneration() + emit(sampledText) } - - println("This is a new thread!") - // Your thread logic here -// } -// thread.start() - }.flowOn(Dispatchers.IO) - - private fun cleanupResources() { - if(cachedTokenPointer != 0L){ - cached_token_free(cachedTokenPointer) - cachedTokenPointer = 0 - } - - if (samplerPointer != 0L) { - llava_sample_free(samplerPointer) - samplerPointer = 0 - } - - if (embedImagePointer != 0L) { - try { - llava_image_embed_free(embedImagePointer) - embedImagePointer = 0 - } catch (e: Exception) { - println(e) - } catch (e: Error) { - throw RuntimeException("Native method not found: ${e.message}") - } - } - - if (llavaCtxPointer != 0L) { - llava_ctx_free(llavaCtxPointer) - llavaCtxPointer = 0 + } finally { + resetGeneration() + free_sampler(sampler) } - } + }.flowOn(Dispatchers.IO) }