Skip to content

Commit

Permalink
feat(ios): implement loadSession & saveSession methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Oct 2, 2023
1 parent 8da7244 commit 8521221
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 0 deletions.
52 changes: 52 additions & 0 deletions ios/RNLlama.mm
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,58 @@ @implementation RNLlama
});
}

RCT_EXPORT_METHOD(loadSession:(double)contextId
withFilePath:(NSString *)filePath
withResolver:(RCTPromiseResolveBlock)resolve
withRejecter:(RCTPromiseRejectBlock)reject)
{
RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]];
if (context == nil) {
reject(@"llama_error", @"Context not found", nil);
return;
}
if ([context isPredicting]) {
reject(@"llama_error", @"Context is busy", nil);
return;
}
dispatch_async(llamaDQueue, ^{
@try {
@autoreleasepool {
int count = [context loadSession:filePath];
resolve(@(count));
}
} @catch (NSException *exception) {
reject(@"llama_cpp_error", exception.reason, nil);
}
});
}

RCT_EXPORT_METHOD(saveSession:(double)contextId
withFilePath:(NSString *)filePath
withResolver:(RCTPromiseResolveBlock)resolve
withRejecter:(RCTPromiseRejectBlock)reject)
{
RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]];
if (context == nil) {
reject(@"llama_error", @"Context not found", nil);
return;
}
if ([context isPredicting]) {
reject(@"llama_error", @"Context is busy", nil);
return;
}
dispatch_async(llamaDQueue, ^{
@try {
@autoreleasepool {
int count = [context saveSession:filePath];
resolve(@(count));
}
} @catch (NSException *exception) {
reject(@"llama_cpp_error", exception.reason, nil);
}
});
}

- (NSArray *)supportedEvents {
return@[
@"@RNLlama_onToken",
Expand Down
2 changes: 2 additions & 0 deletions ios/RNLlamaContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
- (NSArray *)tokenize:(NSString *)text;
- (NSString *)detokenize:(NSArray *)tokens;
- (NSArray *)embedding:(NSString *)text;
- (int)loadSession:(NSString *)path;
- (int)saveSession:(NSString *)path;

- (void)invalidate;

Expand Down
18 changes: 18 additions & 0 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,24 @@ - (NSArray *)embedding:(NSString *)text {
return embeddingResult;
}

- (int)loadSession:(NSString *)path {
std::vector<llama_token> session_tokens;
size_t n_token_count_out = 0;
if (!llama_load_session_file(llama->ctx, [path UTF8String], session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
@throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to load session" userInfo:nil];
}
session_tokens.resize(n_token_count_out);
return n_token_count_out;
}

- (int)saveSession:(NSString *)path {
std::vector<llama_token> session_tokens = llama->embd;
if (!llama_save_session_file(llama->ctx, [path UTF8String], session_tokens.data(), session_tokens.size())) {
@throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to save session" userInfo:nil];
}
return session_tokens.size();
}

- (void)invalidate {
if (llama->grammar != nullptr) {
llama_grammar_free(llama->grammar);
Expand Down
2 changes: 2 additions & 0 deletions src/NativeRNLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ export interface Spec extends TurboModule {
setContextLimit(limit: number): Promise<void>;
initContext(params: NativeContextParams): Promise<NativeLlamaContext>;

loadSession(contextId: number, filepath: string): Promise<number>;
saveSession(contextId: number, filepath: string): Promise<number>;
completion(contextId: number, params: NativeCompletionParams): Promise<NativeCompletionResult>;
stopCompletion(contextId: number): Promise<void>;
tokenize(contextId: number, text: string): Promise<NativeTokenizeResult>;
Expand Down

0 comments on commit 8521221

Please sign in to comment.