Skip to content

Commit

Permalink
[Kernel][Triton] Add Triton implementation for scaled_mm_triton to su…
Browse files Browse the repository at this point in the history
…pport fp8 and int8 SmoothQuant, symmetric case (vllm-project#9857)

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Loc Huynh <[email protected]>
  • Loading branch information
rasmith authored and JC1DA committed Nov 11, 2024
1 parent 6c5f335 commit 446e514
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 0 deletions.
106 changes: 106 additions & 0 deletions tests/kernels/test_triton_scaled_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Tests for the triton_scaled_mm kernel
Run `pytest tests/kernels/test_triton_scaled_mm.py`.
"""
import importlib
from typing import Optional, Type

import pytest
import torch

from vllm.platforms import current_platform

device = "cuda"


def scaled_mm_torch(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
out = scale_a * out
out = scale_b.T * out
out = out.to(out_dtype)
if bias is not None:
out = out + bias

return out


def get_8bit_types():
types = [torch.int8]
supports_fp8 = current_platform.has_device_capability(89)
if current_platform.is_rocm() and supports_fp8:
types.append(torch.float8_e4m3fnuz)
elif current_platform.is_cuda() and supports_fp8:
types.append(torch.float8_e4m3fn)
return types


@pytest.mark.parametrize("M", [1, 33, 64, 512])
@pytest.mark.parametrize("N", [256, 971, 20486])
@pytest.mark.parametrize("K", [128, 496, 1024])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("in_dtype", get_8bit_types())
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
use_scalar_scale_b, use_bias):
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t
).is_floating_point()

current_platform.seed_everything(0)

# NOTE: There are cases, where if the matrix is large enough, an output
# like 65504.4 can be produced, and can easily turn into inf when
# multiplied when using float16/bfloat16. This means one function, e.g.,
# testing function, and another function, e.g. golden function, can
# produce a non-inf value while the other produces an inf value, and
# will cause assert_close/allclose to fail, even though if overflow
# wouldn't have occurred, the values would have been "close."
#
# So, the values here are kept small enough to avoid this situation.
if is_floating_point_type(in_dtype):
a = (0.25 * torch.rand(
(M, K), dtype=torch.float32, device=device)).to(in_dtype)
b = (0.25 * torch.rand(
(K, N), dtype=torch.float32, device=device)).to(in_dtype)
else:
a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device)
b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device)

if use_scalar_scale_a:
scale_a = torch.rand((1, 1), device=device)
else:
scale_a = 0.25 * torch.rand((M, 1), device=device)

if use_scalar_scale_b:
scale_b = torch.rand((1, 1), device=device)
else:
scale_b = 0.25 * torch.rand((N, 1), device=device)

bias = None
if use_bias:
bias = torch.rand((N, ), device=device, dtype=out_dtype)

triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm

c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

a_cpu = a.cpu()
b_cpu = b.cpu()
scale_a_cpu = scale_a.cpu()
scale_b_cpu = scale_b.cpu()
bias_cpu = None if bias is None else bias.cpu()

c_actual = scaled_mm_torch(a_cpu, b_cpu, scale_a_cpu, scale_b_cpu,
out_dtype, bias_cpu)

c_check_cpu = c_check.cpu()
torch.testing.assert_close(c_check_cpu, c_actual, rtol=1e-1, atol=1e-1)
9 changes: 9 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import functools
import importlib
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -486,6 +487,14 @@ def cutlass_scaled_mm(a: torch.Tensor,

m = a.shape[0]
n = b.shape[1]

if current_platform.is_rocm():
triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

out = torch.empty((m, n), dtype=out_dtype, device=a.device)

torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from typing import Optional, Type

import torch
import triton
import triton.language as tl


def is_weak_contiguous(x: torch.Tensor):
strides = x.stride()
sizes = x.shape
is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
return is_transpose or is_not_transpose


@triton.jit
def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr,
M, N, K, stride_am, stride_ak, stride_bk, stride_bn,
stride_cm, stride_cn, ACCUMULATOR_DTYPE: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_SCALE_A: tl.constexpr,
BLOCK_SIZE_SCALE_B: tl.constexpr):
pid = tl.program_id(axis=0)

num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

pid_m = pid // num_pid_n
pid_n = pid % num_pid_n

accumulator_dtype = ACCUMULATOR_DTYPE
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
dtype=accumulator_dtype)

# NOTE: Some tensor inputs are so large, they will cause int32 overflow
# so it is necessary to use tl.int64 for all the offsets, else SEGV will
# eventually occur.

# Offsets and masks.
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
masks_am = offsets_am < M

offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
masks_bn = offsets_bn < N

offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
offsets_a = (stride_am * offsets_am[:, None] +
stride_ak * offsets_k[None, :])
offsets_b = (stride_bk * offsets_k[:, None] +
stride_bn * offsets_bn[None, :])

# NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
# appropriate offsets and masks for each case. Same goes for
# BLOCK_SIZE_SCALE_B.
offsets_scale_am = (tl.arange(0, BLOCK_SIZE_SCALE_A) +
(BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M)
masks_scale_am = offsets_scale_am < M

offsets_scale_bn = (tl.arange(0, BLOCK_SIZE_SCALE_B) +
(BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N)
masks_scale_bn = offsets_scale_bn < N

a_ptrs = a_ptr + offsets_a
b_ptrs = b_ptr + offsets_b

scale_a_ptrs = scale_a_ptr + offsets_scale_am
scale_b_ptrs = scale_b_ptr + offsets_scale_bn

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
masks_k = offsets_k < K
masks_a = masks_am[:, None] & masks_k[None, :]
a = tl.load(a_ptrs, mask=masks_a)

masks_b = masks_k[:, None] & masks_bn[None, :]
b = tl.load(b_ptrs, mask=masks_b)

# Accumulate results.
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)

offsets_k += BLOCK_SIZE_K
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

# Apply scale at end.
masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
# Need to broadcast to the appropriate size, if scale_a is already
# (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
# for scale_b below.
scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
accumulator = scale_a * accumulator.to(tl.float32)

masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
accumulator = scale_b.T * accumulator.to(tl.float32)

# Convert to output format.
c = accumulator.to(c_ptr.type.element_ty)

# Add bias, it's already in output format, so add it after conversion.
if bias_ptr:
offsets_bias = offsets_bn
bias_ptrs = bias_ptr + offsets_bias
bias_mask = offsets_bias < N
bias = tl.load(bias_ptrs, bias_mask)
c += bias

# Save output
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
offs_cm = offs_cm.to(tl.int64)
offs_cn = offs_cn.to(tl.int64)
c_ptrs = (c_ptr + stride_cm * offs_cm[:, None] +
stride_cn * offs_cn[None, :])
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)

tl.store(c_ptrs, c, mask=c_mask)


# input - [M, K]
# weight - [K, N]
def triton_scaled_mm(input: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None,
block_size_m: int = 32,
block_size_n: int = 32,
block_size_k: int = 32) -> torch.Tensor:
M, K = input.shape
N = weight.shape[1]

assert N > 0 and K > 0 and M > 0
assert weight.shape[0] == K
assert input.dtype == weight.dtype
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
[M, 1])
assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size(
[N, 1])
assert out_dtype.is_floating_point
assert bias is None or bias.is_floating_point()
assert is_weak_contiguous(input)
assert is_weak_contiguous(weight)

grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
N, META['BLOCK_SIZE_N']), )

result = torch.empty((M, N), dtype=out_dtype, device=input.device)

has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1

block_size_sa = 1 if has_scalar(scale_a) else block_size_m
block_size_sb = 1 if has_scalar(scale_b) else block_size_n

accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32

# A = input, B = weight, C = result
# A = M x K, B = K x N, C = M x N
scaled_mm_kernel[grid](input,
weight,
scale_a,
scale_b,
result,
bias,
M,
N,
K,
input.stride(0),
input.stride(1),
weight.stride(0),
weight.stride(1),
result.stride(0),
result.stride(1),
accumulator_dtype,
BLOCK_SIZE_M=block_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
BLOCK_SIZE_SCALE_A=block_size_sa,
BLOCK_SIZE_SCALE_B=block_size_sb)

return result.to(out_dtype)

0 comments on commit 446e514

Please sign in to comment.