Skip to content

Commit

Permalink
customized kernel for broadcasting add.
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Nov 3, 2023
1 parent 17a1c4e commit f65044d
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 4 deletions.
1 change: 1 addition & 0 deletions include/flexflow/ops/element_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ElementBinary : public Op {
public:
bool inplace_a, has_same_operands;
bool broadcast_input1, broadcast_input2;
int batch_size;
};

}; // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/ops/kernels/element_binary_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class ElementBinaryMeta : public OpMeta {
OperatorType op_type;
bool inplace_a, has_same_operands;
bool broadcast_input1, broadcast_input2;
int batch_size;
size_t replicate_size;
char op_name[MAX_OPNAME];
};

Expand Down
9 changes: 9 additions & 0 deletions src/ops/element_binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ ElementBinary::ElementBinary(FFModel &model,
numdim, dims, in1->data_type, this);
broadcast_input1 = (inputs[0]->get_volume() != outputs[0]->get_volume());
broadcast_input2 = (inputs[1]->get_volume() != outputs[0]->get_volume());

batch_size = dims[numdim - 2].size;

}

ElementBinary::ElementBinary(
Expand Down Expand Up @@ -337,6 +340,8 @@ OpMeta *ElementBinary::init_task(Task const *task,
m->has_same_operands = eb->has_same_operands;
m->broadcast_input1 = eb->broadcast_input1;
m->broadcast_input2 = eb->broadcast_input2;
m->batch_size = eb->batch_size;

std::strcpy(m->op_name, eb->name);
Domain input1_domain = runtime->get_index_space_domain(
ctx, task->regions[0].region.get_index_space());
Expand Down Expand Up @@ -368,6 +373,10 @@ OpMeta *ElementBinary::init_task(Task const *task,
} else {
output_domain = input1_domain;
}
m->replicate_size = m->broadcast_input1
? (input1_domain.get_volume() / m->batch_size)
: (input2_domain.get_volume() / m->batch_size);

assert(task->regions.size() == regions.size());
assert(regions.size() == num_regions);
init_kernel(m, input1_domain, input2_domain, output_domain);
Expand Down
32 changes: 28 additions & 4 deletions src/ops/kernels/element_binary_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,12 @@ void forward_kernel_wrapper(ElementBinaryMeta const *m,
float *out_ptr) {
hipStream_t stream;
checkCUDA(get_legion_stream(&stream));

hipEvent_t t_start, t_end;
if (m->profiling) {
hipEventCreate(&t_start);
hipEventCreate(&t_end);
hipEventRecord(t_start, stream);
}
// print_tensor<float>(in1_ptr, in1_domain.get_volume(), "input1:");
// print_tensor<float>(in2_ptr, in2_domain.get_volume(), "input2:");
Internal::forward_kernel(m, in1_ptr, in2_ptr, out_ptr, stream);
// print_tensor<float>(out_ptr, in1_domain.get_volume(), "output:");
if (m->profiling) {
Expand Down Expand Up @@ -199,6 +196,21 @@ __global__ void elewise_binary_forward_kernel(coord_t volume,
}
}

// for simplicity, assume the replicate dimension is the batchsize
__global__ void
elewise_binary_forward_kernel_broadcast2(float const *in1_ptr,
float const *in2_ptr,
float *output_ptr,
size_t volume,
size_t batch_size,
size_t replicate_size) {
CUDA_KERNEL_LOOP(i, volume) {
size_t batch = i / replicate_size;
output_ptr[i] =
in1_ptr[i] + in2_ptr[batch * replicate_size + i % replicate_size];
}
}

__global__ void elewise_binary_backward_kernel(coord_t volume,
float const alpha,
float const beta,
Expand Down Expand Up @@ -245,7 +257,6 @@ void forward_kernel(ElementBinaryMeta const *m,
hipStream_t stream) {
checkCUDA(hipblasSetStream(m->handle.blas, stream));
checkCUDNN(miopenSetStream(m->handle.dnn, stream));

float alpha1 = 1.0f, alpha2 = 1.0f, beta = 0.0f;
switch (m->op_type) {
case OP_EW_SUB:
Expand Down Expand Up @@ -284,6 +295,19 @@ void forward_kernel(ElementBinaryMeta const *m,
&alpha1,
m->outputTensor,
out_ptr));
} else if (m->op_type == OP_EW_ADD && m->broadcast_input2) {
int parallelism = m->batch_size * m->replicate_size;
hipLaunchKernelGGL(elewise_binary_forward_kernel_broadcast2,
GET_BLOCKS(parallelism),
CUDA_NUM_THREADS,
0,
stream,
in1_ptr,
in2_ptr,
out_ptr,
m->batch_size * m->replicate_size,
m->batch_size,
m->replicate_size);
} else {
checkCUDNN(miopenOpTensor(m->handle.dnn,
m->opDesc,
Expand Down

0 comments on commit f65044d

Please sign in to comment.