diff --git a/src/ops/group_by.cc b/src/ops/group_by.cc index 8c0bbcbbf3..484081c0f5 100644 --- a/src/ops/group_by.cc +++ b/src/ops/group_by.cc @@ -72,8 +72,8 @@ void FFModel::group_by(const Tensor input, if (alpha != 0.0f) { dims[num_dims - 1] = (int)ceil(alpha * k_experts_per_tok / num_local_experts * input->dims[num_dims - 1]); - } else { // MK: added this for dummy groupby - dims[num_dims - 1] = 128; // TODO remove magic number + } else { + dims[num_dims - 1] = input->dims[num_dims - 1]; } for (int i = 0; i < num_local_experts; i++) { @@ -158,12 +158,13 @@ Group_by::Group_by(FFModel &model, for (int i = 0; i < num_dims; i++) { dims[i] = inputs[0]->dims[i]; } - // set max expert size + // set max expert size if (alpha != 0.0f) { - dims[num_dims - 2].size = (int)ceil(alpha * k_experts_per_tok / n * inputs[0]->dims[2].size); // was inputs[0]->dims[1].size - - } + dims[num_dims - 2].size = (int)ceil(alpha * k_experts_per_tok / n * inputs[0]->dims[2].size); + } else { + dims[num_dims - 2].size = inputs[0]->dims[2].size; + } for (int i = 0; i < n; i++) { outputs[i] = model.create_parallel_tensor_legion_ordering(