Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updates including tests #6

Merged
merged 5 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
platform: [ubuntu-latest, ] # macos-latest, windows-latest]

steps:
Expand Down
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@ name = "torch-fourier-slice"
dynamic = ["version"]
description = "Fourier slice extraction/insertion in PyTorch."
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.10"
license = { text = "BSD-3-Clause" }
authors = [{ name = "Alister Burt", email = "[email protected]" }]
# https://pypi.org/classifiers/
classifiers = [
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: BSD License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand All @@ -40,13 +38,19 @@ dependencies = [
"numpy",
"einops",
"torch_image_lerp",
"torch_grid_utils",
]

# https://peps.python.org/pep-0621/#dependencies-optional-dependencies
# "extras" (e.g. for `pip install .[test]`)
[project.optional-dependencies]
# add dependencies used for testing here
test = ["pytest", "pytest-cov"]
test = [
"pytest",
"pytest-cov",
"torch-fourier-shell-correlation",
"scipy"
]
# add anything else you like to have in your dev environment here
dev = [
"ipython",
Expand Down
2 changes: 1 addition & 1 deletion src/torch_fourier_slice/backproject.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from torch_grid_utils import fftfreq_grid

from .grids import fftfreq_grid
from .slice_insertion import insert_central_slices_rfft_3d


Expand Down
2 changes: 1 addition & 1 deletion src/torch_fourier_slice/grids/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .fftfreq_grid import fftfreq_grid
from .central_slice_fftfreq_grid import central_slice_fftfreq_grid
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import einops
import torch

from .fftfreq_grid import _construct_fftfreq_grid_2d
from torch_grid_utils import fftfreq_grid
from ..dft_utils import rfft_shape, fftshift_2d


Expand All @@ -13,7 +13,7 @@ def central_slice_fftfreq_grid(
) -> torch.Tensor:
# generate 2d grid of DFT sample frequencies, shape (h, w, 2)
h, w = volume_shape[-2:]
grid = _construct_fftfreq_grid_2d(
grid = fftfreq_grid(
image_shape=(h, w),
rfft=rfft,
device=device
Expand Down
158 changes: 0 additions & 158 deletions src/torch_fourier_slice/grids/fftfreq_grid.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/torch_fourier_slice/project.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from torch_grid_utils import fftfreq_grid

from .grids import fftfreq_grid
from .slice_extraction import extract_central_slices_rfft_3d


Expand All @@ -18,7 +18,7 @@ def project_3d_to_2d(
volume: torch.Tensor
`(d, d, d)` volume.
rotation_matrices: torch.Tensor
`(..., 3, 3)` array of rotation matrices for insert of `images`.
`(..., 3, 3)` array of rotation matrices for insertion of `images`.
Rotation matrices left-multiply column vectors containing xyz coordinates.
pad: bool
Whether to pad the volume 2x with zeros to increase sampling rate in the DFT.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch_image_lerp import sample_image_3d

from ..dft_utils import fftfreq_to_dft_coordinates
from ..grids.central_slice_grid import central_slice_fftfreq_grid
from ..grids.central_slice_fftfreq_grid import central_slice_fftfreq_grid


def extract_central_slices_rfft_3d(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch_image_lerp import insert_into_image_3d

from ..dft_utils import fftfreq_to_dft_coordinates, rfft_shape
from ..grids.central_slice_grid import central_slice_fftfreq_grid
from ..grids.central_slice_fftfreq_grid import central_slice_fftfreq_grid


def insert_central_slices_rfft_3d(
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
from pytest import fixture


@fixture
def cube() -> torch.Tensor:
volume = torch.zeros((32, 32, 32))
volume[8:24, 8:24, 8:24] = 1
volume[16, 16, 16] = 32
return volume
61 changes: 58 additions & 3 deletions tests/test_torch_fourier_slice.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,59 @@
# temporary stub
import pytest
import torch

def test_something():
pass
from torch_fourier_slice import project_3d_to_2d, backproject_2d_to_3d
from torch_fourier_shell_correlation import fsc
from scipy.stats import special_ortho_group


def test_project_3d_to_2d_rotation_center():
# rotation center should be at position of DC in DFT
volume = torch.zeros((32, 32, 32))
volume[16, 16, 16] = 1

# make projections
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=100)).float()
projections = project_3d_to_2d(
volume=volume,
rotation_matrices=rotation_matrices,
)

# check max is always at (16, 16), implying point (16, 16) never moves
for image in projections:
max = torch.argmax(image)
i, j = divmod(max.item(), 32)
assert (i, j) == (16, 16)


def test_3d_2d_projection_backprojection_cycle(cube):
# make projections
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=1500)).float()
projections = project_3d_to_2d(
volume=cube,
rotation_matrices=rotation_matrices,
)

# reconstruct
volume = backproject_2d_to_3d(
images=projections,
rotation_matrices=rotation_matrices,
)

# calculate FSC between the projections and the reconstructions
_fsc = fsc(cube, volume.float())

assert torch.all(_fsc[-5:] > 0.99) # few low res shells at 0.98...


@pytest.mark.parametrize(
"images, rotation_matrices",
[
(
torch.rand((10, 28, 28)).float(),
torch.tensor(special_ortho_group.rvs(dim=3, size=10)).float()
),
]
)
def test_dtypes_slice_insertion(images, rotation_matrices):
result = backproject_2d_to_3d(images, rotation_matrices)
assert result.dtype == torch.float64
Loading