Skip to content

Commit

Permalink
Fix nan
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Jun 20, 2024
1 parent f0daf01 commit 5825cc2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 28 deletions.
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ file(GLOB CPP_SOURCES
${PROJECT_SOURCE_DIR}/*.cpp
${PROJECT_SOURCE_DIR}/tests/*.cpp)

set_source_files_properties(${CPP_SOURCES} PROPERTIES COMPILE_FLAGS "-pedantic-errors")

add_library(chatglm STATIC chatglm.cpp)
target_link_libraries(chatglm PUBLIC ggml sentencepiece-static re2)

Expand Down
68 changes: 45 additions & 23 deletions chatglm.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "chatglm.h"
#include <ggml-quants.h>
#include <algorithm>
#include <codecvt>
#include <cstring>
Expand Down Expand Up @@ -65,50 +66,70 @@ static std::string strides_to_string(ggml_tensor *tensor) {
}

std::string to_string(ggml_tensor *tensor, bool with_data) {
std::vector<no_init<char>> buf(ggml_nbytes(tensor));
ggml_backend_tensor_get(tensor, buf.data(), 0, buf.size());
std::vector<char> buf(ggml_nbytes(tensor));
if (tensor->buffer ) {
ggml_backend_tensor_get(tensor, buf.data(), 0, buf.size());
} else {
memcpy(buf.data(), tensor->data, buf.size());
}

std::vector<float> float_buf(ggml_nelements(tensor));

switch (tensor->type) {
case GGML_TYPE_F32:
memcpy(float_buf.data(), buf.data(), buf.size());
break;
case GGML_TYPE_F16:
ggml_fp16_to_fp32_row((const ggml_fp16_t*)buf.data(), float_buf.data(), ggml_nelements(tensor));
break;
case GGML_TYPE_Q4_0:
dequantize_row_q4_0((block_q4_0*)buf.data(), float_buf.data(), ggml_nelements(tensor));
break;
case GGML_TYPE_Q4_1:
dequantize_row_q4_1((block_q4_1*)buf.data(), float_buf.data(), ggml_nelements(tensor));
break;
case GGML_TYPE_Q5_0:
dequantize_row_q5_0((block_q5_0*)buf.data(), float_buf.data(), ggml_nelements(tensor));
break;
case GGML_TYPE_Q5_1:
dequantize_row_q5_1((block_q5_1*)buf.data(), float_buf.data(), ggml_nelements(tensor));
break;
case GGML_TYPE_Q8_0:
dequantize_row_q8_0((block_q8_0*)buf.data(), float_buf.data(), ggml_nelements(tensor));
break;
default:
CHATGLM_THROW << "Unsupported dtype " << tensor->type;
}

std::ostringstream oss;
oss << "ggml_tensor(";

if (with_data) {
if (ggml_n_dims(tensor) > 3)
const int n_dims = ggml_n_dims(tensor);
if (n_dims > 3)
oss << "[";
for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
if (ggml_n_dims(tensor) > 2)
if (n_dims > 2)
oss << (i3 > 0 ? ",\n\n[" : "[");
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
if (ggml_n_dims(tensor) > 1)
if (n_dims > 1)
oss << (i2 > 0 ? ",\n\n[" : "[");
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
oss << (i1 > 0 ? ",\n[" : "[");
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
char *ptr = (char *)buf.data() + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] +
i0 * tensor->nb[0];
oss << (i0 > 0 ? ", " : "");
if (tensor->type == GGML_TYPE_I32) {
oss << *(int *)ptr;
} else {
float val;
if (tensor->type == GGML_TYPE_F32) {
val = *(float *)ptr;
} else if (tensor->type == GGML_TYPE_F16) {
val = ggml_fp16_to_fp32(*(ggml_fp16_t *)ptr);
} else {
CHATGLM_THROW << "unimplemented";
}
oss << std::setw(7) << std::fixed << std::setprecision(4) << val;
}
const int i = ((i3 * tensor->ne[2] + i2 ) * tensor->ne[1] + i1) * tensor->ne[0] + i0;
oss << std::setw(7) << std::fixed << std::setprecision(4) << float_buf[i];
}
oss << "]";
}
if (ggml_n_dims(tensor) > 1)
if (n_dims > 1)
oss << "]";
}
if (ggml_n_dims(tensor) > 2)
if (n_dims > 2)
oss << "]";
}
if (ggml_n_dims(tensor) > 3)
if (n_dims > 3)
oss << "]";
oss << ", ";
}
Expand Down Expand Up @@ -731,6 +752,7 @@ ggml_tensor *BaseModelForCausalLM::forward_graph_compute(const std::vector<int>
ggml_set_input(curr_input_ids);

ggml_tensor *lm_logits = forward(mctx_.get(), curr_input_ids, n_past, n_ctx, is_decoding);
ggml_set_output(lm_logits);

ggml_build_forward_expand(mctx_->gf, lm_logits);
CHATGLM_CHECK(ggml_gallocr_alloc_graph(mctx_->allocr.get(), mctx_->gf));
Expand Down
4 changes: 1 addition & 3 deletions chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ class ModelConfig {
num_virtual_tokens, rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id,
rec.sep_token_id, {}) {}

ModelConfig(ModelType model_type, const ConfigRecordV1GQA &rec, float norm_eps, ActivationType hidden_act,
bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, float rope_theta,
AttentionMaskType attn_mask_type, int num_virtual_tokens)
ModelConfig(ModelType model_type, const ConfigRecordV1GQA &rec, float norm_eps, float rope_theta, int num_virtual_tokens)
: ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_kv_heads,
rec.num_hidden_layers, rec.intermediate_size, norm_eps, rope_theta, num_virtual_tokens,
rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, rec.sep_token_id, {}) {}
Expand Down

0 comments on commit 5825cc2

Please sign in to comment.