Skip to content

Commit

Permalink
[FusedOp] Fix segment fault (#1511)
Browse files Browse the repository at this point in the history
* minor bug fix

* fix
  • Loading branch information
jiazhihao authored Sep 27, 2024
1 parent 9da5546 commit 64c258f
Showing 1 changed file with 35 additions and 34 deletions.
69 changes: 35 additions & 34 deletions src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1678,77 +1678,77 @@ __host__ void FusedOp::backward_task(Task const *task,
int sum = fused->numInputs + fused->numWeights + fused->numOutputs;
assert(sum * 2 == (int)regions.size());
}
GenericTensorAccessorR input_accessor[MAX_NUM_INPUTS];
GenericTensorAccessorW input_grad_accessor[MAX_NUM_INPUTS];
GenericTensorAccessorR weight_accessor[MAX_NUM_WEIGHTS];
GenericTensorAccessorW weight_grad_accessor[MAX_NUM_WEIGHTS];
GenericTensorAccessorR output_accessor[MAX_NUM_OUTPUTS];
GenericTensorAccessorW output_grad_accessor[MAX_NUM_OUTPUTS];
std::vector<GenericTensorAccessorR> input_accessor;
std::vector<GenericTensorAccessorW> input_grad_accessor;
std::vector<GenericTensorAccessorR> weight_accessor;
std::vector<GenericTensorAccessorW> weight_grad_accessor;
std::vector<GenericTensorAccessorR> output_accessor;
std::vector<GenericTensorAccessorW> output_grad_accessor;
int roff = 0;
assert(fused->numInputs <= MAX_NUM_INPUTS);
for (int i = 0; i < fused->numInputs; i++) {
input_accessor[i] =
input_accessor.push_back(
helperGetGenericTensorAccessorRO(fused->input_data_types[i],
regions[i],
task->regions[i],
FID_DATA,
ctx,
runtime);
runtime));
}
roff += fused->numInputs;
assert(fused->numWeights <= MAX_NUM_WEIGHTS);
for (int i = 0; i < fused->numWeights; i++) {
weight_accessor[i] =
weight_accessor.push_back(
helperGetGenericTensorAccessorRO(fused->weight_data_types[i],
regions[i + roff],
task->regions[i + roff],
FID_DATA,
ctx,
runtime);
runtime));
}
roff += fused->numWeights;
assert(fused->numOutputs <= MAX_NUM_OUTPUTS);
for (int i = 0; i < fused->numOutputs; i++) {
output_accessor[i] =
output_accessor.push_back(
helperGetGenericTensorAccessorRO(fused->output_data_types[i],
regions[i + roff],
task->regions[i + roff],
FID_DATA,
ctx,
runtime);
runtime));
}
roff += fused->numOutputs;
for (int i = 0; i < fused->numInputs; i++) {
input_grad_accessor[i] =
input_grad_accessor.push_back(
helperGetGenericTensorAccessorRW(fused->input_data_types[i],
regions[i + roff],
task->regions[i + roff],
FID_DATA,
ctx,
runtime);
runtime));
assert(input_grad_accessor[i].domain == input_accessor[i].domain);
}
roff += fused->numInputs;
for (int i = 0; i < fused->numWeights; i++) {
weight_grad_accessor[i] =
weight_grad_accessor.push_back(
helperGetGenericTensorAccessorRW(fused->weight_data_types[i],
regions[i + roff],
task->regions[i + roff],
FID_DATA,
ctx,
runtime);
runtime));
assert(weight_grad_accessor[i].domain.get_volume() ==
weight_accessor[i].domain.get_volume());
}
roff += fused->numWeights;
for (int i = 0; i < fused->numOutputs; i++) {
output_grad_accessor[i] =
output_grad_accessor.push_back(
helperGetGenericTensorAccessorRW(fused->output_data_types[i],
regions[i + roff],
task->regions[i + roff],
FID_DATA,
ctx,
runtime);
runtime));
assert(output_grad_accessor[i].domain == output_accessor[i].domain);
}
roff += fused->numOutputs;
Expand All @@ -1767,12 +1767,6 @@ __host__ void FusedOp::backward_task(Task const *task,
}

int ioff = 0, woff = 0, ooff = 0;
GenericTensorAccessorR my_input_accessor[MAX_NUM_INPUTS];
GenericTensorAccessorR my_weight_accessor[MAX_NUM_WEIGHTS];
GenericTensorAccessorR my_output_accessor[MAX_NUM_OUTPUTS];
GenericTensorAccessorW my_input_grad_accessor[MAX_NUM_INPUTS];
GenericTensorAccessorW my_weight_grad_accessor[MAX_NUM_WEIGHTS];
GenericTensorAccessorW my_output_grad_accessor[MAX_NUM_OUTPUTS];
// Do backpropagation in the reverse ordering
for (int op = 0; op < fused->numOperators; op++) {
ioff += fused->op_num_inputs[op];
Expand All @@ -1781,36 +1775,43 @@ __host__ void FusedOp::backward_task(Task const *task,
}

for (int op = fused->numOperators - 1; op >= 0; op--) {
std::vector<GenericTensorAccessorR> my_input_accessor;
std::vector<GenericTensorAccessorR> my_weight_accessor;
std::vector<GenericTensorAccessorR> my_output_accessor;
std::vector<GenericTensorAccessorW> my_input_grad_accessor;
std::vector<GenericTensorAccessorW> my_weight_grad_accessor;
std::vector<GenericTensorAccessorW> my_output_grad_accessor;
ioff -= fused->op_num_inputs[op];
woff -= fused->op_num_weights[op];
ooff -= fused->op_num_outputs[op];
for (int i = 0; i < fused->op_num_inputs[op]; i++) {
int my_off = fused->op_input_idx[i + ioff];
if (fused->op_input_source[i + ioff] == SOURCE_INPUT) {
my_input_accessor[i] = input_accessor[my_off];
my_input_grad_accessor[i] = input_grad_accessor[my_off];
my_input_accessor.push_back(input_accessor[my_off]);
my_input_grad_accessor.push_back(input_grad_accessor[my_off]);
assert(my_input_grad_accessor[i].domain == my_input_accessor[i].domain);
} else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) {
my_input_accessor[i] = output_accessor[my_off];
my_input_grad_accessor[i] = output_grad_accessor[my_off];
my_input_accessor.push_back(output_accessor[my_off]);
my_input_grad_accessor.push_back(output_grad_accessor[my_off]);
assert(my_input_grad_accessor[i].domain == my_input_accessor[i].domain);
} else {
assert(false);
}
}
for (int i = 0; i < fused->op_num_weights[op]; i++) {
assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT);
my_weight_accessor[i] = weight_accessor[fused->op_weight_idx[i + woff]];
my_weight_grad_accessor[i] =
weight_grad_accessor[fused->op_weight_idx[i + woff]];
my_weight_accessor.push_back(
weight_accessor[fused->op_weight_idx[i + woff]]);
my_weight_grad_accessor.push_back(
weight_grad_accessor[fused->op_weight_idx[i + woff]]);
assert(my_weight_grad_accessor[i].domain.get_volume() ==
my_weight_accessor[i].domain.get_volume());
}
for (int i = 0; i < fused->op_num_outputs[op]; i++) {
assert(fused->op_output_source[i + ooff] == SOURCE_OUTPUT);
int my_off = fused->op_output_idx[i + ooff];
my_output_accessor[i] = output_accessor[my_off];
my_output_grad_accessor[i] = output_grad_accessor[my_off];
my_output_accessor.push_back(output_accessor[my_off]);
my_output_grad_accessor.push_back(output_grad_accessor[my_off]);
assert(my_output_grad_accessor[i].domain == my_output_accessor[i].domain);
}
switch (fused->op_op_type[op]) {
Expand Down Expand Up @@ -1880,7 +1881,7 @@ __host__ void FusedOp::backward_task(Task const *task,
int num_inputs = fused->op_num_inputs[op];
Kernels::Concat::backward_kernel_wrapper(m,
my_output_grad_accessor[0],
my_input_grad_accessor,
my_input_grad_accessor.data(),
num_inputs,
m->legion_axis);
break;
Expand Down

0 comments on commit 64c258f

Please sign in to comment.