diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 3a465516dd5..5b11f86d0b6 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -1244,6 +1244,23 @@ def fusion_func(fd: FusionDefinition): for torch_dtype in list_of_dtype: test_dtype(torch_dtype) + def test_reduction_pointwise_epilogue(self): + inputs = [ + torch.randn(2, 32, device="cuda", dtype=torch.float32), + torch.randn(2, 128, device="cuda", dtype=torch.float32), + ] + + def fusion_func(fd: FusionDefinition): + t0 = fd.from_pytorch(inputs[0]) + t1 = fd.from_pytorch(inputs[1]) + t2 = fd.ops.sum(t0, [-1], True, torch_dtype_to_nvfuser_dtype(torch.float32)) + t3 = fd.ops.add(t1, t2) + fd.add_output(t3) + + nvf_out1, _ = self.exec_nvfuser(fusion_func, inputs) + eager_out = torch.sum(inputs[0], dim=-1, keepdim=True) + inputs[1] + self.assertEqual(eager_out, nvf_out1[0]) + def test_arithmetic_ops(self): inputs = [ torch.randn(3, 4, 5, device="cuda", dtype=torch.float32),