Skip to content

Commit

Permalink
style: format
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzhuofu committed Nov 23, 2024
1 parent 1f6dab4 commit dc15e34
Showing 1 changed file with 50 additions and 22 deletions.
72 changes: 50 additions & 22 deletions src/ops/kernels/residual_rms_norm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
* limitations under the License.
*/

#include <numeric>
#include "flashinfer/utils.cuh"
#include "flashinfer/math.cuh"
#include "flashinfer/utils.cuh"
#include "flashinfer/vec_dtypes.cuh"
Expand All @@ -23,6 +21,7 @@
#include "flexflow/ops/residual_rms_norm.h"
#include "flexflow/utils/cuda_helper.h"
#include <cublas_v2.h>
#include <numeric>

namespace FlexFlow {
// declare Legion names
Expand Down Expand Up @@ -61,12 +60,17 @@ ResidualRMSNormMeta::~ResidualRMSNormMeta(void) {
}
}

// Adopted from flashinfer (https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/norm.cuh)
// Adopted from flashinfer
// (https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/norm.cuh)
// Main modification is for non-inplace computation
template <uint32_t VEC_SIZE, typename T>
__global__ void FusedAddRMSNormKernel(T const * __restrict__ input, T const * __restrict__ residual, T const * __restrict__ weight,
T* __restrict__ output, T* __restrict__ residual_output,
const uint32_t d, float eps) {
__global__ void FusedAddRMSNormKernel(T const *__restrict__ input,
T const *__restrict__ residual,
T const *__restrict__ weight,
T *__restrict__ output,
T *__restrict__ residual_output,
const uint32_t d,
float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
Expand All @@ -86,8 +90,10 @@ __global__ void FusedAddRMSNormKernel(T const * __restrict__ input, T const * __
flashinfer::vec_t<T, VEC_SIZE> residual_output_vec;
residual_output_vec.fill(0);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE +
thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE +
thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
Expand All @@ -97,7 +103,9 @@ __global__ void FusedAddRMSNormKernel(T const * __restrict__ input, T const * __
residual_output_vec[j] = (T)x;
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
residual_output_vec.store(residual_output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_output_vec.store(residual_output + bx * d +
i * num_threads * VEC_SIZE +
thread_id * VEC_SIZE);
}
}

Expand Down Expand Up @@ -132,37 +140,50 @@ __global__ void FusedAddRMSNormKernel(T const * __restrict__ input, T const * __
residual_output_vec.fill(0);
output_vec.fill(0);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_output_vec.load(residual_output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE +
thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE +
thread_id * VEC_SIZE);
residual_output_vec.load(residual_output + bx * d +
i * num_threads * VEC_SIZE +
thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] = float(residual_output_vec[j]) * rms_rcp * float(weight_vec[j]);
output_vec[j] =
float(residual_output_vec[j]) * rms_rcp * float(weight_vec[j]);
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
output_vec.store(output + bx * d + i * num_threads * VEC_SIZE +
thread_id * VEC_SIZE);
}
}
}

template <typename T>
cudaError_t FusedAddRMSNorm(T const * input, T const * residual, T const * weight,
T * output, T * residual_output,
uint32_t batch_size, uint32_t d,
float eps = 1e-5, cudaStream_t stream = 0) {
cudaError_t FusedAddRMSNorm(T const *input,
T const *residual,
T const *weight,
T *output,
T *residual_output,
uint32_t batch_size,
uint32_t d,
float eps = 1e-5,
cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = flashinfer::ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
void* args[] = {&input, &residual, &weight, &output, &residual_output, &d, &eps};
void *args[] = {
&input, &residual, &weight, &output, &residual_output, &d, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(cudaLaunchKernel(
(void *)kernel, nblks, nthrs, args, smem_size, stream));
});

return cudaSuccess;
Expand Down Expand Up @@ -191,8 +212,15 @@ void forward_kernel(ResidualRMSNormMeta const *m,
int num_threads =
std::max(kernel1_parallelism.second, kernel2_parallelism.second);

checkCUDA(FusedAddRMSNorm<T>(
input1_ptr, input2_ptr, weight_ptr, output_ptr, residual_output_ptr, batch_size, m->in_dim, m->eps, stream));
checkCUDA(FusedAddRMSNorm<T>(input1_ptr,
input2_ptr,
weight_ptr,
output_ptr,
residual_output_ptr,
batch_size,
m->in_dim,
m->eps,
stream));
}

void forward_kernel_wrapper(ResidualRMSNormMeta const *m,
Expand Down

0 comments on commit dc15e34

Please sign in to comment.