Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add size option for saveSession #28

Merged
merged 2 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ public WritableMap loadSession(String path) {
return result;
}

public int saveSession(String path) {
return saveSession(this.context, path);
public int saveSession(String path, int size) {
return saveSession(this.context, path, size);
}

public WritableMap completion(ReadableMap params) {
Expand Down Expand Up @@ -286,7 +286,8 @@ protected static native WritableMap loadSession(
);
protected static native int saveSession(
long contextPtr,
String path
String path,
int size
);
protected static native WritableMap doCompletion(
long context_ptr,
Expand Down
4 changes: 2 additions & 2 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ protected void onPostExecute(WritableMap result) {
tasks.put(task, "loadSession-" + contextId);
}

public void saveSession(double id, final String path, Promise promise) {
public void saveSession(double id, final String path, double size, Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Integer>() {
private Exception exception;
Expand All @@ -124,7 +124,7 @@ protected Integer doInBackground(Void... voids) {
if (context == null) {
throw new Exception("Context not found");
}
Integer count = context.saveSession(path);
Integer count = context.saveSession(path, (int) size);
return count;
} catch (Exception e) {
exception = e;
Expand Down
7 changes: 5 additions & 2 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,18 @@ Java_com_rnllama_LlamaContext_saveSession(
JNIEnv *env,
jobject thiz,
jlong context_ptr,
jstring path
jstring path,
jint size
) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];

const char *path_chars = env->GetStringUTFChars(path, nullptr);

std::vector<llama_token> session_tokens = llama->embd;
if (!llama_save_session_file(llama->ctx, path_chars, session_tokens.data(), session_tokens.size())) {
int default_size = session_tokens.size();
int save_size = size > 0 && size <= default_size ? size : default_size;
if (!llama_save_session_file(llama->ctx, path_chars, session_tokens.data(), save_size)) {
env->ReleaseStringUTFChars(path, path_chars);
return -1;
}
Expand Down
4 changes: 2 additions & 2 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ public void loadSession(double id, String path, Promise promise) {
}

@ReactMethod
public void saveSession(double id, String path, Promise promise) {
rnllama.saveSession(id, path, promise);
public void saveSession(double id, String path, double size, Promise promise) {
rnllama.saveSession(id, path, size, promise);
}

@ReactMethod
Expand Down
4 changes: 2 additions & 2 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ public void loadSession(double id, String path, Promise promise) {
}

@ReactMethod
public void saveSession(double id, String path, Promise promise) {
rnllama.saveSession(id, path, promise);
public void saveSession(double id, String path, int size, Promise promise) {
rnllama.saveSession(id, path, size, promise);
}

@ReactMethod
Expand Down
3 changes: 2 additions & 1 deletion ios/RNLlama.mm
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ @implementation RNLlama

RCT_EXPORT_METHOD(saveSession:(double)contextId
withFilePath:(NSString *)filePath
withSize:(double)size
withResolver:(RCTPromiseResolveBlock)resolve
withRejecter:(RCTPromiseRejectBlock)reject)
{
Expand All @@ -98,7 +99,7 @@ @implementation RNLlama
dispatch_async(dispatch_get_main_queue(), ^{ // TODO: Fix for use in llamaDQue
@try {
@autoreleasepool {
int count = [context saveSession:filePath];
int count = [context saveSession:filePath size:(int)size];
resolve(@(count));
}
} @catch (NSException *exception) {
Expand Down
2 changes: 1 addition & 1 deletion ios/RNLlamaContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- (NSString *)detokenize:(NSArray *)tokens;
- (NSArray *)embedding:(NSString *)text;
- (NSDictionary *)loadSession:(NSString *)path;
- (int)saveSession:(NSString *)path;
- (int)saveSession:(NSString *)path size:(int)size;

- (void)invalidate;

Expand Down
6 changes: 4 additions & 2 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,11 @@ - (NSDictionary *)loadSession:(NSString *)path {
};
}

- (int)saveSession:(NSString *)path {
- (int)saveSession:(NSString *)path size:(int)size {
std::vector<llama_token> session_tokens = llama->embd;
if (!llama_save_session_file(llama->ctx, [path UTF8String], session_tokens.data(), session_tokens.size())) {
int default_size = session_tokens.size();
int save_size = size > 0 && size <= default_size ? size : default_size;
if (!llama_save_session_file(llama->ctx, [path UTF8String], session_tokens.data(), save_size)) {
@throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to save session" userInfo:nil];
}
return session_tokens.size();
Expand Down
2 changes: 1 addition & 1 deletion src/NativeRNLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ export interface Spec extends TurboModule {
initContext(params: NativeContextParams): Promise<NativeLlamaContext>;

loadSession(contextId: number, filepath: string): Promise<NativeSessionLoadResult>;
saveSession(contextId: number, filepath: string): Promise<number>;
saveSession(contextId: number, filepath: string, size: number): Promise<number>;
completion(contextId: number, params: NativeCompletionParams): Promise<NativeCompletionResult>;
stopCompletion(contextId: number): Promise<void>;
tokenize(contextId: number, text: string): Promise<NativeTokenizeResult>;
Expand Down
4 changes: 2 additions & 2 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ export class LlamaContext {
/**
* Save current cached prompt & completion state to a file.
*/
async saveSession(filepath: string): Promise<number> {
return RNLlama.saveSession(this.id, filepath)
async saveSession(filepath: string, options?: { tokenSize: number }): Promise<number> {
return RNLlama.saveSession(this.id, filepath, options?.tokenSize || -1)
}

async completion(
Expand Down
Loading