Skip to content

Commit

Permalink
add test_reduction_pointwise_epilogue
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Nov 2, 2024
1 parent a56a31b commit 6cde1ca
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,23 @@ def fusion_func(fd: FusionDefinition):
for torch_dtype in list_of_dtype:
test_dtype(torch_dtype)

def test_reduction_pointwise_epilogue(self):
inputs = [
torch.randn(2, 32, device="cuda", dtype=torch.float32),
torch.randn(2, 128, device="cuda", dtype=torch.float32),
]

def fusion_func(fd: FusionDefinition):
t0 = fd.from_pytorch(inputs[0])
t1 = fd.from_pytorch(inputs[1])
t2 = fd.ops.sum(t0, [-1], True, torch_dtype_to_nvfuser_dtype(torch.float32))
t3 = fd.ops.add(t1, t2)
fd.add_output(t3)

nvf_out1, _ = self.exec_nvfuser(fusion_func, inputs)
eager_out = torch.sum(inputs[0], dim=-1, keepdim=True) + inputs[1]
self.assertEqual(eager_out, nvf_out1[0])

def test_arithmetic_ops(self):
inputs = [
torch.randn(3, 4, 5, device="cuda", dtype=torch.float32),
Expand Down

0 comments on commit 6cde1ca

Please sign in to comment.