Skip to content

Commit

Permalink
Merge pull request #298 from NexaAI/brian/perf-upgrade
Browse files Browse the repository at this point in the history
add ttft/tps/decoding speed in android demo app
  • Loading branch information
zhiyuan8 authored Dec 4, 2024
2 parents 2c04ef6 + 7e5bd6d commit 7230e2a
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 126 deletions.
251 changes: 138 additions & 113 deletions android/app-java/src/main/java/ai/nexa/app_java/LlamaBridge.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import android.content.Context;
import com.nexa.NexaOmniVlmInference;
import com.nexa.NexaVlmInference;
import android.util.Log;

import java.io.IOException;
Expand Down Expand Up @@ -42,8 +43,11 @@ public class LlamaBridge {

public interface InferenceCallback {
void onStart();

void onToken(String token);

void onComplete(String fullResponse);

void onError(String error);
}

Expand Down Expand Up @@ -74,79 +78,85 @@ public void loadModel() {

// Create with default values for optional parameters
nexaVlmInference = new NexaOmniVlmInference(
modelPath, // modelPath
projectorPath, // projectorPath
"", // imagePath (empty string as default)
new ArrayList<>(Arrays.asList("</s>")), // stopWords (empty list)
DEFAULT_TEMPERATURE, // temperature
DEFAULT_MAX_TOKENS, // maxNewTokens
DEFAULT_TOP_K, // topK
DEFAULT_TOP_P // topP
modelPath, // modelPath
projectorPath, // projectorPath
"", // imagePath (empty string as default)
new ArrayList<>(Arrays.asList("</s>")), // stopWords (empty list)
DEFAULT_TEMPERATURE, // temperature
DEFAULT_MAX_TOKENS, // maxNewTokens
DEFAULT_TOP_K, // topK
DEFAULT_TOP_P // topP
);
nexaVlmInference.loadModel();
isModelLoaded = true;

Log.d(TAG, "Model loaded successfully.");
// messageHandler.addMessage(new MessageModal("Model loaded successfully", "assistant", null));
// messageHandler.addMessage(new MessageModal("Model loaded successfully",
// "assistant", null));
} catch (Exception e) {
Log.e(TAG, "Failed to load model", e);
messageHandler.addMessage(new MessageModal("Error loading model: " + e.getMessage(), "assistant", null));
messageHandler
.addMessage(new MessageModal("Error loading model: " + e.getMessage(), "assistant", null));
}
});
}

// public void processMessage(String message, String imageUri, InferenceCallback callback) {
// if (!isModelLoaded) {
// callback.onError("Model not loaded yet");
// return;
// }
//
// try {
// // Add user message first
// MessageModal userMessage = new MessageModal(message, "user", imageUri);
// messageHandler.addMessage(userMessage);
//
// // Create an initial empty assistant message
// MessageModal assistantMessage = new MessageModal("", "assistant", null);
// messageHandler.addMessage(assistantMessage);
//
// // Convert image URI to absolute path
// String imageAbsolutePath = imagePathHelper.getPathFromUri(imageUri);
//
// Flow<String> flow = nexaVlmInference.createCompletionStream(
// message,
// imageAbsolutePath,
// new ArrayList<>(),
// DEFAULT_TEMPERATURE,
// DEFAULT_MAX_TOKENS,
// DEFAULT_TOP_K,
// DEFAULT_TOP_P
// );
//
// if (flow != null) {
// CoroutineScope scope = CoroutineScopeKt.CoroutineScope(Dispatchers.getMain());
//
// Job job = FlowKt.launchIn(
// FlowKt.onEach(flow, new Function2<String, Continuation<? super Unit>, Object>() {
// @Override
// public Object invoke(String token, Continuation<? super Unit> continuation) {
// messageHandler.updateLastAssistantMessage(token);
// callback.onToken(token);
// return Unit.INSTANCE;
// }
// }),
// scope
// );
// } else {
// messageHandler.finalizeLastAssistantMessage("Error: Failed to create completion stream");
// callback.onError("Failed to create completion stream");
// }
// } catch (Exception e) {
// Log.e(TAG, "Error processing message", e);
// messageHandler.finalizeLastAssistantMessage("Error: " + e.getMessage());
// callback.onError(e.getMessage());
// }
// }
// public void processMessage(String message, String imageUri, InferenceCallback
// callback) {
// if (!isModelLoaded) {
// callback.onError("Model not loaded yet");
// return;
// }
//
// try {
// // Add user message first
// MessageModal userMessage = new MessageModal(message, "user", imageUri);
// messageHandler.addMessage(userMessage);
//
// // Create an initial empty assistant message
// MessageModal assistantMessage = new MessageModal("", "assistant", null);
// messageHandler.addMessage(assistantMessage);
//
// // Convert image URI to absolute path
// String imageAbsolutePath = imagePathHelper.getPathFromUri(imageUri);
//
// Flow<String> flow = nexaVlmInference.createCompletionStream(
// message,
// imageAbsolutePath,
// new ArrayList<>(),
// DEFAULT_TEMPERATURE,
// DEFAULT_MAX_TOKENS,
// DEFAULT_TOP_K,
// DEFAULT_TOP_P
// );
//
// if (flow != null) {
// CoroutineScope scope =
// CoroutineScopeKt.CoroutineScope(Dispatchers.getMain());
//
// Job job = FlowKt.launchIn(
// FlowKt.onEach(flow, new Function2<String, Continuation<? super Unit>,
// Object>() {
// @Override
// public Object invoke(String token, Continuation<? super Unit> continuation) {
// messageHandler.updateLastAssistantMessage(token);
// callback.onToken(token);
// return Unit.INSTANCE;
// }
// }),
// scope
// );
// } else {
// messageHandler.finalizeLastAssistantMessage("Error: Failed to create
// completion stream");
// callback.onError("Failed to create completion stream");
// }
// } catch (Exception e) {
// Log.e(TAG, "Error processing message", e);
// messageHandler.finalizeLastAssistantMessage("Error: " + e.getMessage());
// callback.onError(e.getMessage());
// }
// }

public void processMessage(String message, String imageUri, InferenceCallback callback) {
if (!isModelLoaded) {
Expand All @@ -167,15 +177,18 @@ public void processMessage(String message, String imageUri, InferenceCallback ca
messageHandler.addMessage(assistantMessage);

try {
final long startTime = System.currentTimeMillis();
final long[] firstTokenTime = { 0 };
final int[] tokenCount = { 0 };

Flow<String> flow = nexaVlmInference.createCompletionStream(
message,
imagePath,
new ArrayList<>(Arrays.asList("</s>")),
DEFAULT_TEMPERATURE,
DEFAULT_MAX_TOKENS,
DEFAULT_TOP_K,
DEFAULT_TOP_P
);
DEFAULT_TOP_P);

callback.onStart();
StringBuilder fullResponse = new StringBuilder();
Expand All @@ -188,15 +201,28 @@ public void processMessage(String message, String imageUri, InferenceCallback ca
flow.collect(new FlowCollector<String>() {
@Override
public Object emit(String token, Continuation<? super Unit> continuation) {
if (tokenCount[0] == 0) {
firstTokenTime[0] = System.currentTimeMillis() - startTime;
}
tokenCount[0]++;
fullResponse.append(token);
callback.onToken(token);
return Unit.INSTANCE;
}
}, continuation);

long totalTime = System.currentTimeMillis() - startTime;
double tokensPerSecond = tokenCount[0] / (totalTime / 1000.0);
long decodingTime = totalTime - firstTokenTime[0];
double decodingSpeed = (tokenCount[0] - 1) / (decodingTime / 1000.0);
assistantMessage.setTtft(firstTokenTime[0]);
assistantMessage.setTps(tokensPerSecond);
assistantMessage.setDecodingSpeed(decodingSpeed);
assistantMessage.setTotalTokens(tokenCount[0]);

callback.onComplete(fullResponse.toString());
return Unit.INSTANCE;
}
);
});

collectJob.invokeOnCompletion(new Function1<Throwable, Unit>() {
@Override
Expand All @@ -218,53 +244,52 @@ public void cleanup() {
flowHelper.cancel();
}

// public void processMessageWithParams(
// String message,
// String imageUri,
// float temperature,
// int maxTokens,
// int topK,
// float topP,
// InferenceCallback callback) {
//
// if (!isModelLoaded) {
// callback.onError("Model not loaded yet");
// return;
// }
//
// executor.execute(() -> {
// StringBuilder fullResponse = new StringBuilder();
// try {
// callback.onStart();
//
// Flow<String> completionStream = nexaVlmInference.createCompletionStream(
// message,
// imageUri,
// new ArrayList<>(),
// temperature,
// maxTokens,
// topK,
// topP
// );
//
// completionStream.collect(new FlowCollector<String>() {
// @Override
// public Object emit(String value, Continuation<? super Unit> continuation) {
// fullResponse.append(value);
// callback.onToken(value);
// return Unit.INSTANCE;
// }
// });
//
// callback.onComplete(fullResponse.toString());
//
// } catch (Exception e) {
// Log.e(TAG, "Inference failed", e);
// callback.onError(e.getMessage());
// }
// });
// }

// public void processMessageWithParams(
// String message,
// String imageUri,
// float temperature,
// int maxTokens,
// int topK,
// float topP,
// InferenceCallback callback) {
//
// if (!isModelLoaded) {
// callback.onError("Model not loaded yet");
// return;
// }
//
// executor.execute(() -> {
// StringBuilder fullResponse = new StringBuilder();
// try {
// callback.onStart();
//
// Flow<String> completionStream = nexaVlmInference.createCompletionStream(
// message,
// imageUri,
// new ArrayList<>(),
// temperature,
// maxTokens,
// topK,
// topP
// );
//
// completionStream.collect(new FlowCollector<String>() {
// @Override
// public Object emit(String value, Continuation<? super Unit> continuation) {
// fullResponse.append(value);
// callback.onToken(value);
// return Unit.INSTANCE;
// }
// });
//
// callback.onComplete(fullResponse.toString());
//
// } catch (Exception e) {
// Log.e(TAG, "Inference failed", e);
// callback.onError(e.getMessage());
// }
// });
// }

public void shutdown() {
if (nexaVlmInference != null) {
Expand Down
44 changes: 40 additions & 4 deletions android/app-java/src/main/java/ai/nexa/app_java/MessageModal.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@

public class MessageModal {


private String message;
private String sender;

private String imageUri;
private long ttft;
private double tps;
private double decodingSpeed;
private int totalTokens;

public MessageModal(String message, String sender, String imageUri) {
this.message = message;
this.sender = sender;
this.imageUri = imageUri;
this.ttft = 0;
this.tps = 0.0;
this.decodingSpeed = 0.0;
this.totalTokens = 0;
}


public String getMessage() {
return message;
}
Expand All @@ -38,5 +43,36 @@ public String getImageUri() {
public void setImageUri(String imageUri) {
this.imageUri = imageUri;
}
}

public long getTtft() {
return ttft;
}

public void setTtft(long ttft) {
this.ttft = ttft;
}

public double getTps() {
return tps;
}

public void setTps(double tps) {
this.tps = tps;
}

public double getDecodingSpeed() {
return decodingSpeed;
}

public void setDecodingSpeed(double decodingSpeed) {
this.decodingSpeed = decodingSpeed;
}

public int getTotalTokens() {
return totalTokens;
}

public void setTotalTokens(int totalTokens) {
this.totalTokens = totalTokens;
}
}
Loading

0 comments on commit 7230e2a

Please sign in to comment.