Skip to content

Commit

Permalink
FIX tighten the forward tests according to observed error wrt array size
Browse files Browse the repository at this point in the history
  • Loading branch information
eickenberg committed Oct 5, 2023
1 parent d8a8e4b commit 83ab0ba
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 30 deletions.
10 changes: 10 additions & 0 deletions tests/test_1d/test_forward_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ def test_1d_t1_forward_CPU(values: torch.Tensor) -> None:
) == pytest.approx(0, abs=1e-06)


abs_errors = torch.abs(finufft1D1_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

assert l_inf_error < 3.5e-3 * N ** .6
assert l_2_error < 7.5e-4 * N ** 1.1
assert l_1_error < 5e-4 * N ** 1.6


@pytest.mark.parametrize("targets", cases)
def test_1d_t2_forward_CPU(targets: torch.Tensor):
"""
Expand Down
39 changes: 15 additions & 24 deletions tests/test_2d/test_forward_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,14 @@ def test_2d_t1_forward_CPU(N: int) -> None:

against_torch = torch.fft.fft2(values.reshape(g[0].shape))

assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx(
0, abs=1e-6
)
abs_errors = torch.abs(finufft_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

values = torch.randn(*x.shape, dtype=torch.complex64)

finufft_out = pytorch_finufft.functional.finufft2D1.apply(
torch.from_numpy(x).to(torch.float32),
torch.from_numpy(y).to(torch.float32),
values,
N,
)

against_torch = torch.fft.fft2(values.reshape(g[0].shape))

# NOTE -- the below tolerance is set to 1e-5 instead of -6 due
# to the occasional failing case that seems to be caused by
# the randomness of the test cases in addition to the expected
# accruation of numerical inaccuracies
assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx(
0, abs=1e-5
)
assert l_inf_error < 5e-5 * N
assert l_2_error < 1e-5 * N ** 2
assert l_1_error < 1e-5 * N ** 3


@pytest.mark.parametrize("N", Ns)
Expand Down Expand Up @@ -109,9 +95,14 @@ def test_2d_t2_forward_CPU(N: int) -> None:

against_torch = torch.fft.ifft2(values)

assert abs((finufft_out - against_torch).sum()) / (N**3) == pytest.approx(
0, abs=1e-6
)
abs_errors = torch.abs(finufft_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

assert l_inf_error < 1e-5 * N
assert l_2_error < 1e-5 * N ** 2
assert l_1_error < 1e-5 * N ** 3


# @pytest.mark.parametrize("N", Ns)
Expand Down
23 changes: 17 additions & 6 deletions tests/test_3d/test_forward_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,15 @@ def test_3d_t1_forward_CPU(N: int) -> None:

against_torch = torch.fft.fftn(values.reshape(g[0].shape))

assert abs((finufft_out - against_torch).sum()) / (N**4) == pytest.approx(
0, abs=1e-6
)
abs_errors = torch.abs(finufft_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

assert l_inf_error < 2e-5 * N ** 1.5
assert l_2_error < 1e-5 * N ** 3
assert l_1_error < 1e-5 * N ** 4.5



@pytest.mark.parametrize("N", Ns)
Expand All @@ -69,6 +75,11 @@ def test_3d_t2_forward_CPU(N: int) -> None:

against_torch = torch.fft.ifftn(values)

assert (abs((finufft_out - against_torch).sum())) / (N**4) == pytest.approx(
0, abs=1e-6
)
abs_errors = torch.abs(finufft_out - against_torch)
l_inf_error = abs_errors.max()
l_2_error = torch.sqrt(torch.sum(abs_errors**2))
l_1_error = torch.sum(abs_errors)

assert l_inf_error < 1e-5 * N ** 1.5
assert l_2_error < 1e-5 * N ** 3
assert l_1_error < 1e-5 * N ** 4.5

0 comments on commit 83ab0ba

Please sign in to comment.