Skip to content

Commit

Permalink
cleanup mixtral.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
hugolatendresse committed Dec 15, 2024
1 parent 4404694 commit e133fc6
Showing 1 changed file with 18 additions and 43 deletions.
61 changes: 18 additions & 43 deletions inference/models/mixtral.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ void MIXTRAL::create_mixtral_model(FFModel &ff,
NULL,
embed_init,
"embed_tokens");
// token has dimensions (hidden_size, 1, 128)

Tensor mlp_out = nullptr;

Expand All @@ -88,13 +87,9 @@ void MIXTRAL::create_mixtral_model(FFModel &ff,
std::string("layers." + std::to_string(i) + ".input_layernorm")
.c_str());
} else {
// printf("before first rms norm in layer %d token has %d dims\n",i, token->num_dims);
// printf("before first rms norm in layer %d mlp_out has %d dims\n",i, token->num_dims);
// printf("before first rms norm in layer %d token dims are %d %d %d %d\n",i, token->dims[0], token->dims[1], token->dims[2], token->dims[3]);
// printf("before first rms norm in layer %d, mlp_out dims are %d %d %d %d\n",i, mlp_out->dims[0], mlp_out->dims[1], mlp_out->dims[2], mlp_out->dims[3]);
ff.residual_rms_norm(
token, // (1024, 1, 128) confirmed 3 dims
mlp_out, // (1024, 1, 128) confirmed 3 dims
token, // (1024, batch, sequence)
mlp_out, // (1024, batch, sequence)
token_att_norm,
mixtral_config.rms_norm_eps,
mixtral_config.hidden_size,
Expand All @@ -105,9 +100,7 @@ void MIXTRAL::create_mixtral_model(FFModel &ff,
token = token_att_norm[0];
att_norm = token_att_norm[1];
}
// token has dimensions (hidden_size, 1, 128)


// token has dimensions (hidden_size, batch, sequence)

Tensor qkv_proj = ff.dense(
att_norm,
Expand Down Expand Up @@ -225,12 +218,12 @@ void MIXTRAL::create_mixtral_model(FFModel &ff,
DT_NONE,
std::string("layers." + std::to_string(i) + ".post_attention_layernorm")
.c_str());
token = token_ff_norm[0]; // token has dimensions (hidden_size, 1, 128)
token = token_ff_norm[0]; // token has dimensions (hidden_size, batch, sequence)
Tensor ff_norm = token_ff_norm[1];

// MoE
Tensor gate = ff.dense(
ff_norm, // (hidden_size, 1, 128)
ff_norm, // (hidden_size, batch, sequence)
mixtral_config.num_local_experts,
AC_MODE_NONE,
false,
Expand All @@ -243,7 +236,7 @@ void MIXTRAL::create_mixtral_model(FFModel &ff,
std::string("layers." + std::to_string(i) + ".block_sparse_moe_gate")
.c_str());
gate = ff.softmax(
gate, // (num_experts, 1, 128)
gate, // (num_experts, batch, sequence)
0,
DT_NONE,
std::string("layers." + std::to_string(i) + ".block_sparse_moe_softmax")
Expand All @@ -252,43 +245,30 @@ void MIXTRAL::create_mixtral_model(FFModel &ff,
Tensor topk_out[2] = {nullptr, nullptr};

ff.top_k(
gate, // (num_experts, 1, 128)
gate, // (num_experts, batch, sequence)
topk_out,
mixtral_config.num_experts_per_tok,
false,
std::string("layers." + std::to_string(i) + ".block_sparse_moe_topk")
.c_str());
Tensor topk_values = topk_out[0]; // (experts_per_tok, 1, 128) (confirmed 3 dims)
Tensor topk_indices = topk_out[1]; // (experts_per_tok, 1, 128) (confirmed 3 dims)
Tensor topk_values = topk_out[0]; // (experts_per_tok, batch, sequence)
Tensor topk_indices = topk_out[1]; // (experts_per_tok, batch, sequence)

Tensor grouped_tokens[mixtral_config.num_local_experts] = {nullptr};
ff.group_by( // TODO this group_by does not crash, but it sets all tokens to 0 or something! Need to figure out why it make outptu tokens all the same
ff_norm, // (hidden_size, 1, 128)
ff.group_by(
ff_norm, // (hidden_size, batch, sequence)
topk_indices,
grouped_tokens,
mixtral_config.num_local_experts,
0.0f, // TODO understand why this does not cause a dimension of 128? maybe the 128 is never set?
0.0f,
std::string("layers." + std::to_string(i) + ".block_sparse_moe_groupby")
.c_str());

// Can use this to create a grouped_tokens2 used no where just to see if group_by can run successfully
// Tensor grouped_tokens2[mixtral_config.num_local_experts] = {nullptr};
// ff.group_by(
// ff_norm, // (hidden_size, 1, 128)
// topk_indices,
// grouped_tokens2,
// mixtral_config.num_local_experts,
// 1.0f, // TODO understand why this does not cause a dimension of 128? maybe the 128 is never set?
// std::string("layers." + std::to_string(i) + ".block_sparse_moe_groupby")
// .c_str());


Tensor aggregate_inputs[4 + mixtral_config.num_local_experts] = {nullptr};

for (int expert_idx = 0; expert_idx < mixtral_config.num_local_experts;
expert_idx++) {
// grouped_tokens[expert_idx] = ff_norm; // TODO this is a dirty fix. Restore using group_by!
Tensor w1 = ff.dense(grouped_tokens[expert_idx], // (hidden_size, 1, result of calc in groupby)
Tensor w1 = ff.dense(grouped_tokens[expert_idx], // (hidden_size, batch, max tokens per expert)
mixtral_config.intermediate_size,
AC_MODE_NONE,
false,
Expand Down Expand Up @@ -348,10 +328,7 @@ void MIXTRAL::create_mixtral_model(FFModel &ff,
// Tensor topk_values_reduced = ff.reduce_sum(topk_values, {0}, true);
// topk_values = ff.divide(topk_values, topk_values_reduced);

// mlp_out = aggregate_inputs[5]; // TODO don't use only one expert

// Everything below is needed to use aggregate // TODO try not needing the _dummy stuff

// TODO have 2 fixed inputs instead of 4
Tensor topk_values_DUMMY = ff.softmax(
topk_values,
-1,
Expand All @@ -360,7 +337,7 @@ void MIXTRAL::create_mixtral_model(FFModel &ff,
.c_str());

Tensor gate_DUMMY = ff.softmax(
gate, // (num_experts, 1, 128)
gate, // (num_experts, batch, sequence)
-1,
DT_NONE,
std::string("layers." + std::to_string(i) + ".dummy")
Expand All @@ -370,17 +347,15 @@ void MIXTRAL::create_mixtral_model(FFModel &ff,
aggregate_inputs[1] = topk_indices;
aggregate_inputs[2] = topk_values_DUMMY;
aggregate_inputs[3] = gate_DUMMY;
//

mlp_out = ff.aggregate(aggregate_inputs,
mixtral_config.num_local_experts,
0.0f,
std::string("layers." + std::to_string(i) +
".block_sparse_moe_experts_aggregate")
.c_str());

// printf("mlp_out in layer %d dims are %d %d %d %d\n",i, mlp_out->dims[0], mlp_out->dims[1], mlp_out->dims[2], mlp_out->dims[3]);
assert(mlp_out->dims[0] == mixtral_config.hidden_size && "mlp_out dims[0] != hidden_size");
// printf("seq length is now %d\n", mlp_out->dims[2]);
assert(mlp_out->dims[0] == mixtral_config.hidden_size && "mlp_out dims[0] != hidden_size");

}

Expand Down Expand Up @@ -414,7 +389,7 @@ void MIXTRAL::create_mixtral_model(FFModel &ff,
Tensor softmax = ff.softmax(dense, -1);
output = ff.sampling(softmax, generation_config.topp);
} else {
Tensor softmax = ff.softmax(dense, -1); // TODO added that to copy llama, see if needed in HF transformers impl.
Tensor softmax = ff.softmax(dense, -1);
output = ff.argmax(softmax, /*beam_Search*/ false);
}

Expand Down

0 comments on commit e133fc6

Please sign in to comment.