From 7e5bd6d87a5775851c9c2b3d252df860e6505b42 Mon Sep 17 00:00:00 2001 From: JoyboyBrian Date: Tue, 3 Dec 2024 13:58:06 -0800 Subject: [PATCH] add ttft/tps/decoding speed in android demo app --- .../java/ai/nexa/app_java/LlamaBridge.java | 251 ++++++++++-------- .../java/ai/nexa/app_java/MessageModal.java | 44 ++- .../ai/nexa/app_java/MessageRVAdapter.java | 21 +- .../ai/nexa/app_java/VlmModelManager.java | 9 +- .../app-java/src/main/res/layout/bot_msg.xml | 13 +- android/build.gradle.kts | 4 +- .../gradle/wrapper/gradle-wrapper.properties | 2 +- 7 files changed, 218 insertions(+), 126 deletions(-) diff --git a/android/app-java/src/main/java/ai/nexa/app_java/LlamaBridge.java b/android/app-java/src/main/java/ai/nexa/app_java/LlamaBridge.java index f398c28d..365beed0 100644 --- a/android/app-java/src/main/java/ai/nexa/app_java/LlamaBridge.java +++ b/android/app-java/src/main/java/ai/nexa/app_java/LlamaBridge.java @@ -2,6 +2,7 @@ import android.content.Context; import com.nexa.NexaOmniVlmInference; +import com.nexa.NexaVlmInference; import android.util.Log; import java.io.IOException; @@ -42,8 +43,11 @@ public class LlamaBridge { public interface InferenceCallback { void onStart(); + void onToken(String token); + void onComplete(String fullResponse); + void onError(String error); } @@ -74,79 +78,85 @@ public void loadModel() { // Create with default values for optional parameters nexaVlmInference = new NexaOmniVlmInference( - modelPath, // modelPath - projectorPath, // projectorPath - "", // imagePath (empty string as default) - new ArrayList<>(Arrays.asList("")), // stopWords (empty list) - DEFAULT_TEMPERATURE, // temperature - DEFAULT_MAX_TOKENS, // maxNewTokens - DEFAULT_TOP_K, // topK - DEFAULT_TOP_P // topP + modelPath, // modelPath + projectorPath, // projectorPath + "", // imagePath (empty string as default) + new ArrayList<>(Arrays.asList("")), // stopWords (empty list) + DEFAULT_TEMPERATURE, // temperature + DEFAULT_MAX_TOKENS, // maxNewTokens + DEFAULT_TOP_K, // topK + DEFAULT_TOP_P // topP ); nexaVlmInference.loadModel(); isModelLoaded = true; Log.d(TAG, "Model loaded successfully."); -// messageHandler.addMessage(new MessageModal("Model loaded successfully", "assistant", null)); + // messageHandler.addMessage(new MessageModal("Model loaded successfully", + // "assistant", null)); } catch (Exception e) { Log.e(TAG, "Failed to load model", e); - messageHandler.addMessage(new MessageModal("Error loading model: " + e.getMessage(), "assistant", null)); + messageHandler + .addMessage(new MessageModal("Error loading model: " + e.getMessage(), "assistant", null)); } }); } -// public void processMessage(String message, String imageUri, InferenceCallback callback) { -// if (!isModelLoaded) { -// callback.onError("Model not loaded yet"); -// return; -// } -// -// try { -// // Add user message first -// MessageModal userMessage = new MessageModal(message, "user", imageUri); -// messageHandler.addMessage(userMessage); -// -// // Create an initial empty assistant message -// MessageModal assistantMessage = new MessageModal("", "assistant", null); -// messageHandler.addMessage(assistantMessage); -// -// // Convert image URI to absolute path -// String imageAbsolutePath = imagePathHelper.getPathFromUri(imageUri); -// -// Flow flow = nexaVlmInference.createCompletionStream( -// message, -// imageAbsolutePath, -// new ArrayList<>(), -// DEFAULT_TEMPERATURE, -// DEFAULT_MAX_TOKENS, -// DEFAULT_TOP_K, -// DEFAULT_TOP_P -// ); -// -// if (flow != null) { -// CoroutineScope scope = CoroutineScopeKt.CoroutineScope(Dispatchers.getMain()); -// -// Job job = FlowKt.launchIn( -// FlowKt.onEach(flow, new Function2, Object>() { -// @Override -// public Object invoke(String token, Continuation continuation) { -// messageHandler.updateLastAssistantMessage(token); -// callback.onToken(token); -// return Unit.INSTANCE; -// } -// }), -// scope -// ); -// } else { -// messageHandler.finalizeLastAssistantMessage("Error: Failed to create completion stream"); -// callback.onError("Failed to create completion stream"); -// } -// } catch (Exception e) { -// Log.e(TAG, "Error processing message", e); -// messageHandler.finalizeLastAssistantMessage("Error: " + e.getMessage()); -// callback.onError(e.getMessage()); -// } -// } + // public void processMessage(String message, String imageUri, InferenceCallback + // callback) { + // if (!isModelLoaded) { + // callback.onError("Model not loaded yet"); + // return; + // } + // + // try { + // // Add user message first + // MessageModal userMessage = new MessageModal(message, "user", imageUri); + // messageHandler.addMessage(userMessage); + // + // // Create an initial empty assistant message + // MessageModal assistantMessage = new MessageModal("", "assistant", null); + // messageHandler.addMessage(assistantMessage); + // + // // Convert image URI to absolute path + // String imageAbsolutePath = imagePathHelper.getPathFromUri(imageUri); + // + // Flow flow = nexaVlmInference.createCompletionStream( + // message, + // imageAbsolutePath, + // new ArrayList<>(), + // DEFAULT_TEMPERATURE, + // DEFAULT_MAX_TOKENS, + // DEFAULT_TOP_K, + // DEFAULT_TOP_P + // ); + // + // if (flow != null) { + // CoroutineScope scope = + // CoroutineScopeKt.CoroutineScope(Dispatchers.getMain()); + // + // Job job = FlowKt.launchIn( + // FlowKt.onEach(flow, new Function2, + // Object>() { + // @Override + // public Object invoke(String token, Continuation continuation) { + // messageHandler.updateLastAssistantMessage(token); + // callback.onToken(token); + // return Unit.INSTANCE; + // } + // }), + // scope + // ); + // } else { + // messageHandler.finalizeLastAssistantMessage("Error: Failed to create + // completion stream"); + // callback.onError("Failed to create completion stream"); + // } + // } catch (Exception e) { + // Log.e(TAG, "Error processing message", e); + // messageHandler.finalizeLastAssistantMessage("Error: " + e.getMessage()); + // callback.onError(e.getMessage()); + // } + // } public void processMessage(String message, String imageUri, InferenceCallback callback) { if (!isModelLoaded) { @@ -167,6 +177,10 @@ public void processMessage(String message, String imageUri, InferenceCallback ca messageHandler.addMessage(assistantMessage); try { + final long startTime = System.currentTimeMillis(); + final long[] firstTokenTime = { 0 }; + final int[] tokenCount = { 0 }; + Flow flow = nexaVlmInference.createCompletionStream( message, imagePath, @@ -174,8 +188,7 @@ public void processMessage(String message, String imageUri, InferenceCallback ca DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS, DEFAULT_TOP_K, - DEFAULT_TOP_P - ); + DEFAULT_TOP_P); callback.onStart(); StringBuilder fullResponse = new StringBuilder(); @@ -188,15 +201,28 @@ public void processMessage(String message, String imageUri, InferenceCallback ca flow.collect(new FlowCollector() { @Override public Object emit(String token, Continuation continuation) { + if (tokenCount[0] == 0) { + firstTokenTime[0] = System.currentTimeMillis() - startTime; + } + tokenCount[0]++; fullResponse.append(token); callback.onToken(token); return Unit.INSTANCE; } }, continuation); + + long totalTime = System.currentTimeMillis() - startTime; + double tokensPerSecond = tokenCount[0] / (totalTime / 1000.0); + long decodingTime = totalTime - firstTokenTime[0]; + double decodingSpeed = (tokenCount[0] - 1) / (decodingTime / 1000.0); + assistantMessage.setTtft(firstTokenTime[0]); + assistantMessage.setTps(tokensPerSecond); + assistantMessage.setDecodingSpeed(decodingSpeed); + assistantMessage.setTotalTokens(tokenCount[0]); + callback.onComplete(fullResponse.toString()); return Unit.INSTANCE; - } - ); + }); collectJob.invokeOnCompletion(new Function1() { @Override @@ -218,53 +244,52 @@ public void cleanup() { flowHelper.cancel(); } -// public void processMessageWithParams( -// String message, -// String imageUri, -// float temperature, -// int maxTokens, -// int topK, -// float topP, -// InferenceCallback callback) { -// -// if (!isModelLoaded) { -// callback.onError("Model not loaded yet"); -// return; -// } -// -// executor.execute(() -> { -// StringBuilder fullResponse = new StringBuilder(); -// try { -// callback.onStart(); -// -// Flow completionStream = nexaVlmInference.createCompletionStream( -// message, -// imageUri, -// new ArrayList<>(), -// temperature, -// maxTokens, -// topK, -// topP -// ); -// -// completionStream.collect(new FlowCollector() { -// @Override -// public Object emit(String value, Continuation continuation) { -// fullResponse.append(value); -// callback.onToken(value); -// return Unit.INSTANCE; -// } -// }); -// -// callback.onComplete(fullResponse.toString()); -// -// } catch (Exception e) { -// Log.e(TAG, "Inference failed", e); -// callback.onError(e.getMessage()); -// } -// }); -// } - + // public void processMessageWithParams( + // String message, + // String imageUri, + // float temperature, + // int maxTokens, + // int topK, + // float topP, + // InferenceCallback callback) { + // + // if (!isModelLoaded) { + // callback.onError("Model not loaded yet"); + // return; + // } + // + // executor.execute(() -> { + // StringBuilder fullResponse = new StringBuilder(); + // try { + // callback.onStart(); + // + // Flow completionStream = nexaVlmInference.createCompletionStream( + // message, + // imageUri, + // new ArrayList<>(), + // temperature, + // maxTokens, + // topK, + // topP + // ); + // + // completionStream.collect(new FlowCollector() { + // @Override + // public Object emit(String value, Continuation continuation) { + // fullResponse.append(value); + // callback.onToken(value); + // return Unit.INSTANCE; + // } + // }); + // + // callback.onComplete(fullResponse.toString()); + // + // } catch (Exception e) { + // Log.e(TAG, "Inference failed", e); + // callback.onError(e.getMessage()); + // } + // }); + // } public void shutdown() { if (nexaVlmInference != null) { diff --git a/android/app-java/src/main/java/ai/nexa/app_java/MessageModal.java b/android/app-java/src/main/java/ai/nexa/app_java/MessageModal.java index 1e60921b..dab423c6 100644 --- a/android/app-java/src/main/java/ai/nexa/app_java/MessageModal.java +++ b/android/app-java/src/main/java/ai/nexa/app_java/MessageModal.java @@ -2,19 +2,24 @@ public class MessageModal { - private String message; private String sender; - private String imageUri; + private long ttft; + private double tps; + private double decodingSpeed; + private int totalTokens; public MessageModal(String message, String sender, String imageUri) { this.message = message; this.sender = sender; this.imageUri = imageUri; + this.ttft = 0; + this.tps = 0.0; + this.decodingSpeed = 0.0; + this.totalTokens = 0; } - public String getMessage() { return message; } @@ -38,5 +43,36 @@ public String getImageUri() { public void setImageUri(String imageUri) { this.imageUri = imageUri; } -} + public long getTtft() { + return ttft; + } + + public void setTtft(long ttft) { + this.ttft = ttft; + } + + public double getTps() { + return tps; + } + + public void setTps(double tps) { + this.tps = tps; + } + + public double getDecodingSpeed() { + return decodingSpeed; + } + + public void setDecodingSpeed(double decodingSpeed) { + this.decodingSpeed = decodingSpeed; + } + + public int getTotalTokens() { + return totalTokens; + } + + public void setTotalTokens(int totalTokens) { + this.totalTokens = totalTokens; + } +} diff --git a/android/app-java/src/main/java/ai/nexa/app_java/MessageRVAdapter.java b/android/app-java/src/main/java/ai/nexa/app_java/MessageRVAdapter.java index 90977681..0ce8cec6 100644 --- a/android/app-java/src/main/java/ai/nexa/app_java/MessageRVAdapter.java +++ b/android/app-java/src/main/java/ai/nexa/app_java/MessageRVAdapter.java @@ -58,7 +58,24 @@ public void onBindViewHolder(@NonNull RecyclerView.ViewHolder holder, int positi } break; case "bot": - ((BotViewHolder) holder).botTV.setText(modal.getMessage()); + BotViewHolder botViewHolder = (BotViewHolder) holder; + botViewHolder.botTV.setText(modal.getMessage()); + if (modal.getTtft() > 0) { + double ttftInSeconds = modal.getTtft() / 1000.0; + String metrics = String.format( + "Total Tokens: %d\n" + + "TTFT: %.2fs\n" + + "TPS: %.2f tokens/s\n" + + "Decoding: %.2f tokens/s", + modal.getTotalTokens(), + ttftInSeconds, + modal.getTps(), + modal.getDecodingSpeed()); + botViewHolder.metricsTV.setText(metrics); + botViewHolder.metricsTV.setVisibility(View.VISIBLE); + } else { + botViewHolder.metricsTV.setVisibility(View.GONE); + } break; } } @@ -93,10 +110,12 @@ public UserViewHolder(@NonNull View itemView) { public static class BotViewHolder extends RecyclerView.ViewHolder { TextView botTV; + TextView metricsTV; public BotViewHolder(@NonNull View itemView) { super(itemView); botTV = itemView.findViewById(R.id.idTVBot); + metricsTV = itemView.findViewById(R.id.idTVMetrics); } } } diff --git a/android/app-java/src/main/java/ai/nexa/app_java/VlmModelManager.java b/android/app-java/src/main/java/ai/nexa/app_java/VlmModelManager.java index 74c82ad5..1bb36bab 100644 --- a/android/app-java/src/main/java/ai/nexa/app_java/VlmModelManager.java +++ b/android/app-java/src/main/java/ai/nexa/app_java/VlmModelManager.java @@ -10,11 +10,14 @@ public class VlmModelManager { private static final String TAG = "LlamaBridge"; private static final String MODELS_DIR = "models"; -// private static final String MODEL_TEXT_FILENAME = "nanollava-text-model-q4_0.gguf"; -// private static final String MODEL_MMPROJ_FILENAME = "nanollava-mmproj-f16.gguf"; + // For nanollava +// private static final String MODEL_TEXT_FILENAME = "nanollava-model-q8_0.gguf"; +// private static final String MODEL_MMPROJ_FILENAME = "nanollava-projector-fp16.gguf"; +// + // For Omnivision private static final String MODEL_TEXT_FILENAME = "model-q8_0.gguf"; - private static final String MODEL_MMPROJ_FILENAME = "projector-q8_0.gguf"; + private static final String MODEL_MMPROJ_FILENAME = "projector-fp16.gguf"; // private static final String MODEL_TEXT_FILENAME = "nano-vlm-instruct-llm-F16.gguf"; // private static final String MODEL_MMPROJ_FILENAME = "nano-vlm-instruct-mmproj-F16.gguf"; diff --git a/android/app-java/src/main/res/layout/bot_msg.xml b/android/app-java/src/main/res/layout/bot_msg.xml index 5ee58e1d..1a6c3b05 100644 --- a/android/app-java/src/main/res/layout/bot_msg.xml +++ b/android/app-java/src/main/res/layout/bot_msg.xml @@ -15,7 +15,7 @@ android:layout_height="wrap_content" android:minHeight="35dp" android:background="@color/black" - android:orientation="horizontal" + android:orientation="vertical" android:padding="5dp"> + + - + \ No newline at end of file diff --git a/android/build.gradle.kts b/android/build.gradle.kts index 53401400..a715e2fc 100644 --- a/android/build.gradle.kts +++ b/android/build.gradle.kts @@ -1,8 +1,8 @@ // Top-level build file where you can add configuration options common to all sub-projects/modules. plugins { - id("com.android.application") version "8.2.0" apply false + id("com.android.application") version "8.7.2" apply false id("org.jetbrains.kotlin.android") version "1.9.0" apply false - id("com.android.library") version "8.2.0" apply false + id("com.android.library") version "8.7.2" apply false } buildscript { val kotlin_version by extra("1.9.20") diff --git a/android/gradle/wrapper/gradle-wrapper.properties b/android/gradle/wrapper/gradle-wrapper.properties index a3958c14..7ed9ed7a 100644 --- a/android/gradle/wrapper/gradle-wrapper.properties +++ b/android/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ #Thu Dec 21 14:31:09 AEDT 2023 distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.2-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.9-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists