Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update test_matmul.py
Browse files Browse the repository at this point in the history
Jokeren authored Jan 6, 2024
1 parent af92b88 commit bbc43f6
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions test/ops/test_matmul.py
Original file line number Diff line number Diff line change
@@ -39,7 +39,8 @@ def test_segment_matmul_autograd(dtype, device):

@withCUDA
@pytest.mark.parametrize('dtype', [torch.float, torch.bfloat16])
def test_grouped_matmul_autograd(dtype, device):
@pytest.mark.parametrize('transposed', [True, False])
def test_grouped_matmul_autograd(dtype, transposed, device):
if device.type == 'cuda' and dtype == torch.bfloat16:
pytest.skip('CUDA does not support bfloat16')

@@ -48,11 +49,19 @@ def test_grouped_matmul_autograd(dtype, device):
torch.randn(6, 9, device=device, requires_grad=True),
torch.randn(3, 32, device=device, requires_grad=True),
]
others = [
torch.randn(16, 48, device=device, requires_grad=True),
torch.randn(9, 42, device=device, requires_grad=True),
torch.randn(32, 64, device=device, requires_grad=True),
]
if transposed:
others_origin = [
torch.randn(48, 16, device=device, requires_grad=True),
torch.randn(42, 9, device=device, requires_grad=True),
torch.randn(64, 32, device=device, requires_grad=True),
]
others = [other.t() for other in others_origin]
else:
others = [
torch.randn(16, 48, device=device, requires_grad=True),
torch.randn(9, 42, device=device, requires_grad=True),
torch.randn(32, 64, device=device, requires_grad=True),
]

biases = [
torch.randn(48, device=device, requires_grad=True),
@@ -70,4 +79,7 @@ def test_grouped_matmul_autograd(dtype, device):

sum([out.sum() for out in outs]).backward()
for i in range(len(outs)):
assert others[i].grad.size() == others[i].size()
if transposed:
assert others_origin[i].grad.size() == others_origin[i].size()
else:
assert others[i].grad.size() == others[i].size()

0 comments on commit bbc43f6

Please sign in to comment.