diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 7a7a25d2173d2..0df01d91215f2 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -3,6 +3,7 @@ #include #include "dispatch_utils.h" +#include "attention/attention_dtypes.h" #ifndef USE_ROCM #include #include @@ -13,6 +14,8 @@ #include #include #include + #include "quantization/fp8/amd/hip_float8.h" + #include "quantization/fp8/amd/quant_utils.cuh" using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat162 = __hip_bfloat162; @@ -51,6 +54,38 @@ __global__ void rms_norm_kernel( } } +template +__global__ void scaled_rms_norm_kernel( + c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float scale, const float epsilon, const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + const float x = (float)input[blockIdx.x * hidden_size + idx]; + variance += x * x; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + float r = (x * s_variance) * weight[idx] * scale; + out[blockIdx.x * hidden_size + idx] = c10::Float8_e4m3fnuz( + hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits()); + } +} + /* Converter structs for the conversion from torch types to HIP/CUDA types, and the associated type conversions within HIP/CUDA. These helpers need to be implemented for now because the relevant type conversion @@ -291,6 +326,124 @@ fused_add_rms_norm_kernel( } } +/* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. */ + +template <> +struct Vec { + using Type = uint2; +}; + +template <> +struct Vec { + using Type = uint4; +}; + +template <> +struct Vec { + using Type = bf16_8_t; +}; + +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +scaled_fused_add_rms_norm_kernel( + c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const float scale, const int num_tokens, + const int hidden_size) { + using in_v_t = typename Vec::Type; + using out_v_t = typename Vec::Type; + // Sanity checks on our vector struct and type-punned pointer arithmetic + static_assert(std::is_pod_v<_f16Vec>); + static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); + + const int vec_hidden_size = hidden_size / width; + __shared__ float s_variance; + float variance = 0.0f; + /* These and the argument pointers are all declared `restrict` as they are + not aliased in practice. Argument pointers should not be dereferenced + in this kernel as that would be undefined behavior */ + auto* __restrict__ out_v = + reinterpret_cast(out); + auto* __restrict__ input_v = + reinterpret_cast<_f16Vec*>(input); + auto* __restrict__ residual_v = + reinterpret_cast<_f16Vec*>(residual); + auto* __restrict__ weight_v = + reinterpret_cast*>(weight); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = input_v[id]; + temp += residual_v[id]; + variance += temp.sum_squares(); + residual_v[id] = temp; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = residual_v[id]; + temp *= s_variance; + temp *= weight_v[idx]; + out_v_t temp_quant = fp8::scaled_vec_conversion( + *reinterpret_cast(&temp), scale); + out_v[id] = temp_quant; + } +} + +/* Generic scaled_fused_add_rms_norm_kernel + The width field is not used here but necessary for other specializations. + */ +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +scaled_fused_add_rms_norm_kernel( + c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const float scale, const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + scalar_t z = input[blockIdx.x * hidden_size + idx]; + z += residual[blockIdx.x * hidden_size + idx]; + float x = (float)z; + variance += x * x; + residual[blockIdx.x * hidden_size + idx] = z; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)residual[blockIdx.x * hidden_size + idx]; + float r = (x * s_variance) * (float)weight[idx] / scale; + out[blockIdx.x * hidden_size + idx] = c10::Float8_e4m3fnuz( + hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits()); + } +} + } // namespace vllm void rms_norm(torch::Tensor& out, // [..., hidden_size] @@ -311,6 +464,25 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] }); } +void scaled_rms_norm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + torch::Tensor& scale, double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "scaled_rms_norm_kernel", [&] { + vllm::scaled_rms_norm_kernel<<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), 1.0 / (*scale.data_ptr()), epsilon, + num_tokens, hidden_size); + }); +} + #define LAUNCH_FUSED_ADD_RMS_NORM(width) \ VLLM_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ @@ -355,3 +527,50 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] LAUNCH_FUSED_ADD_RMS_NORM(0); } } + +#define LAUNCH_SCALED_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "scaled_fused_add_rms_norm_kernel", [&] { \ + vllm::scaled_fused_add_rms_norm_kernel \ + <<>>( \ + out.data_ptr(), \ + input.data_ptr(), residual.data_ptr(), \ + weight.data_ptr(), epsilon, \ + *scale.data_ptr(), num_tokens, hidden_size); \ + }); + +void scaled_fused_add_rms_norm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + torch::Tensor& scale, double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + /* This kernel is memory-latency bound in many scenarios. + When num_tokens is large, a smaller block size allows + for increased block occupancy on CUs and better latency + hiding on global mem ops. */ + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(hidden_size, max_block_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + /*If the tensor types are FP16/BF16, try to use the optimized kernel + with packed + vectorized ops. + Max optimization is achieved with a width-8 vector of FP16/BF16s + since we can load at most 128 bits at once in a global memory op. + However, this requires each tensor's data to be aligned to 16 + bytes. + */ + auto inp_ptr = reinterpret_cast(input.data_ptr()); + auto res_ptr = reinterpret_cast(residual.data_ptr()); + auto wt_ptr = reinterpret_cast(weight.data_ptr()); + bool ptrs_are_aligned = + inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; + if (ptrs_are_aligned && hidden_size % 8 == 0) { + LAUNCH_SCALED_FUSED_ADD_RMS_NORM(8); + } else { + LAUNCH_SCALED_FUSED_ADD_RMS_NORM(0); + } +} diff --git a/csrc/ops.h b/csrc/ops.h index 5791c5a71228b..1d95c21bf0fc2 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -32,6 +32,12 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void scaled_rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, + double epsilon); + +void scaled_fused_add_rms_norm(torch::Tensor&out, torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, torch::Tensor& scale, double epsilon); + void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 12d9a45d536a9..651dad9efabd7 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -93,6 +93,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); + // Apply Root Mean Square (RMS) Normalization to the input tensor with scaled output. + ops.def( + "scaled_rms_norm(Tensor! out, Tensor input, Tensor weight, Tensor scale, float epsilon) -> " + "()"); + ops.impl("scaled_rms_norm", torch::kCUDA, &scaled_rms_norm); + + // Fused Add and RMS Normalization with scaled output. + ops.def( + "scaled_fused_add_rms_norm(Tensor! out, Tensor input, Tensor! residual, Tensor weight, " + "Tensor scale, float epsilon) -> ()"); + ops.impl("scaled_fused_add_rms_norm", torch::kCUDA, &scaled_fused_add_rms_norm); + // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7e05daae50070..c618471a9f8fc 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -170,6 +170,19 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def scaled_rms_norm(out: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor, + epsilon: float) -> None: + torch.ops._C.scaled_rms_norm(out, input, weight, scale, epsilon) + + +def scaled_fused_add_rms_norm(out: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor, epsilon: float) -> None: + torch.ops._C.scaled_fused_add_rms_norm(out, input, residual, weight, scale, + epsilon) + + def advance_step(num_seqs: int, num_queries: int, block_size: int, input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, input_positions: torch.Tensor, seq_lens: torch.Tensor, diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index e3d588efd9b6d..a2a825b27b632 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -47,9 +47,21 @@ def forward_cuda( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: from vllm import _custom_ops as ops + if scale is not None: + out = torch.empty_like(x, dtype=torch.float8_e4m3fnuz) + if residual is not None: + ops.scaled_fused_add_rms_norm(out, x, residual, + self.weight.data, scale, + self.variance_epsilon) + return out, residual + ops.scaled_rms_norm(out, x, self.weight.data, scale, + self.variance_epsilon) + return out + if residual is not None: ops.fused_add_rms_norm( x, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e075a8acd000c..edc4233dd878c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -207,6 +207,7 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.use_fp8 = isinstance(quant_config, Fp8Config) rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( @@ -255,12 +256,13 @@ def forward( residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention + scale = None if not self.use_fp8 else self.self_attn.qkv_proj.input_scale if residual is None: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.input_layernorm(hidden_states, None, scale) else: hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual, scale) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -269,8 +271,9 @@ def forward( ) # Fully Connected + scale = None if not self.use_fp8 else self.mlp.gate_up_proj.input_scale hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual, scale) hidden_states = self.mlp(hidden_states) return hidden_states, residual