Skip to content

Commit

Permalink
fix backward gelu, layernorm (#1187)
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc authored Oct 10, 2023
1 parent 50ff264 commit 5e34846
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/ops/element_unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,9 @@ __global__ void elewise_unary_backward_kernel(coord_t volume,
case OP_GELU: {
input_grad[i] =
(T)(output_grad[i] *
(0.5 * static_cast<T>(erfc(-input[i] * M_SQRT1_2)) -
0.5 * M_SQRT1_2 * input[i] * exp(-input[i] * input[i] * 0.5)));
(0.5 * static_cast<T>(erfc(-input[i] * M_SQRT1_2)) +
0.5 * M_SQRT1_2 * input[i] *
((2 / sqrt(M_PI)) * exp(-input[i] * input[i] * 0.5))));
break;
}
case OP_RSQRT: {
Expand Down
5 changes: 3 additions & 2 deletions src/ops/element_unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,9 @@ __global__ void elewise_unary_backward_kernel(coord_t volume,
case OP_GELU: {
input_grad[i] =
(T)(output_grad[i] *
(0.5 * static_cast<T>(erfc(-input[i] * M_SQRT1_2)) -
0.5 * M_SQRT1_2 * input[i] * exp(-input[i] * input[i] * 0.5)));
(0.5 * static_cast<T>(erfc(-input[i] * M_SQRT1_2)) +
0.5 * M_SQRT1_2 * input[i] *
((2 / sqrt(M_PI)) * exp(-input[i] * input[i] * 0.5))));
break;
}
case OP_RSQRT: {
Expand Down
94 changes: 94 additions & 0 deletions src/ops/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,82 @@ __global__ void GammaBetaBackwardCUDAKernel(int64_t M,
}
}

template <typename T>
__device__ __inline__ void compute_gI(T const *__restrict__ dY,
T const *__restrict__ X,
T const *__restrict__ mean,
T const *__restrict__ rstd,
T const *__restrict__ gamma,
T *dX,
int const N,
T *buf) {
auto const i1 = blockIdx.x;
const T mean_val = mean[i1];
const T rstd_val = rstd[i1];
T stats_x1{0}, stats_x2{0};
constexpr int unroll = 4;
auto l = unroll * threadIdx.x;
T const *X_i = X + i1 * N;
T const *dY_i = dY + i1 * N;
T *dX_i = dX + i1 * N;
// vectorized reads don't improve perf, so use regular unrolling

for (; l + unroll - 1 < N; l += blockDim.x * unroll) {
#pragma unroll
for (int k = 0; k < unroll; k++) {
T gamma_val = (gamma != nullptr) ? static_cast<T>(gamma[l + k]) : T(1);
const T c_h = static_cast<T>(X_i[l + k]);
const T c_loss = static_cast<T>(dY_i[l + k]);
stats_x1 += c_loss * gamma_val;
stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val;
}
}
for (; l < N; l++) {
T gamma_val = (gamma != nullptr) ? static_cast<T>(gamma[l]) : T(1);
const T c_h = static_cast<T>(X_i[l]);
const T c_loss = static_cast<T>(dY_i[l]);
stats_x1 += c_loss * gamma_val;
stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val;
}

stats_x1 = BlockReduceSum(stats_x1, buf);
stats_x2 = BlockReduceSum(stats_x2, buf);
if (threadIdx.x == 0) {
buf[0] = stats_x1;
buf[1] = stats_x2;
}
__syncthreads();
stats_x1 = buf[0];
stats_x2 = buf[1];
T fH = N;
T term1 = (T(1) / fH) * rstd_val;

for (int l = threadIdx.x; l < N; l += blockDim.x) {
const T x = X_i[l];
const T dy = dY_i[l];
T gamma_val = (gamma != nullptr) ? static_cast<T>(gamma[l]) : T(1);
T f_grad_input = fH * gamma_val * dy;
f_grad_input -= (x - mean_val) * rstd_val * stats_x2;
f_grad_input -= stats_x1;
f_grad_input *= term1;
dX_i[l] = f_grad_input;
}
}

template <typename T>
__global__ void layer_norm_grad_input_kernel(T const *__restrict__ dY,
T const *__restrict__ X,
T const *__restrict__ mean,
T const *__restrict__ rstd,
T const *__restrict__ gamma,
T *dX,
int const N) {
alignas(sizeof(double)) extern __shared__ char s_data1[];
T *buf = reinterpret_cast<T *>(&s_data1);

compute_gI(dY, X, mean, rstd, gamma, dX, N, buf);
}

/*static*/
template <typename T>
void LayerNorm::backward_kernel(LayerNormMeta const *m,
Expand Down Expand Up @@ -401,6 +477,24 @@ void LayerNorm::backward_kernel(LayerNormMeta const *m,
static_cast<T *>(m->db_ptr),
static_cast<T *>(m->scale_ptr),
static_cast<T *>(m->bias_ptr));

int const warp_size = C10_WARP_SIZE;
int const num_threads = 128;
const dim3 blocks(M);
int nshared = (num_threads / warp_size) * sizeof(T);

hipLaunchKernelGGL(HIP_KERNEL_NAME(layer_norm_grad_input_kernel),
blocks,
num_threads,
nshared,
stream,
output_grad_ptr,
input_ptr,
static_cast<T *>(m->mean_ptr),
static_cast<T *>(m->rstd_ptr),
gamma_ptr,
input_grad_ptr,
N);
if (gamma_grad_ptr != NULL || beta_grad_ptr != NULL) {
if (M < 512) {
// For small batch size, do colwise reduce directly
Expand Down
89 changes: 89 additions & 0 deletions src/ops/layer_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,82 @@ __global__ void GammaBetaBackwardCUDAKernel(int64_t M,
}
}

template <typename T>
__device__ __inline__ void compute_gI(T const *__restrict__ dY,
T const *__restrict__ X,
T const *__restrict__ mean,
T const *__restrict__ rstd,
T const *__restrict__ gamma,
T *dX,
int const N,
T *buf) {
auto const i1 = blockIdx.x;
const T mean_val = mean[i1];
const T rstd_val = rstd[i1];
T stats_x1{0}, stats_x2{0};
constexpr int unroll = 4;
auto l = unroll * threadIdx.x;
T const *X_i = X + i1 * N;
T const *dY_i = dY + i1 * N;
T *dX_i = dX + i1 * N;
// vectorized reads don't improve perf, so use regular unrolling

for (; l + unroll - 1 < N; l += blockDim.x * unroll) {
#pragma unroll
for (int k = 0; k < unroll; k++) {
T gamma_val = (gamma != nullptr) ? static_cast<T>(gamma[l + k]) : T(1);
const T c_h = static_cast<T>(X_i[l + k]);
const T c_loss = static_cast<T>(dY_i[l + k]);
stats_x1 += c_loss * gamma_val;
stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val;
}
}
for (; l < N; l++) {
T gamma_val = (gamma != nullptr) ? static_cast<T>(gamma[l]) : T(1);
const T c_h = static_cast<T>(X_i[l]);
const T c_loss = static_cast<T>(dY_i[l]);
stats_x1 += c_loss * gamma_val;
stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val;
}

stats_x1 = BlockReduceSum(stats_x1, buf);
stats_x2 = BlockReduceSum(stats_x2, buf);
if (threadIdx.x == 0) {
buf[0] = stats_x1;
buf[1] = stats_x2;
}
__syncthreads();
stats_x1 = buf[0];
stats_x2 = buf[1];
T fH = N;
T term1 = (T(1) / fH) * rstd_val;

for (int l = threadIdx.x; l < N; l += blockDim.x) {
const T x = X_i[l];
const T dy = dY_i[l];
T gamma_val = (gamma != nullptr) ? static_cast<T>(gamma[l]) : T(1);
T f_grad_input = fH * gamma_val * dy;
f_grad_input -= (x - mean_val) * rstd_val * stats_x2;
f_grad_input -= stats_x1;
f_grad_input *= term1;
dX_i[l] = f_grad_input;
}
}

template <typename T>
__global__ void layer_norm_grad_input_kernel(T const *__restrict__ dY,
T const *__restrict__ X,
T const *__restrict__ mean,
T const *__restrict__ rstd,
T const *__restrict__ gamma,
T *dX,
int const N) {
alignas(sizeof(double)) extern __shared__ char s_data1[];
T *buf = reinterpret_cast<T *>(&s_data1);

compute_gI(dY, X, mean, rstd, gamma, dX, N, buf);
}

/*static*/
template <typename T>
void LayerNorm::backward_kernel(LayerNormMeta const *m,
Expand Down Expand Up @@ -497,6 +573,19 @@ void LayerNorm::backward_kernel(LayerNormMeta const *m,
static_cast<T *>(m->db_ptr),
static_cast<T *>(m->scale_ptr),
static_cast<T *>(m->bias_ptr));
int const warp_size = C10_WARP_SIZE;
int const num_threads = 128;
const dim3 blocks(M);
int nshared = (num_threads / warp_size) * sizeof(T);
layer_norm_grad_input_kernel<<<blocks, num_threads, nshared, stream>>>(
output_grad_ptr,
input_ptr,
static_cast<T *>(m->mean_ptr),
static_cast<T *>(m->rstd_ptr),
gamma_ptr,
input_grad_ptr,
N);

if (gamma_grad_ptr != NULL || beta_grad_ptr != NULL) {
if (M < 512) {
// For small batch size, do colwise reduce directly
Expand Down

0 comments on commit 5e34846

Please sign in to comment.