-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'dev' into current-preprocessing-pipeline
- Loading branch information
Showing
6 changed files
with
348 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import torch | ||
from ..power_spectrum_utils import _centered_fftn, _centered_ifftn | ||
|
||
|
||
def _compute_bfactor_scaling(b_factor, box_size, voxel_size): | ||
""" | ||
Compute the B-factor scaling factor for a given B-factor, box size, and voxel size. | ||
The B-factor scaling factor is computed as exp(-B * s^2 / 4), where s is the squared | ||
distance in Fourier space. | ||
Parameters | ||
---------- | ||
b_factor: float | ||
B-factor to apply. | ||
box_size: int | ||
Size of the box. | ||
voxel_size: float | ||
Voxel size of the box. | ||
Returns | ||
------- | ||
bfactor_scaling_torch: torch.tensor(shape=(box_size, box_size, box_size)) | ||
B-factor scaling factor. | ||
""" | ||
x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) | ||
y = x.clone() | ||
z = x.clone() | ||
x, y, z = torch.meshgrid(x, y, z, indexing="ij") | ||
|
||
s2 = x**2 + y**2 + z**2 | ||
bfactor_scaling_torch = torch.exp(-b_factor * s2 / 4) | ||
|
||
return bfactor_scaling_torch | ||
|
||
|
||
def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=False): | ||
""" | ||
Normalize volumes by applying a B-factor correction. This is done by multiplying | ||
a centered Fourier transform of the volume by the B-factor scaling factor and then | ||
applying the inverse Fourier transform. See _compute_bfactor_scaling for details on the | ||
computation of the B-factor scaling. | ||
Parameters | ||
---------- | ||
volumes: torch.tensor | ||
Volumes to normalize. The volumes must have shape (N, N, N) or (n_volumes, N, N, N). | ||
bfactor: float | ||
B-factor to apply. | ||
voxel_size: float | ||
Voxel size of the volumes. | ||
in_place: bool - default: False | ||
Whether to normalize the volumes in place. | ||
Returns | ||
------- | ||
volumes: torch.tensor | ||
Normalized volumes. | ||
""" | ||
# assert that volumes have the correct shape | ||
assert volumes.ndim in [ | ||
3, | ||
4, | ||
], "Input volumes must have shape (N, N, N) or (n_volumes, N, N, N)" | ||
|
||
if volumes.ndim == 3: | ||
assert ( | ||
volumes.shape[0] == volumes.shape[1] == volumes.shape[2] | ||
), "Input volumes must have equal dimensions" | ||
else: | ||
assert ( | ||
volumes.shape[1] == volumes.shape[2] == volumes.shape[3] | ||
), "Input volumes must have equal dimensions" | ||
|
||
if not in_place: | ||
volumes = volumes.clone() | ||
|
||
b_factor_scaling = _compute_bfactor_scaling(bfactor, volumes.shape[-1], voxel_size) | ||
|
||
if len(volumes.shape) == 3: | ||
volumes = _centered_fftn(volumes, dim=(0, 1, 2)) | ||
volumes = volumes * b_factor_scaling | ||
volumes = _centered_ifftn(volumes, dim=(0, 1, 2)).real | ||
|
||
elif len(volumes.shape) == 4: | ||
volumes = _centered_fftn(volumes, dim=(1, 2, 3)) | ||
volumes = volumes * b_factor_scaling[None, ...] | ||
volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real | ||
|
||
return volumes |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import torch | ||
|
||
|
||
def _cart2sph(x, y, z): | ||
""" | ||
Converts a grid in cartesian coordinates to spherical coordinates. | ||
Parameters | ||
---------- | ||
x: torch.tensor | ||
x-coordinate of the grid. | ||
y: torch.tensor | ||
y-coordinate of the grid. | ||
z: torch.tensor | ||
""" | ||
hxy = torch.hypot(x, y) | ||
r = torch.hypot(hxy, z) | ||
el = torch.atan2(z, hxy) | ||
az = torch.atan2(y, x) | ||
return az, el, r | ||
|
||
|
||
def _grid_3d(n, dtype=torch.float32): | ||
""" | ||
Generates a centered 3D grid. The grid is given in both cartesian and spherical coordinates. | ||
Parameters | ||
---------- | ||
n: int | ||
Size of the grid. | ||
dtype: torch.dtype | ||
Data type of the grid. | ||
Returns | ||
------- | ||
grid: dict | ||
Dictionary containing the grid in cartesian and spherical coordinates. | ||
keys: x, y, z, phi, theta, r | ||
""" | ||
start = -n // 2 + 1 | ||
end = n // 2 | ||
|
||
if n % 2 == 0: | ||
start -= 1 / 2 | ||
end -= 1 / 2 | ||
|
||
grid = torch.linspace(start, end, n, dtype=dtype) | ||
z, x, y = torch.meshgrid(grid, grid, grid, indexing="ij") | ||
|
||
phi, theta, r = _cart2sph(x, y, z) | ||
|
||
theta = torch.pi / 2 - theta | ||
|
||
return {"x": x, "y": y, "z": z, "phi": phi, "theta": theta, "r": r} | ||
|
||
|
||
def _centered_fftn(x, dim=None): | ||
""" | ||
Wrapper around torch.fft.fftn that centers the Fourier transform. | ||
""" | ||
x = torch.fft.fftn(x, dim=dim) | ||
x = torch.fft.fftshift(x, dim=dim) | ||
return x | ||
|
||
|
||
def _centered_ifftn(x, dim=None): | ||
""" | ||
Wrapper around torch.fft.ifftn that centers the inverse Fourier transform. | ||
""" | ||
x = torch.fft.fftshift(x, dim=dim) | ||
x = torch.fft.ifftn(x, dim=dim) | ||
return x | ||
|
||
|
||
def _average_over_single_shell(shell_index, volume, radii, shell_width=0.5): | ||
""" | ||
Given a volume in Fourier space, compute the average value of the volume over a shell. | ||
Parameters | ||
---------- | ||
shell_index: int | ||
Index of the shell in Fourier space. | ||
volume: torch.tensor | ||
Volume in Fourier space. | ||
radii: torch.tensor | ||
Radii of the Fourier space grid. | ||
shell_width: float | ||
Width of the shell. | ||
Returns | ||
------- | ||
average: float | ||
Average value of the volume over the shell. | ||
""" | ||
inner_diameter = shell_width + shell_index | ||
outer_diameter = shell_width + (shell_index + 1) | ||
mask = (radii > inner_diameter) & (radii < outer_diameter) | ||
return torch.sum(mask * volume) / torch.sum(mask) | ||
|
||
|
||
def _average_over_shells(volume_in_fourier_space, shell_width=0.5): | ||
""" | ||
Vmap wrapper over _average_over_single_shell to compute the average value of a volume in Fourier space over all shells. The input should be a volumetric quantity in Fourier space. | ||
Parameters | ||
---------- | ||
volume_in_fourier_space: torch.tensor | ||
Volume in Fourier space. | ||
Returns | ||
------- | ||
radial_average: torch.tensor | ||
Average value of the volume over all shells. | ||
""" | ||
L = volume_in_fourier_space.shape[0] | ||
dtype = torch.float32 | ||
radii = _grid_3d(L, dtype=dtype)["r"] | ||
|
||
radial_average = torch.vmap( | ||
_average_over_single_shell, in_dims=(0, None, None, None) | ||
)(torch.arange(0, L // 2), volume_in_fourier_space, radii, shell_width) | ||
|
||
return radial_average | ||
|
||
|
||
def compute_power_spectrum(volume, shell_width=0.5): | ||
""" | ||
Compute the power spectrum of a volume. | ||
Parameters | ||
---------- | ||
volume: torch.tensor | ||
Volume for which to compute the power spectrum. | ||
shell_width: float | ||
Width of the shell. | ||
Returns | ||
------- | ||
power_spectrum: torch.tensor | ||
Power spectrum of the volume. | ||
Examples | ||
-------- | ||
volume = mrcfile.open("volume.mrc").data.copy() | ||
volume = torch.tensor(volume, dtype=torch.float32) | ||
power_spectrum = compute_power_spectrum(volume) | ||
""" | ||
|
||
# Compute centered Fourier transforms. | ||
vol_fft = torch.abs(_centered_fftn(volume)) ** 2 | ||
power_spectrum = _average_over_shells(vol_fft, shell_width=shell_width) | ||
|
||
return power_spectrum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.