Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rsg devel #54

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/libtilt/ctf/ctf_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def calculate_ctf(
image_shape: Tuple[int, int],
rfft: bool,
fftshift: bool,
device
alisterburt marked this conversation as resolved.
Show resolved Hide resolved
):
"""

Expand Down Expand Up @@ -56,20 +57,21 @@ def calculate_ctf(
Whether to apply fftshift on the resulting CTF images.
"""
# to torch.Tensor and unit conversions
defocus = torch.atleast_1d(torch.as_tensor(defocus, dtype=torch.float))
assert bool(rfft) + bool(fftshift) <= 1, "Error, only one of `rfft` and `fftshift` may be `True`."
alisterburt marked this conversation as resolved.
Show resolved Hide resolved
defocus = torch.atleast_1d(torch.as_tensor(defocus, dtype=torch.float, device=device))
defocus *= 1e4 # micrometers -> angstroms
astigmatism = torch.atleast_1d(torch.as_tensor(astigmatism, dtype=torch.float))
astigmatism = torch.atleast_1d(torch.as_tensor(astigmatism, dtype=torch.float, device=device))
astigmatism *= 1e4 # micrometers -> angstroms
astigmatism_angle = torch.atleast_1d(torch.as_tensor(astigmatism_angle, dtype=torch.float))
astigmatism_angle = torch.atleast_1d(torch.as_tensor(astigmatism_angle, dtype=torch.float, device=device))
astigmatism_angle *= (C.pi / 180) # degrees -> radians
pixel_size = torch.atleast_1d(torch.as_tensor(pixel_size))
voltage = torch.atleast_1d(torch.as_tensor(voltage, dtype=torch.float))
pixel_size = torch.atleast_1d(torch.as_tensor(pixel_size, device=device))
voltage = torch.atleast_1d(torch.as_tensor(voltage, dtype=torch.float, device=device))
voltage *= 1e3 # kV -> V
spherical_aberration = torch.atleast_1d(
torch.as_tensor(spherical_aberration, dtype=torch.float)
torch.as_tensor(spherical_aberration, dtype=torch.float, device=device)
)
spherical_aberration *= 1e7 # mm -> angstroms
image_shape = torch.as_tensor(image_shape)
image_shape = torch.as_tensor(image_shape, device=device)

# derived quantities used in CTF calculation
defocus_u = defocus + astigmatism
Expand All @@ -79,10 +81,10 @@ def calculate_ctf(
k2 = C.pi / 2 * spherical_aberration * _lambda ** 3
k3 = torch.tensor(np.deg2rad(phase_shift))
k4 = -b_factor / 4
k5 = np.arctan(amplitude_contrast / np.sqrt(1 - amplitude_contrast ** 2))
k5 = torch.arctan(amplitude_contrast / torch.sqrt(1 - amplitude_contrast ** 2))

# construct 2D frequency grids and rescale cycles / px -> cycles / Å
fftfreq_grid = _construct_fftfreq_grid_2d(image_shape=image_shape, rfft=rfft) # (h, w, 2)
fftfreq_grid = _construct_fftfreq_grid_2d(image_shape=image_shape, rfft=rfft, device=device) # (h, w, 2)
fftfreq_grid = fftfreq_grid / einops.rearrange(pixel_size, 'b -> b 1 1 1')
fftfreq_grid_squared = fftfreq_grid ** 2

Expand Down
11 changes: 7 additions & 4 deletions src/libtilt/fft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ def dft_center(
device: torch.device | None = None,
) -> torch.LongTensor:
"""Return the position of the DFT center for a given input shape."""
_rfft_shape = rfft_shape(image_shape)
fft_center = torch.zeros(size=(len(image_shape),), device=device)
image_shape = torch.as_tensor(image_shape).float()
if rfft is True:
image_shape = torch.tensor(rfft_shape(image_shape))
image_shape = torch.tensor(_rfft_shape, device=device)
if fftshifted is True:
fft_center = torch.divide(image_shape, 2, rounding_mode='floor')
if rfft is True:
Expand Down Expand Up @@ -438,17 +439,19 @@ def fftfreq_to_dft_coordinates(
coordinates: torch.Tensor
`(..., d)` array of coordinates into a fftshifted DFT.
"""
_image_shape = image_shape
image_shape = torch.as_tensor(
image_shape, device=frequencies.device, dtype=frequencies.dtype
_image_shape, device=frequencies.device, dtype=frequencies.dtype
)
_rfft_shape = rfft_shape(_image_shape)
_rfft_shape = torch.as_tensor(
rfft_shape(image_shape), device=frequencies.device, dtype=frequencies.dtype
_rfft_shape, device=frequencies.device, dtype=frequencies.dtype
)
coordinates = torch.empty_like(frequencies)
coordinates[..., :-1] = frequencies[..., :-1] * image_shape[:-1]
if rfft is True:
coordinates[..., -1] = frequencies[..., -1] * 2 * (_rfft_shape[-1] - 1)
else:
coordinates[..., -1] = frequencies[..., -1] * image_shape[-1]
dc = dft_center(image_shape, rfft=rfft, fftshifted=True, device=frequencies.device)
dc = dft_center(_image_shape, rfft=rfft, fftshifted=True, device=frequencies.device)
return coordinates + dc
2 changes: 1 addition & 1 deletion src/libtilt/grids/central_slice_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def rotated_central_slice_grid(
device=device,
) # (h, w, 3)
if rotation_matrix_zyx is False:
grid = torch.flip(grid, dims=(-1,))
grid = torch.flip(grid, dims=(-1,)) #TODO: This operation is slow since it is copying the full tensor
alisterburt marked this conversation as resolved.
Show resolved Hide resolved
rotation_matrices = einops.rearrange(rotation_matrices, '... i j -> ... 1 1 i j')
grid = einops.rearrange(grid, 'h w coords -> h w coords 1')
grid = rotation_matrices @ grid
Expand Down
3 changes: 2 additions & 1 deletion src/libtilt/interpolation/interpolate_dft_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def sample_dft_3d(
samples = torch.view_as_complex(samples.contiguous()) # (b, )

# pack data back up and return
[samples] = einops.unpack(samples, pattern='*', packed_shapes=ps)
# [samples] = einops.unpack(samples, pattern='*', packed_shapes=ps)
samples = samples.reshape(*ps) #Ask Alister if this will work in any situation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, is this a significant performance gain? I don't think this works in the general case but it should always work here as far as I can tell

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohhh I see - torch compile couldn't go through unpack?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. This is a requirement for the compiler

alisterburt marked this conversation as resolved.
Show resolved Hide resolved
return samples # (...)


Expand Down
73 changes: 52 additions & 21 deletions src/libtilt/projection/project_fourier.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import torch
import torch.nn.functional as F
import einops
Expand Down Expand Up @@ -33,30 +35,12 @@ def project_fourier(
projections: torch.Tensor
`(..., d, d)` array of projection images.
"""
# padding
if pad is True:
pad_length = volume.shape[-1] // 2
volume = F.pad(volume, pad=[pad_length] * 6, mode='constant', value=0)

# premultiply by sinc2
grid = fftfreq_grid(
image_shape=volume.shape,
rfft=False,
fftshift=True,
norm=True,
device=volume.device
)
volume = volume * torch.sinc(grid) ** 2

# calculate DFT
dft = torch.fft.fftshift(volume, dim=(-3, -2, -1)) # volume center to array origin
dft = torch.fft.rfftn(dft, dim=(-3, -2, -1))
dft = torch.fft.fftshift(dft, dim=(-3, -2,)) # actual fftshift of rfft
dft, vol_shape, pad_length = compute_vol_dtf(volume, pad)
alisterburt marked this conversation as resolved.
Show resolved Hide resolved

# make projections by taking central slices
projections = extract_central_slices_rfft(
dft=dft,
image_shape=volume.shape,
image_shape=vol_shape,
rotation_matrices=rotation_matrices,
rotation_matrix_zyx=rotation_matrix_zyx
) # (..., h, w) rfft
Expand Down Expand Up @@ -92,7 +76,8 @@ def extract_central_slices_rfft(

# flip coordinates in redundant half transform
conjugate_mask = grid[..., 2] < 0
conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3')
# conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3')
conjugate_mask.unsqueeze(-1).repeat(1, 1, 1, 3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeat should be creating a view here and shouldn't be memory intensive even though the tensor is huge - is this not the case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am correct, is expand and not repeat the one that is a view. This change was again a requirement for the compiler.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right for the torch API but einops it creates a view where possible - regardless, compilation is super important.

I'm a little hesitant to lose the rank polymorphism here and it looks like this unsqueeze/repeat is specific to b h w 3 rather than ... h w 3 -> could you try adding some code to intepret the current shape and unsqueeze/repeat according to that? This should allow us to maintain the current flexibility and have compatibility with the compiler

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work conjugate_mask.unsqueeze(-1).expand(*[-1] * len(conjugate_mask.shape), 3) and being more memory efficient, since it is a view.

grid[conjugate_mask] *= -1
conjugate_mask = conjugate_mask[..., 0] # un-repeat

Expand All @@ -107,3 +92,49 @@ def extract_central_slices_rfft(
# take complex conjugate of values from redundant half transform
projections[conjugate_mask] = torch.conj(projections[conjugate_mask])
return projections

def compute_vol_dtf( #TODO: Is this the best place to have this?
alisterburt marked this conversation as resolved.
Show resolved Hide resolved
volume: torch.Tensor,
pad: bool = True,
pad_length: int | None = None
) -> Tuple[torch.Tensor, Tuple[int,int,int], int]:
"""Project a cubic volume by sampling a central slice through its DFT.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this docstring needs fixing


Parameters
----------
volume: torch.Tensor
`(d, d, d)` volume.
pad: bool
Whether to pad the volume with zeros to increase sampling in the DFT.
pad_length: bool
alisterburt marked this conversation as resolved.
Show resolved Hide resolved
The lenght used for padding. If None, volume.shape[-1] // 2 is used instead
alisterburt marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
projections: Tuple[torch.Tensor, torch.Tensor, int]
`(..., d, d, d)` dft of the volume. fftshifted rfft
Tuple[int,int,int] the shape of the volume after padding
int with the padding length
"""
# padding
if pad is True:
if pad_length is None:
pad_length = volume.shape[-1] // 2
volume = F.pad(volume, pad=[pad_length] * 6, mode='constant', value=0)

# premultiply by sinc2
grid = fftfreq_grid(
image_shape=volume.shape,
rfft=False,
fftshift=True,
norm=True,
device=volume.device
)
volume = volume * torch.sinc(grid) ** 2

# calculate DFT
dft = torch.fft.fftshift(volume, dim=(-3, -2, -1)) # volume center to array origin
dft = torch.fft.rfftn(dft, dim=(-3, -2, -1))
dft = torch.fft.fftshift(dft, dim=(-3, -2,)) # actual fftshift of rfft

return dft, volume.shape, pad_length
alisterburt marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 2 additions & 2 deletions src/libtilt/projection/project_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def project_real(volume: torch.Tensor, rotation_matrices: torch.Tensor) -> torch
torch_padding = einops.rearrange(torch_padding, 'whd pad -> (whd pad)')
volume = F.pad(volume, pad=tuple(torch_padding), mode='constant', value=0)
padded_volume_shape = (ps, ps, ps)
volume_coordinates = coordinate_grid(image_shape=padded_volume_shape)
volume_coordinates = coordinate_grid(image_shape=padded_volume_shape, device=volume.device)
volume_coordinates -= padded_sidelength // 2 # (d, h, w, zyx)
volume_coordinates = torch.flip(volume_coordinates, dims=(-1,)) # (d, h, w, zyx)
volume_coordinates = einops.rearrange(volume_coordinates, 'd h w zyx -> d h w zyx 1')
Expand All @@ -73,5 +73,5 @@ def _project_volume(rotation_matrix) -> torch.Tensor:

yl, yh = padding[1, 0], -padding[1, 1]
xl, xh = padding[2, 0], -padding[2, 1]
images = [_project_volume(matrix)[yl:yh, xl:xh] for matrix in rotation_matrices]
images = [_project_volume(matrix)[yl:yh, xl:xh] for matrix in rotation_matrices] #TODO: This can probabaly optimized using vmap
alisterburt marked this conversation as resolved.
Show resolved Hide resolved
return torch.stack(images, dim=0)
2 changes: 1 addition & 1 deletion src/libtilt/shapes/soft_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def _add_soft_edge_single_binary_image(
) -> torch.FloatTensor:
if smoothing_radius == 0:
return image.float()
distances = ndi.distance_transform_edt(torch.logical_not(image))
distances = ndi.distance_transform_edt(torch.logical_not(image)) #TODO: This breaks if the input device is cuda
alisterburt marked this conversation as resolved.
Show resolved Hide resolved
distances = torch.as_tensor(distances, device=image.device).float()
idx = torch.logical_and(distances > 0, distances <= smoothing_radius)
output = torch.clone(image).float()
Expand Down