Skip to content

Commit

Permalink
add scaled act
Browse files Browse the repository at this point in the history
  • Loading branch information
charlifu committed Sep 11, 2024
1 parent d4c6b3d commit e12fafc
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 12 deletions.
42 changes: 42 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"

#ifdef USE_ROCM
#include "quantization/fp8/amd/hip_float8.h"
#endif

namespace vllm {

// Activation and gating kernel template.
Expand All @@ -23,6 +27,22 @@ __global__ void act_and_mul_kernel(
}
}

// Scaled activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void scaled_act_and_mul_kernel(
c10::Float8_e4m3fnuz* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d, const float scale) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
float r = ACT_FN(x) * y * scale;
out[token_idx * d + idx] = c10::Float8_e4m3fnuz(
hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits());
}
}

template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
Expand Down Expand Up @@ -69,12 +89,34 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
input.data_ptr<scalar_t>(), d); \
});

// Launch activation and gating kernel.
#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \
vllm::scaled_act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<c10::Float8_e4m3fnuz>(), \
input.data_ptr<scalar_t>(), d, \
1.0 / (*scale.data_ptr<float>())); \
});

void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}

void scaled_silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& scale) {
LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}

void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
Expand Down
3 changes: 3 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void scaled_silu_and_mul(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
Expand Down
4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);

// Activation function used in SwiGLU.
ops.def("scaled_silu_and_mul(Tensor! out, Tensor input, Tensor scale) -> ()");
ops.impl("scaled_silu_and_mul", torch::kCUDA, &scaled_silu_and_mul);

// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
Expand Down
5 changes: 5 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul(out, x)


def scaled_silu_and_mul(out: torch.Tensor, x: torch.Tensor,
scale: torch.Tensor) -> None:
torch.ops._C.scaled_silu_and_mul(out, x, scale)


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul(out, x)

Expand Down
15 changes: 12 additions & 3 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,22 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self,
x: torch.Tensor,
scale: Optional[torch.Tensor] = None) -> torch.Tensor:
from vllm import _custom_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
if scale is None:
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
else:
# for scaled fp8 output
out = torch.empty(output_shape,
dtype=torch.float8_e4m3fnuz,
device=x.device)
ops.scaled_silu_and_mul(out, x, scale)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):

def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()

Expand Down Expand Up @@ -137,6 +138,7 @@ def apply_weights(self,
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.out_dtype = torch.get_default_dtype()

def create_weights(
self,
Expand Down Expand Up @@ -164,6 +165,7 @@ def apply(self,
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=None,
input_scale_ub=layer.input_scale_ub,
bias=bias,
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.nn.functional as F
from torch.nn import Module
from torch.nn.parameter import Parameter
import torch.nn.functional as F

import vllm.envs as envs
from vllm import _custom_ops as ops
Expand Down Expand Up @@ -123,6 +123,7 @@ def __init__(self, quant_config: Fp8Config):
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.out_dtype = torch.get_default_dtype()
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
# Disable marlin for rocm
if is_hip():
Expand Down Expand Up @@ -246,7 +247,7 @@ def process_weights_after_loading(self, layer: Module) -> None:

# Pad the weight
if envs.VLLM_FP8_PADDING:
weight = F.pad(weight, (0,256), "constant", 0)[...,:-256]
weight = F.pad(weight, (0, 256), "constant", 0)[..., :-256]
torch.cuda.empty_cache()

# Update layer with new values.
Expand Down Expand Up @@ -280,6 +281,7 @@ def apply(self,
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
Expand Down
16 changes: 10 additions & 6 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def apply_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype,
input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -119,11 +120,14 @@ def apply_fp8_linear(
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
qinput, x_scale = ops.scaled_fp8_quant(
input,
input_scale,
num_token_padding=17,
use_per_token_if_dynamic=use_per_token_if_dynamic)
if input.dtype != torch.float8_e4m3fnuz:
qinput, x_scale = ops.scaled_fp8_quant(
input,
input_scale,
num_token_padding=17,
use_per_token_if_dynamic=use_per_token_if_dynamic)
else:
qinput, x_scale = input, input_scale

per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)
Expand All @@ -137,7 +141,7 @@ def apply_fp8_linear(
output = torch._scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
scale_result=TORCH_SCALED_MM_SCALE_RESULT,
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.use_fp8 = isinstance(quant_config, Fp8Config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
Expand All @@ -95,7 +97,8 @@ def forward(self, x):
x = out.view(x.shape[0], x.shape[1], out.shape[1])
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.act_fn(
gate_up, self.down_proj.input_scale if self.use_fp8 else None)
x, _ = self.down_proj(x)
return x

Expand Down

0 comments on commit e12fafc

Please sign in to comment.