diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index ba9b6938..414d610e 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -103,8 +103,8 @@ public WritableMap loadSession(String path) { return result; } - public int saveSession(String path) { - return saveSession(this.context, path); + public int saveSession(String path, int size) { + return saveSession(this.context, path, size); } public WritableMap completion(ReadableMap params) { @@ -286,7 +286,8 @@ protected static native WritableMap loadSession( ); protected static native int saveSession( long contextPtr, - String path + String path, + int size ); protected static native WritableMap doCompletion( long context_ptr, diff --git a/android/src/main/java/com/rnllama/RNLlama.java b/android/src/main/java/com/rnllama/RNLlama.java index 74723db4..eb423dd0 100644 --- a/android/src/main/java/com/rnllama/RNLlama.java +++ b/android/src/main/java/com/rnllama/RNLlama.java @@ -112,7 +112,7 @@ protected void onPostExecute(WritableMap result) { tasks.put(task, "loadSession-" + contextId); } - public void saveSession(double id, final String path, Promise promise) { + public void saveSession(double id, final String path, double size, Promise promise) { final int contextId = (int) id; AsyncTask task = new AsyncTask() { private Exception exception; @@ -124,7 +124,7 @@ protected Integer doInBackground(Void... voids) { if (context == null) { throw new Exception("Context not found"); } - Integer count = context.saveSession(path); + Integer count = context.saveSession(path, (int) size); return count; } catch (Exception e) { exception = e; diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 7d24d9f1..5d38120a 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -222,7 +222,8 @@ Java_com_rnllama_LlamaContext_saveSession( JNIEnv *env, jobject thiz, jlong context_ptr, - jstring path + jstring path, + jint size ) { UNUSED(thiz); auto llama = context_map[(long) context_ptr]; @@ -230,7 +231,9 @@ Java_com_rnllama_LlamaContext_saveSession( const char *path_chars = env->GetStringUTFChars(path, nullptr); std::vector session_tokens = llama->embd; - if (!llama_save_session_file(llama->ctx, path_chars, session_tokens.data(), session_tokens.size())) { + int default_size = session_tokens.size(); + int save_size = size > 0 && size <= default_size ? size : default_size; + if (!llama_save_session_file(llama->ctx, path_chars, session_tokens.data(), save_size)) { env->ReleaseStringUTFChars(path, path_chars); return -1; } diff --git a/android/src/newarch/java/com/rnllama/RNLlamaModule.java b/android/src/newarch/java/com/rnllama/RNLlamaModule.java index 55bf27e5..38c5a1c6 100644 --- a/android/src/newarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/newarch/java/com/rnllama/RNLlamaModule.java @@ -48,8 +48,8 @@ public void loadSession(double id, String path, Promise promise) { } @ReactMethod - public void saveSession(double id, String path, Promise promise) { - rnllama.saveSession(id, path, promise); + public void saveSession(double id, String path, double size, Promise promise) { + rnllama.saveSession(id, path, size, promise); } @ReactMethod diff --git a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java index 82ee8277..4b34e2a5 100644 --- a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java @@ -49,8 +49,8 @@ public void loadSession(double id, String path, Promise promise) { } @ReactMethod - public void saveSession(double id, String path, Promise promise) { - rnllama.saveSession(id, path, promise); + public void saveSession(double id, String path, int size, Promise promise) { + rnllama.saveSession(id, path, size, promise); } @ReactMethod diff --git a/ios/RNLlama.mm b/ios/RNLlama.mm index 04089e88..8d98fe70 100644 --- a/ios/RNLlama.mm +++ b/ios/RNLlama.mm @@ -83,6 +83,7 @@ @implementation RNLlama RCT_EXPORT_METHOD(saveSession:(double)contextId withFilePath:(NSString *)filePath + withSize:(double)size withResolver:(RCTPromiseResolveBlock)resolve withRejecter:(RCTPromiseRejectBlock)reject) { @@ -98,7 +99,7 @@ @implementation RNLlama dispatch_async(dispatch_get_main_queue(), ^{ // TODO: Fix for use in llamaDQue @try { @autoreleasepool { - int count = [context saveSession:filePath]; + int count = [context saveSession:filePath size:(int)size]; resolve(@(count)); } } @catch (NSException *exception) { diff --git a/ios/RNLlamaContext.h b/ios/RNLlamaContext.h index 93ff2eed..88a77d34 100644 --- a/ios/RNLlamaContext.h +++ b/ios/RNLlamaContext.h @@ -23,7 +23,7 @@ - (NSString *)detokenize:(NSArray *)tokens; - (NSArray *)embedding:(NSString *)text; - (NSDictionary *)loadSession:(NSString *)path; -- (int)saveSession:(NSString *)path; +- (int)saveSession:(NSString *)path size:(int)size; - (void)invalidate; diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index ee5d1037..0f33570c 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -354,9 +354,11 @@ - (NSDictionary *)loadSession:(NSString *)path { }; } -- (int)saveSession:(NSString *)path { +- (int)saveSession:(NSString *)path size:(int)size { std::vector session_tokens = llama->embd; - if (!llama_save_session_file(llama->ctx, [path UTF8String], session_tokens.data(), session_tokens.size())) { + int default_size = session_tokens.size(); + int save_size = size > 0 && size <= default_size ? size : default_size; + if (!llama_save_session_file(llama->ctx, [path UTF8String], session_tokens.data(), save_size)) { @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to save session" userInfo:nil]; } return session_tokens.size(); diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index 6554d6ee..67a949b3 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -116,7 +116,7 @@ export interface Spec extends TurboModule { initContext(params: NativeContextParams): Promise; loadSession(contextId: number, filepath: string): Promise; - saveSession(contextId: number, filepath: string): Promise; + saveSession(contextId: number, filepath: string, size: number): Promise; completion(contextId: number, params: NativeCompletionParams): Promise; stopCompletion(contextId: number): Promise; tokenize(contextId: number, text: string): Promise; diff --git a/src/index.ts b/src/index.ts index ba517d57..3150fe1c 100644 --- a/src/index.ts +++ b/src/index.ts @@ -67,8 +67,8 @@ export class LlamaContext { /** * Save current cached prompt & completion state to a file. */ - async saveSession(filepath: string): Promise { - return RNLlama.saveSession(this.id, filepath) + async saveSession(filepath: string, options?: { tokenSize: number }): Promise { + return RNLlama.saveSession(this.id, filepath, options?.tokenSize || -1) } async completion(