Skip to content

Commit

Permalink
imatrix: be able to specify the name of the output tensor
Browse files Browse the repository at this point in the history
For some models the same tensor is used for token embeddings and
output. This tensor tends to be named token_embedding.weight rather
than output.weight, which prevernts us from collecting imatrix data
for this tensor. With this commit we can tell the name of the
output tensor to the imatrix tool.
  • Loading branch information
Kawrakow committed Jun 26, 2024
1 parent 71725a9 commit 0a3a2c4
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
8 changes: 8 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.process_output = true;
return true;
}
if (arg == "--output-tensor-name") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.output_tensor_name = argv[i];
return true;
}
if (arg == "--no-ppl") {
params.compute_ppl = false;
return true;
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ struct gpt_params {

// imatrix params
std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file
std::string output_tensor_name = "output.weight"; // name of the output tensor

int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
Expand Down
3 changes: 2 additions & 1 deletion examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
if (t->op != GGML_OP_MUL_MAT) return false;
// why are small batches ignored (<16 tokens)?
if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false;
if (!(wname.substr(0, 4) == "blk." || (m_params.process_output && wname == "output.weight"))) return false;
//printf("wname = %s\n", wname.c_str());
if (!(wname.substr(0, 4) == "blk." || (m_params.process_output && wname == m_params.output_tensor_name))) return false;
return true;
}

Expand Down

0 comments on commit 0a3a2c4

Please sign in to comment.