Skip to content

Commit

Permalink
feat: implement loadSession & saveSession methods (#23)
Browse files Browse the repository at this point in the history
* feat(ios): implement loadSession & saveSession methods

* feat(example): add commands

* feat(android): implement loadSession / saveSession

* docs: update

* feat: put loaded session tokens to rnllama context & return tokens

* feat: sync llama.cpp for load/save session fix

* feat: update from llama.cpp

* fix(ios): use main queue to load/save session

* fix: resize to n_ctx before load session
  • Loading branch information
jhen0409 authored Oct 4, 2023
1 parent 43e036e commit 68fe09c
Show file tree
Hide file tree
Showing 29 changed files with 2,072 additions and 243 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ The binding’s deisgn inspired by [server.cpp](https://github.com/ggerganov/lla
- `/detokenize`: `context.detokenize(tokens)`
- `/embedding`: `context.embedding(content)`
- Other methods
- `context.loadSession(path)`
- `context.saveSession(path)`
- `context.stopCompletion()`
- `context.release()`

Expand Down
20 changes: 20 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,18 @@ void onPartialCompletion(WritableMap tokenResult) {
}
}

public WritableMap loadSession(String path) {
WritableMap result = loadSession(this.context, path);
if (result.hasKey("error")) {
throw new IllegalStateException(result.getString("error"));
}
return result;
}

public int saveSession(String path) {
return saveSession(this.context, path);
}

public WritableMap completion(ReadableMap params) {
if (!params.hasKey("prompt")) {
throw new IllegalArgumentException("Missing required parameter: prompt");
Expand Down Expand Up @@ -228,6 +240,14 @@ protected static native long initContext(
float rope_freq_base,
float rope_freq_scale
);
protected static native WritableMap loadSession(
long contextPtr,
String path
);
protected static native int saveSession(
long contextPtr,
String path
);
protected static native WritableMap doCompletion(
long context_ptr,
String prompt,
Expand Down
66 changes: 66 additions & 0 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,72 @@ protected void onPostExecute(WritableMap result) {
tasks.put(task, "initContext");
}

public void loadSession(double id, final String path, Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
protected WritableMap doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
WritableMap result = context.loadSession(path);
return result;
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(WritableMap result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
tasks.remove(this);
}
}.execute();
tasks.put(task, "loadSession-" + contextId);
}

public void saveSession(double id, final String path, Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Integer>() {
private Exception exception;

@Override
protected Integer doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
Integer count = context.saveSession(path);
return count;
} catch (Exception e) {
exception = e;
}
return -1;
}

@Override
protected void onPostExecute(Integer result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
tasks.remove(this);
}
}.execute();
tasks.put(task, "saveSession-" + contextId);
}

public void completion(double id, final ReadableMap params, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
Expand Down
51 changes: 51 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,57 @@ Java_com_rnllama_LlamaContext_initContext(
return reinterpret_cast<jlong>(llama->ctx);
}

JNIEXPORT jobject JNICALL
Java_com_rnllama_LlamaContext_loadSession(
JNIEnv *env,
jobject thiz,
jlong context_ptr,
jstring path
) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];
const char *path_chars = env->GetStringUTFChars(path, nullptr);

auto result = createWriteableMap(env);
size_t n_token_count_out = 0;
llama->embd.resize(llama->params.n_ctx);
if (!llama_load_session_file(llama->ctx, path_chars, llama->embd.data(), llama->embd.capacity(), &n_token_count_out)) {
env->ReleaseStringUTFChars(path, path_chars);

putString(env, result, "error", "Failed to load session");
return reinterpret_cast<jobject>(result);
}
llama->embd.resize(n_token_count_out);
env->ReleaseStringUTFChars(path, path_chars);

const std::string text = rnllama::tokens_to_str(llama->ctx, llama->embd.cbegin(), llama->embd.cend());
putInt(env, result, "tokens_loaded", n_token_count_out);
putString(env, result, "prompt", text.c_str());
return reinterpret_cast<jobject>(result);
}

JNIEXPORT jint JNICALL
Java_com_rnllama_LlamaContext_saveSession(
JNIEnv *env,
jobject thiz,
jlong context_ptr,
jstring path
) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];

const char *path_chars = env->GetStringUTFChars(path, nullptr);

std::vector<llama_token> session_tokens = llama->embd;
if (!llama_save_session_file(llama->ctx, path_chars, session_tokens.data(), session_tokens.size())) {
env->ReleaseStringUTFChars(path, path_chars);
return -1;
}

env->ReleaseStringUTFChars(path, path_chars);
return session_tokens.size();
}

static inline jobject tokenProbsToMap(
JNIEnv *env,
rnllama::llama_rn_context *llama,
Expand Down
10 changes: 10 additions & 0 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ public void initContext(final ReadableMap params, final Promise promise) {
rnllama.initContext(params, promise);
}

@ReactMethod
public void loadSession(double id, String path, Promise promise) {
rnllama.loadSession(id, path, promise);
}

@ReactMethod
public void saveSession(double id, String path, Promise promise) {
rnllama.saveSession(id, path, promise);
}

@ReactMethod
public void completion(double id, final ReadableMap params, final Promise promise) {
rnllama.completion(id, params, promise);
Expand Down
10 changes: 10 additions & 0 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ public void initContext(final ReadableMap params, final Promise promise) {
rnllama.initContext(params, promise);
}

@ReactMethod
public void loadSession(double id, String path, Promise promise) {
rnllama.loadSession(id, path, promise);
}

@ReactMethod
public void saveSession(double id, String path, Promise promise) {
rnllama.saveSession(id, path, promise);
}

@ReactMethod
public void completion(double id, final ReadableMap params, final Promise promise) {
rnllama.completion(id, params, promise);
Expand Down
4 changes: 2 additions & 2 deletions cpp/build-info.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#ifndef BUILD_INFO_H
#define BUILD_INFO_H

#define BUILD_NUMBER 1299
#define BUILD_COMMIT "f5ef5cf"
#define BUILD_NUMBER 1317
#define BUILD_COMMIT "79f34ab"
#define BUILD_COMPILER ""
#define BUILD_TARGET "unknown"

Expand Down
3 changes: 3 additions & 0 deletions cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.interactive_first = true;
} else if (arg == "-ins" || arg == "--instruct") {
params.instruct = true;
} else if (arg == "--infill") {
params.infill = true;
} else if (arg == "--multiline-input") {
params.multiline_input = true;
} else if (arg == "--simple-io") {
Expand Down Expand Up @@ -921,6 +923,7 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
result += piece;
}

// NOTE: the original tokenizer decodes bytes after collecting the pieces.
return result;
}

Expand Down
1 change: 1 addition & 0 deletions cpp/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct gpt_params {
bool use_mlock = false; // use mlock to keep model in memory
bool numa = false; // attempt optimizations that help on some NUMA systems
bool verbose_prompt = false; // print prompt tokens before generation
bool infill = false; // use infill mode
};

bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
Expand Down
9 changes: 4 additions & 5 deletions cpp/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1213,12 +1213,9 @@ void lm_ggml_metal_graph_compute(
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));

if (__builtin_popcount(n_head) != 1) {
LM_GGML_ASSERT(false && "only power-of-two n_head implemented");
}

const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);

[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
Expand All @@ -1239,7 +1236,9 @@ void lm_ggml_metal_graph_compute(
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
Expand Down
11 changes: 9 additions & 2 deletions cpp/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,9 @@ kernel void kernel_alibi_f32(
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant float & m0,
constant float & m0,
constant float & m1,
constant int & n_heads_log2_floor,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
Expand All @@ -846,7 +848,12 @@ kernel void kernel_alibi_f32(
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);

device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
float m_k = pow(m0, i2 + 1);
float m_k;
if (i2 < n_heads_log2_floor) {
m_k = pow(m0, i2 + 1);
} else {
m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
}
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
Expand Down
Loading

0 comments on commit 68fe09c

Please sign in to comment.