Skip to content

Commit

Permalink
add scaled rms
Browse files Browse the repository at this point in the history
  • Loading branch information
charlifu committed Sep 12, 2024
1 parent e12fafc commit 11df7bd
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 3 deletions.
219 changes: 219 additions & 0 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <c10/cuda/CUDAGuard.h>

#include "dispatch_utils.h"
#include "attention/attention_dtypes.h"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
Expand All @@ -13,6 +14,8 @@
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#include "quantization/fp8/amd/hip_float8.h"
#include "quantization/fp8/amd/quant_utils.cuh"

using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
Expand Down Expand Up @@ -51,6 +54,38 @@ __global__ void rms_norm_kernel(
}
}

template <typename scalar_t>
__global__ void scaled_rms_norm_kernel(
c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float scale, const float epsilon, const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x;
}

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float r = (x * s_variance) * weight[idx] * scale;
out[blockIdx.x * hidden_size + idx] = c10::Float8_e4m3fnuz(
hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits());
}
}

/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
Expand Down Expand Up @@ -291,6 +326,124 @@ fused_add_rms_norm_kernel(
}
}

/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */

template <>
struct Vec<c10::Float8_e4m3fnuz, 8> {
using Type = uint2;
};

template <>
struct Vec<c10::Half, 8> {
using Type = uint4;
};

template <>
struct Vec<c10::BFloat16, 8> {
using Type = bf16_8_t;
};

template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
scaled_fused_add_rms_norm_kernel(
c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float scale, const int num_tokens,
const int hidden_size) {
using in_v_t = typename Vec<scalar_t, width>::Type;
using out_v_t = typename Vec<c10::Float8_e4m3fnuz, width>::Type;
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);

const int vec_hidden_size = hidden_size / width;
__shared__ float s_variance;
float variance = 0.0f;
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ out_v =
reinterpret_cast<out_v_t*>(out);
auto* __restrict__ input_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = input_v[id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
}

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
out_v_t temp_quant = fp8::scaled_vec_conversion<out_v_t, in_v_t>(
*reinterpret_cast<in_v_t*>(&temp), scale);
out_v[id] = temp_quant;
}
}

/* Generic scaled_fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
scaled_fused_add_rms_norm_kernel(
c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float scale, const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx];
z += residual[blockIdx.x * hidden_size + idx];
float x = (float)z;
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
}

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
float r = (x * s_variance) * (float)weight[idx] / scale;
out[blockIdx.x * hidden_size + idx] = c10::Float8_e4m3fnuz(
hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits());
}
}

} // namespace vllm

void rms_norm(torch::Tensor& out, // [..., hidden_size]
Expand All @@ -311,6 +464,25 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
});
}

void scaled_rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "scaled_rms_norm_kernel", [&] {
vllm::scaled_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fnuz>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), 1.0 / (*scale.data_ptr<float>()), epsilon,
num_tokens, hidden_size);
});
}

#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
Expand Down Expand Up @@ -355,3 +527,50 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}

#define LAUNCH_SCALED_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "scaled_fused_add_rms_norm_kernel", [&] { \
vllm::scaled_fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<c10::Float8_e4m3fnuz>(), \
input.data_ptr<scalar_t>(), residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), epsilon, \
*scale.data_ptr<float>(), num_tokens, hidden_size); \
});

void scaled_fused_add_rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;

dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_SCALED_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_SCALED_FUSED_ADD_RMS_NORM(0);
}
}
6 changes: 6 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);

void scaled_rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale,
double epsilon);

void scaled_fused_add_rms_norm(torch::Tensor&out, torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, torch::Tensor& scale, double epsilon);

void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
Expand Down
12 changes: 12 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);

// Apply Root Mean Square (RMS) Normalization to the input tensor with scaled output.
ops.def(
"scaled_rms_norm(Tensor! out, Tensor input, Tensor weight, Tensor scale, float epsilon) -> "
"()");
ops.impl("scaled_rms_norm", torch::kCUDA, &scaled_rms_norm);

// Fused Add and RMS Normalization with scaled output.
ops.def(
"scaled_fused_add_rms_norm(Tensor! out, Tensor input, Tensor! residual, Tensor weight, "
"Tensor scale, float epsilon) -> ()");
ops.impl("scaled_fused_add_rms_norm", torch::kCUDA, &scaled_fused_add_rms_norm);

// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
Expand Down
13 changes: 13 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,19 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)


def scaled_rms_norm(out: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, scale: torch.Tensor,
epsilon: float) -> None:
torch.ops._C.scaled_rms_norm(out, input, weight, scale, epsilon)


def scaled_fused_add_rms_norm(out: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor, epsilon: float) -> None:
torch.ops._C.scaled_fused_add_rms_norm(out, input, residual, weight, scale,
epsilon)


def advance_step(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor, seq_lens: torch.Tensor,
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,21 @@ def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm import _custom_ops as ops

if scale is not None:
out = torch.empty_like(x, dtype=torch.float8_e4m3fnuz)
if residual is not None:
ops.scaled_fused_add_rms_norm(out, x, residual,
self.weight.data, scale,
self.variance_epsilon)
return out, residual
ops.scaled_rms_norm(out, x, self.weight.data, scale,
self.variance_epsilon)
return out

if residual is not None:
ops.fused_add_rms_norm(
x,
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.use_fp8 = isinstance(quant_config, Fp8Config)
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
Expand Down Expand Up @@ -255,12 +256,13 @@ def forward(
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
scale = None if not self.use_fp8 else self.self_attn.qkv_proj.input_scale
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.input_layernorm(hidden_states, None, scale)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states, residual, scale)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
Expand All @@ -269,8 +271,9 @@ def forward(
)

# Fully Connected
scale = None if not self.use_fp8 else self.mlp.gate_up_proj.input_scale
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states, residual, scale)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual

Expand Down

0 comments on commit 11df7bd

Please sign in to comment.