diff --git a/README.md b/README.md index 67addd9b..4e7ddb82 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ C++ implementation of [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGL Highlights: * Pure C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp). * Accelerated memory-efficient CPU inference with int4/int8 quantization, optimized KV cache and parallel computing. +* P-Tuning v2 and LoRA finetuned models support. * Streaming generation with typewriter effect. * Python binding, web demo, api servers and more possibilities. @@ -68,7 +69,9 @@ You are free to try any of the below quantization types by specifying `-t * `f16`: half precision floating point weights without quantization. * `f32`: single precision floating point weights without quantization. -For LoRA model, add `-l ` flag to merge your LoRA weights into the base model. +For LoRA models, add `-l ` flag to merge your LoRA weights into the base model. For example, run `python3 chatglm_cpp/convert.py -i THUDM/chatglm3-6b -t q4_0 -o chatglm3-ggml-lora.bin -l shibing624/chatglm3-6b-csc-chinese-lora` to merge public LoRA weights from Hugging Face. + +For P-Tuning v2 models using the [official finetuning script](https://github.com/THUDM/ChatGLM3/tree/main/finetune_demo), additional weights are automatically detected by `convert.py`. If `past_key_values` is on the output weight list, the P-Tuning checkpoint is successfully converted. **Build & Run** diff --git a/chatglm.cpp b/chatglm.cpp index d0e70df5..bd3af4b5 100644 --- a/chatglm.cpp +++ b/chatglm.cpp @@ -171,7 +171,7 @@ void ggml_graph_compute_helper(std::vector &buf, ggml_cgraph // for debugging purpose [[maybe_unused]] static inline ggml_tensor *add_zero(ggml_context *ctx, ggml_tensor *tensor) { - ggml_tensor *zeros = ggml_new_tensor(ctx, tensor->type, tensor->n_dims, tensor->ne); + ggml_tensor *zeros = ggml_new_tensor(ctx, GGML_TYPE_F32, tensor->n_dims, tensor->ne); ggml_set_f32(zeros, 0); tensor_to_device(zeros); ggml_tensor *out = tensor_assign_buffers(ggml_add(ctx, tensor, zeros)); @@ -452,15 +452,12 @@ ggml_tensor *RMSNorm::forward(ModelContext *ctx, ggml_tensor *input) const { static ggml_tensor *apply_activation_inplace(ggml_context *ctx, ggml_tensor *hidden_states, ActivationType hidden_act) { switch (hidden_act) { case ActivationType::GELU: - hidden_states = tensor_assign_buffers(ggml_gelu_inplace(ctx, hidden_states)); - break; + return tensor_assign_buffers(ggml_gelu_inplace(ctx, hidden_states)); case ActivationType::SILU: - hidden_states = tensor_assign_buffers(ggml_silu_inplace(ctx, hidden_states)); - break; + return tensor_assign_buffers(ggml_silu_inplace(ctx, hidden_states)); default: CHATGLM_THROW << "Unknown activation type " << (int)hidden_act; } - return hidden_states; } ggml_tensor *BasicMLP::forward(ModelContext *ctx, ggml_tensor *hidden_states) const { @@ -515,9 +512,9 @@ std::string to_string(ModelType model_type) { } static ggml_tensor *apply_rotary_emb_basic(ModelContext *ctx, ggml_tensor *layer, ggml_tensor *position_ids, int n_ctx, - RopeType rope_type, int dim_scale) { - // tensor a (activation) is of shape [qlen, heads, head_size] - // tensor b (position_ids) is of shape [qlen] + RopeType rope_type, float rope_theta, int dim_scale) { + // tensor a (activation) is of shape [s, #h, d] + // tensor b (position_ids) is of shape [s] ggml_context *gctx = ctx->ctx_b.get(); #ifdef GGML_USE_CUBLAS if (!ggml_is_contiguous(layer)) { @@ -526,14 +523,14 @@ static ggml_tensor *apply_rotary_emb_basic(ModelContext *ctx, ggml_tensor *layer #endif const int head_size = layer->ne[0]; const int rope_dim = head_size / dim_scale; - layer = tensor_assign_buffers( - ggml_rope_inplace(gctx, layer, position_ids, rope_dim, (int)rope_type, n_ctx)); // [qlen, heads, head_size] + layer = tensor_assign_buffers(ggml_rope_custom_inplace(gctx, layer, position_ids, rope_dim, (int)rope_type, n_ctx, + rope_theta, 1.f)); // [s, #h, d] return layer; } static ggml_tensor *apply_rotary_emb_glm(ModelContext *ctx, ggml_tensor *layer, ggml_tensor *position_ids, int n_ctx) { - // tensor a (activation) is of shape [qlen, heads, head_size] - // tensor b (position_ids) is of shape [2 * qlen] + // tensor a (activation) is of shape [s, #h, d] + // tensor b (position_ids) is of shape [2 * s] ggml_context *gctx = ctx->ctx_b.get(); const int head_size = layer->ne[0]; @@ -556,9 +553,9 @@ static ggml_tensor *apply_rotary_emb_glm(ModelContext *ctx, ggml_tensor *layer, #endif a1_rope = tensor_assign_buffers( - ggml_rope_inplace(gctx, a1_rope, b1, rope_dim, (int)RopeType::NEOX, n_ctx)); // [qlen, heads, head_size/2] + ggml_rope_inplace(gctx, a1_rope, b1, rope_dim, (int)RopeType::NEOX, n_ctx)); // [s, #h, d/2] a2_rope = tensor_assign_buffers( - ggml_rope_inplace(gctx, a2_rope, b2, rope_dim, (int)RopeType::NEOX, n_ctx)); // [qlen, heads, head_size/2] + ggml_rope_inplace(gctx, a2_rope, b2, rope_dim, (int)RopeType::NEOX, n_ctx)); // [s, #h, d/2] #ifdef GGML_USE_CUBLAS a1_rope = ggml_cpy(gctx, a1_rope, a1); @@ -570,22 +567,48 @@ static ggml_tensor *apply_rotary_emb_glm(ModelContext *ctx, ggml_tensor *layer, return layer; } +[[maybe_unused]] static ggml_tensor *apply_rotary_emb_glm2(ModelContext *ctx, ggml_tensor *layer, + ggml_tensor *position_ids) { + // layer: [s, #h, d], position_ids: [s] + ggml_context *gctx = ctx->ctx_b.get(); +#ifdef GGML_USE_CUBLAS + if (!ggml_is_contiguous(layer)) { + layer = tensor_assign_buffers(ggml_cont(gctx, layer)); + } +#endif + const int head_size = layer->ne[0]; + const int rope_dim = head_size / 2; + ggml_tensor *roped_layer = + tensor_assign_buffers(ggml_rope(gctx, layer, position_ids, rope_dim, (int)RopeType::GPTJ, 0)); // [s, #h, d] + + ggml_tensor *roped_layer_view = tensor_assign_buffers( + ggml_view_3d(gctx, roped_layer, rope_dim, roped_layer->ne[1], roped_layer->ne[2], roped_layer->nb[1], + roped_layer->nb[2], rope_dim * roped_layer->nb[0])); // [s, #h, d/2] + + ggml_tensor *layer_view = + tensor_assign_buffers(ggml_view_3d(gctx, layer, rope_dim, layer->ne[1], layer->ne[2], layer->nb[1], + layer->nb[2], rope_dim * layer->nb[0])); // [s, #h, d/2] + + ggml_build_forward_expand(&ctx->gf, ggml_cpy(gctx, layer_view, roped_layer_view)); + + return roped_layer; +} + static ggml_tensor *apply_rotary_emb(ModelContext *ctx, ggml_tensor *layer, ggml_tensor *position_ids, int n_ctx, - RopeType rope_type, int dim_scale) { + RopeType rope_type, float rope_theta, int dim_scale) { switch (rope_type) { case RopeType::GPTJ: case RopeType::NEOX: - layer = apply_rotary_emb_basic(ctx, layer, position_ids, n_ctx, rope_type, dim_scale); - break; + return apply_rotary_emb_basic(ctx, layer, position_ids, n_ctx, rope_type, rope_theta, dim_scale); case RopeType::CHATGLM: - layer = apply_rotary_emb_glm(ctx, layer, position_ids, n_ctx); - break; + return apply_rotary_emb_glm(ctx, layer, position_ids, n_ctx); + // case RopeType::CHATGLM2: + // return apply_rotary_emb_glm2(ctx, layer, position_ids); case RopeType::DISABLED: - break; + return layer; default: CHATGLM_THROW << "Unknown rope type " << (int)rope_type; } - return layer; } static inline ggml_tensor *apply_attention_mask_causal(ModelContext *ctx, ggml_tensor *attn_scores, int n_past) { @@ -593,16 +616,18 @@ static inline ggml_tensor *apply_attention_mask_causal(ModelContext *ctx, ggml_t } static ggml_tensor *apply_attention_mask_glm(ModelContext *ctx, ggml_tensor *attn_scores, int n_past) { - // attn_scores is of shape [heads, qlen, klen] + // attn_scores: [#h, s, kvs] + // semantic: attn_scores[:, :-1, -1] = -inf ggml_context *gctx = ctx->ctx_b.get(); + const int kvlen = attn_scores->ne[0]; const int qlen = attn_scores->ne[1]; const int num_attention_heads = attn_scores->ne[2]; ggml_tensor *inf = ggml_new_tensor_3d(gctx, attn_scores->type, 1, qlen - 1, num_attention_heads); ggml_set_f32(inf, -INFINITY); tensor_to_device(inf); // TODO: optimize - ggml_tensor *masked_attn_scores = tensor_assign_buffers( - ggml_view_3d(gctx, attn_scores, 1, qlen - 1, num_attention_heads, qlen * ggml_element_size(attn_scores), - qlen * qlen * ggml_element_size(attn_scores), (qlen - 1) * ggml_element_size(attn_scores))); + ggml_tensor *masked_attn_scores = + tensor_assign_buffers(ggml_view_3d(gctx, attn_scores, 1, qlen - 1, num_attention_heads, attn_scores->nb[1], + attn_scores->nb[2], (kvlen - 1) * attn_scores->nb[0])); ggml_build_forward_expand(&ctx->gf, ggml_cpy(gctx, inf, masked_attn_scores)); return attn_scores; } @@ -629,12 +654,12 @@ ggml_tensor *BasicAttention::forward(ModelContext *ctx, ggml_tensor *hidden_stat const int num_shared_q_heads = num_attention_heads / num_kv_heads; const bool is_gqa = num_shared_q_heads > 1; - ggml_tensor *qkv = query_key_value.forward(ctx, hidden_states); // [qlen, hidden + 2 * kv_hidden] + ggml_tensor *qkv = query_key_value.forward(ctx, hidden_states); // [sq, (#h + 2 * #kvh) * d] // split mixed qkv into separate query, key and value - ggml_tensor *query_layer; // [qlen, heads, head_size] - ggml_tensor *key_layer; // [qlen, kv_heads, head_size] - ggml_tensor *value_layer; // [qlen, kv_heads, head_size] + ggml_tensor *query_layer; // [s, #h, d] + ggml_tensor *key_layer; // [s, #kvh, d] + ggml_tensor *value_layer; // [s, #kvh, d] if (interleaved_qkv) { CHATGLM_CHECK(!is_gqa) << "interleaved qkv is not supported for GQA"; @@ -655,42 +680,39 @@ ggml_tensor *BasicAttention::forward(ModelContext *ctx, ggml_tensor *hidden_stat qkv->nb[1], (hidden_size + head_size * num_kv_heads) * ggml_element_size(qkv)); } - query_layer = apply_rotary_emb(ctx, query_layer, position_ids, n_ctx, rope_type, rope_dim_scale); - key_layer = apply_rotary_emb(ctx, key_layer, position_ids, n_ctx, rope_type, rope_dim_scale); + query_layer = apply_rotary_emb(ctx, query_layer, position_ids, n_ctx, rope_type, rope_theta, rope_dim_scale); + key_layer = apply_rotary_emb(ctx, key_layer, position_ids, n_ctx, rope_type, rope_theta, rope_dim_scale); - query_layer = - tensor_assign_buffers(ggml_cont(gctx, ggml_permute(gctx, query_layer, 0, 2, 1, 3))); // [heads, qlen, head_size] + query_layer = tensor_assign_buffers(ggml_cont(gctx, ggml_permute(gctx, query_layer, 0, 2, 1, 3))); // [#h, s, d] if (num_shared_q_heads > 1) { - query_layer = - tensor_assign_buffers(ggml_reshape_3d(gctx, query_layer, head_size, num_shared_q_heads * qlen, - num_kv_heads)); // [kv_heads, shared_qheads * qlen, head_size] + query_layer = tensor_assign_buffers(ggml_reshape_3d(gctx, query_layer, head_size, num_shared_q_heads * qlen, + num_kv_heads)); // [#kvh, (#h/#kvh) * s, d] } - key_layer = tensor_assign_buffers(ggml_permute(gctx, key_layer, 0, 2, 1, 3)); // [kv_heads, qlen, head_size] - - value_layer = tensor_assign_buffers(ggml_permute(gctx, value_layer, 1, 2, 0, 3)); // [kv_heads, head_size, qlen] + key_layer = tensor_assign_buffers(ggml_permute(gctx, key_layer, 0, 2, 1, 3)); // [#kvh, s, d] + value_layer = tensor_assign_buffers(ggml_permute(gctx, value_layer, 1, 2, 0, 3)); // [#kvh, d, s] // store key & value to cache ggml_tensor *k_cache_view = tensor_assign_buffers( ggml_view_3d(gctx, k_cache, head_size, qlen, num_kv_heads, k_cache->nb[1], k_cache->nb[2], - n_past * head_size * ggml_element_size(k_cache))); // [kv_heads, qlen, head_size] + (num_virtual_tokens + n_past) * head_size * ggml_element_size(k_cache))); // [#kvh, s, d] ggml_build_forward_expand(&ctx->gf, ggml_cpy(gctx, key_layer, k_cache_view)); ggml_tensor *v_cache_view = tensor_assign_buffers(ggml_view_3d(gctx, v_cache, qlen, head_size, num_kv_heads, v_cache->nb[1], v_cache->nb[2], - n_past * ggml_element_size(v_cache))); // [kv_heads, head_size, qlen] + (num_virtual_tokens + n_past) * ggml_element_size(v_cache))); // [#kvh, d, s] ggml_build_forward_expand(&ctx->gf, ggml_cpy(gctx, value_layer, v_cache_view)); // concat key & value with past kv - key_layer = tensor_assign_buffers(ggml_view_3d(gctx, k_cache, head_size, n_past + qlen, num_kv_heads, - k_cache->nb[1], k_cache->nb[2], - 0)); // [kv_heads, klen, head_size] - value_layer = tensor_assign_buffers(ggml_view_3d(gctx, v_cache, n_past + qlen, head_size, num_kv_heads, - v_cache->nb[1], v_cache->nb[2], - 0)); // [kv_heads, head_size, klen] + key_layer = tensor_assign_buffers(ggml_view_3d(gctx, k_cache, head_size, num_virtual_tokens + n_past + qlen, + num_kv_heads, k_cache->nb[1], k_cache->nb[2], + 0)); // [#kvh, kvs, d] + value_layer = tensor_assign_buffers(ggml_view_3d(gctx, v_cache, num_virtual_tokens + n_past + qlen, head_size, + num_kv_heads, v_cache->nb[1], v_cache->nb[2], + 0)); // [#kvh, d, kvs] // attention ggml_tensor *attn_scores = - tensor_assign_buffers(ggml_mul_mat(gctx, key_layer, query_layer)); // [kv_heads, shared_qheads * qlen, klen] + tensor_assign_buffers(ggml_mul_mat(gctx, key_layer, query_layer)); // [#kvh, (#h/#kvh) * s, kvs] attn_scores = tensor_assign_buffers(ggml_scale_inplace(gctx, attn_scores, ggml_new_f32(gctx, 1.f / std::sqrt(head_size)))); if (use_alibi) { @@ -699,27 +721,27 @@ ggml_tensor *BasicAttention::forward(ModelContext *ctx, ggml_tensor *hidden_stat if (n_past == 0) { // build attention mask for context input if (num_shared_q_heads > 1) { - attn_scores = ggml_reshape_3d(gctx, attn_scores, n_past + qlen, qlen, - num_attention_heads); // [heads, qlen, klen] + attn_scores = ggml_reshape_3d(gctx, attn_scores, num_virtual_tokens + n_past + qlen, qlen, + num_attention_heads); // [#h, s, kvs] } - attn_scores = apply_attention_mask(ctx, attn_scores, n_past, attn_mask_type); + attn_scores = apply_attention_mask(ctx, attn_scores, num_virtual_tokens + n_past, attn_mask_type); if (num_shared_q_heads > 1) { - attn_scores = ggml_reshape_3d(gctx, attn_scores, n_past + qlen, num_shared_q_heads * qlen, - num_kv_heads); // [kv_heads, shared_qheads * qlen, klen] + attn_scores = + ggml_reshape_3d(gctx, attn_scores, num_virtual_tokens + n_past + qlen, num_shared_q_heads * qlen, + num_kv_heads); // [#kvh, (#h/#kvh) * s, kvs] } } ggml_tensor *attn_probs = - tensor_assign_buffers(ggml_soft_max_inplace(gctx, attn_scores)); // [kv_heads, shared_qheads * qlen, klen] + tensor_assign_buffers(ggml_soft_max_inplace(gctx, attn_scores)); // [#kvh, (#h/#kvh) * s, kvs] - ggml_tensor *context_layer = tensor_assign_buffers( - ggml_mul_mat(gctx, value_layer, attn_probs)); // [kv_heads, shared_qheads * qlen, head_size] + ggml_tensor *context_layer = + tensor_assign_buffers(ggml_mul_mat(gctx, value_layer, attn_probs)); // [#kvh, (#h/#kvh) * s, d] if (num_shared_q_heads > 1) { context_layer = ggml_reshape_3d(gctx, context_layer, head_size, qlen, - num_attention_heads); // [heads, qlen, head_size] + num_attention_heads); // [#h, s, d] } - context_layer = tensor_assign_buffers( - ggml_cont(gctx, ggml_permute(gctx, context_layer, 0, 2, 1, 3))); // [qlen, heads, head_size] - context_layer = tensor_assign_buffers(ggml_reshape_2d(gctx, context_layer, hidden_size, qlen)); // [qlen, hidden] + context_layer = tensor_assign_buffers(ggml_cont(gctx, ggml_permute(gctx, context_layer, 0, 2, 1, 3))); // [s, #h, d] + context_layer = tensor_assign_buffers(ggml_reshape_2d(gctx, context_layer, hidden_size, qlen)); // [s, #h * d] ggml_tensor *attn_output = dense.forward(ctx, context_layer); return attn_output; @@ -730,8 +752,8 @@ BaseModelForCausalLM::BaseModelForCausalLM(ModelConfig config, size_t mem_size, ctx_.dtype = config.dtype; const size_t ctx_w_size = num_weights * ggml_tensor_overhead(); const size_t ctx_kv_size = 2 * config.num_hidden_layers * - (config.max_length * config.hidden_size / config.num_attention_heads * - config.num_kv_heads * ggml_type_size(GGML_TYPE_F16) + + ((config.max_length + config.num_virtual_tokens) * config.hidden_size / + config.num_attention_heads * config.num_kv_heads * ggml_type_size(GGML_TYPE_F16) + ggml_tensor_overhead()); ctx_.ctx_w = make_unique_ggml_context(ctx_w_size, nullptr, true); ctx_.ctx_kv = make_unique_ggml_context(ctx_kv_size + 1 * MB, nullptr, false); // 1MB extra for MPS @@ -1223,6 +1245,21 @@ ChatGLM2ForCausalLM::ChatGLM2ForCausalLM(const ModelConfig &config) } void ChatGLM2ForCausalLM::load(ModelLoader &loader) { + if (config.num_virtual_tokens > 0) { + const int head_size = config.hidden_size / config.num_attention_heads; + auto prefix_cache_ctx = make_unique_ggml_context( + ggml_tensor_overhead() + config.num_hidden_layers * 2 * config.num_kv_heads * config.num_virtual_tokens * + head_size * ggml_type_size(GGML_TYPE_F16), + nullptr, false); + ggml_tensor *past_key_values = + ggml_new_tensor_4d(prefix_cache_ctx.get(), GGML_TYPE_F16, head_size, config.num_virtual_tokens, + config.num_kv_heads, config.num_hidden_layers * 2); + CHATGLM_CHECK(ggml_used_mem(prefix_cache_ctx.get()) == ggml_get_mem_size(prefix_cache_ctx.get())) + << "corrupted prefix cache"; + loader.read_tensor("past_key_values", past_key_values); + load_prefix_cache(past_key_values); + } + std::unordered_map glu_name_map; for (int i = 0; i < config.num_hidden_layers; i++) { std::string layer_prefix = "transformer.encoder.layers." + std::to_string(i) + '.'; @@ -1705,7 +1742,7 @@ Pipeline::Pipeline(const std::string &path, int max_length) { if (max_length > 0) { CHATGLM_CHECK(max_length <= config.max_length) << "Requested max_length (" << max_length << ") exceeds the max possible model sequence length (" - << config.max_length; + << config.max_length << ")"; config.max_length = max_length; } }; @@ -1722,11 +1759,17 @@ Pipeline::Pipeline(const std::string &path, int max_length) { // load version int version = loader.read_basic(); if (model_type == ModelType::CHATGLM) { - CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version; - // load config - ModelConfig config(model_type, loader.read_basic(), 1e-5f, ActivationType::GELU, true, true, - true, false, RopeType::CHATGLM, -1, AttentionMaskType::CHATGLM); + ModelConfig config; + if (version == 1) { + config = ModelConfig(model_type, loader.read_basic(), 1e-5f, ActivationType::GELU, true, + true, true, false, RopeType::CHATGLM, 10000.f, -1, AttentionMaskType::CHATGLM, 0); + } else if (version == 2) { + config = ModelConfig(model_type, loader.read_basic(), ActivationType::GELU, true, true, + true, false, RopeType::CHATGLM, -1, AttentionMaskType::CHATGLM); + } else { + CHATGLM_THROW << "only support version 1 or 2 for now but got " << version; + } _update_config_max_length(config, max_length); // load tokenizer @@ -1739,11 +1782,17 @@ Pipeline::Pipeline(const std::string &path, int max_length) { model = std::make_unique(config); model->load(loader); } else if (model_type == ModelType::CHATGLM2 || model_type == ModelType::CHATGLM3) { - CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version; - // load config - ModelConfig config(model_type, loader.read_basic(), 1e-5f, ActivationType::SILU, true, false, - false, false, RopeType::GPTJ, 2, AttentionMaskType::CAUSAL); + ModelConfig config; + if (version == 1) { + config = ModelConfig(model_type, loader.read_basic(), 1e-5f, ActivationType::SILU, true, + false, false, false, RopeType::GPTJ, 10000.f, 2, AttentionMaskType::CAUSAL, 0); + } else if (version == 2) { + config = ModelConfig(model_type, loader.read_basic(), ActivationType::SILU, true, false, + false, false, RopeType::GPTJ, 2, AttentionMaskType::CAUSAL); + } else { + CHATGLM_THROW << "only support version 1 or 2 for now but got " << version; + } _update_config_max_length(config, max_length); // load tokenizer @@ -1770,7 +1819,7 @@ Pipeline::Pipeline(const std::string &path, int max_length) { // load config ModelConfig config(model_type, loader.read_basic(), 1e-6f, ActivationType::SILU, false, false, - false, false, RopeType::NEOX, 1, AttentionMaskType::CAUSAL); + false, false, RopeType::NEOX, 10000.f, 1, AttentionMaskType::CAUSAL, 0); _update_config_max_length(config, max_length); // load tokenizer @@ -1789,7 +1838,7 @@ Pipeline::Pipeline(const std::string &path, int max_length) { // load config ModelConfig config(model_type, loader.read_basic(), 1e-6f, ActivationType::SILU, false, false, - false, true, RopeType::DISABLED, -1, AttentionMaskType::CAUSAL); + false, true, RopeType::DISABLED, 10000.f, -1, AttentionMaskType::CAUSAL, 0); _update_config_max_length(config, max_length); // load tokenizer @@ -1811,10 +1860,10 @@ Pipeline::Pipeline(const std::string &path, int max_length) { ModelConfig config; if (rec.hidden_size == 4096) { config = ModelConfig(model_type, rec, 1e-6f, ActivationType::SILU, true, true, false, false, RopeType::NEOX, - 1, AttentionMaskType::CAUSAL); + 10000.f, 1, AttentionMaskType::CAUSAL, 0); } else { config = ModelConfig(model_type, rec, 1e-6f, ActivationType::SILU, false, false, false, false, - RopeType::NEOX, 1, AttentionMaskType::CAUSAL); + RopeType::NEOX, 10000.f, 1, AttentionMaskType::CAUSAL, 0); } _update_config_max_length(config, max_length); diff --git a/chatglm.h b/chatglm.h index 91a6edcd..386ea7e7 100644 --- a/chatglm.h +++ b/chatglm.h @@ -76,10 +76,27 @@ struct ConfigRecordV1 { }; // For compatibility -struct ConfigRecordV2 : public ConfigRecordV1 { +struct ConfigRecordV1GQA : public ConfigRecordV1 { int num_kv_heads; }; +// TODO: use json to serialize config +struct ConfigRecordV2 { + ggml_type dtype; + int vocab_size; + int hidden_size; + int num_attention_heads; + int num_key_value_heads; + int num_hidden_layers; + int intermediate_size; + float norm_eps; + int num_virtual_tokens; + float rope_theta; + int max_length; + int eos_token_id; + int pad_token_id; +}; + enum class ActivationType { GELU, SILU, @@ -89,6 +106,7 @@ enum class RopeType { GPTJ = 0, NEOX = 2, CHATGLM = 4, + CHATGLM2 = 8, DISABLED = 10000, }; @@ -105,33 +123,44 @@ class ModelConfig { ModelConfig(ModelType model_type, ggml_type dtype, int vocab_size, int hidden_size, int num_attention_heads, int num_kv_heads, int num_hidden_layers, int intermediate_size, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi, - RopeType rope_type, int rope_dim_scale, AttentionMaskType attn_mask_type, int max_length, - int bos_token_id, int eos_token_id, int pad_token_id, int sep_token_id, - std::vector extra_eos_token_ids) + RopeType rope_type, float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type, + int num_virtual_tokens, int max_length, int bos_token_id, int eos_token_id, int pad_token_id, + int sep_token_id, std::vector extra_eos_token_ids) : model_type(model_type), dtype(dtype), vocab_size(vocab_size), hidden_size(hidden_size), num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), num_hidden_layers(num_hidden_layers), intermediate_size(intermediate_size), norm_eps(norm_eps), hidden_act(hidden_act), use_qkv_bias(use_qkv_bias), use_dense_bias(use_dense_bias), interleaved_qkv(interleaved_qkv), use_alibi(use_alibi), rope_type(rope_type), - rope_dim_scale(rope_dim_scale), attn_mask_type(attn_mask_type), max_length(max_length), - bos_token_id(bos_token_id), eos_token_id(eos_token_id), pad_token_id(pad_token_id), - sep_token_id(sep_token_id), extra_eos_token_ids(std::move(extra_eos_token_ids)) {} + rope_theta(rope_theta), rope_dim_scale(rope_dim_scale), attn_mask_type(attn_mask_type), + num_virtual_tokens(num_virtual_tokens), max_length(max_length), bos_token_id(bos_token_id), + eos_token_id(eos_token_id), pad_token_id(pad_token_id), sep_token_id(sep_token_id), + extra_eos_token_ids(std::move(extra_eos_token_ids)) {} ModelConfig(ModelType model_type, const ConfigRecordV1 &rec, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type, - int rope_dim_scale, AttentionMaskType attn_mask_type) + float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type, int num_virtual_tokens) : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_attention_heads, rec.num_hidden_layers, rec.intermediate_size, norm_eps, hidden_act, - use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_dim_scale, - attn_mask_type, rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, - rec.sep_token_id, {}) {} + use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_theta, rope_dim_scale, + attn_mask_type, num_virtual_tokens, rec.max_length, rec.bos_token_id, rec.eos_token_id, + rec.pad_token_id, rec.sep_token_id, {}) {} - ModelConfig(ModelType model_type, const ConfigRecordV2 &rec, float norm_eps, ActivationType hidden_act, + ModelConfig(ModelType model_type, const ConfigRecordV1GQA &rec, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type, - int rope_dim_scale, AttentionMaskType attn_mask_type) + float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type, int num_virtual_tokens) : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_kv_heads, rec.num_hidden_layers, rec.intermediate_size, norm_eps, hidden_act, use_qkv_bias, use_dense_bias, - interleaved_qkv, use_alibi, rope_type, rope_dim_scale, attn_mask_type, rec.max_length, - rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, rec.sep_token_id, {}) {} + interleaved_qkv, use_alibi, rope_type, rope_theta, rope_dim_scale, attn_mask_type, + num_virtual_tokens, rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, + rec.sep_token_id, {}) {} + + ModelConfig(ModelType model_type, const ConfigRecordV2 &rec, ActivationType hidden_act, bool use_qkv_bias, + bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type, int rope_dim_scale, + AttentionMaskType attn_mask_type) + : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, + rec.num_key_value_heads, rec.num_hidden_layers, rec.intermediate_size, rec.norm_eps, hidden_act, + use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rec.rope_theta, + rope_dim_scale, attn_mask_type, rec.num_virtual_tokens, rec.max_length, -1, rec.eos_token_id, + rec.pad_token_id, -1, {}) {} std::string model_type_name() const { return to_string(model_type); } @@ -151,8 +180,10 @@ class ModelConfig { bool interleaved_qkv; bool use_alibi; RopeType rope_type; + float rope_theta; int rope_dim_scale; AttentionMaskType attn_mask_type; + int num_virtual_tokens; int max_length; int bos_token_id; int eos_token_id; @@ -388,16 +419,17 @@ class BasicAttention { BasicAttention() = default; BasicAttention(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length, bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type, - int rope_dim_scale, AttentionMaskType attn_mask_type) + float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type, int num_virtual_tokens) : num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), interleaved_qkv(interleaved_qkv), - use_alibi(use_alibi), rope_type(rope_type), rope_dim_scale(rope_dim_scale), attn_mask_type(attn_mask_type), + use_alibi(use_alibi), rope_type(rope_type), rope_theta(rope_theta), rope_dim_scale(rope_dim_scale), + attn_mask_type(attn_mask_type), num_virtual_tokens(num_virtual_tokens), query_key_value(ctx, hidden_size, hidden_size + 2 * (hidden_size / num_attention_heads) * num_kv_heads, use_qkv_bias), dense(ctx, hidden_size, hidden_size, use_dense_bias), - k_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, hidden_size / num_attention_heads, max_length, - num_kv_heads)), - v_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, max_length, hidden_size / num_attention_heads, - num_kv_heads)) {} + k_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, hidden_size / num_attention_heads, + max_length + num_virtual_tokens, num_kv_heads)), + v_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, max_length + num_virtual_tokens, + hidden_size / num_attention_heads, num_kv_heads)) {} ggml_tensor *forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *position_ids, int n_past, int n_ctx) const; @@ -408,12 +440,14 @@ class BasicAttention { bool interleaved_qkv; bool use_alibi; RopeType rope_type; + float rope_theta; int rope_dim_scale; AttentionMaskType attn_mask_type; + int num_virtual_tokens; Linear query_key_value; Linear dense; - ggml_tensor *k_cache; // [kv_heads, max_len, head_size] - ggml_tensor *v_cache; // [kv_heads, head_size, max_len] + ggml_tensor *k_cache; // [#kvh, s, d] + ggml_tensor *v_cache; // [#kvh, d, s] }; template @@ -422,11 +456,12 @@ class BasicBlock { BasicBlock() = default; BasicBlock(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size, int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias, - bool interleaved_qkv, bool use_alibi, RopeType rope_type, int rope_dim_scale, - AttentionMaskType attn_mask_type) + bool interleaved_qkv, bool use_alibi, RopeType rope_type, float rope_theta, int rope_dim_scale, + AttentionMaskType attn_mask_type, int num_virtual_tokens) : input_layernorm(ctx, hidden_size, false, norm_eps), attention(ctx, hidden_size, num_attention_heads, num_kv_heads, max_length, use_qkv_bias, use_dense_bias, - interleaved_qkv, use_alibi, rope_type, rope_dim_scale, attn_mask_type), + interleaved_qkv, use_alibi, rope_type, rope_theta, rope_dim_scale, attn_mask_type, + num_virtual_tokens), post_attention_layernorm(ctx, hidden_size, false, norm_eps), mlp(ctx, hidden_size, intermediate_size, hidden_act) {} @@ -517,16 +552,44 @@ class BasicModel { return hidden_states; } + void load_prefix_cache(const ModelConfig &config, ggml_tensor *past_key_values) { + ggml_cgraph gf{}; + auto ctx = make_unique_ggml_context(config.num_hidden_layers * 7 * ggml_tensor_overhead(), nullptr, false); + const int head_size = config.hidden_size / config.num_attention_heads; + for (size_t i = 0; i < layers.size(); i++) { + auto &attn = layers[i].attention; + ggml_tensor *virtual_key = ggml_view_3d(ctx.get(), past_key_values, head_size, config.num_virtual_tokens, + config.num_kv_heads, past_key_values->nb[1], past_key_values->nb[2], + i * 2 * past_key_values->nb[3]); // [#h, v, d] + ggml_tensor *k_cache_view = + ggml_view_3d(ctx.get(), attn.k_cache, head_size, config.num_virtual_tokens, config.num_kv_heads, + attn.k_cache->nb[1], attn.k_cache->nb[2], 0); // [#h, v, d] + ggml_build_forward_expand(&gf, ggml_cpy(ctx.get(), virtual_key, k_cache_view)); + + ggml_tensor *virtual_value = ggml_view_3d( + ctx.get(), past_key_values, head_size, config.num_virtual_tokens, config.num_kv_heads, + past_key_values->nb[1], past_key_values->nb[2], (i * 2 + 1) * past_key_values->nb[3]); // [#h, v, d] + virtual_value = ggml_permute(ctx.get(), virtual_value, 1, 0, 2, 3); // [#h, d, v] + ggml_tensor *v_cache_view = + ggml_view_3d(ctx.get(), attn.v_cache, config.num_virtual_tokens, head_size, config.num_kv_heads, + attn.v_cache->nb[1], attn.v_cache->nb[2], 0); // [#h, d, v] + ggml_build_forward_expand(&gf, ggml_cpy(ctx.get(), virtual_value, v_cache_view)); + } + CHATGLM_CHECK(ggml_used_mem(ctx.get()) == ggml_get_mem_size(ctx.get())) << "corrupted prefix cache context"; + std::vector compute_buffer; + ggml_graph_compute_helper(compute_buffer, &gf, 0); + } + private: std::vector build_layers(ModelContext *ctx, const ModelConfig &config) { std::vector layers; layers.reserve(config.num_hidden_layers); for (int layer_id = 0; layer_id < config.num_hidden_layers; layer_id++) { - // TODO: reduce max length? 32k might be too large for cpu inference layers.emplace_back(ctx, config.hidden_size, config.num_attention_heads, config.num_kv_heads, config.intermediate_size, config.max_length, config.norm_eps, config.hidden_act, config.use_qkv_bias, config.use_dense_bias, config.interleaved_qkv, config.use_alibi, - config.rope_type, config.rope_dim_scale, config.attn_mask_type); + config.rope_type, config.rope_theta, config.rope_dim_scale, config.attn_mask_type, + config.num_virtual_tokens); } return layers; } @@ -745,6 +808,8 @@ class BasicModelForCausalLM : public BaseModelForCausalLM { return lm_logits; } + void load_prefix_cache(ggml_tensor *past_key_values) { transformer.load_prefix_cache(config, past_key_values); } + protected: void to_cpu() { for (auto &item : state_dict_) { @@ -818,13 +883,14 @@ class GLMBlock : public BasicBlock { GLMBlock() = default; GLMBlock(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size, int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias, - bool interleaved_qkv, bool use_alibi, RopeType rope_type, int rope_dim_scale, - AttentionMaskType attn_mask_type) - : BasicBlock( - LayerNorm(ctx, hidden_size, false, norm_eps), - BasicAttention(ctx, hidden_size, num_attention_heads, num_attention_heads, max_length, use_qkv_bias, - use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_dim_scale, attn_mask_type), - LayerNorm(ctx, hidden_size, false, norm_eps), BasicMLP(ctx, hidden_size, intermediate_size, hidden_act)), + bool interleaved_qkv, bool use_alibi, RopeType rope_type, float rope_theta, int rope_dim_scale, + AttentionMaskType attn_mask_type, int num_virtual_tokens) + : BasicBlock(LayerNorm(ctx, hidden_size, false, norm_eps), + BasicAttention(ctx, hidden_size, num_attention_heads, num_attention_heads, max_length, + use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_theta, + rope_dim_scale, attn_mask_type, num_virtual_tokens), + LayerNorm(ctx, hidden_size, false, norm_eps), + BasicMLP(ctx, hidden_size, intermediate_size, hidden_act)), alpha_value(std::sqrt(2.f * 28)) {} ggml_tensor *forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *position_ids, int n_past, diff --git a/chatglm_cpp/__init__.py b/chatglm_cpp/__init__.py index 92d9b724..babfac5c 100644 --- a/chatglm_cpp/__init__.py +++ b/chatglm_cpp/__init__.py @@ -6,7 +6,7 @@ import chatglm_cpp._C as _C from chatglm_cpp._C import ChatMessage -__version__ = "0.3.1" +__version__ = "0.3.2" @dataclass diff --git a/chatglm_cpp/convert.py b/chatglm_cpp/convert.py index 87b6924d..a98168c1 100644 --- a/chatglm_cpp/convert.py +++ b/chatglm_cpp/convert.py @@ -128,8 +128,6 @@ def quantize_q5_1(tensor: torch.Tensor) -> torch.Tensor: def dump_tensor(f, name: str, tensor: torch.Tensor, ggml_type: GGMLType): - assert tensor.dtype == torch.float32 - # tensor name f.write(struct.pack("i", len(name.encode()))) f.write(name.encode()) @@ -165,7 +163,9 @@ def dump_state_dict(f, weight_names, state_dict, quantization_bit, ggml_type): tensor_info = [] for name in tqdm(weight_names, desc="Processing model states"): tensor = state_dict[name] - if tensor.ndim == 2: + if name == "past_key_values": + tensor_ggml_type = GGMLType.F16 + elif tensor.ndim == 2: # 2d weight: should quantize it if needed # step 1: de-quantize it back to float32 @@ -191,7 +191,7 @@ def dump_state_dict(f, weight_names, state_dict, quantization_bit, ggml_type): tensor_ggml_type = GGMLType.F32 dump_tensor(f, name, tensor, tensor_ggml_type) - tensor_info.append((name, tensor.shape, tensor_ggml_type.name)) + tensor_info.append((name, tuple(tensor.shape), tensor_ggml_type.name)) print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql")) @@ -200,12 +200,25 @@ class BaseConverter: @classmethod def convert(cls, f, model, tokenizer, ggml_type): f.write(b"ggml") # magic - f.write(struct.pack("ii", cls.MODEL_TYPE.value, 1)) # model type & version + f.write(struct.pack("i", cls.MODEL_TYPE.value)) # model type cls.dump_config(f, model.config, ggml_type) cls.dump_tokenizer(f, tokenizer) cls.dump_model(f, model, ggml_type) +def get_prefix_cache(prefix_encoder, pre_seq_len, num_layers, num_kv_heads, head_size): + prefix_tokens = torch.arange(pre_seq_len, dtype=torch.long) + with torch.no_grad(): + past_key_values = prefix_encoder(prefix_tokens) + past_key_values = ( + past_key_values.to(torch.half) + .view(pre_seq_len, num_layers * 2, num_kv_heads, head_size) + .permute(1, 2, 0, 3) + .contiguous() + ) + return past_key_values + + class ChatGLMConverter(BaseConverter): MODEL_TYPE = ModelType.CHATGLM @@ -215,20 +228,24 @@ def dump_config(f, config, ggml_type): assert ( config.inner_hidden_size == 4 * config.hidden_size ), "unimplemented: inner_hidden_size should be 4 times hidden_size" + + config_version = 2 config_values = [ ggml_type.value, config.vocab_size, config.hidden_size, config.num_attention_heads, + config.num_attention_heads, config.num_layers, config.inner_hidden_size, + config.layernorm_epsilon, + config.pre_seq_len if config.pre_seq_len is not None else 0, + 10000.0, # rope_theta config.max_sequence_length, - config.bos_token_id if config.bos_token_id is not None else -1, config.eos_token_id if config.eos_token_id is not None else -1, config.pad_token_id if config.pad_token_id is not None else -1, - config.sep_token_id if config.sep_token_id is not None else -1, ] - f.write(struct.pack("i" * len(config_values), *config_values)) + f.write(struct.pack("iiiiiiiififiii", config_version, *config_values)) @staticmethod def dump_tokenizer(f, tokenizer): @@ -268,8 +285,8 @@ def dump_model(f, model, ggml_type): class ChatGLM2Converter(BaseConverter): MODEL_TYPE = ModelType.CHATGLM2 - @staticmethod - def dump_config(f, config, ggml_type): + @classmethod + def dump_config(cls, f, config, ggml_type): assert config.add_bias_linear is False, "unimplemented: add_bias_linear must be false" assert config.add_qkv_bias is True, "unimplemented: add_qkv_bias must be true" assert ( @@ -283,22 +300,24 @@ def dump_config(f, config, ggml_type): assert config.post_layer_norm is True, "unimplemented: post_layer_norm must be true" assert config.rmsnorm is True, "unimplemented: rmsnorm must be true" + config_version = 2 config_values = [ ggml_type.value, config.padded_vocab_size, config.hidden_size, config.num_attention_heads, + config.multi_query_group_num, config.num_layers, config.ffn_hidden_size, + config.layernorm_epsilon, + config.pre_seq_len if config.pre_seq_len is not None else 0, + 10000.0 * getattr(config, "rope_ratio", 1), # rope_theta config.seq_length, - config.bos_token_id if config.bos_token_id is not None else -1, config.eos_token_id if config.eos_token_id is not None else -1, config.pad_token_id if config.pad_token_id is not None else -1, - config.sep_token_id if config.sep_token_id is not None else -1, - config.multi_query_group_num, ] - f.write(struct.pack("i" * len(config_values), *config_values)) + f.write(struct.pack("iiiiiiiififiii", config_version, *config_values)) @staticmethod def dump_tokenizer(f, tokenizer): @@ -308,8 +327,24 @@ def dump_tokenizer(f, tokenizer): @staticmethod def dump_model(f, model, ggml_type): - weight_names = ["transformer.embedding.word_embeddings.weight"] - for i in range(model.config.num_layers): + config = model.config + + state_dict = model.state_dict() + + weight_names = [] + if config.pre_seq_len is not None and config.pre_seq_len > 0: + past_key_values = get_prefix_cache( + model.transformer.prefix_encoder, + config.pre_seq_len, + config.num_layers, + config.multi_query_group_num, + config.kv_channels, + ) + state_dict["past_key_values"] = past_key_values + weight_names.append("past_key_values") + + weight_names.append("transformer.embedding.word_embeddings.weight") + for i in range(config.num_layers): weight_names += [ f"transformer.encoder.layers.{i}.input_layernorm.weight", f"transformer.encoder.layers.{i}.self_attention.query_key_value.weight", @@ -323,7 +358,7 @@ def dump_model(f, model, ggml_type): "transformer.encoder.final_layernorm.weight", "transformer.output_layer.weight", ] - dump_state_dict(f, weight_names, model.state_dict(), model.config.quantization_bit, ggml_type) + dump_state_dict(f, weight_names, state_dict, config.quantization_bit, ggml_type) class ChatGLM3Converter(ChatGLM2Converter): @@ -331,10 +366,11 @@ class ChatGLM3Converter(ChatGLM2Converter): class BaichuanConverter(BaseConverter): - @staticmethod - def dump_config(f, config, ggml_type): + @classmethod + def dump_config(cls, f, config, ggml_type): assert config.hidden_act == "silu", "unimplemented: hidden_act must be silu" + config_version = 1 config_values = [ ggml_type.value, config.vocab_size, @@ -349,7 +385,7 @@ def dump_config(f, config, ggml_type): config.sep_token_id if config.sep_token_id is not None else -1, ] - f.write(struct.pack("i" * len(config_values), *config_values)) + f.write(struct.pack("i" * (1 + len(config_values)), config_version, *config_values)) @staticmethod def dump_tokenizer(f, tokenizer): @@ -397,6 +433,7 @@ class InternLMConverter(BaseConverter): def dump_config(f, config, ggml_type): assert config.hidden_act == "silu", "unimplemented: hidden_act must be silu" + config_version = 1 config_values = [ ggml_type.value, config.vocab_size, @@ -411,7 +448,7 @@ def dump_config(f, config, ggml_type): config.sep_token_id if config.sep_token_id is not None else -1, ] - f.write(struct.pack("i" * len(config_values), *config_values)) + f.write(struct.pack("i" * (1 + len(config_values)), config_version, *config_values)) @staticmethod def dump_tokenizer(f, tokenizer): @@ -485,6 +522,8 @@ def convert(f: BinaryIO, model_name_or_path: str, lora_model_name_or_path: Optio model = PeftModel.from_pretrained(model, lora_model_name_or_path) model = model.merge_and_unload() + model = model.eval() + if model.config.model_type == "chatglm": if hasattr(model.config, "multi_query_attention"): # ChatGLM3 shares the same architecture and model config with ChatGLM2, but its tokenizer further supports system prompts, diff --git a/chatglm_test.cpp b/chatglm_test.cpp index 3daf5351..fcfe53ac 100644 --- a/chatglm_test.cpp +++ b/chatglm_test.cpp @@ -22,11 +22,22 @@ static inline void expect_all_close(ggml_tensor *a, ggml_tensor *b, float atol = ASSERT_EQ(a->type, GGML_TYPE_F32); ASSERT_EQ(ggml_nelements(a), ggml_nelements(b)); int64_t numel = ggml_nelements(a); + float max_abs_diff = 0.f; + int64_t num_mismatch = 0; for (int64_t i = 0; i < numel; i++) { float ai = ((float *)a->data)[i]; float bi = ((float *)b->data)[i]; - EXPECT_LT(std::abs(ai - bi), atol + rtol * std::abs(bi)) << "diff " << ai << " vs " << bi; + float abs_diff = std::abs(ai - bi); + max_abs_diff = std::max(max_abs_diff, abs_diff); + if (abs_diff >= atol + rtol * std::abs(bi)) { + num_mismatch++; + } } + EXPECT_EQ(num_mismatch, 0) << "Tensors are not close!\n\n" + << "Mismatched elements: " << num_mismatch << " / " << numel << " (" + << num_mismatch * 100 / numel << "%)\n" + << "Greatest absolute difference: " << max_abs_diff << " (up to " << std::scientific + << atol << " allowed)\n"; } static inline char *read_tensor_data(char *ptr, ggml_tensor *tensor) { @@ -277,16 +288,13 @@ class ChatGLMTest : public ::testing::Test { float perf_device_graph_compute() { return _perf_graph_compute_impl(); } template - void test_model(const Model &model, const ModelConfig &config, const fs::path &data_path, int seq_len, + void test_model(Model &model, const ModelConfig &config, const fs::path &data_path, int seq_len, const std::vector &all_weights) { ASSERT_EQ(config.num_hidden_layers, 1); MappedFile mapped_file(data_path.string()); char *ptr = mapped_file.data; - tensor_to_device(model.layers[0].attention.k_cache); - tensor_to_device(model.layers[0].attention.v_cache); - ggml_tensor *x1 = ggml_new_tensor_1d(ctx.ctx_b.get(), GGML_TYPE_I32, seq_len); ggml_tensor *ref_y1 = ggml_new_tensor_2d(ctx.ctx_b.get(), GGML_TYPE_F32, config.hidden_size, seq_len); ggml_tensor *x2 = ggml_new_tensor_1d(ctx.ctx_b.get(), GGML_TYPE_I32, 1); @@ -299,6 +307,18 @@ class ChatGLMTest : public ::testing::Test { std::vector cpu_tensors{model.word_embeddings.weight, x1, x2, x3}; + if (config.num_virtual_tokens > 0) { + const int head_size = config.hidden_size / config.num_attention_heads; + ggml_tensor *past_key_values = + ggml_new_tensor_4d(ctx.ctx_b.get(), GGML_TYPE_F16, head_size, config.num_virtual_tokens, + config.num_kv_heads, config.num_hidden_layers * 2); // [l * 2, #h, v, d] + ptr = read_tensor_data(ptr, past_key_values); + model.load_prefix_cache(config, past_key_values); + } + + tensor_to_device(model.layers[0].attention.k_cache); + tensor_to_device(model.layers[0].attention.v_cache); + for (auto tensor : all_tensors) { ptr = read_tensor_data(ptr, tensor); if (std::find(cpu_tensors.begin(), cpu_tensors.end(), tensor) == cpu_tensors.end()) { @@ -310,6 +330,7 @@ class ChatGLMTest : public ::testing::Test { // self attention { + reset_cgraph(); ggml_tensor *out_y1 = model.forward(&ctx, x1, 0, seq_len); EXPECT_EQ(out_y1->backend, ref_y1->backend); out_y1->backend = GGML_BACKEND_CPU; @@ -320,8 +341,8 @@ class ChatGLMTest : public ::testing::Test { } // cross attention - reset_cgraph(); { + reset_cgraph(); ggml_tensor *out_y2 = model.forward(&ctx, x2, seq_len, seq_len); EXPECT_EQ(out_y2->backend, ref_y2->backend); out_y2->backend = GGML_BACKEND_CPU; @@ -330,8 +351,8 @@ class ChatGLMTest : public ::testing::Test { expect_all_close(ref_y2, out_y2, 5e-4); } - reset_cgraph(); { + reset_cgraph(); ggml_tensor *out_y3 = model.forward(&ctx, x3, seq_len + 1, seq_len); EXPECT_EQ(out_y3->backend, ref_y3->backend); out_y3->backend = GGML_BACKEND_CPU; @@ -585,8 +606,45 @@ TEST_F(ChatGLMTest, GLMModel) { ModelType::CHATGLM, GGML_TYPE_F32, /*vocab_size=*/5, /*hidden_size=*/32, /*num_attention_heads=*/8, /*num_kv_heads=*/8, /*num_hidden_layers=*/1, /*intermediate_size=*/128, /*norm_eps=*/1e-5f, /*hidden_act=*/ActivationType::GELU, /*use_qkv_bias=*/true, /*use_dense_bias=*/true, - /*interleaved_qkv=*/true, /*use_alibi=*/false, /*rope_type=*/RopeType::CHATGLM, /*rope_dim_scale=*/-1, - /*attn_mask_type=*/AttentionMaskType::CHATGLM, + /*interleaved_qkv=*/true, /*use_alibi=*/false, /*rope_type=*/RopeType::CHATGLM, /*rope_theta=*/10000.f, + /*rope_dim_scale=*/-1, + /*attn_mask_type=*/AttentionMaskType::CHATGLM, /*num_virtual_tokens=*/0, + /*max_length=*/8, /*bos_token_id=*/-1, /*eos_token_id=*/-1, /*pad_token_id=*/-1, /*sep_token_id=*/-1, + /*extra_eos_token_ids=*/{}); + + constexpr int seq_len = 3; + + ChatGLMModel model(&ctx, config); + + std::vector all_weights{model.word_embeddings.weight, + model.layers[0].input_layernorm.weight, + model.layers[0].input_layernorm.bias, + model.layers[0].attention.query_key_value.weight, + model.layers[0].attention.query_key_value.bias, + model.layers[0].attention.dense.weight, + model.layers[0].attention.dense.bias, + model.layers[0].post_attention_layernorm.weight, + model.layers[0].post_attention_layernorm.bias, + model.layers[0].mlp.dense_h_to_4h.weight, + model.layers[0].mlp.dense_h_to_4h.bias, + model.layers[0].mlp.dense_4h_to_h.weight, + model.layers[0].mlp.dense_4h_to_h.bias, + model.final_layernorm.weight, + model.final_layernorm.bias}; + + test_model(model, config, data_path, seq_len, all_weights); +} + +TEST_F(ChatGLMTest, GLMPTuningV2Model) { + fs::path data_path = fs::path(__FILE__).parent_path() / "tests/data/glm_ptuning_v2_model.data"; + + ModelConfig config( + ModelType::CHATGLM, GGML_TYPE_F32, /*vocab_size=*/5, /*hidden_size=*/32, /*num_attention_heads=*/8, + /*num_kv_heads=*/8, /*num_hidden_layers=*/1, /*intermediate_size=*/128, /*norm_eps=*/1e-5f, + /*hidden_act=*/ActivationType::GELU, /*use_qkv_bias=*/true, /*use_dense_bias=*/true, + /*interleaved_qkv=*/true, /*use_alibi=*/false, /*rope_type=*/RopeType::CHATGLM, /*rope_theta=*/10000.f, + /*rope_dim_scale=*/-1, + /*attn_mask_type=*/AttentionMaskType::CHATGLM, /*num_virtual_tokens=*/5, /*max_length=*/8, /*bos_token_id=*/-1, /*eos_token_id=*/-1, /*pad_token_id=*/-1, /*sep_token_id=*/-1, /*extra_eos_token_ids=*/{}); @@ -620,8 +678,9 @@ TEST_F(ChatGLMTest, GLM2Model) { ModelType::CHATGLM2, GGML_TYPE_F32, /*vocab_size=*/5, /*hidden_size=*/32, /*num_attention_heads=*/8, /*num_kv_heads=*/2, /*num_hidden_layers=*/1, /*intermediate_size=*/48, /*norm_eps=*/1e-5f, /*hidden_act=*/ActivationType::SILU, /*use_qkv_bias=*/true, /*use_dense_bias=*/false, - /*interleaved_qkv=*/false, /*use_alibi=*/false, /*rope_type=*/RopeType::GPTJ, /*rope_dim_scale=*/2, - /*attn_mask_type=*/AttentionMaskType::CAUSAL, + /*interleaved_qkv=*/false, /*use_alibi=*/false, /*rope_type=*/RopeType::GPTJ, /*rope_theta=*/10000.f, + /*rope_dim_scale=*/2, + /*attn_mask_type=*/AttentionMaskType::CAUSAL, /*num_virtual_tokens=*/0, /*max_length=*/8, /*bos_token_id=*/-1, /*eos_token_id=*/-1, /*pad_token_id=*/-1, /*sep_token_id=*/-1, /*extra_eos_token_ids=*/{}); @@ -629,9 +688,6 @@ TEST_F(ChatGLMTest, GLM2Model) { ChatGLM2Model model(&ctx, config); - tensor_to_device(model.layers[0].attention.k_cache); - tensor_to_device(model.layers[0].attention.v_cache); - std::vector all_weights{model.word_embeddings.weight, model.layers[0].input_layernorm.weight, model.layers[0].attention.query_key_value.weight, @@ -653,8 +709,9 @@ TEST_F(ChatGLMTest, GLM3Model) { ModelType::CHATGLM3, GGML_TYPE_F32, /*vocab_size=*/5, /*hidden_size=*/32, /*num_attention_heads=*/8, /*num_kv_heads=*/2, /*num_hidden_layers=*/1, /*intermediate_size=*/48, /*norm_eps=*/1e-5f, /*hidden_act=*/ActivationType::SILU, /*use_qkv_bias=*/true, /*use_dense_bias=*/false, - /*interleaved_qkv=*/false, /*use_alibi=*/false, /*rope_type=*/RopeType::GPTJ, /*rope_dim_scale=*/2, - /*attn_mask_type=*/AttentionMaskType::CAUSAL, + /*interleaved_qkv=*/false, /*use_alibi=*/false, /*rope_type=*/RopeType::GPTJ, /*rope_theta=*/10000.f, + /*rope_dim_scale=*/2, + /*attn_mask_type=*/AttentionMaskType::CAUSAL, /*num_virtual_tokens=*/0, /*max_length=*/8, /*bos_token_id=*/-1, /*eos_token_id=*/-1, /*pad_token_id=*/-1, /*sep_token_id=*/-1, /*extra_eos_token_ids=*/{}); @@ -662,8 +719,36 @@ TEST_F(ChatGLMTest, GLM3Model) { ChatGLM3Model model(&ctx, config); - tensor_to_device(model.layers[0].attention.k_cache); - tensor_to_device(model.layers[0].attention.v_cache); + std::vector all_weights{model.word_embeddings.weight, + model.layers[0].input_layernorm.weight, + model.layers[0].attention.query_key_value.weight, + model.layers[0].attention.query_key_value.bias, + model.layers[0].attention.dense.weight, + model.layers[0].post_attention_layernorm.weight, + model.layers[0].mlp.gate_proj.weight, + model.layers[0].mlp.up_proj.weight, + model.layers[0].mlp.down_proj.weight, + model.final_layernorm.weight}; + + test_model(model, config, data_path, seq_len, all_weights); +} + +TEST_F(ChatGLMTest, GLM3PTuningV2Model) { + fs::path data_path = fs::path(__FILE__).parent_path() / "tests/data/glm3_ptuning_v2_model.data"; + + ModelConfig config( + ModelType::CHATGLM3, GGML_TYPE_F32, /*vocab_size=*/5, /*hidden_size=*/32, /*num_attention_heads=*/8, + /*num_kv_heads=*/2, /*num_hidden_layers=*/1, /*intermediate_size=*/48, /*norm_eps=*/1e-5f, + /*hidden_act=*/ActivationType::SILU, /*use_qkv_bias=*/true, /*use_dense_bias=*/false, + /*interleaved_qkv=*/false, /*use_alibi=*/false, /*rope_type=*/RopeType::GPTJ, /*rope_theta=*/10000.f, + /*rope_dim_scale=*/2, + /*attn_mask_type=*/AttentionMaskType::CAUSAL, /*num_virtual_tokens=*/5, + /*max_length=*/8, /*bos_token_id=*/-1, /*eos_token_id=*/-1, /*pad_token_id=*/-1, /*sep_token_id=*/-1, + /*extra_eos_token_ids=*/{}); + + constexpr int seq_len = 3; + + ChatGLM3Model model(&ctx, config); std::vector all_weights{model.word_embeddings.weight, model.layers[0].input_layernorm.weight, @@ -686,8 +771,9 @@ TEST_F(ChatGLMTest, Baichuan7BModel) { ModelType::BAICHUAN7B, GGML_TYPE_F32, /*vocab_size=*/5, /*hidden_size=*/32, /*num_attention_heads=*/8, /*num_kv_heads=*/8, /*num_hidden_layers=*/1, /*intermediate_size=*/32 * 3, /*norm_eps=*/1e-6f, /*hidden_act=*/ActivationType::SILU, /*use_qkv_bias=*/false, /*use_dense_bias=*/false, - /*interleaved_qkv=*/false, /*use_alibi=*/false, /*rope_type=*/RopeType::NEOX, /*rope_dim_scale=*/1, - /*attn_mask_type=*/AttentionMaskType::CAUSAL, + /*interleaved_qkv=*/false, /*use_alibi=*/false, /*rope_type=*/RopeType::NEOX, /*rope_theta=*/10000.f, + /*rope_dim_scale=*/1, + /*attn_mask_type=*/AttentionMaskType::CAUSAL, /*num_virtual_tokens=*/0, /*max_length=*/8, /*bos_token_id=*/-1, /*eos_token_id=*/-1, /*pad_token_id=*/-1, /*sep_token_id=*/-1, /*extra_eos_token_ids=*/{}); @@ -715,8 +801,9 @@ TEST_F(ChatGLMTest, Baichuan13BModel) { ModelType::BAICHUAN13B, GGML_TYPE_F32, /*vocab_size=*/5, /*hidden_size=*/32, /*num_attention_heads=*/8, /*num_kv_heads=*/8, /*num_hidden_layers=*/1, /*intermediate_size=*/32 * 3, /*norm_eps=*/1e-6f, /*hidden_act=*/ActivationType::SILU, /*use_qkv_bias=*/false, /*use_dense_bias=*/false, - /*interleaved_qkv=*/false, /*use_alibi=*/true, /*rope_type=*/RopeType::DISABLED, /*rope_dim_scale=*/-1, - /*attn_mask_type=*/AttentionMaskType::CAUSAL, + /*interleaved_qkv=*/false, /*use_alibi=*/true, /*rope_type=*/RopeType::DISABLED, /*rope_theta=*/10000.f, + /*rope_dim_scale=*/-1, + /*attn_mask_type=*/AttentionMaskType::CAUSAL, /*num_virtual_tokens=*/0, /*max_length=*/8, /*bos_token_id=*/-1, /*eos_token_id=*/-1, /*pad_token_id=*/-1, /*sep_token_id=*/-1, /*extra_eos_token_ids=*/{}); @@ -744,8 +831,9 @@ TEST_F(ChatGLMTest, InternLMModel) { ModelType::INTERNLM, GGML_TYPE_F32, /*vocab_size=*/5, /*hidden_size=*/32, /*num_attention_heads=*/8, /*num_kv_heads=*/8, /*num_hidden_layers=*/1, /*intermediate_size=*/32 * 3, /*norm_eps=*/1e-6f, /*hidden_act=*/ActivationType::SILU, /*use_qkv_bias=*/true, /*use_dense_bias=*/true, - /*interleaved_qkv=*/false, /*use_alibi=*/false, /*rope_type=*/RopeType::NEOX, /*rope_dim_scale=*/1, - /*attn_mask_type=*/AttentionMaskType::CAUSAL, + /*interleaved_qkv=*/false, /*use_alibi=*/false, /*rope_type=*/RopeType::NEOX, /*rope_theta=*/10000.f, + /*rope_dim_scale=*/1, + /*attn_mask_type=*/AttentionMaskType::CAUSAL, /*num_virtual_tokens=*/0, /*max_length=*/8, /*bos_token_id=*/-1, /*eos_token_id=*/-1, /*pad_token_id=*/-1, /*sep_token_id=*/-1, /*extra_eos_token_ids=*/{}); @@ -1128,7 +1216,8 @@ TEST(Pipeline, ChatGLM3) { { ChatMessage output = pipeline.chat(messages, gen_config); EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT); - EXPECT_EQ(output.content, "根据您的要求,我使用随机数生成器API生成了一个在0和100之间的随机数,结果为22。"); + EXPECT_EQ(output.content, + "根据API调用结果,我为您生成了一个随机数,随机数的范围在0到100之间。这个随机数是22。"); } } @@ -1143,11 +1232,14 @@ TEST(Pipeline, ChatGLM3) { { ChatMessage output = pipeline.chat(messages, gen_config); EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT); - EXPECT_EQ(output.content, "好的,我会为您列出100以内的所有质数。\n\n质数是指只能被1和它本身整除的大于1" - "的整数。例如,2、3、5、7等都是质数。\n\n让我们开始吧!"); + EXPECT_EQ(output.content, R"(好的,我会为您列出100以内的所有质数。 + +质数是指只能被1和它本身整除的正整数。例如,2、3、5、7等都是质数。 + +让我们开始吧!)"); EXPECT_EQ(output.tool_calls.front().code.input, R"(```python +# Function to check if a number is prime def is_prime(n): - """Check if a number is prime.""" if n <= 1: return False if n <= 3: @@ -1162,8 +1254,8 @@ def is_prime(n): return True # Get all prime numbers up to 100 -primes_upto_100 = [i for i in range(2, 101) if is_prime(i)] -primes_upto_100 +primes_up_to_100 = [i for i in range(2, 101) if is_prime(i)] +primes_up_to_100 ```)"); messages.emplace_back(std::move(output)); } diff --git a/tests/data/glm3_ptuning_v2_model.data b/tests/data/glm3_ptuning_v2_model.data new file mode 100644 index 00000000..730a0a49 Binary files /dev/null and b/tests/data/glm3_ptuning_v2_model.data differ diff --git a/tests/data/glm_ptuning_v2_model.data b/tests/data/glm_ptuning_v2_model.data new file mode 100644 index 00000000..c7fc1549 Binary files /dev/null and b/tests/data/glm_ptuning_v2_model.data differ diff --git a/tests/test_convert.py b/tests/test_convert.py index 816b163a..ff51abcb 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -3,7 +3,14 @@ import torch import torch.nn.functional as F -from chatglm_cpp.convert import quantize_q4_0, quantize_q4_1, quantize_q5_0, quantize_q5_1, quantize_q8_0 +from chatglm_cpp.convert import ( + get_prefix_cache, + quantize_q4_0, + quantize_q4_1, + quantize_q5_0, + quantize_q5_1, + quantize_q8_0, +) HERE = Path(__file__).resolve().parent @@ -168,10 +175,6 @@ def test_quantize_q5_1(): assert (q_tensor == ggml_q_tensor).all() -CHATGLM_MODEL_PATH = Path( - "~/.cache/huggingface/hub/models--THUDM--chatglm-6b/snapshots/619e736c6d4cd139840579c5482063b75bed5666" -).expanduser() - CHATGLM2_MODEL_PATH = Path( "~/.cache/huggingface/hub/models--THUDM--chatglm2-6b/snapshots/b1502f4f75c71499a3d566b14463edd62620ce9f" ).expanduser() @@ -242,6 +245,44 @@ def make_data_rms_norm(): def make_data_glm_model(): + def _forward_steps(model, seq_len): + # self attention + x1 = torch.arange(seq_len, dtype=torch.int64)[None, :] + position_ids = torch.tensor([[[0, 1, 1], [0, 0, 1]]]) + attn_mask = torch.tensor([[0, 0, 1], [0, 0, 1], [0, 0, 0]], dtype=torch.bool)[None, None, :] + with torch.no_grad(): + out = model(x1, position_ids=position_ids, attention_mask=attn_mask, use_cache=True) + y1 = out.last_hidden_state + kv_cache = out.past_key_values + + # cross attention + x2 = torch.tensor([[seq_len]], dtype=torch.int64) + position_ids = torch.tensor([[[1], [2]]]) + attn_mask = None + with torch.no_grad(): + out = model( + x2, position_ids=position_ids, attention_mask=attn_mask, past_key_values=kv_cache, use_cache=True + ) + y2 = out.last_hidden_state + kv_cache = out.past_key_values + + # cross attention + x3 = torch.tensor([[seq_len + 1]], dtype=torch.int64) + position_ids = torch.tensor([[[1], [3]]]) + attn_mask = None + with torch.no_grad(): + out = model( + x3, position_ids=position_ids, attention_mask=attn_mask, past_key_values=kv_cache, use_cache=True + ) + y3 = out.last_hidden_state + kv_cache = out.past_key_values + + return x1, y1, x2, y2, x3, y3 + + CHATGLM_MODEL_PATH = Path( + "~/.cache/huggingface/hub/models--THUDM--chatglm-6b/snapshots/8b7d33596d18c5e83e2da052d05ca4db02e60620" + ).expanduser() + sys.path.append(str(CHATGLM_MODEL_PATH)) from modeling_chatglm import ChatGLMModel from transformers import AutoConfig @@ -260,36 +301,54 @@ def make_data_glm_model(): seq_len = 3 - # self attention - x1 = torch.arange(seq_len, dtype=torch.int64)[None, :] - position_ids = torch.tensor([[[0, 1, 1], [0, 0, 1]]]) - attn_mask = torch.tensor([[0, 0, 1], [0, 0, 1], [0, 0, 0]], dtype=torch.bool)[None, None, :] - with torch.no_grad(): - out = m(x1, position_ids=position_ids, attention_mask=attn_mask, use_cache=True) - y1 = out.last_hidden_state - kv_cache = out.past_key_values + x1, y1, x2, y2, x3, y3 = _forward_steps(m, seq_len) - # cross attention - x2 = torch.tensor([[seq_len]], dtype=torch.int64) - position_ids = torch.tensor([[[1], [2]]]) - attn_mask = None - with torch.no_grad(): - out = m(x2, position_ids=position_ids, attention_mask=attn_mask, past_key_values=kv_cache, use_cache=True) - y2 = out.last_hidden_state - kv_cache = out.past_key_values + print(m) - # cross attention - x3 = torch.tensor([[seq_len + 1]], dtype=torch.int64) - position_ids = torch.tensor([[[1], [3]]]) - attn_mask = None - with torch.no_grad(): - out = m(x3, position_ids=position_ids, attention_mask=attn_mask, past_key_values=kv_cache, use_cache=True) - y3 = out.last_hidden_state - kv_cache = out.past_key_values + with open(HERE / "data/glm_model.data", "wb") as f: + m.word_embeddings.weight.data.numpy().tofile(f) + m.layers[0].input_layernorm.weight.data.numpy().tofile(f) + m.layers[0].input_layernorm.bias.data.numpy().tofile(f) + m.layers[0].attention.query_key_value.weight.data.numpy().tofile(f) + m.layers[0].attention.query_key_value.bias.data.numpy().tofile(f) + m.layers[0].attention.dense.weight.data.numpy().tofile(f) + m.layers[0].attention.dense.bias.data.numpy().tofile(f) + m.layers[0].post_attention_layernorm.weight.data.numpy().tofile(f) + m.layers[0].post_attention_layernorm.bias.data.numpy().tofile(f) + m.layers[0].mlp.dense_h_to_4h.weight.data.numpy().tofile(f) + m.layers[0].mlp.dense_h_to_4h.bias.data.numpy().tofile(f) + m.layers[0].mlp.dense_4h_to_h.weight.data.numpy().tofile(f) + m.layers[0].mlp.dense_4h_to_h.bias.data.numpy().tofile(f) + m.final_layernorm.weight.data.numpy().tofile(f) + m.final_layernorm.bias.data.numpy().tofile(f) + + x1.int().numpy().tofile(f) + y1.data.numpy().tofile(f) + x2.int().numpy().tofile(f) + y2.data.numpy().tofile(f) + x3.int().numpy().tofile(f) + y3.data.numpy().tofile(f) + + # p-tuning v2 + config.pre_seq_len = 5 + m = ChatGLMModel(config).float().eval() + for param in m.parameters(): + param.data.uniform_(-0.5, 0.5) + + x1, y1, x2, y2, x3, y3 = _forward_steps(m, seq_len) print(m) - with open(HERE / "data/glm_model.data", "wb") as f: + past_key_values = get_prefix_cache( + m.prefix_encoder, + config.pre_seq_len, + config.num_layers, + config.num_attention_heads, + config.hidden_size // config.num_attention_heads, + ) + + with open(HERE / "data/glm_ptuning_v2_model.data", "wb") as f: + past_key_values.data.numpy().tofile(f) m.word_embeddings.weight.data.numpy().tofile(f) m.layers[0].input_layernorm.weight.data.numpy().tofile(f) m.layers[0].input_layernorm.bias.data.numpy().tofile(f) @@ -385,7 +444,44 @@ def make_data_glm2_model(): def make_data_glm3_model(): - CHATGLM3_MODEL_PATH = Path("./chatglm3-6b").expanduser() + + def _forward_steps(model, seq_len): + # self attention + x1 = torch.arange(seq_len, dtype=torch.int64)[None, :] + position_ids = torch.arange(seq_len, dtype=torch.int64)[None, :] + attn_mask = torch.ones(1, seq_len, dtype=torch.int64) + with torch.no_grad(): + out = model(x1, position_ids=position_ids, attention_mask=attn_mask, use_cache=True) + y1 = out.last_hidden_state + kv_cache = out.past_key_values + + # cross attention + x2 = torch.tensor([[seq_len]], dtype=torch.int64) + position_ids = torch.tensor([[seq_len]], dtype=torch.int64) + attn_mask = torch.ones(1, seq_len + 1, dtype=torch.int64) + with torch.no_grad(): + out = model( + x2, position_ids=position_ids, attention_mask=attn_mask, past_key_values=kv_cache, use_cache=True + ) + y2 = out.last_hidden_state + kv_cache = out.past_key_values + + # cross attention + x3 = torch.tensor([[seq_len + 1]], dtype=torch.int64) + position_ids = torch.tensor([[seq_len + 1]], dtype=torch.int64) + attn_mask = torch.ones(1, seq_len + 2, dtype=torch.int64) + with torch.no_grad(): + out = model( + x3, position_ids=position_ids, attention_mask=attn_mask, past_key_values=kv_cache, use_cache=True + ) + y3 = out.last_hidden_state + kv_cache = out.past_key_values + + return x1, y1, x2, y2, x3, y3 + + CHATGLM3_MODEL_PATH = Path( + "~/.cache/huggingface/hub/models--THUDM--chatglm3-6b/snapshots/a5ba5501eb873d40d48bd0983bd2a8dd006bb838" + ).expanduser() sys.path.append(str(CHATGLM3_MODEL_PATH)) from modeling_chatglm import ChatGLMModel @@ -407,36 +503,44 @@ def make_data_glm3_model(): seq_len = 3 - # self attention - x1 = torch.arange(seq_len, dtype=torch.int64)[None, :] - position_ids = torch.arange(seq_len, dtype=torch.int64)[None, :] - attn_mask = torch.ones(1, seq_len, dtype=torch.int64) - with torch.no_grad(): - out = m(x1, position_ids=position_ids, attention_mask=attn_mask, use_cache=True) - y1 = out.last_hidden_state - kv_cache = out.past_key_values + x1, y1, x2, y2, x3, y3 = _forward_steps(m, seq_len) - # cross attention - x2 = torch.tensor([[seq_len]], dtype=torch.int64) - position_ids = torch.tensor([[seq_len]], dtype=torch.int64) - attn_mask = torch.ones(1, seq_len + 1, dtype=torch.int64) - with torch.no_grad(): - out = m(x2, position_ids=position_ids, attention_mask=attn_mask, past_key_values=kv_cache, use_cache=True) - y2 = out.last_hidden_state - kv_cache = out.past_key_values + print(m) - # cross attention - x3 = torch.tensor([[seq_len + 1]], dtype=torch.int64) - position_ids = torch.tensor([[seq_len + 1]], dtype=torch.int64) - attn_mask = torch.ones(1, seq_len + 2, dtype=torch.int64) - with torch.no_grad(): - out = m(x3, position_ids=position_ids, attention_mask=attn_mask, past_key_values=kv_cache, use_cache=True) - y3 = out.last_hidden_state - kv_cache = out.past_key_values + with open(HERE / "data/glm3_model.data", "wb") as f: + m.embedding.word_embeddings.weight.data.numpy().tofile(f) + m.encoder.layers[0].input_layernorm.weight.data.numpy().tofile(f) + m.encoder.layers[0].self_attention.query_key_value.weight.data.numpy().tofile(f) + m.encoder.layers[0].self_attention.query_key_value.bias.data.numpy().tofile(f) + m.encoder.layers[0].self_attention.dense.weight.data.numpy().tofile(f) + m.encoder.layers[0].post_attention_layernorm.weight.data.numpy().tofile(f) + m.encoder.layers[0].mlp.dense_h_to_4h.weight.data.numpy().tofile(f) + m.encoder.layers[0].mlp.dense_4h_to_h.weight.data.numpy().tofile(f) + m.encoder.final_layernorm.weight.data.numpy().tofile(f) + + x1.int().numpy().tofile(f) + y1.numpy().tofile(f) + x2.int().numpy().tofile(f) + y2.numpy().tofile(f) + x3.int().numpy().tofile(f) + y3.numpy().tofile(f) + + # p-tuning v2 + config.pre_seq_len = 5 + m = ChatGLMModel(config).float().eval() + for param in m.parameters(): + param.data.uniform_(-0.5, 0.5) + + x1, y1, x2, y2, x3, y3 = _forward_steps(m, seq_len) print(m) - with open(HERE / "data/glm3_model.data", "wb") as f: + past_key_values = get_prefix_cache( + m.prefix_encoder, config.pre_seq_len, config.num_layers, config.multi_query_group_num, config.kv_channels + ) + + with open(HERE / "data/glm3_ptuning_v2_model.data", "wb") as f: + past_key_values.data.numpy().tofile(f) m.embedding.word_embeddings.weight.data.numpy().tofile(f) m.encoder.layers[0].input_layernorm.weight.data.numpy().tofile(f) m.encoder.layers[0].self_attention.query_key_value.weight.data.numpy().tofile(f)