Skip to content

Commit

Permalink
aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
hugolatendresse committed Dec 16, 2024
1 parent f289afc commit 01641a4
Showing 1 changed file with 1 addition and 28 deletions.
29 changes: 1 addition & 28 deletions src/ops/aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,37 +136,10 @@ Aggregate::Aggregate(FFModel &model,
assert(inputs[2]->num_dims >= 2 + 1);
assert(inputs[3]->num_dims >= 2 + 1);

// TODO uncomment all those assertions
// assert(n <= AGGREGATE_MAX_N && "Increase AGGREGATE_MAX_N in #define");
// assert(inputs[0]->dims[0].size <= AGGREGATE_MAX_K &&
// "Increase AGGREGATE_MAX_K in #define");
// assert(inputs[0]->dims[1].size <= AGGREGATE_MAX_BATCH_SIZE &&
// "Increase AGGREGATE_MAX_BATCH_SIZE in #define");
//
// assert(n + FIXED_ARG_CNT == numInputs);
// assert(n > 0);
// //printf("In Aggregate::Aggregate, inputs[0]->num_dims = %d\n", inputs[0]->num_dims);
// //printf("In Aggregate::Aggregate, inputs[0] dims are %d %d %d %d\n", inputs[0]->dims[0].size, inputs[0]->dims[1].size, inputs[0]->dims[2].size, inputs[0]->dims[3].size);
// // TODO the inequalities below used to be equalities, not sure it's a good idea to switch to inequalities
// assert(inputs[0]->num_dims >= 2 + 1); // inputs[0] has dims (experts_per_token, 1, 128, 1) (confirmed dim count)
// assert(inputs[1]->num_dims >= 2 + 1);
// assert(inputs[2]->num_dims >= 2 + 1);
// assert(inputs[3]->num_dims >= 2 + 1);
//
// for (int i = 0; i < inputs[0]->num_dims; i++) {
// assert(inputs[0]->dims[i] == inputs[1]->dims[i]);
// assert(inputs[0]->dims[i] == inputs[2]->dims[i]);
// }
// assert(inputs[0]->dims[1] == inputs[3]->dims[1]);
// assert(inputs[3]->dims[0].size == n);

// expert inputs
int num_dim = inputs[FIXED_ARG_CNT]->num_dims; // 3
int out_dim = inputs[FIXED_ARG_CNT]->dims[0].size;
// for (int i = 1; i < n; i++) {
// assert(inputs[i + FIXED_ARG_CNT]->num_dims == num_dim);
// assert(inputs[i + FIXED_ARG_CNT]->dims[0].size == out_dim);
// }

// Set output shape
ParallelDim dims[MAX_TENSOR_DIM];
for (int i = 0; i < num_dim - 1; i++) {
Expand Down

0 comments on commit 01641a4

Please sign in to comment.