Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPT] improve rms_norm kernel #258

Merged
merged 13 commits into from
Nov 20, 2024
144 changes: 86 additions & 58 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,86 +23,88 @@

namespace vllm {

template <typename scalar_t>
struct __align__(16) vec8_t {
scalar_t x, y, z, w, u, v, s, t;
#ifdef __HIP__MI300_MI250__

__device__ vec8_t() : x(0), y(0), z(0), w(0), u(0), v(0), s(0), t(0) {}
__device__ vec8_t(scalar_t x, scalar_t y, scalar_t z, scalar_t w, scalar_t u,
scalar_t v, scalar_t s, scalar_t t)
: x(x), y(y), z(z), w(w), u(u), v(v), s(s), t(t) {}
// TODO(woosuk): Further optimize this kernel.
// TODO(woosuk): Further optimize this kernel.
kkHuang-amd marked this conversation as resolved.
Show resolved Hide resolved
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens,
const int hidden_size, const int vec_hidden_size) {
__shared__ float s_variance;
float v8_variance_sum = 0.0f;

__device__ vec8_t operator*(const vec8_t& other) const {
return vec8_t(x * other.x, y * other.y, z * other.z, w * other.w,
u * other.u, v * other.v, s * other.s, t * other.t);
}
const int64_t tx = threadIdx.x;
const int64_t bx = blockIdx.x;
const int64_t num_threads = blockDim.x;

__device__ vec8_t operator*(const float& scale) const {
return vec8_t(x * scale, y * scale, z * scale, w * scale, u * scale,
v * scale, s * scale, t * scale);
}
auto* __restrict__ out_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(out);
auto* __restrict__ input_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(
input + bx * static_cast<int64_t>(hidden_size));
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);

__device__ vec8_t operator+(const vec8_t& other) const {
return vec8_t(x + other.x, y + other.y, z + other.z, w + other.w,
u + other.u, v + other.v, s + other.s, t + other.t);
// Compute variance. Be careful, hidden_size should multiple of 4.
for (int idx = tx; idx < vec_hidden_size; idx += num_threads) {
_f16Vec<scalar_t, width> temp = input_v[idx];
v8_variance_sum += temp.sum_squares();
}

__device__ void operator+=(const vec8_t& other) {
x += other.x;
y += other.y;
z += other.z;
w += other.w;
u += other.u;
v += other.v;
s += other.s;
t += other.t;
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;

__device__ scalar_t sum() const { return x + y + z + w + u + v + s + t; }
};
float variance =
BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, num_threads);

#ifdef __HIP__MI300_MI250__
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
variance = s_variance;

vec8_t<scalar_t> v8_variance = {0, 0, 0, 0, 0, 0, 0, 0};
for (int idx = tx; idx < vec_hidden_size; idx += num_threads) {
_f16Vec<scalar_t, width> temp = input_v[idx];
temp *= variance;
temp *= weight_v[idx];
out_v[bx * static_cast<int64_t>(vec_hidden_size) + idx] = temp;
}
}

vec8_t<scalar_t>* vectorized_out = reinterpret_cast<vec8_t<scalar_t>*>(out);
vec8_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec8_t<scalar_t> const*>(input);
vec8_t<scalar_t> const* vectorized_weight =
reinterpret_cast<vec8_t<scalar_t> const*>(weight);
const int vec_hidden_size = hidden_size >> 3;
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
kkHuang-amd marked this conversation as resolved.
Show resolved Hide resolved
rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens,
const int hidden_size, const int vec_hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;

// Compute variance. Be careful, hidden_size should multiple of 4.
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
vec8_t<scalar_t> x = vectorized_in[blockIdx.x * vec_hidden_size + idx];
v8_variance += x * x;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x =
(float)input[blockIdx.x * static_cast<int64_t>(hidden_size) + idx];
variance += x * x;
}
float v8_variance_sum = v8_variance.sum();

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float variance =
BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, blockDim.x);
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
vec8_t<scalar_t> v8_in = vectorized_in[blockIdx.x * vec_hidden_size + idx];
vec8_t<scalar_t> v8_w = vectorized_weight[idx];
vectorized_out[blockIdx.x * vec_hidden_size + idx] =
v8_in * s_variance * v8_w;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x =
(float)input[blockIdx.x * static_cast<int64_t>(hidden_size) + idx];
out[blockIdx.x * static_cast<int64_t>(hidden_size) + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
}

Expand Down Expand Up @@ -254,10 +256,35 @@ struct Vec<c10::BFloat16, 8> {

} // namespace vllm

#define LAUNCH_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { \
vllm::rms_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size, \
vec_hidden_size); \
});

void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
#ifdef __HIP__MI300_MI250__
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
int vec_size = 16 / input.element_size();
int vec_hidden_size = hidden_size / vec_size;

dim3 grid(num_tokens);
dim3 block(std::min(vec_hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

if (vec_size % 8 == 0) {
LAUNCH_RMS_NORM(8);
} else {
LAUNCH_RMS_NORM(0);
}
#else
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;

Expand All @@ -270,6 +297,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
#endif
}

#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
Expand Down
4 changes: 2 additions & 2 deletions csrc/type_convert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct _typeConvert<c10::Half> {
}
};

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template <>
Expand Down Expand Up @@ -162,4 +162,4 @@ struct alignas(16) _f16Vec {
return result;
}
};
} // namespace vllm
} // namespace vllm