Skip to content

Commit

Permalink
Layer norm kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
reyna-abhyankar committed Oct 9, 2023
1 parent e09080c commit e1fb8a8
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 139 deletions.
35 changes: 16 additions & 19 deletions lib/kernels/include/kernels/layer_norm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,39 @@

namespace FlexFlow {

class LayerNormPerDeviceState : public PerDeviceOpState {
public:
LayerNormPerDeviceState(FFHandler handle,
bool elementwise_affine_,
int64_t effective_batch_size_,
int64_t effective_num_elements_,
bool profiling_,
float eps_);

public:
bool elementwise_affine;
int64_t effective_batch_size, effective_num_elements;
float eps;
struct LayerNormPerDeviceState {
float *mean, *rstd, *ds, *db, *scale, *bias;
char op_name[MAX_OPNAME];
DataType data_type;
};

namespace Kernels {
namespace LayerNorm {

LayerNormPerDeviceState init_kernel(PerDeviceFFHandle handle,
int64_t batch_size);

void forward_kernel(ffStream_t stream,
LayerNormPerDeviceState const *m,
LayerNormPerDeviceState const &m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
GenericTensorAccessorW const &gamma,
GenericTensorAccessorW const &beta);
GenericTensorAccessorW const &beta,
DataType data_type,
int64_t batch_size,
int64_t num_elements,
float eps);

void backward_kernel(ffStream_t stream,
LayerNormPerDeviceState const *m,
LayerNormPerDeviceState const &m,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &gamma,
GenericTensorAccessorW const &gamma_grad,
GenericTensorAccessorW const &beta_grad);
GenericTensorAccessorW const &beta_grad,
DataType data_type,
int64_t batch_size,
int64_t num_elements,
float eps);

} // namespace LayerNorm
} // namespace Kernels
Expand Down
134 changes: 70 additions & 64 deletions lib/kernels/src/cuda/layer_norm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,86 +24,73 @@ constexpr int kCUDABlockReduceNumThreads = 512;
constexpr int kCUDANumThreads = 256;
constexpr int kColwiseReduceTileSize = 32;

LayerNormPerDeviceState::LayerNormPerDeviceState(
FFHandler handle,
bool elementwise_affine_,
int64_t effective_batch_size_,
int64_t effective_num_elements_,
bool profiling_,
float eps_)
: PerDeviceOpState(handle) {
elementwise_affine = elementwise_affine_;
effective_batch_size = effective_batch_size_;
effective_num_elements = effective_num_elements_;
profiling = profiling_;
eps = eps_;
checkCUDA(cudaMalloc(&mean_ptr, sizeof(float) * effective_batch_size));
checkCUDA(cudaMalloc(&rstd_ptr, sizeof(float) * effective_batch_size));
checkCUDA(cudaMalloc(&ds_ptr, sizeof(float) * effective_batch_size));
checkCUDA(cudaMalloc(&db_ptr, sizeof(float) * effective_batch_size));
checkCUDA(cudaMalloc(&scale_ptr, sizeof(float) * effective_batch_size));
checkCUDA(cudaMalloc(&bias_ptr, sizeof(float) * effective_batch_size));
}

namespace Kernels {
namespace LayerNorm {

LayerNormPerDeviceState init_kernel(PerDeviceFFHandle handle,
int64_t effective_batch_size) {
float *mean, *rstd, *ds, *db, *scale, *bias;
checkCUDA(cudaMalloc(&mean, sizeof(float) * batch_size));
checkCUDA(cudaMalloc(&rstd, sizeof(float) * batch_size));
checkCUDA(cudaMalloc(&ds, sizeof(float) * batch_size));
checkCUDA(cudaMalloc(&db, sizeof(float) * batch_size));
checkCUDA(cudaMalloc(&scale, sizeof(float) * batch_size));
checkCUDA(cudaMalloc(&bias, sizeof(float) * batch_size));

LayerNormPerDeviceState per_device_state = {mean, rstd, ds, db, scale, bias};
return per_device_state;
}

template <DataType T>
struct ForwardKernel {
void operator()(cudaStream_t stream,
LayerNormPerDeviceState const *m,
LayerNormPerDeviceState const &m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
GenericTensorAccessorW const &gamma,
GenericTensorAccessorW const &beta) {
GenericTensorAccessorW const &beta,
int64_t batch_size,
int64_t num_elements,
float eps) {
RowwiseMomentsCUDAKernel<float>
<<<m->effective_batch_size, kCUDABlockReduceNumThreads, 0, stream>>>(
m->effective_num_elements,
m->eps,
input.get<T>(),
m->mean_ptr,
m->rstd_ptr);
<<<batch_size, kCUDABlockReduceNumThreads, 0, stream>>>(
num_elements, eps, input.get<T>(), m.mean, m.rstd);
LayerNormForwardCUDAKernel<float>
<<<m->effective_batch_size, kCUDANumThreads, 0, stream>>>(
m->effective_num_elements,
input.get<T>(),
m->mean_ptr,
m->rstd_ptr,
gamma.get<T>(),
beta.get<T>(),
output.get<T>());
<<<batch_size, kCUDANumThreads, 0, stream>>>(num_elements,
input.get<T>(),
m.mean,
m.rstd,
gamma.get<T>(),
beta.get<T>(),
output.get<T>());
}
}

template <DataType T>
struct BackwardKernel {
void operator()(cudaStream_t stream,
LayerNormPerDeviceState const *m,
LayerNormPerDeviceState const &m,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &gamma,
GenericTensorAccessorW const &gamma_grad,
GenericTensorAccessorW const &beta_grad) {
const int64_t M = m->effective_batch_size;
const int64_t N = m->effective_num_elements;
GenericTensorAccessorW const &beta_grad,
int64_t batch_size,
int64_t num_elements,
float eps) {
const int64_t M = batch_size;
const int64_t N = num_elements;
ComputeInternalGradientsCUDAKernel<T>
<<<M, kCUDABlockReduceNumThreads, 0, stream>>>(N,
output_grad.get<T>(),
input.get<T>(),
gamma.get<T>(),
m->ds_ptr,
m->db_ptr);
m.ds,
m.db);
const int64_t B = (M + kCUDANumThreads - 1) / kCUDANumThreads;
ComputeGradientFusedParamsCUDAKernel<T>
<<<B, kCUDANumThreads, 0, stream>>>(M,
N,
m->mean_ptr,
m->rstd_ptr,
m->ds_ptr,
m->db_ptr,
m->scale_ptr,
m->bias_ptr);
ComputeGradientFusedParamsCUDAKernel<T><<<B, kCUDANumThreads, 0, stream>>>(
M, N, m.mean, m.rstd, m.ds, m.db, m.scale, m.bias);
if (gamma_grad.get<T>() != NULL || beta_grad.get<T>() != NULL) {
if (M < 512) {
// For small batch size, do colwise reduce directly
Expand All @@ -113,8 +100,8 @@ struct BackwardKernel {
N,
output_grad.get<T>(),
input.get<T>(),
m->mean_ptr,
m->rstd_ptr,
m.mean,
m.rstd,
gamma_grad.get<T>(),
beta_grad.get<T>());
} else {
Expand All @@ -127,8 +114,8 @@ struct BackwardKernel {
N,
output_grad.get<T>(),
input.get<T>(),
m->mean_ptr,
m->rstd_ptr,
m.mean,
m.rstd,
gamma_grad.get<T>(),
beta_grad.get<T>());
}
Expand All @@ -137,32 +124,51 @@ struct BackwardKernel {
}

void forward_kernel(cudaStream_t stream,
LayerNormPerDeviceState const *m,
LayerNormPerDeviceState const &m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
GenericTensorAccessorW const &gamma,
GenericTensorAccessorW const &beta) {
DataTypeDispatch1<ForwardKernel>{}(
m->data_type, stream, m, input, output, gamma, beta);
GenericTensorAccessorW const &beta,
DataType data_type,
int64_t batch_size,
int64_t num_elements,
float eps) {
DataTypeDispatch1<ForwardKernel>{}(data_type,
stream,
m,
input,
output,
gamma,
beta,
batch_size,
num_elements,
eps);
}

void backward_kernel(cudaStream_t stream,
LayerNormPerDeviceState const *m,
LayerNormPerDeviceState const &m,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &gamma,
GenericTensorAccessorW const &gamma_grad,
GenericTensorAccessorW const &beta_grad) {
DataTypeDispatch1<BackwardKernel>{}(m->data_type,
GenericTensorAccessorW const &beta_grad,
DataType data_type,
int64_t batch_size,
int64_t num_elements,
float eps) {
DataTypeDispatch1<BackwardKernel>{}(data_type,
stream,
m,
output_grad,
input,
input_grad,
gamma,
gamma_grad,
beta_grad);
beta_grad,
batch_size,
num_elements,
eps);
}

template <typename T>
Expand Down
Loading

0 comments on commit e1fb8a8

Please sign in to comment.