Skip to content

Commit

Permalink
Minor test and benchmarks updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ilmarkov committed Nov 15, 2024
1 parent 68512d4 commit 72d6cd3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
6 changes: 3 additions & 3 deletions benchmarks/cusparseLt_benchmarks/benchmark_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
# helpers
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
a = get_random_mat(m, k, dtype)
b = get_random_mat(n, k, dtype).t()
a = get_random_mat(n, k, dtype)
b = get_random_mat(m, k, dtype).t()
return a, b


Expand Down Expand Up @@ -213,7 +213,7 @@ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
KNs = model_shapes(model, tp_size)
MKNs = []
for m in Ms:
assert m % 32 == 0, "Batch size has to be a multiple of 32"
assert m % 16 == 0, "Batch size has to be a multiple of 16"
for k, n in KNs:
if k % 32 or n % 32:
continue
Expand Down
15 changes: 10 additions & 5 deletions tests/kernels/test_semi_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,10 @@ def test_torch_semi_structured_sparse_dense_T_fp8_matmul():
# Cached version
B = torch.full((N, K), .25, device='cuda', dtype=dtype).t()
C = dense_matmul(A_pruned, B, dtype=dtype).to(torch.float32)
C_sparse = semi_structured_sparse_dense_gemm(A, B).to(torch.float32)
C_sparse = semi_structured_sparse_dense_gemm(A,
B,
out_dtype=torch.bfloat16).to(
torch.float32)
torch.testing.assert_close(C, C_sparse, rtol=1e-1, atol=1e-1)

# Noncached version
Expand Down Expand Up @@ -174,8 +177,9 @@ def test_torch_semi_structured_dense_sparse_T_matmul(mnk, dtype):
not is_semi_structured_supported()
or not is_quant_method_supported("modelopt"),
reason="Semi structured fp8 matmul is not supported on this GPU type.")
def test_torch_semi_structured_dense_sparse_T_fp8_matmul():
M, N, K = (32, 64, 32)
@pytest.mark.parametrize("mnk", MNK)
def test_torch_semi_structured_dense_sparse_T_fp8_matmul(mnk):
M, N, K = mnk
dtype = torch.float8_e4m3fn
B_T_pruned = generate_pruned_semi_structured_mat(N, K, dtype=dtype)
B_T = compress_to_torch_sparse_semi_structured_mat(B_T_pruned)
Expand Down Expand Up @@ -290,9 +294,10 @@ def test_torch_semi_structured_dense_sparse_T_fp8_scaled_matmul():
@pytest.mark.skipif(
not is_semi_structured_supported(),
reason="Semi structured matmul is not supported on this GPU type.")
def test_torch_semi_structured_sparse_dense_t_int8_scaled_matmul():
@pytest.mark.parametrize("mnk", MNK)
def test_torch_semi_structured_sparse_dense_t_int8_scaled_matmul(mnk):
dtype = torch.int8
M, N, K = (32, 64, 32)
M, N, K = mnk
A_pruned = generate_pruned_semi_structured_mat(M, K, dtype)
A = compress_to_torch_sparse_semi_structured_mat(A_pruned)
B = get_random_mat(N, K, dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def matmul_(a, b, **kwargs):

scale = scale_a * scale_b
if a_packed.dtype == torch.float8_e4m3fn:
if not (per_tensor_activations and per_tensor_weights):
scale = scale[:, None]
result = matmul_(a_packed.packed, b_dense, out_dtype=torch.float32)
result = torch.narrow(result, 1, 0, col)
result = result * scale
Expand Down

0 comments on commit 72d6cd3

Please sign in to comment.