diff --git a/inference/models/mixtral.cc b/inference/models/mixtral.cc index a728307383..c39ad3802f 100644 --- a/inference/models/mixtral.cc +++ b/inference/models/mixtral.cc @@ -88,8 +88,8 @@ void MIXTRAL::create_mixtral_model(FFModel &ff, .c_str()); } else { ff.residual_rms_norm( - token, // (1024, batch, sequence) - mlp_out, // (1024, batch, sequence) + token, // (hidden_dim, batch, sequence) + mlp_out, // (hidden_dim, batch, sequence) token_att_norm, mixtral_config.rms_norm_eps, mixtral_config.hidden_size,