diff --git a/ios/RNLlama.mm b/ios/RNLlama.mm index 2d591fe7..862e647c 100644 --- a/ios/RNLlama.mm +++ b/ios/RNLlama.mm @@ -56,6 +56,58 @@ @implementation RNLlama }); } +RCT_EXPORT_METHOD(loadSession:(double)contextId + withFilePath:(NSString *)filePath + withResolver:(RCTPromiseResolveBlock)resolve + withRejecter:(RCTPromiseRejectBlock)reject) +{ + RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]]; + if (context == nil) { + reject(@"llama_error", @"Context not found", nil); + return; + } + if ([context isPredicting]) { + reject(@"llama_error", @"Context is busy", nil); + return; + } + dispatch_async(llamaDQueue, ^{ + @try { + @autoreleasepool { + int count = [context loadSession:filePath]; + resolve(@(count)); + } + } @catch (NSException *exception) { + reject(@"llama_cpp_error", exception.reason, nil); + } + }); +} + +RCT_EXPORT_METHOD(saveSession:(double)contextId + withFilePath:(NSString *)filePath + withResolver:(RCTPromiseResolveBlock)resolve + withRejecter:(RCTPromiseRejectBlock)reject) +{ + RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]]; + if (context == nil) { + reject(@"llama_error", @"Context not found", nil); + return; + } + if ([context isPredicting]) { + reject(@"llama_error", @"Context is busy", nil); + return; + } + dispatch_async(llamaDQueue, ^{ + @try { + @autoreleasepool { + int count = [context saveSession:filePath]; + resolve(@(count)); + } + } @catch (NSException *exception) { + reject(@"llama_cpp_error", exception.reason, nil); + } + }); +} + - (NSArray *)supportedEvents { return@[ @"@RNLlama_onToken", diff --git a/ios/RNLlamaContext.h b/ios/RNLlamaContext.h index 09461798..497ca89b 100644 --- a/ios/RNLlamaContext.h +++ b/ios/RNLlamaContext.h @@ -22,6 +22,8 @@ - (NSArray *)tokenize:(NSString *)text; - (NSString *)detokenize:(NSArray *)tokens; - (NSArray *)embedding:(NSString *)text; +- (int)loadSession:(NSString *)path; +- (int)saveSession:(NSString *)path; - (void)invalidate; diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index 74d33bb2..f86017b5 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -337,6 +337,24 @@ - (NSArray *)embedding:(NSString *)text { return embeddingResult; } +- (int)loadSession:(NSString *)path { + std::vector session_tokens; + size_t n_token_count_out = 0; + if (!llama_load_session_file(llama->ctx, [path UTF8String], session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { + @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to load session" userInfo:nil]; + } + session_tokens.resize(n_token_count_out); + return n_token_count_out; +} + +- (int)saveSession:(NSString *)path { + std::vector session_tokens = llama->embd; + if (!llama_save_session_file(llama->ctx, [path UTF8String], session_tokens.data(), session_tokens.size())) { + @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to save session" userInfo:nil]; + } + return session_tokens.size(); +} + - (void)invalidate { if (llama->grammar != nullptr) { llama_grammar_free(llama->grammar); diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index f33a4e3c..674b2e0c 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -110,6 +110,8 @@ export interface Spec extends TurboModule { setContextLimit(limit: number): Promise; initContext(params: NativeContextParams): Promise; + loadSession(contextId: number, filepath: string): Promise; + saveSession(contextId: number, filepath: string): Promise; completion(contextId: number, params: NativeCompletionParams): Promise; stopCompletion(contextId: number): Promise; tokenize(contextId: number, text: string): Promise;