Skip to content

Commit

Permalink
Update code
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 committed Dec 12, 2024
1 parent c7a3a7d commit 94a4945
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 378 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"

/*
This file defines custom epilogues for fusing channel scales, token scales,
Expand Down
87 changes: 1 addition & 86 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Run `pytest tests/kernels/test_cutlass.py`.
"""
from typing import Optional, Tuple, Type
from typing import Optional, Type

import pytest
import torch
Expand Down Expand Up @@ -55,62 +55,6 @@ def rand_int8(shape: tuple, device: str = "cuda"):
return to_int8(torch.rand(shape, device=device) * 255 - 128)


def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.bfloat16)


def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.float16)


def prune_to_2_4(tensor):
# Reshape tensor to [N, 4] where N is number of groups of 4
original_shape = tensor.shape
reshaped = tensor.reshape(-1, 4)

# Get indices of top 2 absolute values in each group of 4
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)

# Create binary mask
mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1,
index=indices,
src=torch.ones_like(indices, dtype=mask.dtype))

# Apply mask and reshape back
pruned = reshaped * mask

# Turn all -0.0 to 0.0
pruned[pruned == -0.0] = 0.0

return pruned.reshape(original_shape)


def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5

b = prune_to_2_4(b.t()).t()

if dtype == torch.int8:
a, b = to_int8(a), to_int8(b)
elif dtype == torch.float8_e4m3fn:
a, b = to_fp8(a), to_fp8(b)
elif dtype == torch.float16:
a, b = to_fp16(a), to_fp16(b)
elif dtype == torch.bfloat16:
a, b = to_bf16(a), to_bf16(b)
else:
raise ValueError("unsupported dtype")

b_compressed, e = ops.cutlass_compress_entry(b.t())

# Compressed B, Metadata, Original A, B
return b_compressed, e, a, b


def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
Expand Down Expand Up @@ -459,35 +403,6 @@ def test_cutlass_subset():
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)


# Test working with a subset of A and B for sparse matmul
def test_cutlass_sparse_subset():
big_m = 1024
m, n, k = 512, 512, 512

# Create tensors
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
big_m, n, k)
a = whole_a[0:m, 0:k]
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10

print("in test")

out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)

torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)


# Test to make sure cuda graphs work
class CutlassLayer(torch.nn.Module):

Expand Down
Loading

0 comments on commit 94a4945

Please sign in to comment.