Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1D line extraction from 2D images #10

Merged
merged 9 commits into from
Nov 16, 2024
2 changes: 1 addition & 1 deletion src/torch_fourier_slice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
__author__ = "Alister Burt"
__email__ = "[email protected]"

from .project import project_3d_to_2d
from .project import project_3d_to_2d, project_2d_to_1d
from .backproject import backproject_2d_to_3d
from .slice_insertion import insert_central_slices_rfft_3d
from .slice_extraction import extract_central_slices_rfft_3d
12 changes: 10 additions & 2 deletions src/torch_fourier_slice/dft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ def rfft_shape(input_shape: Sequence[int]) -> Tuple[int, ...]:
return tuple(rfft_shape)


def fftshift_1d(input: torch.Tensor, rfft: bool) -> torch.Tensor:
if rfft is False:
output = torch.fft.fftshift(input, dim=(-1))
else:
output = input
return output


def fftshift_2d(input: torch.Tensor, rfft: bool) -> torch.Tensor:
if rfft is False:
output = torch.fft.fftshift(input, dim=(-2, -1))
Expand Down Expand Up @@ -88,9 +96,9 @@ def dft_center(
fft_center = torch.zeros(size=(len(image_shape),), device=device)
image_shape = torch.as_tensor(image_shape).float()
if rfft is True:
image_shape = torch.tensor(rfft_shape(image_shape))
image_shape = torch.tensor(rfft_shape(image_shape), device=device)
if fftshifted is True:
fft_center = torch.divide(image_shape, 2, rounding_mode='floor')
fft_center = torch.divide(image_shape, 2, rounding_mode="floor")
if rfft is True:
fft_center[-1] = 0
return fft_center.long()
3 changes: 2 additions & 1 deletion src/torch_fourier_slice/grids/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .central_slice_fftfreq_grid import central_slice_fftfreq_grid
from .central_line_fftfreq_grid import central_line_fftfreq_grid
from .central_slice_fftfreq_grid import central_slice_fftfreq_grid
33 changes: 33 additions & 0 deletions src/torch_fourier_slice/grids/central_line_fftfreq_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import einops
import torch

from ..dft_utils import fftshift_1d, rfft_shape


def central_line_fftfreq_grid(
image_shape: tuple[int, int],
rfft: bool,
fftshift: bool = False,
device: torch.device | None = None,
) -> torch.Tensor:
# generate 1d grid of DFT sample frequencies, shape (w, 1)
w, = image_shape[-1:]
grid = (
torch.fft.rfftfreq(w, device=device)
if rfft
else torch.fft.fftfreq(w, device=device)
)

# get grid of same shape with all zeros, append as third coordinate
if rfft is True:
zeros = torch.zeros(size=rfft_shape((w,)), dtype=grid.dtype, device=device)
else:
zeros = torch.zeros(size=(w,), dtype=grid.dtype, device=device)
central_slice_grid, _ = einops.pack([zeros, grid], pattern="w *") # (w, 2)

# fftshift if requested
if fftshift is True:
central_slice_grid = einops.rearrange(central_slice_grid, "w freq -> freq w")
central_slice_grid = fftshift_1d(central_slice_grid, rfft=rfft)
central_slice_grid = einops.rearrange(central_slice_grid, "freq w -> w freq")
return central_slice_grid
66 changes: 65 additions & 1 deletion src/torch_fourier_slice/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn.functional as F
from torch_grid_utils import fftfreq_grid

from .slice_extraction import extract_central_slices_rfft_3d
from .slice_extraction import extract_central_slices_rfft_2d, extract_central_slices_rfft_3d


def project_3d_to_2d(
Expand Down Expand Up @@ -67,3 +67,67 @@ def project_3d_to_2d(
if pad is True:
projections = projections[..., pad_length:-pad_length, pad_length:-pad_length]
return torch.real(projections)


def project_2d_to_1d(
image: torch.Tensor,
rotation_matrices: torch.Tensor,
pad: bool = True,
fftfreq_max: float | None = None,
) -> torch.Tensor:
"""Project a square image by sampling a central line through its DFT.

Parameters
----------
image: torch.Tensor
`(d, d)` image.
rotation_matrices: torch.Tensor
`(..., 2, 2)` array of rotation matrices for extraction of `lines`.
Rotation matrices left-multiply column vectors containing xy coordinates.
pad: bool
Whether to pad the volume 2x with zeros to increase sampling rate in the DFT.
fftfreq_max: float | None
Maximum frequency (cycles per pixel) included in the projection.

Returns
-------
projections: torch.Tensor
`(..., d)` array of projected lines.
"""
# padding
if pad is True:
pad_length = image.shape[-1] // 2
image = F.pad(image, pad=[pad_length] * 4, mode='constant', value=0)

# premultiply by sinc2
grid = fftfreq_grid(
image_shape=image.shape,
rfft=False,
fftshift=True,
norm=True,
device=image.device
)
image = image * torch.sinc(grid) ** 2

# calculate DFT
dft = torch.fft.fftshift(image, dim=(-2, -1)) # image center to array origin
dft = torch.fft.rfftn(dft, dim=(-2, -1))
dft = torch.fft.fftshift(dft, dim=(-2,)) # actual fftshift of 2D rfft

# make projections by taking central slices
projections = extract_central_slices_rfft_2d(
image_rfft=dft,
image_shape=image.shape,
rotation_matrices=rotation_matrices,
fftfreq_max=fftfreq_max
) # (..., w) rfft stack

# transform back to real space
# not needed for 1D: torch.fft.ifftshift(projections, dim=(-2,))
projections = torch.fft.irfftn(projections, dim=(-1))
projections = torch.fft.ifftshift(projections, dim=(-1)) # recenter line in real space

# unpad if required
if pad is True:
projections = projections[..., pad_length:-pad_length]
return torch.real(projections)
1 change: 1 addition & 0 deletions src/torch_fourier_slice/slice_extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from ._extract_central_slices_rfft_2d import extract_central_slices_rfft_2d
from ._extract_central_slices_rfft_3d import extract_central_slices_rfft_3d
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import einops
import torch
from torch_image_lerp import sample_image_2d

from ..dft_utils import fftfreq_to_dft_coordinates
from ..grids.central_line_fftfreq_grid import central_line_fftfreq_grid


def extract_central_slices_rfft_2d(
image_rfft: torch.Tensor,
image_shape: tuple[int, int],
rotation_matrices: torch.Tensor, # (..., 2, 2)
fftfreq_max: float | None = None,
) -> torch.Tensor:
"""Extract central slice from an fftshifted rfft."""
# generate grid of DFT sample frequencies for a central slice spanning the x-plane
freq_grid = central_line_fftfreq_grid(
image_shape=image_shape,
rfft=True,
fftshift=True,
device=image_rfft.device,
) # (w, 2) yx coords

# keep track of some shapes
stack_shape = tuple(rotation_matrices.shape[:-2])
rfft_shape = (freq_grid.shape[-2],)
output_shape = (*stack_shape, *rfft_shape)

# get (b, 2, 1) array of yx coordinates to rotate
if fftfreq_max is not None:
freq_grid_mask = freq_grid <= fftfreq_max
valid_coords = freq_grid[freq_grid_mask, ...]
else:
valid_coords = freq_grid
valid_coords = einops.rearrange(valid_coords, "b yx -> b yx 1")

# rotation matrices rotate xyz coordinates, make them rotate zyx coordinates
# xyz:
# [a b c] [x] [ax + by + cz]
# [d e f] [y] = [dx + ey + fz]
# [g h i] [z] [gx + hy + iz]
#
# zyx:
# [i h g] [z] [gx + hy + iz]
# [f e d] [y] = [dx + ey + fz]
# [c b a] [x] [ax + by + cz]
rotation_matrices = torch.flip(rotation_matrices, dims=(-2, -1))

# add extra dim to rotation matrices for broadcasting
rotation_matrices = einops.rearrange(rotation_matrices, "... i j -> ... 1 i j")

# rotate all valid coordinates by each rotation matrix
rotated_coords = rotation_matrices @ valid_coords # (..., b, yx, 1)

# remove last dim of size 1
rotated_coords = einops.rearrange(rotated_coords, "... b yx 1 -> ... b yx")

# flip coordinates that ended up in redundant half transform after rotation
conjugate_mask = rotated_coords[..., 1] < 0
rotated_coords[conjugate_mask, ...] *= -1

# convert frequencies to array coordinates in fftshifted DFT
rotated_coords = fftfreq_to_dft_coordinates(
frequencies=rotated_coords, image_shape=image_shape, rfft=True
) # (...) rfft
samples = sample_image_2d(image=image_rfft, coordinates=rotated_coords)

# take complex conjugate of values from redundant half transform
samples[conjugate_mask] = torch.conj(samples[conjugate_mask])

# insert samples back into DFTs
projection_image_dfts = torch.zeros(
output_shape, device=image_rfft.device, dtype=image_rfft.dtype
)
if fftfreq_max is None:
freq_grid_mask = torch.ones(
size=rfft_shape, dtype=torch.bool, device=image_rfft.device
)

projection_image_dfts[..., freq_grid_mask] = samples

return projection_image_dfts
20 changes: 19 additions & 1 deletion tests/test_torch_fourier_slice.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from torch_fourier_slice import project_3d_to_2d, backproject_2d_to_3d
from torch_fourier_slice import project_3d_to_2d, project_2d_to_1d, backproject_2d_to_3d
from torch_fourier_shell_correlation import fsc
from scipy.stats import special_ortho_group

Expand All @@ -25,6 +25,24 @@ def test_project_3d_to_2d_rotation_center():
assert (i, j) == (16, 16)


def test_project_2d_to_1d_rotation_center():
# rotation center should be at position of DC in DFT
image = torch.zeros((32, 32))
image[16, 16] = 1

# make projections
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=2, size=100)).float()
projections = project_2d_to_1d(
image=image,
rotation_matrices=rotation_matrices,
)

# check max is always at (16), implying point (16) never moves
for image in projections:
i = torch.argmax(image)
assert i == 16


def test_3d_2d_projection_backprojection_cycle(cube):
# make projections
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=1500)).float()
Expand Down