Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ttft/tps/decoding speed in android demo app #298

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 138 additions & 113 deletions android/app-java/src/main/java/ai/nexa/app_java/LlamaBridge.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import android.content.Context;
import com.nexa.NexaOmniVlmInference;
import com.nexa.NexaVlmInference;
import android.util.Log;

import java.io.IOException;
Expand Down Expand Up @@ -42,8 +43,11 @@ public class LlamaBridge {

public interface InferenceCallback {
void onStart();

void onToken(String token);

void onComplete(String fullResponse);

void onError(String error);
}

Expand Down Expand Up @@ -74,79 +78,85 @@ public void loadModel() {

// Create with default values for optional parameters
nexaVlmInference = new NexaOmniVlmInference(
modelPath, // modelPath
projectorPath, // projectorPath
"", // imagePath (empty string as default)
new ArrayList<>(Arrays.asList("</s>")), // stopWords (empty list)
DEFAULT_TEMPERATURE, // temperature
DEFAULT_MAX_TOKENS, // maxNewTokens
DEFAULT_TOP_K, // topK
DEFAULT_TOP_P // topP
modelPath, // modelPath
projectorPath, // projectorPath
"", // imagePath (empty string as default)
new ArrayList<>(Arrays.asList("</s>")), // stopWords (empty list)
DEFAULT_TEMPERATURE, // temperature
DEFAULT_MAX_TOKENS, // maxNewTokens
DEFAULT_TOP_K, // topK
DEFAULT_TOP_P // topP
);
nexaVlmInference.loadModel();
isModelLoaded = true;

Log.d(TAG, "Model loaded successfully.");
// messageHandler.addMessage(new MessageModal("Model loaded successfully", "assistant", null));
// messageHandler.addMessage(new MessageModal("Model loaded successfully",
// "assistant", null));
} catch (Exception e) {
Log.e(TAG, "Failed to load model", e);
messageHandler.addMessage(new MessageModal("Error loading model: " + e.getMessage(), "assistant", null));
messageHandler
.addMessage(new MessageModal("Error loading model: " + e.getMessage(), "assistant", null));
}
});
}

// public void processMessage(String message, String imageUri, InferenceCallback callback) {
// if (!isModelLoaded) {
// callback.onError("Model not loaded yet");
// return;
// }
//
// try {
// // Add user message first
// MessageModal userMessage = new MessageModal(message, "user", imageUri);
// messageHandler.addMessage(userMessage);
//
// // Create an initial empty assistant message
// MessageModal assistantMessage = new MessageModal("", "assistant", null);
// messageHandler.addMessage(assistantMessage);
//
// // Convert image URI to absolute path
// String imageAbsolutePath = imagePathHelper.getPathFromUri(imageUri);
//
// Flow<String> flow = nexaVlmInference.createCompletionStream(
// message,
// imageAbsolutePath,
// new ArrayList<>(),
// DEFAULT_TEMPERATURE,
// DEFAULT_MAX_TOKENS,
// DEFAULT_TOP_K,
// DEFAULT_TOP_P
// );
//
// if (flow != null) {
// CoroutineScope scope = CoroutineScopeKt.CoroutineScope(Dispatchers.getMain());
//
// Job job = FlowKt.launchIn(
// FlowKt.onEach(flow, new Function2<String, Continuation<? super Unit>, Object>() {
// @Override
// public Object invoke(String token, Continuation<? super Unit> continuation) {
// messageHandler.updateLastAssistantMessage(token);
// callback.onToken(token);
// return Unit.INSTANCE;
// }
// }),
// scope
// );
// } else {
// messageHandler.finalizeLastAssistantMessage("Error: Failed to create completion stream");
// callback.onError("Failed to create completion stream");
// }
// } catch (Exception e) {
// Log.e(TAG, "Error processing message", e);
// messageHandler.finalizeLastAssistantMessage("Error: " + e.getMessage());
// callback.onError(e.getMessage());
// }
// }
// public void processMessage(String message, String imageUri, InferenceCallback
// callback) {
// if (!isModelLoaded) {
// callback.onError("Model not loaded yet");
// return;
// }
//
// try {
// // Add user message first
// MessageModal userMessage = new MessageModal(message, "user", imageUri);
// messageHandler.addMessage(userMessage);
//
// // Create an initial empty assistant message
// MessageModal assistantMessage = new MessageModal("", "assistant", null);
// messageHandler.addMessage(assistantMessage);
//
// // Convert image URI to absolute path
// String imageAbsolutePath = imagePathHelper.getPathFromUri(imageUri);
//
// Flow<String> flow = nexaVlmInference.createCompletionStream(
// message,
// imageAbsolutePath,
// new ArrayList<>(),
// DEFAULT_TEMPERATURE,
// DEFAULT_MAX_TOKENS,
// DEFAULT_TOP_K,
// DEFAULT_TOP_P
// );
//
// if (flow != null) {
// CoroutineScope scope =
// CoroutineScopeKt.CoroutineScope(Dispatchers.getMain());
//
// Job job = FlowKt.launchIn(
// FlowKt.onEach(flow, new Function2<String, Continuation<? super Unit>,
// Object>() {
// @Override
// public Object invoke(String token, Continuation<? super Unit> continuation) {
// messageHandler.updateLastAssistantMessage(token);
// callback.onToken(token);
// return Unit.INSTANCE;
// }
// }),
// scope
// );
// } else {
// messageHandler.finalizeLastAssistantMessage("Error: Failed to create
// completion stream");
// callback.onError("Failed to create completion stream");
// }
// } catch (Exception e) {
// Log.e(TAG, "Error processing message", e);
// messageHandler.finalizeLastAssistantMessage("Error: " + e.getMessage());
// callback.onError(e.getMessage());
// }
// }

public void processMessage(String message, String imageUri, InferenceCallback callback) {
if (!isModelLoaded) {
Expand All @@ -167,15 +177,18 @@ public void processMessage(String message, String imageUri, InferenceCallback ca
messageHandler.addMessage(assistantMessage);

try {
final long startTime = System.currentTimeMillis();
final long[] firstTokenTime = { 0 };
final int[] tokenCount = { 0 };

Flow<String> flow = nexaVlmInference.createCompletionStream(
message,
imagePath,
new ArrayList<>(Arrays.asList("</s>")),
DEFAULT_TEMPERATURE,
DEFAULT_MAX_TOKENS,
DEFAULT_TOP_K,
DEFAULT_TOP_P
);
DEFAULT_TOP_P);

callback.onStart();
StringBuilder fullResponse = new StringBuilder();
Expand All @@ -188,15 +201,28 @@ public void processMessage(String message, String imageUri, InferenceCallback ca
flow.collect(new FlowCollector<String>() {
@Override
public Object emit(String token, Continuation<? super Unit> continuation) {
if (tokenCount[0] == 0) {
firstTokenTime[0] = System.currentTimeMillis() - startTime;
}
tokenCount[0]++;
fullResponse.append(token);
callback.onToken(token);
return Unit.INSTANCE;
}
}, continuation);

long totalTime = System.currentTimeMillis() - startTime;
double tokensPerSecond = tokenCount[0] / (totalTime / 1000.0);
long decodingTime = totalTime - firstTokenTime[0];
double decodingSpeed = (tokenCount[0] - 1) / (decodingTime / 1000.0);
assistantMessage.setTtft(firstTokenTime[0]);
assistantMessage.setTps(tokensPerSecond);
assistantMessage.setDecodingSpeed(decodingSpeed);
assistantMessage.setTotalTokens(tokenCount[0]);

callback.onComplete(fullResponse.toString());
return Unit.INSTANCE;
}
);
});

collectJob.invokeOnCompletion(new Function1<Throwable, Unit>() {
@Override
Expand All @@ -218,53 +244,52 @@ public void cleanup() {
flowHelper.cancel();
}

// public void processMessageWithParams(
// String message,
// String imageUri,
// float temperature,
// int maxTokens,
// int topK,
// float topP,
// InferenceCallback callback) {
//
// if (!isModelLoaded) {
// callback.onError("Model not loaded yet");
// return;
// }
//
// executor.execute(() -> {
// StringBuilder fullResponse = new StringBuilder();
// try {
// callback.onStart();
//
// Flow<String> completionStream = nexaVlmInference.createCompletionStream(
// message,
// imageUri,
// new ArrayList<>(),
// temperature,
// maxTokens,
// topK,
// topP
// );
//
// completionStream.collect(new FlowCollector<String>() {
// @Override
// public Object emit(String value, Continuation<? super Unit> continuation) {
// fullResponse.append(value);
// callback.onToken(value);
// return Unit.INSTANCE;
// }
// });
//
// callback.onComplete(fullResponse.toString());
//
// } catch (Exception e) {
// Log.e(TAG, "Inference failed", e);
// callback.onError(e.getMessage());
// }
// });
// }

// public void processMessageWithParams(
// String message,
// String imageUri,
// float temperature,
// int maxTokens,
// int topK,
// float topP,
// InferenceCallback callback) {
//
// if (!isModelLoaded) {
// callback.onError("Model not loaded yet");
// return;
// }
//
// executor.execute(() -> {
// StringBuilder fullResponse = new StringBuilder();
// try {
// callback.onStart();
//
// Flow<String> completionStream = nexaVlmInference.createCompletionStream(
// message,
// imageUri,
// new ArrayList<>(),
// temperature,
// maxTokens,
// topK,
// topP
// );
//
// completionStream.collect(new FlowCollector<String>() {
// @Override
// public Object emit(String value, Continuation<? super Unit> continuation) {
// fullResponse.append(value);
// callback.onToken(value);
// return Unit.INSTANCE;
// }
// });
//
// callback.onComplete(fullResponse.toString());
//
// } catch (Exception e) {
// Log.e(TAG, "Inference failed", e);
// callback.onError(e.getMessage());
// }
// });
// }

public void shutdown() {
if (nexaVlmInference != null) {
Expand Down
44 changes: 40 additions & 4 deletions android/app-java/src/main/java/ai/nexa/app_java/MessageModal.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@

public class MessageModal {


private String message;
private String sender;

private String imageUri;
private long ttft;
private double tps;
private double decodingSpeed;
private int totalTokens;

public MessageModal(String message, String sender, String imageUri) {
this.message = message;
this.sender = sender;
this.imageUri = imageUri;
this.ttft = 0;
this.tps = 0.0;
this.decodingSpeed = 0.0;
this.totalTokens = 0;
}


public String getMessage() {
return message;
}
Expand All @@ -38,5 +43,36 @@ public String getImageUri() {
public void setImageUri(String imageUri) {
this.imageUri = imageUri;
}
}

public long getTtft() {
return ttft;
}

public void setTtft(long ttft) {
this.ttft = ttft;
}

public double getTps() {
return tps;
}

public void setTps(double tps) {
this.tps = tps;
}

public double getDecodingSpeed() {
return decodingSpeed;
}

public void setDecodingSpeed(double decodingSpeed) {
this.decodingSpeed = decodingSpeed;
}

public int getTotalTokens() {
return totalTokens;
}

public void setTotalTokens(int totalTokens) {
this.totalTokens = totalTokens;
}
}
Loading