Skip to content

Commit

Permalink
ZHIHAO COMMENTS
Browse files Browse the repository at this point in the history
  • Loading branch information
hugolatendresse committed Dec 11, 2024
1 parent bab8ccb commit f9ed445
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 5 deletions.
4 changes: 2 additions & 2 deletions inference/models/mixtral.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ void MIXTRAL::create_mixtral_model(FFModel &ff,
Tensor aggregate_inputs[4 + mixtral_config.num_local_experts] = {nullptr};
for (int expert_idx = 0; expert_idx < mixtral_config.num_local_experts;
expert_idx++) {
grouped_tokens[expert_idx] = ff_norm; // TODO this is a dirty fix. Restore using group_by!
// grouped_tokens[expert_idx] = ff_norm; // TODO this is a dirty fix. Restore using group_by!
Tensor w1 = ff.dense(grouped_tokens[expert_idx], // (hidden_size, 1, result of calc in groupby)
mixtral_config.intermediate_size,
AC_MODE_NONE,
Expand Down Expand Up @@ -336,7 +336,7 @@ void MIXTRAL::create_mixtral_model(FFModel &ff,

aggregate_inputs[0] = topk_values;
aggregate_inputs[1] = topk_indices;
aggregate_inputs[2] = topk_values;
aggregate_inputs[2] = topk_values; // TODO Causes Legion runtime error!!
aggregate_inputs[3] = gate;
mlp_out = aggregate_inputs[5]; // TODO don't use only one expert
// mlp_out = ff.aggregate(aggregate_inputs,
Expand Down
15 changes: 15 additions & 0 deletions src/ops/aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ OpMeta *Aggregate::init_task(Task const *task,
// ... including some steps with GenericTensorAccessorR
// Shoud I include?

// Only needed to allocate memroy in the kernel
AggregateMeta *m = new AggregateMeta(handle, agg, gpu_mem_allocator);
for (int i = 0; i < 10; i++) { // TODO 10 is a magic number
m->input_type[i] = agg->inputs[i]->data_type;
Expand Down Expand Up @@ -489,6 +490,18 @@ void Aggregate::forward_task(Task const *task,
ctx, task->regions[total_input_cnt].region.get_index_space());


Aggregate::forward_kernel_wrapper(m,
bc,
exp_preds,
acc_gate_assign.ptr(rect_gate_assign),
acc_gate_pred.ptr(rect_gate_pred),
acc_output.ptr(rect_output),
n,
k,
rows,
batch_size,
out_dim);

// TODO One of those three linese cause the mismatch error
// get gate_pred, gate_assign, output
//AccessorRO<float, 3> const acc_gate_pred(regions[0], FID_DATA); // This one alone does cause the problem
Expand Down Expand Up @@ -534,6 +547,7 @@ void Aggregate::forward_task(Task const *task,

// printf("CALLING FOWARD_KERNEL_WRAPPER IN FORWARD_TASK\n");

// From ZJ: we lose shape of tensors when we do this approach. Appraoch in sigmoid silu is recommended
// Aggregate::forward_kernel_wrapper(m,
// exp_preds,
// acc_gate_assign.ptr(rect_gate_assign),
Expand Down Expand Up @@ -604,6 +618,7 @@ void Aggregate::inference_task(Task const *task,

// TODO should we have an inference_kernel wrapper?
Aggregate::forward_kernel_wrapper(m,
bc,
exp_preds,
acc_gate_assign.ptr(rect_gate_assign),
acc_gate_pred.ptr(rect_gate_pred),
Expand Down
2 changes: 2 additions & 0 deletions src/ops/aggregate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ __global__ void agg_backward_kernel(float **exp_preds,

/*static*/
void Aggregate::forward_kernel_wrapper(AggregateMeta const *m,
BatchConfig const *bc,
float **exp_preds,
int const *acc_gate_assign_ptr,
float const *acc_gate_pred_ptr,
Expand Down Expand Up @@ -307,6 +308,7 @@ void Aggregate::backward_kernel_wrapper(AggregateMeta const *m,
}
}
// Only needed if we allocate memory , hwihci s not our case
AggregateMeta::AggregateMeta(FFHandler handler,
Aggregate const *aggr,
MemoryAllocator &gpu_mem_allocator)
Expand Down
4 changes: 3 additions & 1 deletion src/ops/inc_multihead_self_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,8 @@ OpMeta *IncMultiHeadSelfAttention::init_task(
// printf("running IncMultiHeadSelfAttention::init_task\n");
IncMultiHeadSelfAttention const *attn = (IncMultiHeadSelfAttention *)task->args;
FFHandler handle = *((FFHandler const *)task->local_args);

// We call the below to get the shape info, so we can do the assertions
// I also shouldnt care about offloading
GenericTensorAccessorR input =
helperGetGenericTensorAccessorRO(attn->inputs[0]->data_type,
regions[0],
Expand Down Expand Up @@ -745,6 +746,7 @@ bool IncMultiHeadSelfAttention::get_int_parameter(PMParameter para,
}
}

// Just for benchmarking, don't need that
bool IncMultiHeadSelfAttention::measure_operator_cost(
Simulator *sim, MachineView const &mv, CostMetrics &cost_metrics) const {
return false;
Expand Down
3 changes: 2 additions & 1 deletion src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m,
int const per_head_size = m->qProjSize;
float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f;
size_t smem_sz;
if (per_head_size == 32) {
if (per_head_size == 32) { // ok to do that
constexpr int THREADS_PER_VALUE_32 = threads_per_value_t<DT, 32>::value;
LAUNCH_ATTENTION_SCORE_KERNEL(
DT, 32, 32, 4, THREADS_PER_VALUE_32, 128, stream);
Expand Down Expand Up @@ -1517,6 +1517,7 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper(
assert(input.data_type == output.data_type);
if (input.data_type == DT_HALF) {
// calling input.get_inc_ptr() below would cause a legion error type mismatch get index space doamine
Kernels::IncMultiHeadAttention::inference_kernel(
m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream);
} else if (input.data_type == DT_FLOAT) {
Expand Down
2 changes: 1 addition & 1 deletion src/ops/sigmoid_silu_multi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ void SigmoidSiluMulti::inference_kernel_wrapper(
min(CUDA_NUM_THREADS, num_elements),
0,
stream>>>(input1.domain.get_volume(),
input1.get_float_ptr(),
input1.get_float_ptr(), // Ultimately we get pointers,whereas in mixtarl branch we pass pointers to this func.
input2.get_float_ptr(),
output.get_float_ptr());
} else if (m->input_type[0] == DT_HALF) {
Expand Down

0 comments on commit f9ed445

Please sign in to comment.