Skip to content

Commit

Permalink
Enable other datatypes
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 committed Nov 18, 2024
1 parent 6870093 commit 8d7b0df
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 33 deletions.
96 changes: 66 additions & 30 deletions benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from utils import make_rand_sparse_tensors, to_fp16
from utils import make_rand_sparse_tensors, to_fp16, to_bf16

Check failure on line 8 in benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py:8:45: F401 `utils.to_fp16` imported but unused

import vllm._custom_ops as ops

Expand All @@ -32,12 +32,22 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
assert dtype == torch.int8
a_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)

# Create tensors
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
aT = a.t()
bT = b.t()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.bfloat16)

Check failure on line 44 in benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py:44:81: E501 Line too long (93 > 80)
out_ref = ops.cutlass_scaled_mm(a, bT, scale_a, scale_b, torch.bfloat16)

if not torch.allclose(out.t(), out_ref):
print("Incorrect result")
exit()

timers = []

# pytorch impl - bfloat16
Expand All @@ -55,25 +65,25 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# cutlass impl: bf16 output
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b,
ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a,

Check failure on line 68 in benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py:68:81: E501 Line too long (85 > 80)
torch.bfloat16))

# cutlass with bias: bf16 output
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.bfloat16,
ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.bfloat16,

Check failure on line 74 in benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py:74:81: E501 Line too long (101 > 80)
bias))

# cutlass impl: fp16 output
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_fp16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b,
ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a,

Check failure on line 80 in benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py:80:81: E501 Line too long (85 > 80)
torch.float16))

# cutlass with bias: fp16 output
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_fp16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16,
ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.float16,

Check failure on line 86 in benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py:86:81: E501 Line too long (100 > 80)
bias.to(dtype=torch.float16)))

return timers
Expand Down Expand Up @@ -172,12 +182,28 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
def bench_fp16(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
assert dtype == torch.float16
a_compressed, e, a, b = make_rand_sparse_tensors(torch.float16, m, n, k)

m, k, n = 1, 128, 256

# Create tensors
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float16, m, n, k)
aT = a.t()
bT = b.t()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.bfloat16)
out_ref = to_bf16(a@bT)

if not torch.allclose(out.t(), out_ref, rtol=1e-2, atol=1e-2):
print("Incorrect result")
print(out.t())
print(out_ref)
exit()
else:
print("Correct result")

timers = []

# # pytorch impl w. bf16
Expand Down Expand Up @@ -213,38 +239,48 @@ def bench_fp16(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# cutlass impl: bf16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp16_fp16_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b,
ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a,
torch.bfloat16))

# cutlass impl: fp16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp16_fp16_fp16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16))
ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.float16))

# cutlass impl: bf16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_fp16_fp16_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.bfloat16,
bias))
# # cutlass impl: bf16 output, with bias
# timers.append(
# bench_fn(label, sub_label, "cutlass_fp16_fp16_bf16_scaled_sparse_mm_bias",
# ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.bfloat16,
# bias))

# cutlass impl: fp16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_fp16_fp16_fp16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16,
bias.to(dtype=torch.float16)))
# # cutlass impl: fp16 output, with bias
# timers.append(
# bench_fn(label, sub_label, "cutlass_fp16_fp16_fp16_scaled_sparse_mm_bias",
# ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.float16,
# bias.to(dtype=torch.float16)))

return timers


def bench_bf16(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
assert dtype == torch.bfloat16
a_compressed, e, a, b = make_rand_sparse_tensors(torch.bfloat16, m, n, k)

# Create tensors
b_compressed, e, a, b = make_rand_sparse_tensors(torch.bfloat16, m, n, k)
aT = a.t()
bT = b.t()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.bfloat16)
out_ref = to_bf16(a@bT)

if not torch.allclose(out.t(), out_ref):
print("Incorrect result")
exit()

timers = []

# # pytorch impl w. bf16
Expand Down Expand Up @@ -280,37 +316,37 @@ def bench_bf16(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# cutlass impl: bf16 output
timers.append(
bench_fn(label, sub_label, "cutlass_bf16_bf16_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b,
ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a,
torch.bfloat16))

# cutlass impl: fp16 output
timers.append(
bench_fn(label, sub_label, "cutlass_bf16_bf16_fp16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16))
ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.float16))

# cutlass impl: bf16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_bf16_bf16_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.bfloat16,
ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.bfloat16,
bias))

# cutlass impl: fp16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_bf16_bf16_fp16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16,
ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.float16,
bias.to(dtype=torch.float16)))

return timers


def bench_v1(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
# if dtype == torch.int8:
# return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.float8_e4m3fn:
return bench_fp8(dtype, m, k, n, label, sub_label)
# if dtype == torch.float16:
# return bench_fp16(dtype, m, k, n, label, sub_label)
# if dtype == torch.bfloat16:
# return bench_bf16(dtype, m, k, n, label, sub_label)
if dtype == torch.float16:
return bench_fp16(dtype, m, k, n, label, sub_label)
if dtype == torch.bfloat16:
return bench_bf16(dtype, m, k, n, label, sub_label)
raise ValueError("unsupported type")
2 changes: 1 addition & 1 deletion csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor co
out, a, e, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
else {
else { // a.dtype() == torch::kBFloat16
TORCH_CHECK(a.dtype() == torch::kBFloat16);
TORCH_CHECK(e.dtype() == torch::kUInt8);
TORCH_CHECK(b.dtype() == torch::kBFloat16);
Expand Down
4 changes: 2 additions & 2 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,8 +542,8 @@ def cutlass_scaled_sparse_mm(a: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[
1] and bias.dtype == out_dtype
assert bias is None or bias.shape[0] == a.shape[0] \
and bias.dtype == out_dtype

m = a.shape[0]
n = b.shape[1]
Expand Down

0 comments on commit 8d7b0df

Please sign in to comment.