Skip to content

Commit

Permalink
feat: sync llama.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Oct 13, 2023
1 parent f832580 commit 6174659
Show file tree
Hide file tree
Showing 13 changed files with 424 additions and 285 deletions.
1 change: 1 addition & 0 deletions android/src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ set(
${RNLLAMA_LIB_DIR}/k_quants.c
${RNLLAMA_LIB_DIR}/common.cpp
${RNLLAMA_LIB_DIR}/grammar-parser.cpp
${RNLLAMA_LIB_DIR}/sampling.cpp
${RNLLAMA_LIB_DIR}/llama.cpp
${RNLLAMA_LIB_DIR}/rn-llama.hpp
${CMAKE_SOURCE_DIR}/jni.cpp
Expand Down
56 changes: 29 additions & 27 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,40 +299,33 @@ Java_com_rnllama_LlamaContext_doCompletion(

llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
llama->params.grammar = env->GetStringUTFChars(grammar, nullptr);
llama->params.temp = temperature;

int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
llama->params.n_threads = n_threads > 0 ? n_threads : default_n_threads;

llama->params.n_predict = n_predict;
llama->params.n_probs = n_probs;
llama->params.repeat_last_n = repeat_last_n;
llama->params.repeat_penalty = repeat_penalty;
llama->params.presence_penalty = presence_penalty;
llama->params.frequency_penalty = frequency_penalty;
llama->params.mirostat = mirostat;
llama->params.mirostat_tau = mirostat_tau;
llama->params.mirostat_eta = mirostat_eta;
llama->params.top_k = top_k;
llama->params.top_p = top_p;
llama->params.tfs_z = tfs_z;
llama->params.typical_p = typical_p;
llama->params.ignore_eos = ignore_eos;

llama->params.antiprompt.clear();
int stop_len = env->GetArrayLength(stop);
for (int i = 0; i < stop_len; i++) {
jstring stop_str = (jstring) env->GetObjectArrayElement(stop, i);
const char *stop_chars = env->GetStringUTFChars(stop_str, nullptr);
llama->params.antiprompt.push_back(stop_chars);
env->ReleaseStringUTFChars(stop_str, stop_chars);
}

llama->params.logit_bias.clear();
auto & sparams = llama->params.sampling_params;
sparams.temp = temperature;
sparams.repeat_last_n = repeat_last_n;
sparams.repeat_penalty = repeat_penalty;
sparams.presence_penalty = presence_penalty;
sparams.frequency_penalty = frequency_penalty;
sparams.mirostat = mirostat;
sparams.mirostat_tau = mirostat_tau;
sparams.mirostat_eta = mirostat_eta;
sparams.top_k = top_k;
sparams.top_p = top_p;
sparams.tfs_z = tfs_z;
sparams.typical_p = typical_p;
sparams.n_probs = n_probs;

sparams.logit_bias.clear();
if (ignore_eos) {
llama->params.logit_bias[llama_token_eos(llama->ctx)] = -INFINITY;
sparams.logit_bias[llama_token_eos(llama->ctx)] = -INFINITY;
}

const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
Expand All @@ -346,9 +339,9 @@ Java_com_rnllama_LlamaContext_doCompletion(
llama_token tok = static_cast<llama_token>(doubleArray[0]);
if (tok >= 0 && tok < n_vocab) {
if (doubleArray[1] != 0) { // If the second element is not false (0)
llama->params.logit_bias[tok] = doubleArray[1];
sparams.logit_bias[tok] = doubleArray[1];
} else {
llama->params.logit_bias[tok] = -INFINITY;
sparams.logit_bias[tok] = -INFINITY;
}
}

Expand All @@ -357,6 +350,15 @@ Java_com_rnllama_LlamaContext_doCompletion(
env->DeleteLocalRef(el);
}

llama->params.antiprompt.clear();
int stop_len = env->GetArrayLength(stop);
for (int i = 0; i < stop_len; i++) {
jstring stop_str = (jstring) env->GetObjectArrayElement(stop, i);
const char *stop_chars = env->GetStringUTFChars(stop_str, nullptr);
llama->params.antiprompt.push_back(stop_chars);
env->ReleaseStringUTFChars(stop_str, stop_chars);
}

if (!llama->loadGrammar()) {
auto result = createWriteableMap(env);
putString(env, result, "error", "Failed to load grammar");
Expand Down Expand Up @@ -408,7 +410,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
auto tokenResult = createWriteableMap(env);
putString(env, tokenResult, "token", to_send.c_str());

if (llama->params.n_probs > 0) {
if (llama->params.sampling_params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama->ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
Expand Down
4 changes: 2 additions & 2 deletions cpp/build-info.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#ifndef BUILD_INFO_H
#define BUILD_INFO_H

#define BUILD_NUMBER 1364
#define BUILD_COMMIT "9f6ede1"
#define BUILD_NUMBER 1378
#define BUILD_COMMIT "1e0e873"
#define BUILD_COMPILER ""
#define BUILD_TARGET "unknown"

Expand Down
Loading

0 comments on commit 6174659

Please sign in to comment.