Skip to content

Commit

Permalink
feat: sync llama.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Oct 19, 2023
1 parent 3a07270 commit a8e9555
Show file tree
Hide file tree
Showing 19 changed files with 933 additions and 327 deletions.
4 changes: 4 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,10 @@ Java_com_rnllama_LlamaContext_freeContext(
if (llama->ctx) {
llama_free(llama->ctx);
}
if (llama->ctx_sampling != nullptr)
{
llama_sampling_free(llama->ctx_sampling);
}
context_map.erase((long) llama->ctx);
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/build-info.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#ifndef BUILD_INFO_H
#define BUILD_INFO_H

#define BUILD_NUMBER 1378
#define BUILD_COMMIT "1e0e873"
#define BUILD_NUMBER 1399
#define BUILD_COMMIT "004797f"
#define BUILD_COMPILER ""
#define BUILD_TARGET "unknown"

Expand Down
33 changes: 28 additions & 5 deletions cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,27 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
return cparams;
}

void llama_batch_clear(struct llama_batch & batch) {
batch.n_tokens = 0;
}

void llama_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
batch.token [batch.n_tokens] = id;
batch.pos [batch.n_tokens] = pos,
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;

batch.n_tokens++;
}

std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto mparams = llama_model_params_from_gpt_params(params);

Expand Down Expand Up @@ -879,21 +900,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_bos) {
return llama_tokenize(llama_get_model(ctx), text, add_bos);
bool add_bos,
bool special) {
return llama_tokenize(llama_get_model(ctx), text, add_bos, special);
}

std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos) {
bool add_bos,
bool special) {
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
LM_GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
Expand Down
22 changes: 19 additions & 3 deletions cpp/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct gpt_params {
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files

// TODO: avoid tuple, use struct
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
std::string lora_base = ""; // base model path for the lora adapter

Expand Down Expand Up @@ -124,10 +125,23 @@ void process_escapes(std::string& input);
// Model utils
//

// TODO: avoid tuplue, use struct
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params);

struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);

// Batch utils

void llama_batch_clear(struct llama_batch & batch);

void llama_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits);

//
// Vocab utils
//
Expand All @@ -137,12 +151,14 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_bos);
bool add_bos,
bool special = false);

std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos);
bool add_bos,
bool special = false);

// tokenizes a token into a piece
// should work similar to Python's `tokenizer.id_to_piece`
Expand Down
47 changes: 44 additions & 3 deletions cpp/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
LM_GGML_METAL_DECL_KERNEL(get_rows_f16);
LM_GGML_METAL_DECL_KERNEL(get_rows_q4_0);
LM_GGML_METAL_DECL_KERNEL(get_rows_q4_1);
LM_GGML_METAL_DECL_KERNEL(get_rows_q5_0);
LM_GGML_METAL_DECL_KERNEL(get_rows_q5_1);
LM_GGML_METAL_DECL_KERNEL(get_rows_q8_0);
LM_GGML_METAL_DECL_KERNEL(get_rows_q2_K);
LM_GGML_METAL_DECL_KERNEL(get_rows_q3_K);
Expand All @@ -87,6 +89,8 @@
LM_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
LM_GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
LM_GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
LM_GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
LM_GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
LM_GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
LM_GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
LM_GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
Expand All @@ -97,6 +101,8 @@
LM_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
LM_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
LM_GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
LM_GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
LM_GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
LM_GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
LM_GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
LM_GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
Expand Down Expand Up @@ -254,6 +260,8 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char* format,
LM_GGML_METAL_ADD_KERNEL(get_rows_f16);
LM_GGML_METAL_ADD_KERNEL(get_rows_q4_0);
LM_GGML_METAL_ADD_KERNEL(get_rows_q4_1);
LM_GGML_METAL_ADD_KERNEL(get_rows_q5_0);
LM_GGML_METAL_ADD_KERNEL(get_rows_q5_1);
LM_GGML_METAL_ADD_KERNEL(get_rows_q8_0);
LM_GGML_METAL_ADD_KERNEL(get_rows_q2_K);
LM_GGML_METAL_ADD_KERNEL(get_rows_q3_K);
Expand All @@ -268,6 +276,8 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char* format,
LM_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
LM_GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
Expand All @@ -278,8 +288,10 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char* format,
LM_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
LM_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
Expand Down Expand Up @@ -346,6 +358,8 @@ void lm_ggml_metal_free(struct lm_ggml_metal_context * ctx) {
LM_GGML_METAL_DEL_KERNEL(get_rows_f16);
LM_GGML_METAL_DEL_KERNEL(get_rows_q4_0);
LM_GGML_METAL_DEL_KERNEL(get_rows_q4_1);
LM_GGML_METAL_DEL_KERNEL(get_rows_q5_0);
LM_GGML_METAL_DEL_KERNEL(get_rows_q5_1);
LM_GGML_METAL_DEL_KERNEL(get_rows_q8_0);
LM_GGML_METAL_DEL_KERNEL(get_rows_q2_K);
LM_GGML_METAL_DEL_KERNEL(get_rows_q3_K);
Expand All @@ -360,6 +374,8 @@ void lm_ggml_metal_free(struct lm_ggml_metal_context * ctx) {
LM_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
LM_GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
Expand All @@ -370,8 +386,10 @@ void lm_ggml_metal_free(struct lm_ggml_metal_context * ctx) {
LM_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
LM_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
Expand Down Expand Up @@ -1052,6 +1070,8 @@ void lm_ggml_metal_graph_compute(
case LM_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
case LM_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
case LM_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
case LM_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
case LM_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
case LM_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
case LM_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
case LM_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
Expand Down Expand Up @@ -1121,6 +1141,24 @@ void lm_ggml_metal_graph_compute(
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
} break;
case LM_GGML_TYPE_Q5_0:
{
LM_GGML_ASSERT(ne02 == 1);
LM_GGML_ASSERT(ne12 == 1);

nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
} break;
case LM_GGML_TYPE_Q5_1:
{
LM_GGML_ASSERT(ne02 == 1);
LM_GGML_ASSERT(ne12 == 1);

nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
} break;
case LM_GGML_TYPE_Q8_0:
{
LM_GGML_ASSERT(ne02 == 1);
Expand Down Expand Up @@ -1201,7 +1239,8 @@ void lm_ggml_metal_graph_compute(
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];

if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q8_0 ||
if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 ||
src0t == LM_GGML_TYPE_Q5_0 || src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 ||
src0t == LM_GGML_TYPE_Q2_K) { // || src0t == LM_GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
Expand Down Expand Up @@ -1233,6 +1272,8 @@ void lm_ggml_metal_graph_compute(
case LM_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case LM_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case LM_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
case LM_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
case LM_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
case LM_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
case LM_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
case LM_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
Expand Down
Loading

0 comments on commit a8e9555

Please sign in to comment.