From e1fb8a8ff8378701252a30002d0373de19ef2c82 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Sun, 8 Oct 2023 23:27:28 -0700 Subject: [PATCH] Layer norm kernels --- .../include/kernels/layer_norm_kernels.h | 35 +++-- lib/kernels/src/cuda/layer_norm_kernels.cu | 134 +++++++++--------- lib/kernels/src/hip/layer_norm_kernels.cpp | 130 +++++++++-------- 3 files changed, 160 insertions(+), 139 deletions(-) diff --git a/lib/kernels/include/kernels/layer_norm_kernels.h b/lib/kernels/include/kernels/layer_norm_kernels.h index a49e1b3483..fb49854653 100644 --- a/lib/kernels/include/kernels/layer_norm_kernels.h +++ b/lib/kernels/include/kernels/layer_norm_kernels.h @@ -5,42 +5,39 @@ namespace FlexFlow { -class LayerNormPerDeviceState : public PerDeviceOpState { -public: - LayerNormPerDeviceState(FFHandler handle, - bool elementwise_affine_, - int64_t effective_batch_size_, - int64_t effective_num_elements_, - bool profiling_, - float eps_); - -public: - bool elementwise_affine; - int64_t effective_batch_size, effective_num_elements; - float eps; +struct LayerNormPerDeviceState { float *mean, *rstd, *ds, *db, *scale, *bias; - char op_name[MAX_OPNAME]; - DataType data_type; }; namespace Kernels { namespace LayerNorm { +LayerNormPerDeviceState init_kernel(PerDeviceFFHandle handle, + int64_t batch_size); + void forward_kernel(ffStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorW const &gamma, - GenericTensorAccessorW const &beta); + GenericTensorAccessorW const &beta, + DataType data_type, + int64_t batch_size, + int64_t num_elements, + float eps); void backward_kernel(ffStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &gamma, GenericTensorAccessorW const &gamma_grad, - GenericTensorAccessorW const &beta_grad); + GenericTensorAccessorW const &beta_grad, + DataType data_type, + int64_t batch_size, + int64_t num_elements, + float eps); } // namespace LayerNorm } // namespace Kernels diff --git a/lib/kernels/src/cuda/layer_norm_kernels.cu b/lib/kernels/src/cuda/layer_norm_kernels.cu index 65d33bec5e..eb9e291750 100644 --- a/lib/kernels/src/cuda/layer_norm_kernels.cu +++ b/lib/kernels/src/cuda/layer_norm_kernels.cu @@ -24,86 +24,73 @@ constexpr int kCUDABlockReduceNumThreads = 512; constexpr int kCUDANumThreads = 256; constexpr int kColwiseReduceTileSize = 32; -LayerNormPerDeviceState::LayerNormPerDeviceState( - FFHandler handle, - bool elementwise_affine_, - int64_t effective_batch_size_, - int64_t effective_num_elements_, - bool profiling_, - float eps_) - : PerDeviceOpState(handle) { - elementwise_affine = elementwise_affine_; - effective_batch_size = effective_batch_size_; - effective_num_elements = effective_num_elements_; - profiling = profiling_; - eps = eps_; - checkCUDA(cudaMalloc(&mean_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(cudaMalloc(&rstd_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(cudaMalloc(&ds_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(cudaMalloc(&db_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(cudaMalloc(&scale_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(cudaMalloc(&bias_ptr, sizeof(float) * effective_batch_size)); -} - namespace Kernels { namespace LayerNorm { +LayerNormPerDeviceState init_kernel(PerDeviceFFHandle handle, + int64_t effective_batch_size) { + float *mean, *rstd, *ds, *db, *scale, *bias; + checkCUDA(cudaMalloc(&mean, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&rstd, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&ds, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&db, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&scale, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&bias, sizeof(float) * batch_size)); + + LayerNormPerDeviceState per_device_state = {mean, rstd, ds, db, scale, bias}; + return per_device_state; +} + template struct ForwardKernel { void operator()(cudaStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorW const &gamma, - GenericTensorAccessorW const &beta) { + GenericTensorAccessorW const &beta, + int64_t batch_size, + int64_t num_elements, + float eps) { RowwiseMomentsCUDAKernel - <<effective_batch_size, kCUDABlockReduceNumThreads, 0, stream>>>( - m->effective_num_elements, - m->eps, - input.get(), - m->mean_ptr, - m->rstd_ptr); + <<>>( + num_elements, eps, input.get(), m.mean, m.rstd); LayerNormForwardCUDAKernel - <<effective_batch_size, kCUDANumThreads, 0, stream>>>( - m->effective_num_elements, - input.get(), - m->mean_ptr, - m->rstd_ptr, - gamma.get(), - beta.get(), - output.get()); + <<>>(num_elements, + input.get(), + m.mean, + m.rstd, + gamma.get(), + beta.get(), + output.get()); } } template struct BackwardKernel { void operator()(cudaStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &gamma, GenericTensorAccessorW const &gamma_grad, - GenericTensorAccessorW const &beta_grad) { - const int64_t M = m->effective_batch_size; - const int64_t N = m->effective_num_elements; + GenericTensorAccessorW const &beta_grad, + int64_t batch_size, + int64_t num_elements, + float eps) { + const int64_t M = batch_size; + const int64_t N = num_elements; ComputeInternalGradientsCUDAKernel <<>>(N, output_grad.get(), input.get(), gamma.get(), - m->ds_ptr, - m->db_ptr); + m.ds, + m.db); const int64_t B = (M + kCUDANumThreads - 1) / kCUDANumThreads; - ComputeGradientFusedParamsCUDAKernel - <<>>(M, - N, - m->mean_ptr, - m->rstd_ptr, - m->ds_ptr, - m->db_ptr, - m->scale_ptr, - m->bias_ptr); + ComputeGradientFusedParamsCUDAKernel<<>>( + M, N, m.mean, m.rstd, m.ds, m.db, m.scale, m.bias); if (gamma_grad.get() != NULL || beta_grad.get() != NULL) { if (M < 512) { // For small batch size, do colwise reduce directly @@ -113,8 +100,8 @@ struct BackwardKernel { N, output_grad.get(), input.get(), - m->mean_ptr, - m->rstd_ptr, + m.mean, + m.rstd, gamma_grad.get(), beta_grad.get()); } else { @@ -127,8 +114,8 @@ struct BackwardKernel { N, output_grad.get(), input.get(), - m->mean_ptr, - m->rstd_ptr, + m.mean, + m.rstd, gamma_grad.get(), beta_grad.get()); } @@ -137,24 +124,40 @@ struct BackwardKernel { } void forward_kernel(cudaStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorW const &gamma, - GenericTensorAccessorW const &beta) { - DataTypeDispatch1{}( - m->data_type, stream, m, input, output, gamma, beta); + GenericTensorAccessorW const &beta, + DataType data_type, + int64_t batch_size, + int64_t num_elements, + float eps) { + DataTypeDispatch1{}(data_type, + stream, + m, + input, + output, + gamma, + beta, + batch_size, + num_elements, + eps); } void backward_kernel(cudaStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &gamma, GenericTensorAccessorW const &gamma_grad, - GenericTensorAccessorW const &beta_grad) { - DataTypeDispatch1{}(m->data_type, + GenericTensorAccessorW const &beta_grad, + DataType data_type, + int64_t batch_size, + int64_t num_elements, + float eps) { + DataTypeDispatch1{}(data_type, stream, m, output_grad, @@ -162,7 +165,10 @@ void backward_kernel(cudaStream_t stream, input_grad, gamma, gamma_grad, - beta_grad); + beta_grad, + batch_size, + num_elements, + eps); } template diff --git a/lib/kernels/src/hip/layer_norm_kernels.cpp b/lib/kernels/src/hip/layer_norm_kernels.cpp index dc2685ef28..5cf82b213e 100644 --- a/lib/kernels/src/hip/layer_norm_kernels.cpp +++ b/lib/kernels/src/hip/layer_norm_kernels.cpp @@ -24,57 +24,53 @@ constexpr int kCUDABlockReduceNumThreads = 512; constexpr int kCUDANumThreads = 256; constexpr int kColwiseReduceTileSize = 32; -LayerNormPerDeviceState::LayerNormPerDeviceState( - FFHandler handle, - bool elementwise_affine_, - int64_t effective_batch_size_, - int64_t effective_num_elements_, - bool profiling_, - float eps_) - : PerDeviceOpState(handle) { - elementwise_affine = elementwise_affine_; - effective_batch_size = effective_batch_size_; - effective_num_elements = effective_num_elements_; - profiling = profiling_; - eps = eps_; - checkCUDA(hipMalloc(&mean_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&rstd_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&ds_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&db_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&scale_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&bias_ptr, sizeof(float) * effective_batch_size)); -} - namespace Kernels { namespace LayerNorm { +LayerNormPerDeviceState init_kernel(PerDeviceFFHandle handle, + int64_t effective_batch_size) { + float *mean, *rstd, *ds, *db, *scale, *bias; + checkCUDA(cudaMalloc(&mean, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&rstd, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&ds, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&db, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&scale, sizeof(float) * batch_size)); + checkCUDA(cudaMalloc(&bias, sizeof(float) * batch_size)); + + LayerNormPerDeviceState per_device_state = {mean, rstd, ds, db, scale, bias}; + return per_device_state; +} + template struct ForwardKernel { void operator()(hipStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorW const &gamma, - GenericTensorAccessorW const &beta) { + GenericTensorAccessorW const &beta, + int64_t batch_size, + int64_t num_elements, + float eps) { hipLaunchKernelGGL(HIP_KERNEL_NAME(RowwiseMomentsCUDAKernel), - m->effective_batch_size, + batch_size, kCUDABlockReduceNumThreads, 0, stream, - m->effective_num_elements, - m->eps, + num_elements, + m.eps, input.get(), - m->mean_ptr, - m->rstd_ptr); + m.mean, + m.rstd); hipLaunchKernelGGL(HIP_KERNEL_NAME(LayerNormForwardCUDAKernel), - m->effective_batch_size, + batch_size, kCUDANumThreads, 0, stream, - m->effective_num_elements, + num_elements, input.get(), - m->mean_ptr, - m->rstd_ptr, + m.mean, + m.rstd, gamma.get(), beta.get(), output.get()); @@ -84,15 +80,18 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(hipStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &gamma, GenericTensorAccessorW const &gamma_grad, - GenericTensorAccessorW const &beta_grad) { - const int64_t M = m->effective_batch_size; - const int64_t N = m->effective_num_elements; + GenericTensorAccessorW const &beta_grad, + int64_t batch_size, + int64_t num_elements, + float eps) { + const int64_t M = batch_size; + const int64_t N = num_elements; hipLaunchKernelGGL(HIP_KERNEL_NAME(ComputeInternalGradientsCUDAKernel), M, kCUDABlockReduceNumThreads, @@ -102,8 +101,8 @@ struct BackwardKernel { output_grad.get(), input.get(), gamma.get(), - m->ds_ptr, - m->db_ptr); + m.ds, + m.db); const int64_t B = (M + kCUDANumThreads - 1) / kCUDANumThreads; hipLaunchKernelGGL(HIP_KERNEL_NAME(ComputeGradientFusedParamsCUDAKernel), B, @@ -112,12 +111,12 @@ struct BackwardKernel { stream, M, N, - m->mean_ptr, - m->rstd_ptr, - m->ds_ptr, - m->db_ptr, - m->scale_ptr, - m->bias_ptr); + m.mean, + m.rstd, + m.ds, + m.db, + m.scale, + m.bias); if (gamma_grad.get() != NULL || beta_grad.get() != NULL) { if (M < 512) { // For small batch size, do colwise reduce directly @@ -132,8 +131,8 @@ struct BackwardKernel { N, output_grad.get(), input.get(), - m->mean_ptr, - m->rstd_ptr, + m.mean, + m.rstd, gamma_grad.get(), beta_grad.get()); } else { @@ -150,8 +149,8 @@ struct BackwardKernel { N, output_grad.get(), input.get(), - m->mean_ptr, - m->rstd_ptr, + m.mean, + m.rstd, gamma_grad.get(), beta_grad.get()); } @@ -159,24 +158,40 @@ struct BackwardKernel { } void forward_kernel(hipStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorW const &gamma, - GenericTensorAccessorW const &beta) { - DataTypeDispatch1{}( - m->data_type, stream, m, input, output, gamma, beta); + GenericTensorAccessorW const &beta, + DataType data_type, + int64_t batch_size, + int64_t num_elements, + float eps) { + DataTypeDispatch1{}(data_type, + stream, + m, + input, + output, + gamma, + beta, + batch_size, + num_elements, + eps); } void backward_kernel(hipStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &gamma, GenericTensorAccessorW const &gamma_grad, - GenericTensorAccessorW const &beta_grad) { - DataTypeDispatch1{}(m->data_type, + GenericTensorAccessorW const &beta_grad, + DataType data_type, + int64_t batch_size, + int64_t num_elements, + float eps) { + DataTypeDispatch1{}(data_type, stream, m, output_grad, @@ -184,7 +199,10 @@ struct BackwardKernel { input_grad, gamma, gamma_grad, - beta_grad); + beta_grad, + batch_size, + num_elements, + eps); } template