Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rsg devel #54

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ ENV/

# IDE settings
.vscode/
.idea/

libtilt/_version.py
src/libtilt/_version.py
21 changes: 12 additions & 9 deletions src/libtilt/ctf/ctf_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def calculate_ctf(
image_shape: Tuple[int, int],
rfft: bool,
fftshift: bool,
device: torch.device | None = None
):
"""

Expand Down Expand Up @@ -56,20 +57,22 @@ def calculate_ctf(
Whether to apply fftshift on the resulting CTF images.
"""
# to torch.Tensor and unit conversions
defocus = torch.atleast_1d(torch.as_tensor(defocus, dtype=torch.float))
if bool(rfft) + bool(fftshift) > 1:
raise ValueError("Only one of `rfft` and `fftshift` may be `True`.")
defocus = torch.atleast_1d(torch.as_tensor(defocus, dtype=torch.float, device=device))
defocus *= 1e4 # micrometers -> angstroms
astigmatism = torch.atleast_1d(torch.as_tensor(astigmatism, dtype=torch.float))
astigmatism = torch.atleast_1d(torch.as_tensor(astigmatism, dtype=torch.float, device=device))
astigmatism *= 1e4 # micrometers -> angstroms
astigmatism_angle = torch.atleast_1d(torch.as_tensor(astigmatism_angle, dtype=torch.float))
astigmatism_angle = torch.atleast_1d(torch.as_tensor(astigmatism_angle, dtype=torch.float, device=device))
astigmatism_angle *= (C.pi / 180) # degrees -> radians
pixel_size = torch.atleast_1d(torch.as_tensor(pixel_size))
voltage = torch.atleast_1d(torch.as_tensor(voltage, dtype=torch.float))
pixel_size = torch.atleast_1d(torch.as_tensor(pixel_size, device=device))
voltage = torch.atleast_1d(torch.as_tensor(voltage, dtype=torch.float, device=device))
voltage *= 1e3 # kV -> V
spherical_aberration = torch.atleast_1d(
torch.as_tensor(spherical_aberration, dtype=torch.float)
torch.as_tensor(spherical_aberration, dtype=torch.float, device=device)
)
spherical_aberration *= 1e7 # mm -> angstroms
image_shape = torch.as_tensor(image_shape)
image_shape = torch.as_tensor(image_shape, device=device)

# derived quantities used in CTF calculation
defocus_u = defocus + astigmatism
Expand All @@ -79,10 +82,10 @@ def calculate_ctf(
k2 = C.pi / 2 * spherical_aberration * _lambda ** 3
k3 = torch.tensor(np.deg2rad(phase_shift))
k4 = -b_factor / 4
k5 = np.arctan(amplitude_contrast / np.sqrt(1 - amplitude_contrast ** 2))
k5 = torch.arctan(amplitude_contrast / torch.sqrt(1 - amplitude_contrast ** 2))

# construct 2D frequency grids and rescale cycles / px -> cycles / Å
fftfreq_grid = _construct_fftfreq_grid_2d(image_shape=image_shape, rfft=rfft) # (h, w, 2)
fftfreq_grid = _construct_fftfreq_grid_2d(image_shape=image_shape, rfft=rfft, device=device) # (h, w, 2)
fftfreq_grid = fftfreq_grid / einops.rearrange(pixel_size, 'b -> b 1 1 1')
fftfreq_grid_squared = fftfreq_grid ** 2

Expand Down
11 changes: 7 additions & 4 deletions src/libtilt/fft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ def dft_center(
device: torch.device | None = None,
) -> torch.LongTensor:
"""Return the position of the DFT center for a given input shape."""
_rfft_shape = rfft_shape(image_shape)
fft_center = torch.zeros(size=(len(image_shape),), device=device)
image_shape = torch.as_tensor(image_shape).float()
if rfft is True:
image_shape = torch.tensor(rfft_shape(image_shape))
image_shape = torch.tensor(_rfft_shape, device=device)
if fftshifted is True:
fft_center = torch.divide(image_shape, 2, rounding_mode='floor')
if rfft is True:
Expand Down Expand Up @@ -438,17 +439,19 @@ def fftfreq_to_dft_coordinates(
coordinates: torch.Tensor
`(..., d)` array of coordinates into a fftshifted DFT.
"""
_image_shape = image_shape
image_shape = torch.as_tensor(
image_shape, device=frequencies.device, dtype=frequencies.dtype
_image_shape, device=frequencies.device, dtype=frequencies.dtype
)
_rfft_shape = rfft_shape(_image_shape)
_rfft_shape = torch.as_tensor(
rfft_shape(image_shape), device=frequencies.device, dtype=frequencies.dtype
_rfft_shape, device=frequencies.device, dtype=frequencies.dtype
)
coordinates = torch.empty_like(frequencies)
coordinates[..., :-1] = frequencies[..., :-1] * image_shape[:-1]
if rfft is True:
coordinates[..., -1] = frequencies[..., -1] * 2 * (_rfft_shape[-1] - 1)
else:
coordinates[..., -1] = frequencies[..., -1] * image_shape[-1]
dc = dft_center(image_shape, rfft=rfft, fftshifted=True, device=frequencies.device)
dc = dft_center(_image_shape, rfft=rfft, fftshifted=True, device=frequencies.device)
return coordinates + dc
3 changes: 2 additions & 1 deletion src/libtilt/interpolation/interpolate_dft_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def sample_dft_3d(
samples = torch.view_as_complex(samples.contiguous()) # (b, )

# pack data back up and return
[samples] = einops.unpack(samples, pattern='*', packed_shapes=ps)
# [samples] = einops.unpack(samples, pattern='*', packed_shapes=ps)
samples = samples.reshape(*ps) # replaces commented line above, for performance
return samples # (...)


Expand Down
73 changes: 52 additions & 21 deletions src/libtilt/projection/project_fourier.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import torch
import torch.nn.functional as F
import einops
Expand Down Expand Up @@ -33,30 +35,12 @@ def project_fourier(
projections: torch.Tensor
`(..., d, d)` array of projection images.
"""
# padding
if pad is True:
pad_length = volume.shape[-1] // 2
volume = F.pad(volume, pad=[pad_length] * 6, mode='constant', value=0)

# premultiply by sinc2
grid = fftfreq_grid(
image_shape=volume.shape,
rfft=False,
fftshift=True,
norm=True,
device=volume.device
)
volume = volume * torch.sinc(grid) ** 2

# calculate DFT
dft = torch.fft.fftshift(volume, dim=(-3, -2, -1)) # volume center to array origin
dft = torch.fft.rfftn(dft, dim=(-3, -2, -1))
dft = torch.fft.fftshift(dft, dim=(-3, -2,)) # actual fftshift of rfft
dft, vol_shape, pad_length = _compute_dft(volume, pad)

# make projections by taking central slices
projections = extract_central_slices_rfft(
dft=dft,
image_shape=volume.shape,
image_shape=vol_shape,
rotation_matrices=rotation_matrices,
rotation_matrix_zyx=rotation_matrix_zyx
) # (..., h, w) rfft
Expand Down Expand Up @@ -92,7 +76,8 @@ def extract_central_slices_rfft(

# flip coordinates in redundant half transform
conjugate_mask = grid[..., 2] < 0
conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3')
# conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3') #This operation does not compile
conjugate_mask = conjugate_mask.unsqueeze(-1).expand(*[-1] * len(conjugate_mask.shape), 3) #This does
grid[conjugate_mask] *= -1
conjugate_mask = conjugate_mask[..., 0] # un-repeat

Expand All @@ -107,3 +92,49 @@ def extract_central_slices_rfft(
# take complex conjugate of values from redundant half transform
projections[conjugate_mask] = torch.conj(projections[conjugate_mask])
return projections

def _compute_dft(
volume: torch.Tensor,
pad: bool = True,
pad_length: int | None = None
) -> Tuple[torch.Tensor, Tuple[int,int,int], int]:
"""Computes the DFT of a volume. Intended to be used as a preprocessing before using extract_central_slices_rfft.

Parameters
----------
volume: torch.Tensor
`(d, d, d)` volume.
pad: bool
Whether to pad the volume with zeros to increase sampling in the DFT.
pad_length: int | None
The length used for padding each side of each dimension. If pad_length=None, and pad=True then volume.shape[-1] // 2 is used instead

Returns
-------
projections: Tuple[torch.Tensor, torch.Tensor, int]
`(..., d, d, d)` dft of the volume. fftshifted rfft
Tuple[int,int,int] the shape of the volume after padding
int with the padding length
"""
# padding
if pad is True:
if pad_length is None:
pad_length = volume.shape[-1] // 2
volume = F.pad(volume, pad=[pad_length] * 6, mode='constant', value=0)

# premultiply by sinc2
grid = fftfreq_grid(
image_shape=volume.shape,
rfft=False,
fftshift=True,
norm=True,
device=volume.device
)
volume = volume * torch.sinc(grid) ** 2

# calculate DFT
dft = torch.fft.fftshift(volume, dim=(-3, -2, -1)) # volume center to array origin
dft = torch.fft.rfftn(dft, dim=(-3, -2, -1))
dft = torch.fft.fftshift(dft, dim=(-3, -2,)) # actual fftshift of rfft

return dft, volume.shape, pad_length
4 changes: 2 additions & 2 deletions src/libtilt/projection/project_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def project_real(volume: torch.Tensor, rotation_matrices: torch.Tensor) -> torch
torch_padding = einops.rearrange(torch_padding, 'whd pad -> (whd pad)')
volume = F.pad(volume, pad=tuple(torch_padding), mode='constant', value=0)
padded_volume_shape = (ps, ps, ps)
volume_coordinates = coordinate_grid(image_shape=padded_volume_shape)
volume_coordinates = coordinate_grid(image_shape=padded_volume_shape, device=volume.device)
volume_coordinates -= padded_sidelength // 2 # (d, h, w, zyx)
volume_coordinates = torch.flip(volume_coordinates, dims=(-1,)) # (d, h, w, zyx)
volume_coordinates = einops.rearrange(volume_coordinates, 'd h w zyx -> d h w zyx 1')
Expand All @@ -73,5 +73,5 @@ def _project_volume(rotation_matrix) -> torch.Tensor:

yl, yh = padding[1, 0], -padding[1, 1]
xl, xh = padding[2, 0], -padding[2, 1]
images = [_project_volume(matrix)[yl:yh, xl:xh] for matrix in rotation_matrices]
images = [_project_volume(matrix)[yl:yh, xl:xh] for matrix in rotation_matrices] #TODO: This can probabaly optimized using vmap
alisterburt marked this conversation as resolved.
Show resolved Hide resolved
return torch.stack(images, dim=0)