Skip to content

Commit

Permalink
Update ifft tests for batching
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Oct 25, 2023
1 parent 0867c79 commit a6cac23
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
5 changes: 3 additions & 2 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def finuifft_type1(
warnings.warn("finuifft_type1 recieved isign, which will be overwritten to 1")
finufftkwargs["isign"] = 1
res: torch.Tensor = finufft_type1(points, values, output_shape, **finufftkwargs)
res = res / res.numel()
res = res / torch.tensor(output_shape).prod()
return res


Expand All @@ -560,5 +560,6 @@ def finuifft_type2(
warnings.warn("finuifft_type2 recieved isign, which will be overwritten to 1")
finufftkwargs["isign"] = 1
res: torch.Tensor = finufft_type2(points, targets, **finufftkwargs)
res = res / res.numel()
ndim = torch.atleast_2d(points).shape[0]
res = res / torch.tensor(targets.shape[-ndim:]).prod()
return res
26 changes: 13 additions & 13 deletions tests/test_inverses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@

import pytorch_finufft

Ns = [
5,
10,
15,
100,
]

dims = [1, 2, 3]


def check_t2_ifft_undoes_t1(N: int, dim: int, device: str) -> None:
"""
Expand All @@ -13,7 +22,8 @@ def check_t2_ifft_undoes_t1(N: int, dim: int, device: str) -> None:
g = np.mgrid[slices] * 2 * np.pi / N
points = torch.from_numpy(g.reshape(dim, -1)).to(device)

values = torch.randn(*points[0].shape, dtype=torch.complex128).to(device)
# batched values to test that functionality for these as well
values = torch.randn(3, *points[0].shape, dtype=torch.complex128).to(device)

print("N is " + str(N))
print("dim is " + str(dim))
Expand All @@ -34,17 +44,6 @@ def check_t2_ifft_undoes_t1(N: int, dim: int, device: str) -> None:
np.testing.assert_allclose(values.cpu().numpy(), back.cpu().numpy(), atol=1e-4)


Ns = [
5,
10,
15,
100,
101,
]

dims = [1, 2, 3]


@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("dim", dims)
def test_t2_ifft_undoes_t1_forward_CPU(N, dim):
Expand All @@ -59,7 +58,8 @@ def check_t1_ifft_undoes_t2(N: int, dim: int, device: str) -> None:
g = np.mgrid[slices] * 2 * np.pi / N
points = torch.from_numpy(g.reshape(g.shape[0], -1)).to(device)

targets = torch.randn(*g[0].shape, dtype=torch.complex128).to(device)
# batched targets to test that functionality for these as well
targets = torch.randn(3, *g[0].shape, dtype=torch.complex128).to(device)

print("N is " + str(N))
print("dim is " + str(dim))
Expand Down

0 comments on commit a6cac23

Please sign in to comment.