From a8e9555d75168e21df6d71c95ea9191fb26976b5 Mon Sep 17 00:00:00 2001 From: jhen Date: Thu, 19 Oct 2023 10:27:15 +0800 Subject: [PATCH] feat: sync llama.cpp --- android/src/main/jni.cpp | 4 + cpp/build-info.h | 4 +- cpp/common.cpp | 33 +++- cpp/common.h | 22 ++- cpp/ggml-metal.m | 47 ++++- cpp/ggml-metal.metal | 163 +++++++++++++++- cpp/ggml.c | 34 ++++ cpp/ggml.h | 3 + cpp/k_quants.c | 30 ++- cpp/llama.cpp | 411 +++++++++++++++++++++++++++++++++------ cpp/llama.h | 30 +-- cpp/log.h | 101 +++++++--- cpp/rn-llama.hpp | 62 ++---- cpp/sampling.cpp | 211 +++++++++++--------- cpp/sampling.h | 87 ++++----- example/ios/Podfile.lock | 4 +- ios/RNLlamaContext.mm | 4 - llama.cpp | 2 +- scripts/llama.cpp.patch | 8 +- 19 files changed, 933 insertions(+), 327 deletions(-) diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index fa42043d..7d24d9f1 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -572,6 +572,10 @@ Java_com_rnllama_LlamaContext_freeContext( if (llama->ctx) { llama_free(llama->ctx); } + if (llama->ctx_sampling != nullptr) + { + llama_sampling_free(llama->ctx_sampling); + } context_map.erase((long) llama->ctx); } diff --git a/cpp/build-info.h b/cpp/build-info.h index 93f808c2..d7613432 100644 --- a/cpp/build-info.h +++ b/cpp/build-info.h @@ -1,8 +1,8 @@ #ifndef BUILD_INFO_H #define BUILD_INFO_H -#define BUILD_NUMBER 1378 -#define BUILD_COMMIT "1e0e873" +#define BUILD_NUMBER 1399 +#define BUILD_COMMIT "004797f" #define BUILD_COMPILER "" #define BUILD_TARGET "unknown" diff --git a/cpp/common.cpp b/cpp/common.cpp index 077cc959..ce26523c 100644 --- a/cpp/common.cpp +++ b/cpp/common.cpp @@ -820,6 +820,27 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param return cparams; } +void llama_batch_clear(struct llama_batch & batch) { + batch.n_tokens = 0; +} + +void llama_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + const std::vector & seq_ids, + bool logits) { + batch.token [batch.n_tokens] = id; + batch.pos [batch.n_tokens] = pos, + batch.n_seq_id[batch.n_tokens] = seq_ids.size(); + for (size_t i = 0; i < seq_ids.size(); ++i) { + batch.seq_id[batch.n_tokens][i] = seq_ids[i]; + } + batch.logits [batch.n_tokens] = logits; + + batch.n_tokens++; +} + std::tuple llama_init_from_gpt_params(gpt_params & params) { auto mparams = llama_model_params_from_gpt_params(params); @@ -879,21 +900,23 @@ std::tuple llama_init_from_gpt_par std::vector llama_tokenize( const struct llama_context * ctx, const std::string & text, - bool add_bos) { - return llama_tokenize(llama_get_model(ctx), text, add_bos); + bool add_bos, + bool special) { + return llama_tokenize(llama_get_model(ctx), text, add_bos, special); } std::vector llama_tokenize( const struct llama_model * model, const std::string & text, - bool add_bos) { + bool add_bos, + bool special) { // upper limit for the number of tokens int n_tokens = text.length() + add_bos; std::vector result(n_tokens); - n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos); + n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos); + int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special); LM_GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); diff --git a/cpp/common.h b/cpp/common.h index 36fd4416..65d3d20c 100644 --- a/cpp/common.h +++ b/cpp/common.h @@ -70,6 +70,7 @@ struct gpt_params { std::vector antiprompt; // string upon seeing which more user input is prompted std::string logdir = ""; // directory in which to save YAML log files + // TODO: avoid tuple, use struct std::vector> lora_adapter; // lora adapter path with user defined scale std::string lora_base = ""; // base model path for the lora adapter @@ -124,10 +125,23 @@ void process_escapes(std::string& input); // Model utils // +// TODO: avoid tuplue, use struct std::tuple llama_init_from_gpt_params(gpt_params & params); -struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params); + +struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); +// Batch utils + +void llama_batch_clear(struct llama_batch & batch); + +void llama_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + const std::vector & seq_ids, + bool logits); + // // Vocab utils // @@ -137,12 +151,14 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param std::vector llama_tokenize( const struct llama_context * ctx, const std::string & text, - bool add_bos); + bool add_bos, + bool special = false); std::vector llama_tokenize( const struct llama_model * model, const std::string & text, - bool add_bos); + bool add_bos, + bool special = false); // tokenizes a token into a piece // should work similar to Python's `tokenizer.id_to_piece` diff --git a/cpp/ggml-metal.m b/cpp/ggml-metal.m index 9d4a5531..b1336b9f 100644 --- a/cpp/ggml-metal.m +++ b/cpp/ggml-metal.m @@ -73,6 +73,8 @@ LM_GGML_METAL_DECL_KERNEL(get_rows_f16); LM_GGML_METAL_DECL_KERNEL(get_rows_q4_0); LM_GGML_METAL_DECL_KERNEL(get_rows_q4_1); + LM_GGML_METAL_DECL_KERNEL(get_rows_q5_0); + LM_GGML_METAL_DECL_KERNEL(get_rows_q5_1); LM_GGML_METAL_DECL_KERNEL(get_rows_q8_0); LM_GGML_METAL_DECL_KERNEL(get_rows_q2_K); LM_GGML_METAL_DECL_KERNEL(get_rows_q3_K); @@ -87,6 +89,8 @@ LM_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); LM_GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32); LM_GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32); + LM_GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32); + LM_GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32); LM_GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32); LM_GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32); LM_GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32); @@ -97,6 +101,8 @@ LM_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); LM_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); LM_GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); + LM_GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32); + LM_GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32); LM_GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32); LM_GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32); LM_GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32); @@ -254,6 +260,8 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char* format, LM_GGML_METAL_ADD_KERNEL(get_rows_f16); LM_GGML_METAL_ADD_KERNEL(get_rows_q4_0); LM_GGML_METAL_ADD_KERNEL(get_rows_q4_1); + LM_GGML_METAL_ADD_KERNEL(get_rows_q5_0); + LM_GGML_METAL_ADD_KERNEL(get_rows_q5_1); LM_GGML_METAL_ADD_KERNEL(get_rows_q8_0); LM_GGML_METAL_ADD_KERNEL(get_rows_q2_K); LM_GGML_METAL_ADD_KERNEL(get_rows_q3_K); @@ -268,6 +276,8 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char* format, LM_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); LM_GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32); LM_GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32); LM_GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32); LM_GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32); LM_GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32); @@ -278,8 +288,10 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char* format, LM_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); LM_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); LM_GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); - LM_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); LM_GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32); + LM_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); LM_GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); LM_GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); LM_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); @@ -346,6 +358,8 @@ void lm_ggml_metal_free(struct lm_ggml_metal_context * ctx) { LM_GGML_METAL_DEL_KERNEL(get_rows_f16); LM_GGML_METAL_DEL_KERNEL(get_rows_q4_0); LM_GGML_METAL_DEL_KERNEL(get_rows_q4_1); + LM_GGML_METAL_DEL_KERNEL(get_rows_q5_0); + LM_GGML_METAL_DEL_KERNEL(get_rows_q5_1); LM_GGML_METAL_DEL_KERNEL(get_rows_q8_0); LM_GGML_METAL_DEL_KERNEL(get_rows_q2_K); LM_GGML_METAL_DEL_KERNEL(get_rows_q3_K); @@ -360,6 +374,8 @@ void lm_ggml_metal_free(struct lm_ggml_metal_context * ctx) { LM_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); LM_GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32); LM_GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32); LM_GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32); LM_GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32); LM_GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32); @@ -370,8 +386,10 @@ void lm_ggml_metal_free(struct lm_ggml_metal_context * ctx) { LM_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); LM_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); LM_GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); - LM_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); LM_GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32); + LM_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); LM_GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); LM_GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); LM_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); @@ -1052,6 +1070,8 @@ void lm_ggml_metal_graph_compute( case LM_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; case LM_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; case LM_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; + case LM_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break; + case LM_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break; case LM_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break; case LM_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; case LM_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; @@ -1121,6 +1141,24 @@ void lm_ggml_metal_graph_compute( nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; } break; + case LM_GGML_TYPE_Q5_0: + { + LM_GGML_ASSERT(ne02 == 1); + LM_GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32]; + } break; + case LM_GGML_TYPE_Q5_1: + { + LM_GGML_ASSERT(ne02 == 1); + LM_GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32]; + } break; case LM_GGML_TYPE_Q8_0: { LM_GGML_ASSERT(ne02 == 1); @@ -1201,7 +1239,8 @@ void lm_ggml_metal_graph_compute( [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; - if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q8_0 || + if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || + src0t == LM_GGML_TYPE_Q5_0 || src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K) { // || src0t == LM_GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -1233,6 +1272,8 @@ void lm_ggml_metal_graph_compute( case LM_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case LM_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case LM_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; + case LM_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break; + case LM_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break; case LM_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break; case LM_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; case LM_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; diff --git a/cpp/ggml-metal.metal b/cpp/ggml-metal.metal index 99b9fd7a..69fc7136 100644 --- a/cpp/ggml-metal.metal +++ b/cpp/ggml-metal.metal @@ -18,6 +18,21 @@ typedef struct { uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; +#define QK5_0 32 +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; + +#define QK5_1 32 +typedef struct { + half d; // delta + half m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; + #define QK8_0 32 typedef struct { half d; // delta @@ -399,8 +414,11 @@ kernel void kernel_rms_norm( // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); @@ -417,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; float m = qb_curr->m; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); @@ -428,6 +449,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1]) + sumy * m; } +// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (sumy * -16.f + acc[0] + acc[1]); +} + +// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_1/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + // putting them in the kernel cause a significant performance penalty #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group @@ -525,6 +589,43 @@ kernel void kernel_mul_mv_q4_1_f32( mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } +kernel void kernel_mul_mv_q5_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q5_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + + #define NB_Q8_0 8 kernel void kernel_mul_mv_q8_0_f32( @@ -2149,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg } } +template +void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + md; + reg[i/2][2*(i%2)+1] = d * x1 + md; + } +} + +template +void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + m; + reg[i/2][2*(i%2)+1] = d * x1 + m; + } +} + template void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { device const int8_t * qs = ((device const int8_t *)xb->qs); @@ -2490,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; @@ -2518,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/cpp/ggml.c b/cpp/ggml.c index 9fc1bfd9..17ce01b8 100644 --- a/cpp/ggml.c +++ b/cpp/ggml.c @@ -5494,6 +5494,39 @@ struct lm_ggml_tensor * lm_ggml_view_tensor( return result; } +struct lm_ggml_tensor * lm_ggml_get_first_tensor(struct lm_ggml_context * ctx) { + struct lm_ggml_object * obj = ctx->objects_begin; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == LM_GGML_OBJECT_TENSOR) { + return (struct lm_ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + +struct lm_ggml_tensor * lm_ggml_get_next_tensor(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor) { + struct lm_ggml_object * obj = (struct lm_ggml_object *) ((char *)tensor - LM_GGML_OBJECT_SIZE); + obj = obj->next; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == LM_GGML_OBJECT_TENSOR) { + return (struct lm_ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + struct lm_ggml_tensor * lm_ggml_get_tensor(struct lm_ggml_context * ctx, const char * name) { struct lm_ggml_object * obj = ctx->objects_begin; @@ -8647,6 +8680,7 @@ void lm_ggml_set_param( LM_GGML_ASSERT(tensor->grad == NULL); tensor->grad = lm_ggml_dup_tensor(ctx, tensor); + lm_ggml_format_name(tensor->grad, "%s (grad)", tensor->name); } // lm_ggml_compute_forward_dup diff --git a/cpp/ggml.h b/cpp/ggml.h index 45368eb9..6ece44c9 100644 --- a/cpp/ggml.h +++ b/cpp/ggml.h @@ -704,6 +704,9 @@ extern "C" { LM_GGML_API struct lm_ggml_tensor * lm_ggml_dup_tensor (struct lm_ggml_context * ctx, const struct lm_ggml_tensor * src); LM_GGML_API struct lm_ggml_tensor * lm_ggml_view_tensor(struct lm_ggml_context * ctx, struct lm_ggml_tensor * src); + // Context tensor enumeration and lookup + LM_GGML_API struct lm_ggml_tensor * lm_ggml_get_first_tensor(struct lm_ggml_context * ctx); + LM_GGML_API struct lm_ggml_tensor * lm_ggml_get_next_tensor (struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor); LM_GGML_API struct lm_ggml_tensor * lm_ggml_get_tensor(struct lm_ggml_context * ctx, const char * name); LM_GGML_API struct lm_ggml_tensor * lm_ggml_set_zero(struct lm_ggml_tensor * tensor); diff --git a/cpp/k_quants.c b/cpp/k_quants.c index 57548b90..eacebad1 100644 --- a/cpp/k_quants.c +++ b/cpp/k_quants.c @@ -462,12 +462,9 @@ void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) { } size_t lm_ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - const int nb = k / QK_K; - - // TODO - collect histograms - although, at a second thought, I don't really care about them - (void)hist; + (void)hist; // TODO: collect histograms - for (int j = 0; j < nb; j += k) { + for (int j = 0; j < n; j += k) { block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K; quantize_row_q2_K_reference(src + j, y, k); } @@ -678,12 +675,9 @@ void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) { } size_t lm_ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - const int nb = k / QK_K; - - // TODO - collect histograms - although, at a second thought, I don't really care about them - (void)hist; + (void)hist; // TODO: collect histograms - for (int j = 0; j < nb; j += k) { + for (int j = 0; j < n; j += k) { block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K; quantize_row_q3_K_reference(src + j, y, k); } @@ -846,9 +840,9 @@ void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) { size_t lm_ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { assert(k % QK_K == 0); - const int nb = k / QK_K; (void)hist; // TODO: collect histograms - for (int j = 0; j < nb; j += k) { + + for (int j = 0; j < n; j += k) { block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K; quantize_row_q4_K_reference(src + j, y, k); } @@ -1052,9 +1046,9 @@ void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) { size_t lm_ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { assert(k % QK_K == 0); - const int nb = k / QK_K; - (void)hist; - for (int j = 0; j < nb; j += k) { + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K; quantize_row_q5_K_reference(src + j, y, k); } @@ -1200,11 +1194,9 @@ void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) { size_t lm_ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) { assert(k % QK_K == 0); - const int nb = k / QK_K; - - (void)hist; // TODO + (void)hist; // TODO: collect histograms - for (int j = 0; j < nb; j += k) { + for (int j = 0; j < n; j += k) { block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K; quantize_row_q6_K_reference(src + j, y, k); } diff --git a/cpp/llama.cpp b/cpp/llama.cpp index eaf27cb2..e3fa6a5a 100644 --- a/cpp/llama.cpp +++ b/cpp/llama.cpp @@ -75,6 +75,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -1194,6 +1195,8 @@ struct llama_vocab { std::unordered_map token_to_id; std::vector id_to_token; + std::unordered_map special_tokens_cache; + std::map, int> bpe_ranks; // default LLaMA special tokens @@ -1458,7 +1461,10 @@ static bool llama_kv_cache_find_slot( for (uint32_t i = 0; i < n_tokens; i++) { cache.cells[cache.head + i].pos = batch.pos[i]; - cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]); + + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); + } } return true; @@ -1538,6 +1544,9 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); if (new_head == cache.size) new_head = i; + } else { + cache.cells[i].seq_id.clear(); + cache.cells[i].seq_id.insert(seq_id); } } @@ -2136,7 +2145,7 @@ static void llm_load_hparams( } // TODO: This should probably be in llama.h -static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos); +static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false); static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch); static void llm_load_vocab( @@ -2252,6 +2261,101 @@ static void llm_load_vocab( GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID)); GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID)); GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID)); + + // build special tokens cache + { + // TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type, + // and will always be correctly labeled in 'added_tokens.json' etc. + // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed + // to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer + // are special tokens. + // From testing, this appears to corelate 1:1 with special tokens. + // + + // Counting special tokens and verifying in only one direction + // is sufficient to detect difference in those two sets. + // + uint32_t special_tokens_count_by_type = 0; + uint32_t special_tokens_count_from_verification = 0; + + bool special_tokens_definition_mismatch = false; + + for (const auto & t : vocab.token_to_id) { + const auto & token = t.first; + const auto & id = t.second; + + // Count all non-normal tokens in the vocab while iterating + if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) { + special_tokens_count_by_type++; + } + + // Skip single character tokens + if (token.length() > 1) { + bool is_tokenizable = false; + + // Split token string representation in two, in all possible ways + // and check if both halves can be matched to a valid token + for (unsigned i = 1; i < token.length();) { + const auto left = token.substr(0, i); + const auto right = token.substr(i); + + // check if we didnt partition in the middle of a utf sequence + auto utf = utf8_len(left.at(left.length() - 1)); + + if (utf == 1) { + if (vocab.token_to_id.find(left) != vocab.token_to_id.end() && + vocab.token_to_id.find(right) != vocab.token_to_id.end() ) { + is_tokenizable = true; + break; + } + i++; + } else { + // skip over the rest of multibyte utf sequence + i += utf - 1; + } + } + + if (!is_tokenizable) { + // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1 + // it's faster to re-filter them here, since there are way less candidates now + + // Calculate a total "utf" length of a token string representation + size_t utf8_str_len = 0; + for (unsigned i = 0; i < token.length();) { + utf8_str_len++; + i += utf8_len(token.at(i)); + } + + // And skip the ones which are one character + if (utf8_str_len > 1) { + // At this point what we have left are special tokens only + vocab.special_tokens_cache[token] = id; + + // Count manually found special tokens + special_tokens_count_from_verification++; + + // If this manually found special token is not marked as such, flag a mismatch + if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) { + special_tokens_definition_mismatch = true; + } + } + } + } + } + + if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) { + LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n", + __func__, + special_tokens_count_from_verification, vocab.id_to_token.size(), + special_tokens_count_by_type, vocab.id_to_token.size() + ); + } else { + LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n", + __func__, + special_tokens_count_from_verification, vocab.id_to_token.size() + ); + } + } } static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { @@ -2850,8 +2954,8 @@ static void llm_load_tensors( auto & layer = model.layers[i]; layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split); - layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); @@ -3091,7 +3195,7 @@ static struct lm_ggml_cgraph * llm_build_llama( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -3477,7 +3581,7 @@ static struct lm_ggml_cgraph * llm_build_baichaun( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -3876,7 +3980,7 @@ static struct lm_ggml_cgraph * llm_build_refact( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -4228,7 +4332,7 @@ static struct lm_ggml_cgraph * llm_build_falcon( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -4580,7 +4684,7 @@ static struct lm_ggml_cgraph * llm_build_starcoder( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -4811,7 +4915,7 @@ static struct lm_ggml_cgraph * llm_build_persimmon( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; @@ -5209,7 +5313,7 @@ static struct lm_ggml_cgraph * llm_build_bloom( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -5379,7 +5483,7 @@ static struct lm_ggml_cgraph * llm_build_mpt( const int64_t n_layer = hparams.n_layer; const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA + const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); const int64_t n_embd_gqa = hparams.n_embd_gqa(); @@ -5477,7 +5581,7 @@ static struct lm_ggml_cgraph * llm_build_mpt( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -5732,7 +5836,6 @@ static struct lm_ggml_cgraph * llama_build_graph( // // - lctx: llama context // - batch: batch to evaluate -// - n_threads: number of threads to use // // return 0 on success // return positive int on warning @@ -5778,8 +5881,11 @@ static int llama_decode_internal( // helpers for smoother batch API transistion // after deprecating the llama_eval calls, these will be removed - std::vector pos; - std::vector seq_id; + std::vector pos; + + std::vector n_seq_id; + std::vector seq_id_arr; + std::vector> seq_id; if (batch.pos == nullptr) { pos.resize(n_tokens); @@ -5791,12 +5897,18 @@ static int llama_decode_internal( } if (batch.seq_id == nullptr) { + n_seq_id.resize(n_tokens); seq_id.resize(n_tokens); + seq_id_arr.resize(n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { - seq_id[i] = batch.all_seq_id; + n_seq_id[i] = 1; + seq_id[i].resize(1); + seq_id[i][0] = batch.all_seq_id; + seq_id_arr[i] = seq_id[i].data(); } - batch.seq_id = seq_id.data(); + batch.n_seq_id = n_seq_id.data(); + batch.seq_id = seq_id_arr.data(); } if (!llama_kv_cache_find_slot(kv_self, batch)) { @@ -5817,6 +5929,13 @@ static int llama_decode_internal( lm_ggml_allocr_alloc_graph(lctx.alloc, gf); + struct lm_ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + struct lm_ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; + + LM_GGML_ASSERT(strcmp(res->name, "result_output") == 0); + LM_GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + + #ifdef LM_GGML_USE_CUBLAS for (int i = 0; i < gf->n_leafs; i++) { lm_ggml_tensor * node = gf->leafs[i]; @@ -5834,6 +5953,12 @@ static int llama_decode_internal( } lm_ggml_cuda_set_mul_mat_q(cparams.mul_mat_q); + + // HACK: ggml-alloc may change the tensor backend when reusing a parent, so force output to be on the CPU here if needed + if (!lctx.embedding.empty()) { + embeddings->backend = LM_GGML_BACKEND_CPU; + } + res->backend = LM_GGML_BACKEND_CPU; #endif // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (lm_ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -5858,12 +5983,6 @@ static int llama_decode_internal( n_threads = 1; } - struct lm_ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - struct lm_ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - - LM_GGML_ASSERT(strcmp(res->name, "result_output") == 0); - LM_GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); - #if LM_GGML_USE_MPI const int64_t n_layer = hparams.n_layer; lm_ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); @@ -6476,7 +6595,137 @@ struct llm_tokenizer_bpe { llm_bigram_bpe::queue work_queue; }; -static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) { +typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{ + FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, + FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT +} FRAGMENT_BUFFER_VARIANT_TYPE; + +struct fragment_buffer_variant{ + fragment_buffer_variant(llama_vocab::id _token) + : + type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN), + token(_token), + raw_text(_dummy), + offset(0), + length(0){} + fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length) + : + type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT), + token((llama_vocab::id)-1), + raw_text(_raw_text), + offset(_offset), + length(_length){ + LM_GGML_ASSERT( _offset >= 0 ); + LM_GGML_ASSERT( _length >= 1 ); + LM_GGML_ASSERT( offset + length <= raw_text.length() ); + } + + const FRAGMENT_BUFFER_VARIANT_TYPE type; + const llama_vocab::id token; + const std::string _dummy; + const std::string & raw_text; + const uint64_t offset; + const uint64_t length; +}; + +// #define PRETOKENIZERDEBUG + +static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) +{ + // for each special token + for (const auto & st: vocab.special_tokens_cache) { + const auto & special_token = st.first; + const auto & special_id = st.second; + + // for each text fragment + std::forward_list::iterator it = buffer.begin(); + while (it != buffer.end()) { + auto & fragment = (*it); + + // if a fragment is text ( not yet processed ) + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + auto * raw_text = &(fragment.raw_text); + + auto raw_text_base_offset = fragment.offset; + auto raw_text_base_length = fragment.length; + + // loop over the text + while (true) { + // find the first occurence of a given special token in this fragment + // passing offset argument only limit the "search area" but match coordinates + // are still relative to the source full raw_text + auto match = raw_text->find(special_token, raw_text_base_offset); + + // no occurences found, stop processing this fragment for a given special token + if (match == std::string::npos) break; + + // check if match is within bounds of offset <-> length + if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break; + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr, "FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); +#endif + auto source = std::distance(buffer.begin(), it); + + // if match is further than base offset + // then we have some text to the left of it + if (match > raw_text_base_offset) { + // left + const int64_t left_reminder_offset = raw_text_base_offset + 0; + const int64_t left_reminder_length = match - raw_text_base_offset; + buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length); + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr, "FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str()); +#endif + it++; + } + + // special token + buffer.emplace_after(it, special_id); + it++; + + // right + if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) { + const int64_t right_reminder_offset = match + special_token.length(); + const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); + buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length); + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr, "FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str()); +#endif + + it++; + + if (source == 0) { + buffer.erase_after(buffer.before_begin()); + } else { + buffer.erase_after(std::next(buffer.begin(), (source-1))); + } + + // repeat for the right side + raw_text_base_offset = right_reminder_offset; + raw_text_base_length = right_reminder_length; + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr, "RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); +#endif + } else { + if (source == 0) { + buffer.erase_after(buffer.before_begin()); + } else { + buffer.erase_after(std::next(buffer.begin(), (source-1))); + } + break; + } + } + } + it++; + } + } +} + +static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) { std::vector output; // OG tokenizer behavior: @@ -6492,20 +6741,58 @@ static std::vector llama_tokenize_internal(const llama_vocab & return output; } + std::forward_list fragment_buffer; + fragment_buffer.emplace_front( raw_text, 0, raw_text.length() ); + + if (special) tokenizer_st_partition( vocab, fragment_buffer ); + switch (vocab.type) { case LLAMA_VOCAB_TYPE_SPM: { - // without adding this leading whitespace, we do not get the same results as the original tokenizer - raw_text = " " + raw_text; + for (const auto & fragment: fragment_buffer) + { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) + { + // without adding this leading whitespace, we do not get the same results as the original tokenizer - llm_tokenizer_spm tokenizer(vocab); - llama_escape_whitespace(raw_text); - tokenizer.tokenize(raw_text, output); + // TODO: It's likely possible to get rid of this string copy entirely + // by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer + // and passing 'add space prefix' as bool argument + // + auto raw_text = (special ? "" : " ") + fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + llm_tokenizer_spm tokenizer(vocab); + llama_escape_whitespace(raw_text); + tokenizer.tokenize(raw_text, output); + } + else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + { + output.push_back(fragment.token); + } + } } break; case LLAMA_VOCAB_TYPE_BPE: { - llm_tokenizer_bpe tokenizer(vocab); - tokenizer.tokenize(raw_text, output); + for (const auto & fragment: fragment_buffer) + { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) + { + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + llm_tokenizer_bpe tokenizer(vocab); + tokenizer.tokenize(raw_text, output); + } + else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + { + output.push_back(fragment.token); + } + } } break; } @@ -8848,6 +9135,9 @@ void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llam } void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); } @@ -9300,7 +9590,7 @@ int llama_eval_embd( int n_past) { llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); - llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, }; + llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, }; const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { @@ -9321,20 +9611,21 @@ struct llama_batch llama_batch_get_one( llama_pos pos_0, llama_seq_id seq_id) { return { - /*n_tokens =*/ n_tokens, - /*tokens =*/ tokens, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, - /*all_pos_0 =*/ pos_0, - /*all_pos_1 =*/ 1, - /*all_seq_id =*/ seq_id, + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*all_pos_0 =*/ pos_0, + /*all_pos_1 =*/ 1, + /*all_seq_id =*/ seq_id, }; } -struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) { - llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; +struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) { + llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; if (embd) { batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd); @@ -9342,19 +9633,29 @@ struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) { batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); } - batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); - batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens); - batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); + batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + } + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); return batch; } void llama_batch_free(struct llama_batch batch) { - if (batch.token) free(batch.token); - if (batch.embd) free(batch.embd); - if (batch.pos) free(batch.pos); - if (batch.seq_id) free(batch.seq_id); - if (batch.logits) free(batch.logits); + if (batch.token) free(batch.token); + if (batch.embd) free(batch.embd); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; i < batch.n_tokens; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); } int llama_decode( @@ -9419,15 +9720,15 @@ llama_token llama_token_eot(const struct llama_context * ctx) { return ctx->model.vocab.special_eot_id; } - int llama_tokenize( const struct llama_model * model, const char * text, int text_len, llama_token * tokens, int n_max_tokens, - bool add_bos) { - auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos); + bool add_bos, + bool special) { + auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special); if (n_max_tokens < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); diff --git a/cpp/llama.h b/cpp/llama.h index b84ddab7..5aae8f53 100644 --- a/cpp/llama.h +++ b/cpp/llama.h @@ -133,11 +133,12 @@ extern "C" { typedef struct llama_batch { int32_t n_tokens; - llama_token * token; - float * embd; - llama_pos * pos; - llama_seq_id * seq_id; - int8_t * logits; + llama_token * token; + float * embd; + llama_pos * pos; + int32_t * n_seq_id; + llama_seq_id ** seq_id; + int8_t * logits; // NOTE: helpers for smooth API transition - can be deprecated in the future // for future-proof code, use the above fields instead and ignore everything below @@ -446,7 +447,8 @@ extern "C" { llama_pos pos_0, llama_seq_id seq_id); - // Allocates a batch of tokens on the heap + // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens + // Each token can be assigned up to n_seq_max sequence ids // The batch has to be freed with llama_batch_free() // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token @@ -454,7 +456,8 @@ extern "C" { // All members are left uninitialized LLAMA_API struct llama_batch llama_batch_init( int32_t n_tokens, - int32_t embd); + int32_t embd, + int32_t n_seq_max); // Frees a batch of tokens allocated with llama_batch_init() LLAMA_API void llama_batch_free(struct llama_batch batch); @@ -511,17 +514,20 @@ extern "C" { // Tokenization // - // Convert the provided text into tokens. - // The tokens pointer must be large enough to hold the resulting tokens. - // Returns the number of tokens on success, no more than n_max_tokens - // Returns a negative number on failure - the number of tokens that would have been returned + /// @details Convert the provided text into tokens. + /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. + /// @return Returns the number of tokens on success, no more than n_max_tokens + /// @return Returns a negative number on failure - the number of tokens that would have been returned + /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. + /// Does not insert a leading space. LLAMA_API int llama_tokenize( const struct llama_model * model, const char * text, int text_len, llama_token * tokens, int n_max_tokens, - bool add_bos); + bool add_bos, + bool special); // Token Id -> Piece. // Uses the vocabulary in the provided context. diff --git a/cpp/log.h b/cpp/log.h index 507110a2..b99e9e7a 100644 --- a/cpp/log.h +++ b/cpp/log.h @@ -592,38 +592,75 @@ inline std::string log_var_to_string_impl(const std::vector & var) return buf.str(); } -#define LOG_TOKENS_TOSTR_PRETTY(ctx, tokens) \ - [&tokens, &ctx]() \ - { \ - std::stringstream buf; \ - buf << "[ "; \ - \ - bool first = true; \ - for (const auto &token : tokens) \ - { \ - if (!first) \ - buf << ", "; \ - else \ - first = false; \ - \ - auto detokenized = llama_token_to_piece(ctx, token); \ - \ - detokenized.erase( \ - std::remove_if( \ - detokenized.begin(), \ - detokenized.end(), \ - [](const unsigned char c) { return !std::isprint(c); }), \ - detokenized.end()); \ - \ - buf \ - << "'" << detokenized << "'" \ - << ":" << std::to_string(token); \ - } \ - buf << " ]"; \ - \ - return buf.str(); \ - }() \ - .c_str() +template +inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens) +{ + std::stringstream buf; + buf << "[ "; + + bool first = true; + for (const auto &token : tokens) + { + if (!first) { + buf << ", "; + } else { + first = false; + } + + auto detokenized = llama_token_to_piece(ctx, token); + + detokenized.erase( + std::remove_if( + detokenized.begin(), + detokenized.end(), + [](const unsigned char c) { return !std::isprint(c); }), + detokenized.end()); + + buf + << "'" << detokenized << "'" + << ":" << std::to_string(token); + } + buf << " ]"; + + return buf.str(); +} + +template +inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch) +{ + std::stringstream buf; + buf << "[ "; + + bool first = true; + for (int i = 0; i < batch.n_tokens; ++i) + { + if (!first) { + buf << ", "; + } else { + first = false; + } + + auto detokenized = llama_token_to_piece(ctx, batch.token[i]); + + detokenized.erase( + std::remove_if( + detokenized.begin(), + detokenized.end(), + [](const unsigned char c) { return !std::isprint(c); }), + detokenized.end()); + + buf + << "\n" << std::to_string(i) + << ":token '" << detokenized << "'" + << ":pos " << std::to_string(batch.pos[i]) + << ":n_seq_id " << std::to_string(batch.n_seq_id[i]) + << ":seq_id " << std::to_string(batch.seq_id[i][0]) + << ":logits " << std::to_string(batch.logits[i]); + } + buf << " ]"; + + return buf.str(); +} #ifdef LOG_DISABLE_LOGS diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index 5f7ef230..d38a94f8 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -5,7 +5,6 @@ #include #include "common.h" #include "llama.h" -#include "grammar-parser.h" namespace rnllama { @@ -139,15 +138,11 @@ struct llama_rn_context size_t n_remain = 0; std::vector embd; - std::vector last_n_tokens; llama_model *model = nullptr; llama_context *ctx = nullptr; gpt_params params; - llama_sampling_context ctx_sampling; - - grammar_parser::parse_state parsed_grammar; - llama_grammar *grammar = nullptr; + llama_sampling_context *ctx_sampling; bool truncated = false; bool stopped_eos = false; @@ -168,6 +163,10 @@ struct llama_rn_context llama_free_model(model); model = nullptr; } + if (ctx_sampling != nullptr) + { + llama_sampling_free(ctx_sampling); + } } void rewind() @@ -189,11 +188,10 @@ struct llama_rn_context n_remain = 0; n_past = 0; - if (grammar != nullptr) { - llama_grammar_free(grammar); - grammar = nullptr; - ctx_sampling = llama_sampling_context_init(params, NULL); + if (ctx_sampling != nullptr) { + llama_sampling_free(ctx_sampling); } + ctx_sampling = llama_sampling_init(params); } bool loadModel(gpt_params ¶ms_) @@ -205,35 +203,12 @@ struct llama_rn_context LOG_ERROR("unable to load model: %s", params_.model.c_str()); return false; } - - last_n_tokens.resize(params.n_ctx); - std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); return true; } bool loadGrammar() { - if (!params.grammar.empty()) { - parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - LOG_ERROR("grammar parse error, grammar: %s", params.grammar.c_str()); - return false; - } - grammar_parser::print_grammar(stderr, parsed_grammar); - - { - auto it = params.sampling_params.logit_bias.find(llama_token_eos(ctx)); - if (it != params.sampling_params.logit_bias.end() && it->second == -INFINITY) { - LOG_WARNING("EOS token is disabled, which will cause most grammars to fail"); - } - } - - std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); - } - ctx_sampling = llama_sampling_context_init(params, grammar); + ctx_sampling = llama_sampling_init(params); return true; } @@ -256,7 +231,7 @@ struct llama_rn_context std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin()); LOG_VERBOSE("input truncated, n_ctx: %d, n_keep: %d, n_left: %d, new_tokens: %s", params.n_ctx, params.n_keep, n_left, tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()).c_str() @@ -267,8 +242,8 @@ struct llama_rn_context else { const size_t ps = num_prompt_tokens; - std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps); } // compare the evaluated prompt with the new prompt @@ -367,23 +342,22 @@ struct llama_rn_context std::vector candidates; candidates.reserve(llama_n_vocab(model)); - result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates); + result.tok = llama_sampling_sample(ctx_sampling, ctx, NULL); - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false }; const int32_t n_probs = params.sampling_params.n_probs; if (params.sampling_params.temp <= 0 && n_probs > 0) { // For llama_sample_token_greedy we need to sort candidates - llama_sample_softmax(ctx, &candidates_p); + llama_sample_softmax(ctx, &cur_p); } - for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i) + for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i) { - result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); + result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); } - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(result.tok); + llama_sampling_accept(ctx_sampling, ctx, result.tok); if (tg) { num_tokens_predicted++; } diff --git a/cpp/sampling.cpp b/cpp/sampling.cpp index 8ce41945..0b246658 100644 --- a/cpp/sampling.cpp +++ b/cpp/sampling.cpp @@ -1,64 +1,81 @@ #include "sampling.h" -llama_sampling_context::~llama_sampling_context() { - for (auto & it : sequence_contexts) { - if (it.second.grammar != NULL) { - llama_grammar_free(it.second.grammar); - it.second.grammar = NULL; +struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) { + struct llama_sampling_context * result = new llama_sampling_context(); + + result->params = params.sampling_params; + result->grammar = nullptr; + + // if there is a grammar, parse it + if (!params.grammar.empty()) { + result->parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + + // will be empty (default) if there are parse errors + if (result->parsed_grammar.rules.empty()) { + fprintf(stderr, "%s: failed to parse grammar\n", __func__); + return nullptr; } + + std::vector grammar_rules(result->parsed_grammar.c_rules()); + + result->grammar = llama_grammar_init( + grammar_rules.data(), + grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); } + + result->prev.resize(params.n_ctx); + + return result; } -llama_sampling_context llama_sampling_context_init( - const struct gpt_params & params, - llama_grammar * grammar) { - llama_sampling_context result; +void llama_sampling_free(struct llama_sampling_context * ctx) { + if (ctx->grammar != NULL) { + llama_grammar_free(ctx->grammar); + } - result.params = params.sampling_params; - result.grammar = grammar; - return result; + delete ctx; } -// Note: Creates the context if it doesn't exist, so this always return something. -llama_sampler_sequence_context & llama_sampling_get_sequence_context( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq) { - const auto it = ctx_sampling.sequence_contexts.find(seq); - if (it != ctx_sampling.sequence_contexts.end()) { - return it->second; +void llama_sampling_reset(llama_sampling_context * ctx) { + if (ctx->grammar != NULL) { + llama_grammar_free(ctx->grammar); } - llama_sampler_sequence_context new_ctx = { - 2.0f * ctx_sampling.params.mirostat_tau, - ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL, - }; - return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second; + + if (!ctx->parsed_grammar.rules.empty()) { + std::vector grammar_rules(ctx->parsed_grammar.c_rules()); + + ctx->grammar = llama_grammar_init( + grammar_rules.data(), + grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); + } + + std::fill(ctx->prev.begin(), ctx->prev.end(), 0); + ctx->cur.clear(); } -bool llama_sampling_context_reset( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq) { - const auto it = ctx_sampling.sequence_contexts.find(seq); - if (it == ctx_sampling.sequence_contexts.end()) return false; - if (it->second.grammar != NULL) { - llama_grammar_free(it->second.grammar); - it->second.grammar = NULL; +void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { + if (dst->grammar) { + llama_grammar_free(dst->grammar); + dst->grammar = nullptr; } - ctx_sampling.sequence_contexts.erase(it); - return true; + + if (src->grammar) { + dst->grammar = llama_grammar_copy(src->grammar); + } + + dst->prev = src->prev; } llama_token llama_sampling_sample( - struct llama_context * ctx, - struct llama_context * ctx_guidance, - struct llama_sampling_context & ctx_sampling, - const std::vector & last_tokens, - std::vector & candidates, - const int idx, - llama_seq_id seq) { - const int n_ctx = llama_n_ctx(ctx); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - - const llama_sampling_params & params = ctx_sampling.params; + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx) { + const int n_ctx = llama_n_ctx(ctx_main); + const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + + const llama_sampling_params & params = ctx_sampling->params; + const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const float top_p = params.top_p; @@ -73,41 +90,45 @@ llama_token llama_sampling_sample( const float mirostat_eta = params.mirostat_eta; const bool penalize_nl = params.penalize_nl; + auto & prev = ctx_sampling->prev; + auto & cur = ctx_sampling->cur; + llama_token id = 0; - float * logits = llama_get_logits_ith(ctx, idx); + float * logits = llama_get_logits_ith(ctx_main, idx); // Apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { logits[it->first] += it->second; } - candidates.clear(); + cur.clear(); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); } - llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; - if (ctx_guidance) { - llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale); + if (ctx_cfg) { + llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_cfg, params.cfg_scale); } // apply penalties - if (!last_tokens.empty()) { - const float nl_logit = logits[llama_token_nl(ctx)]; - const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx); + if (!prev.empty()) { + const float nl_logit = logits[llama_token_nl(ctx_main)]; + const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx); - llama_sample_repetition_penalty(ctx, &cur_p, - last_tokens.data() + last_tokens.size() - last_n_repeat, + llama_sample_repetition_penalty(ctx_main, &cur_p, + prev.data() + prev.size() - last_n_repeat, last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(ctx, &cur_p, - last_tokens.data() + last_tokens.size() - last_n_repeat, + llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p, + prev.data() + prev.size() - last_n_repeat, last_n_repeat, alpha_frequency, alpha_presence); if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { - if (cur_p.data[idx].id == llama_token_nl(ctx)) { + if (cur_p.data[idx].id == llama_token_nl(ctx_main)) { cur_p.data[idx].logit = nl_logit; break; } @@ -115,52 +136,58 @@ llama_token llama_sampling_sample( } } - llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq); - - if (ctx_seq.grammar != NULL) { - llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar); + if (ctx_sampling->grammar != NULL) { + llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); } if (temp <= 0) { // Greedy sampling - id = llama_sample_token_greedy(ctx, &cur_p); + id = llama_sample_token_greedy(ctx_main, &cur_p); } else { if (mirostat == 1) { const int mirostat_m = 100; - llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu); + llama_sample_temp(ctx_main, &cur_p, temp); + id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu); } else if (mirostat == 2) { - llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu); + llama_sample_temp(ctx_main, &cur_p, temp); + id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); } else { // Temperature sampling size_t min_keep = std::max(1, params.n_probs); - llama_sample_top_k (ctx, &cur_p, top_k, min_keep); - llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep); - llama_sample_typical (ctx, &cur_p, typical_p, min_keep); - llama_sample_top_p (ctx, &cur_p, top_p, min_keep); - llama_sample_temp(ctx, &cur_p, temp); - - { - const int n_top = 10; - LOG("top %d candidates:\n", n_top); - - for (int i = 0; i < n_top; i++) { - const llama_token id = cur_p.data[i].id; - (void)id; // To avoid a warning that id is unused when logging is disabled. - LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p); - } - } - - id = llama_sample_token(ctx, &cur_p); - - LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str()); + llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); + llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); + llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); + llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); + llama_sample_temp (ctx_main, &cur_p, temp); + + id = llama_sample_token(ctx_main, &cur_p); + + //{ + // const int n_top = 10; + // LOG("top %d candidates:\n", n_top); + + // for (int i = 0; i < n_top; i++) { + // const llama_token id = cur_p.data[i].id; + // (void)id; // To avoid a warning that id is unused when logging is disabled. + // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p); + // } + //} + + LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str()); } } - if (ctx_seq.grammar != NULL) { - llama_grammar_accept_token(ctx, ctx_seq.grammar, id); - } - return id; } + +void llama_sampling_accept( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + llama_token id) { + ctx_sampling->prev.erase(ctx_sampling->prev.begin()); + ctx_sampling->prev.push_back(id); + + if (ctx_sampling->grammar != NULL) { + llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); + } +} diff --git a/cpp/sampling.h b/cpp/sampling.h index 0aab5d03..50afcbc1 100644 --- a/cpp/sampling.h +++ b/cpp/sampling.h @@ -2,6 +2,8 @@ #include "llama.h" +#include "grammar-parser.h" + #include #include #include @@ -34,75 +36,64 @@ typedef struct llama_sampling_params { } llama_sampling_params; -// per-sequence sampler context -typedef struct llama_sampler_sequence_context { - float mirostat_mu; // mirostat sampler state - llama_grammar * grammar; -} llama_sampler_sequence_context; - // general sampler context -typedef struct llama_sampling_context { - ~llama_sampling_context(); - - // parameters that will be used for sampling and when creating - // new llama_sampler_sequence_context instances +// TODO: move to llama.h +struct llama_sampling_context { + // parameters that will be used for sampling llama_sampling_params params; - // map of sequence ids to sampler contexts - std::unordered_map sequence_contexts; + // mirostat sampler state + float mirostat_mu; - // when non-NULL, new instances of llama_sampler_sequence_context - // will get a copy of the grammar here - // note: only the pointer is stored here, it is not a copy of - // the grammar and shouldn't be freed llama_grammar * grammar; -} llama_sampling_context; + + // internal + grammar_parser::parse_state parsed_grammar; + + // TODO: replace with ring-buffer + std::vector prev; + std::vector cur; +}; #include "common.h" // Create a new sampling context instance. -llama_sampling_context llama_sampling_context_init( - const struct gpt_params & params, - llama_grammar * grammar = NULL); - -// Fetches the sampler context for the specified sequence id (defaults to 0). -// If the context for that sequence id doesn't already exist, it will be created with -// default values based on the parameters in the ctx_sampling argument. -llama_sampler_sequence_context & llama_sampling_get_sequence_context( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq = 0); - -// Reset the sampler context for the supplied sequence id (defaults to 0). -// This is necessary to reuse a sequence id or free memory used by sequences -// that are no longer required. -bool llama_sampling_context_reset( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq = 0); +struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params); + +void llama_sampling_free(struct llama_sampling_context * ctx); + +// Reset the sampler context +// - clear prev tokens +// - reset grammar +void llama_sampling_reset(llama_sampling_context * ctx); + +// Copy the sampler context +void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function // Note: When using multiple sequences, it is the caller's responsibility to call -// llama_sampling_context_reset when a sequence ends +// llama_sampling_reset when a sequence ends // // required: -// - ctx: context to use for sampling +// - ctx_main: context to use for sampling // - ctx_sampling: sampling-specific context // // optional: -// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL -// - last_tokens: needed for repetition penalty, ignore if empty -// - idx: sample from llama_get_logits_ith(ctx, idx) -// - seq: sequence id to associate sampler state with +// - ctx_cfg: context to use for classifier-free guidance +// - idx: sample from llama_get_logits_ith(ctx, idx) // // returns: // - token: sampled token // - candidates: vector of candidate tokens // llama_token llama_sampling_sample( - struct llama_context * ctx, - struct llama_context * ctx_guidance, - struct llama_sampling_context & ctx_sampling, - const std::vector & last_tokens, - std::vector & candidates, - const int idx = 0, - llama_seq_id seq = 0); + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + int idx = 0); + +void llama_sampling_accept( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + llama_token id); diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock index bcfbed7b..c83a36e4 100644 --- a/example/ios/Podfile.lock +++ b/example/ios/Podfile.lock @@ -8,7 +8,7 @@ PODS: - hermes-engine/Pre-built (= 0.72.3) - hermes-engine/Pre-built (0.72.3) - libevent (2.1.12) - - llama-rn (0.3.0-rc.0): + - llama-rn (0.3.0-rc.1): - RCT-Folly - RCTRequired - RCTTypeSafety @@ -1242,7 +1242,7 @@ SPEC CHECKSUMS: glog: 04b94705f318337d7ead9e6d17c019bd9b1f6b1b hermes-engine: 10fbd3f62405c41ea07e71973ea61e1878d07322 libevent: 4049cae6c81cdb3654a443be001fb9bdceff7913 - llama-rn: 181274aa4c46da201545cdf45ccf0300e9bc0363 + llama-rn: 9d97cb43a97a1315dd6853226165291f0c661ea5 RCT-Folly: 424b8c9a7a0b9ab2886ffe9c3b041ef628fd4fb1 RCTRequired: a2faf4bad4e438ca37b2040cb8f7799baa065c18 RCTTypeSafety: cb09f3e4747b6d18331a15eb05271de7441ca0b3 diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index 2ce63f1c..ee5d1037 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -363,11 +363,7 @@ - (int)saveSession:(NSString *)path { } - (void)invalidate { - if (llama->grammar != nullptr) { - llama_grammar_free(llama->grammar); - } delete llama; - // llama_backend_free(); } diff --git a/llama.cpp b/llama.cpp index 1e0e873c..004797f6 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit 1e0e873c373c33989beb6bc64d83cd572ab7fe2b +Subproject commit 004797f6ac135383f8c1d1f5bd415ddee2f79318 diff --git a/scripts/llama.cpp.patch b/scripts/llama.cpp.patch index b31f440f..e5e8b70f 100644 --- a/scripts/llama.cpp.patch +++ b/scripts/llama.cpp.patch @@ -1,6 +1,6 @@ ---- llama.cpp.orig 2023-10-12 09:34:44 -+++ llama.cpp 2023-10-12 09:36:38 -@@ -102,6 +102,17 @@ +--- llama.cpp.orig 2023-10-19 10:27:20 ++++ llama.cpp 2023-10-19 10:27:21 +@@ -103,6 +103,17 @@ #define LLAMA_LOG_WARN(...) llama_log_internal(LM_GGML_LOG_LEVEL_WARN , __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(LM_GGML_LOG_LEVEL_ERROR, __VA_ARGS__) @@ -18,7 +18,7 @@ // // helpers // -@@ -736,16 +747,16 @@ +@@ -737,16 +748,16 @@ if (prefetch > 0) { // Advise the kernel to preload the mapped memory