Skip to content

Commit

Permalink
dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Dec 20, 2023
1 parent f65044d commit bcab56a
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 48 deletions.
7 changes: 7 additions & 0 deletions include/flexflow/ops/dropout.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
#include "flexflow/node.h"
#include "flexflow/operator.h"
#include "flexflow/ops/dropout_params.h"
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
#include <curand.h>
#include <curand_kernel.h>
#elif defined(FF_USE_HIP_ROCM)
#include <hiprand/hiprand.h>
#include <hiprand/hiprand_kernel.h>
#endif

namespace FlexFlow {

Expand Down
16 changes: 12 additions & 4 deletions include/flexflow/ops/kernels/dropout_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "flexflow/fftype.h"
#include "flexflow/op_meta.h"
#include "flexflow/ops/dropout.h"
#include "flexflow/accessor.h"

namespace FlexFlow {

Expand All @@ -17,33 +18,40 @@ class DropoutMeta : public OpMeta {
~DropoutMeta(void);
Realm::RegionInstance reserveInst;
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
curandState *state;
cudnnTensorDescriptor_t inputTensor, outputTensor;
cudnnDropoutDescriptor_t dropoutDesc;
#else
miopenTensorDescriptor_t inputTensor, outputTensor;
miopenDropoutDescriptor_t dropoutDesc;
hiprandState *state;
#endif
void *reserveSpace, *dropoutStates;
size_t reserveSpaceSize, dropoutStateSize;
size_t num_elements;
long long seed;
float rate;
};

namespace Kernels {
namespace Dropout {
void forward_kernel_wrapper(DropoutMeta *m,
float const *input_ptr,
float *output_ptr);
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);
void backward_kernel_wrapper(DropoutMeta *m,
float const *output_grad_ptr,
float *input_grad_ptr);
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorW const &input_grad);

namespace Internal {
void forward_kernel(DropoutMeta *m,
float const *input_ptr,
float *output_ptr,
size_t num_elements,
ffStream_t stream);
void backward_kernel(DropoutMeta *m,
float const *output_grad_ptr,
float *input_grad_ptr,
size_t num_elements,
ffStream_t stream);
} // namespace Internal
} // namespace Dropout
Expand Down
38 changes: 25 additions & 13 deletions src/ops/dropout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ using PCG::Node;

using namespace FlexFlow::Kernels::Dropout;

Tensor FFModel::dropout(const Tensor input,
Tensor FFModel::dropout(Tensor const input,
float rate,
unsigned long long seed,
char const *name) {
Expand Down Expand Up @@ -86,7 +86,7 @@ bool operator==(DropoutParams const &lhs, DropoutParams const &rhs) {
}

Dropout::Dropout(FFModel &model,
const ParallelTensor _input,
ParallelTensor const _input,
float _rate,
unsigned long long _seed,
char const *name)
Expand All @@ -111,12 +111,12 @@ Dropout::Dropout(FFModel &model,

Dropout::Dropout(FFModel &model,
Dropout const &other,
const ParallelTensor input)
ParallelTensor const input)
: Dropout(model, input, other.rate, other.seed, other.name) {}

Dropout::Dropout(FFModel &model,
DropoutParams const &params,
const ParallelTensor input,
ParallelTensor const input,
char const *name)
: Dropout(model, input, params.rate, params.seed, name) {}

Expand Down Expand Up @@ -210,12 +210,12 @@ void Dropout::forward_task(Task const *task,
assert(task->regions.size() == 2);
// const Dropout* dropout = (const Dropout*) task->args;
DropoutMeta *m = *((DropoutMeta **)task->local_args);
float const *input_ptr = helperGetTensorPointerRO<float>(
regions[0], task->regions[0], FID_DATA, ctx, runtime);
float *output_ptr = helperGetTensorPointerWO<float>(
regions[1], task->regions[1], FID_DATA, ctx, runtime);

forward_kernel_wrapper(m, input_ptr, output_ptr);

GenericTensorAccessorR input = helperGetGenericTensorAccessorRO(
m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime);
GenericTensorAccessorW output = helperGetGenericTensorAccessorWO(
m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);
forward_kernel_wrapper(m, input, output);
}

void Dropout::backward(FFModel const &ff) {
Expand Down Expand Up @@ -264,7 +264,13 @@ void Dropout::backward_task(Task const *task,
float const *output_grad_ptr = helperGetTensorPointerRO<float>(
regions[1], task->regions[1], FID_DATA, ctx, runtime);

backward_kernel_wrapper(m, output_grad_ptr, input_grad_ptr);

GenericTensorAccessorW input_grad = helperGetGenericTensorAccessorRW(
m->output_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime);
GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO(
m->input_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);

backward_kernel_wrapper(m, output_grad, input_grad);
}

void Dropout::serialize(Legion::Serializer &sez) const {
Expand Down Expand Up @@ -304,30 +310,36 @@ bool Dropout::measure_operator_cost(Simulator *sim,
sim->free_all();
float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT);
assert(input_ptr != NULL);

GenericTensorAccessorR input_acc(m->input_type[0], sub_input.get_domain(), input_ptr);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
assert(output_ptr != NULL);

GenericTensorAccessorW output_acc(m->output_type[0], sub_input.get_domain(), output_ptr);
cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

assert(m->profiling == false);

std::function<void()> forward, backward;
forward = [&] { forward_kernel_wrapper(m, input_ptr, output_ptr); };
forward = [&] { forward_kernel_wrapper(m, input_acc, output_acc); };
if (sim->computationMode == COMP_MODE_TRAINING) {
float *input_grad_ptr =
(float *)sim->allocate(sub_input.get_volume(), DT_FLOAT);
assert(input_grad_ptr != NULL);
GenericTensorAccessorW input_grad_acc(m->output_type[0], sub_input.get_domain(), input_grad_ptr);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_grad_ptr =
(float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
assert(output_grad_ptr != NULL);
GenericTensorAccessorR output_grad_acc(m->output_type[0], sub_input.get_domain(), output_grad_ptr);
cost_metrics.outputs_memory +=
cost_metrics.total_mem_diff_from(sim->offset);

backward = [&] {
backward_kernel_wrapper(m, output_grad_ptr, input_grad_ptr);
backward_kernel_wrapper(m, output_grad_acc, input_grad_acc);
};
}

Expand Down
126 changes: 101 additions & 25 deletions src/ops/kernels/dropout_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ DropoutMeta::DropoutMeta(FFHandler handler,
Domain const &output_domain)
: OpMeta(handler) {
profiling = dropout->profiling;
rate = dropout->rate;
seed = dropout->seed;
input_type[0] = dropout->data_type;
output_type[0] = dropout->data_type;

checkCUDNN(miopenCreateTensorDescriptor(&inputTensor));
checkCUDNN(miopenCreateTensorDescriptor(&outputTensor));
checkCUDNN(miopenCreateDropoutDescriptor(&dropoutDesc));
Expand Down Expand Up @@ -78,56 +83,127 @@ DropoutMeta::~DropoutMeta(void) {
namespace Kernels {
namespace Dropout {

__global__ void dropout_forward_kernel(float p,
long long seed,
size_t num_elements,
float const *input_ptr,
float *output_ptr) {
CUDA_KERNEL_LOOP(i, num_elements) {
float scale = 1.0 / p;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, i, 0, &state);
float rand = hiprand_uniform(&state);
if (input_ptr[i] < p) {
output_ptr[i] = 0;
} else {
output_ptr[i] = input_ptr[i] * scale;
}
}
}

__global__ void dropout_backward_kernel(float p,
long long seed,
size_t num_elements,
float const *input_ptr,
float *output_ptr) {
CUDA_KERNEL_LOOP(i, num_elements) {
float scale = 1.0 / p;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, i, 0, &state);
float rand = hiprand_uniform(&state);
if (input_ptr[i] < p) {
output_ptr[i] = 0;
} else {
output_ptr[i] = input_ptr[i] * scale;
}
}
}

void forward_kernel_wrapper(DropoutMeta *m,
float const *input_ptr,
float *output_ptr) {
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
hipStream_t stream;
checkCUDA(get_legion_stream(&stream));
Internal::forward_kernel(m, input_ptr, output_ptr, stream);

Internal::forward_kernel(m,
input.get_float_ptr(),
output.get_float_ptr(),
input.domain.get_volume(),
stream);

// printf("dropout %d\n", input.domain.get_volume());
// assert(false);
}

void backward_kernel_wrapper(DropoutMeta *m,
float const *output_grad_ptr,
float *input_grad_ptr) {
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorW const &input_grad) {
hipStream_t stream;
checkCUDA(get_legion_stream(&stream));
Internal::backward_kernel(m, output_grad_ptr, input_grad_ptr, stream);
Internal::backward_kernel(m,
output_grad.get_float_ptr(),
input_grad.get_float_ptr(),
output_grad.domain.get_volume(),
stream);
}

namespace Internal {

void forward_kernel(DropoutMeta *m,
float const *input_ptr,
float *output_ptr,
size_t num_elements,
hipStream_t stream) {
checkCUDNN(miopenSetStream(m->handle.dnn, stream));
int parallelism = num_elements;
hipLaunchKernelGGL(HIP_KERNEL_NAME(dropout_forward_kernel),
GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS, parallelism),
0,
stream,
m->seed,
m->rate,
num_elements,
input_ptr,
output_ptr);

checkCUDNN(miopenDropoutForward(m->handle.dnn,
m->dropoutDesc,
m->inputTensor /* not used */,
m->inputTensor,
input_ptr,
m->outputTensor,
output_ptr,
m->reserveSpace,
m->reserveSpaceSize));
// checkCUDNN(miopenDropoutForward(m->handle.dnn,
// m->dropoutDesc,
// m->inputTensor /* not used */,
// m->inputTensor,
// input_ptr,
// m->outputTensor,
// output_ptr,
// m->reserveSpace,
// m->reserveSpaceSize));
}

void backward_kernel(DropoutMeta *m,
float const *output_grad_ptr,
float *input_grad_ptr,
size_t num_elements,
hipStream_t stream) {
checkCUDNN(miopenSetStream(m->handle.dnn, stream));

checkCUDNN(miopenDropoutBackward(m->handle.dnn,
m->dropoutDesc,
m->inputTensor /* not used */,
m->outputTensor,
output_grad_ptr,
m->inputTensor,
input_grad_ptr,
m->reserveSpace,
m->reserveSpaceSize));
int parallelism = num_elements;
hipLaunchKernelGGL(HIP_KERNEL_NAME(dropout_backward_kernel),
GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS, parallelism),
0,
stream,
m->seed,
m->rate,
num_elements,
output_grad_ptr,
input_grad_ptr);
// checkCUDNN(miopenDropoutBackward(m->handle.dnn,
// m->dropoutDesc,
// m->inputTensor /* not used */,
// m->outputTensor,
// output_grad_ptr,
// m->inputTensor,
// input_grad_ptr,
// m->reserveSpace,
// m->reserveSpaceSize));
}

} // namespace Internal
Expand Down
12 changes: 6 additions & 6 deletions src/ops/kernels/dropout_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,19 @@ namespace Kernels {
namespace Dropout {

void forward_kernel_wrapper(DropoutMeta *m,
float const *input_ptr,
float *output_ptr) {
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
Internal::forward_kernel(m, input_ptr, output_ptr, stream);
Internal::forward_kernel(m, input.get_float_ptr(), output.get_float_ptr(), stream);
}

void backward_kernel_wrapper(DropoutMeta *m,
float const *output_grad_ptr,
float *input_grad_ptr) {
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorW const &input_grad) {
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
Internal::backward_kernel(m, output_grad_ptr, input_grad_ptr, stream);
Internal::backward_kernel(m, output_grad.get_float_ptr(), input_grad.get_float_ptr(), stream);
}

namespace Internal {
Expand Down

0 comments on commit bcab56a

Please sign in to comment.