Skip to content

Commit

Permalink
Merge branch 'dev' into current-preprocessing-pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Sep 3, 2024
2 parents cb4b812 + cef000d commit 1b594a0
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 42 deletions.
33 changes: 17 additions & 16 deletions src/cryo_challenge/_preprocessing/align_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,22 @@ def align_submission(
--------
volumes (torch.Tensor): aligned submission volumes
"""
obj_vol = volumes[0].numpy().astype(np.float32)

obj_vol = Volume(obj_vol / obj_vol.sum())
ref_vol = Volume(ref_volume / ref_volume.sum())

_, R_est = align_BO(
ref_vol,
obj_vol,
loss_type=params["BOT_loss"],
downsampled_size=params["BOT_box_size"],
max_iters=params["BOT_iter"],
refine=params["BOT_refine"],
)
R_est = Rotation(R_est.astype(np.float32))

volumes = torch.from_numpy(Volume(volumes.numpy()).rotate(R_est)._data)
for i in range(len(volumes)):
obj_vol = volumes[i].numpy().astype(np.float32).copy()

obj_vol = Volume(obj_vol / obj_vol.sum())
ref_vol = Volume(ref_volume.copy() / ref_volume.sum())

_, R_est = align_BO(
ref_vol,
obj_vol,
loss_type=params["BOT_loss"],
downsampled_size=params["BOT_box_size"],
max_iters=params["BOT_iter"],
refine=params["BOT_refine"],
)
R_est = Rotation(R_est.astype(np.float32))

volumes[i] = torch.from_numpy(Volume(volumes[i].numpy()).rotate(R_est)._data)

return volumes
89 changes: 89 additions & 0 deletions src/cryo_challenge/_preprocessing/bfactor_normalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
from ..power_spectrum_utils import _centered_fftn, _centered_ifftn


def _compute_bfactor_scaling(b_factor, box_size, voxel_size):
"""
Compute the B-factor scaling factor for a given B-factor, box size, and voxel size.
The B-factor scaling factor is computed as exp(-B * s^2 / 4), where s is the squared
distance in Fourier space.
Parameters
----------
b_factor: float
B-factor to apply.
box_size: int
Size of the box.
voxel_size: float
Voxel size of the box.
Returns
-------
bfactor_scaling_torch: torch.tensor(shape=(box_size, box_size, box_size))
B-factor scaling factor.
"""
x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size))
y = x.clone()
z = x.clone()
x, y, z = torch.meshgrid(x, y, z, indexing="ij")

s2 = x**2 + y**2 + z**2
bfactor_scaling_torch = torch.exp(-b_factor * s2 / 4)

return bfactor_scaling_torch


def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=False):
"""
Normalize volumes by applying a B-factor correction. This is done by multiplying
a centered Fourier transform of the volume by the B-factor scaling factor and then
applying the inverse Fourier transform. See _compute_bfactor_scaling for details on the
computation of the B-factor scaling.
Parameters
----------
volumes: torch.tensor
Volumes to normalize. The volumes must have shape (N, N, N) or (n_volumes, N, N, N).
bfactor: float
B-factor to apply.
voxel_size: float
Voxel size of the volumes.
in_place: bool - default: False
Whether to normalize the volumes in place.
Returns
-------
volumes: torch.tensor
Normalized volumes.
"""
# assert that volumes have the correct shape
assert volumes.ndim in [
3,
4,
], "Input volumes must have shape (N, N, N) or (n_volumes, N, N, N)"

if volumes.ndim == 3:
assert (
volumes.shape[0] == volumes.shape[1] == volumes.shape[2]
), "Input volumes must have equal dimensions"
else:
assert (
volumes.shape[1] == volumes.shape[2] == volumes.shape[3]
), "Input volumes must have equal dimensions"

if not in_place:
volumes = volumes.clone()

b_factor_scaling = _compute_bfactor_scaling(bfactor, volumes.shape[-1], voxel_size)

if len(volumes.shape) == 3:
volumes = _centered_fftn(volumes, dim=(0, 1, 2))
volumes = volumes * b_factor_scaling
volumes = _centered_ifftn(volumes, dim=(0, 1, 2)).real

elif len(volumes.shape) == 4:
volumes = _centered_fftn(volumes, dim=(1, 2, 3))
volumes = volumes * b_factor_scaling[None, ...]
volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real

return volumes
22 changes: 0 additions & 22 deletions src/cryo_challenge/_preprocessing/normalize.py

This file was deleted.

153 changes: 153 additions & 0 deletions src/cryo_challenge/power_spectrum_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch


def _cart2sph(x, y, z):
"""
Converts a grid in cartesian coordinates to spherical coordinates.
Parameters
----------
x: torch.tensor
x-coordinate of the grid.
y: torch.tensor
y-coordinate of the grid.
z: torch.tensor
"""
hxy = torch.hypot(x, y)
r = torch.hypot(hxy, z)
el = torch.atan2(z, hxy)
az = torch.atan2(y, x)
return az, el, r


def _grid_3d(n, dtype=torch.float32):
"""
Generates a centered 3D grid. The grid is given in both cartesian and spherical coordinates.
Parameters
----------
n: int
Size of the grid.
dtype: torch.dtype
Data type of the grid.
Returns
-------
grid: dict
Dictionary containing the grid in cartesian and spherical coordinates.
keys: x, y, z, phi, theta, r
"""
start = -n // 2 + 1
end = n // 2

if n % 2 == 0:
start -= 1 / 2
end -= 1 / 2

grid = torch.linspace(start, end, n, dtype=dtype)
z, x, y = torch.meshgrid(grid, grid, grid, indexing="ij")

phi, theta, r = _cart2sph(x, y, z)

theta = torch.pi / 2 - theta

return {"x": x, "y": y, "z": z, "phi": phi, "theta": theta, "r": r}


def _centered_fftn(x, dim=None):
"""
Wrapper around torch.fft.fftn that centers the Fourier transform.
"""
x = torch.fft.fftn(x, dim=dim)
x = torch.fft.fftshift(x, dim=dim)
return x


def _centered_ifftn(x, dim=None):
"""
Wrapper around torch.fft.ifftn that centers the inverse Fourier transform.
"""
x = torch.fft.fftshift(x, dim=dim)
x = torch.fft.ifftn(x, dim=dim)
return x


def _average_over_single_shell(shell_index, volume, radii, shell_width=0.5):
"""
Given a volume in Fourier space, compute the average value of the volume over a shell.
Parameters
----------
shell_index: int
Index of the shell in Fourier space.
volume: torch.tensor
Volume in Fourier space.
radii: torch.tensor
Radii of the Fourier space grid.
shell_width: float
Width of the shell.
Returns
-------
average: float
Average value of the volume over the shell.
"""
inner_diameter = shell_width + shell_index
outer_diameter = shell_width + (shell_index + 1)
mask = (radii > inner_diameter) & (radii < outer_diameter)
return torch.sum(mask * volume) / torch.sum(mask)


def _average_over_shells(volume_in_fourier_space, shell_width=0.5):
"""
Vmap wrapper over _average_over_single_shell to compute the average value of a volume in Fourier space over all shells. The input should be a volumetric quantity in Fourier space.
Parameters
----------
volume_in_fourier_space: torch.tensor
Volume in Fourier space.
Returns
-------
radial_average: torch.tensor
Average value of the volume over all shells.
"""
L = volume_in_fourier_space.shape[0]
dtype = torch.float32
radii = _grid_3d(L, dtype=dtype)["r"]

radial_average = torch.vmap(
_average_over_single_shell, in_dims=(0, None, None, None)
)(torch.arange(0, L // 2), volume_in_fourier_space, radii, shell_width)

return radial_average


def compute_power_spectrum(volume, shell_width=0.5):
"""
Compute the power spectrum of a volume.
Parameters
----------
volume: torch.tensor
Volume for which to compute the power spectrum.
shell_width: float
Width of the shell.
Returns
-------
power_spectrum: torch.tensor
Power spectrum of the volume.
Examples
--------
volume = mrcfile.open("volume.mrc").data.copy()
volume = torch.tensor(volume, dtype=torch.float32)
power_spectrum = compute_power_spectrum(volume)
"""

# Compute centered Fourier transforms.
vol_fft = torch.abs(_centered_fftn(volume)) ** 2
power_spectrum = _average_over_shells(vol_fft, shell_width=shell_width)

return power_spectrum
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
},
"0": {
"name": "raw_submission_in_testdata",
"align": 1,
"flavor_name": "test flavor",
"path": "tests/data/unprocessed_dataset_2_submissions/submission_x",
"populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt",
"submission_version": "1.0",
"box_size": 32,
"pixel_size": 15.022,
"path": "tests/data/unprocessed_dataset_2_submissions/submission_x",
"flip": 1,
"populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt",
"submission_version": "1.0"
"align": 1
}
}
Loading

0 comments on commit 1b594a0

Please sign in to comment.