From e12fafc9e5fc612a5f1e2df2d9125860b5043a17 Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 11 Sep 2024 20:08:35 +0000 Subject: [PATCH] add scaled act --- csrc/activation_kernels.cu | 42 +++++++++++++++++++ csrc/ops.h | 3 ++ csrc/torch_bindings.cpp | 4 ++ vllm/_custom_ops.py | 5 +++ vllm/model_executor/layers/activation.py | 15 +++++-- .../schemes/compressed_tensors_w8a8_fp8.py | 2 + .../layers/quantization/fbgemm_fp8.py | 2 + .../model_executor/layers/quantization/fp8.py | 6 ++- .../layers/quantization/utils/w8a8_utils.py | 16 ++++--- vllm/model_executor/models/llama.py | 5 ++- 10 files changed, 88 insertions(+), 12 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 5ed1dc3b8f792..2dff7de3fd3fd 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -7,6 +7,10 @@ #include "cuda_compat.h" #include "dispatch_utils.h" +#ifdef USE_ROCM + #include "quantization/fp8/amd/hip_float8.h" +#endif + namespace vllm { // Activation and gating kernel template. @@ -23,6 +27,22 @@ __global__ void act_and_mul_kernel( } } +// Scaled activation and gating kernel template. +template +__global__ void scaled_act_and_mul_kernel( + c10::Float8_e4m3fnuz* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d, const float scale) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); + float r = ACT_FN(x) * y * scale; + out[token_idx * d + idx] = c10::Float8_e4m3fnuz( + hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits()); + } +} + template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) @@ -69,12 +89,34 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { input.data_ptr(), d); \ }); +// Launch activation and gating kernel. +#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \ + vllm::scaled_act_and_mul_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d, \ + 1.0 / (*scale.data_ptr())); \ + }); + void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } +void scaled_silu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& scale) { + LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); +} + void gelu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { diff --git a/csrc/ops.h b/csrc/ops.h index 3b8fd602aed68..5791c5a71228b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -44,6 +44,9 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +void scaled_silu_and_mul(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); + void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 8c682f39d1d80..12d9a45d536a9 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -52,6 +52,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + // Activation function used in SwiGLU. + ops.def("scaled_silu_and_mul(Tensor! out, Tensor input, Tensor scale) -> ()"); + ops.impl("scaled_silu_and_mul", torch::kCUDA, &scaled_silu_and_mul); + // Activation function used in GeGLU with `none` approximation. ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 09a2cdc4174be..7e05daae50070 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -49,6 +49,11 @@ def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: torch.ops._C.silu_and_mul(out, x) +def scaled_silu_and_mul(out: torch.Tensor, x: torch.Tensor, + scale: torch.Tensor) -> None: + torch.ops._C.scaled_silu_and_mul(out, x, scale) + + def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: torch.ops._C.gelu_and_mul(out, x) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 4c14fe476ee4a..589a9b3bdc3f6 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -28,13 +28,22 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] - def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + def forward_cuda(self, + x: torch.Tensor, + scale: Optional[torch.Tensor] = None) -> torch.Tensor: from vllm import _custom_ops as ops d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) - out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - ops.silu_and_mul(out, x) + if scale is None: + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + ops.silu_and_mul(out, x) + else: + # for scaled fp8 output + out = torch.empty(output_shape, + dtype=torch.float8_e4m3fnuz, + device=x.device) + ops.scaled_silu_and_mul(out, x, scale) return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 5931ec36c97d5..f66cd13837edf 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -22,6 +22,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy + self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme self.cutlass_fp8_supported = cutlass_fp8_supported() @@ -137,6 +138,7 @@ def apply_weights(self, input=x, weight=layer.weight, weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, input_scale=layer.input_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 0b1f6ff685200..8eefc30a2038d 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -76,6 +76,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config self.cutlass_fp8_supported = cutlass_fp8_supported() + self.out_dtype = torch.get_default_dtype() def create_weights( self, @@ -164,6 +165,7 @@ def apply(self, input=x, weight=layer.weight, weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, input_scale=None, input_scale_ub=layer.input_scale_ub, bias=bias, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2b917e2d4e1a9..420c3e5632069 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,9 +1,9 @@ from typing import Any, Callable, Dict, List, Optional import torch +import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter -import torch.nn.functional as F import vllm.envs as envs from vllm import _custom_ops as ops @@ -123,6 +123,7 @@ def __init__(self, quant_config: Fp8Config): # kernel for fast weight-only FP8 quantization capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] + self.out_dtype = torch.get_default_dtype() self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN # Disable marlin for rocm if is_hip(): @@ -246,7 +247,7 @@ def process_weights_after_loading(self, layer: Module) -> None: # Pad the weight if envs.VLLM_FP8_PADDING: - weight = F.pad(weight, (0,256), "constant", 0)[...,:-256] + weight = F.pad(weight, (0, 256), "constant", 0)[..., :-256] torch.cuda.empty_cache() # Update layer with new values. @@ -280,6 +281,7 @@ def apply(self, input=x, weight=layer.weight, weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, input_scale=layer.input_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index f04db27902dbd..a73a08856da40 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -87,6 +87,7 @@ def apply_fp8_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, + out_dtype: torch.dtype, input_scale: Optional[torch.Tensor] = None, input_scale_ub: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, @@ -119,11 +120,14 @@ def apply_fp8_linear( # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. # This could change in the future. - qinput, x_scale = ops.scaled_fp8_quant( - input, - input_scale, - num_token_padding=17, - use_per_token_if_dynamic=use_per_token_if_dynamic) + if input.dtype != torch.float8_e4m3fnuz: + qinput, x_scale = ops.scaled_fp8_quant( + input, + input_scale, + num_token_padding=17, + use_per_token_if_dynamic=use_per_token_if_dynamic) + else: + qinput, x_scale = input, input_scale per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1) @@ -137,7 +141,7 @@ def apply_fp8_linear( output = torch._scaled_mm( qinput, weight, - out_dtype=input.dtype, + out_dtype=out_dtype, scale_a=x_scale, scale_b=weight_scale, scale_result=TORCH_SCALED_MM_SCALE_RESULT, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e5e715fcaf8f8..e075a8acd000c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -42,6 +42,7 @@ QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -79,6 +80,7 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.down_proj") + self.use_fp8 = isinstance(quant_config, Fp8Config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -95,7 +97,8 @@ def forward(self, x): x = out.view(x.shape[0], x.shape[1], out.shape[1]) else: gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) + x = self.act_fn( + gate_up, self.down_proj.input_scale if self.use_fp8 else None) x, _ = self.down_proj(x) return x