-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
114 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
from ..power_spectrum_utils import _centered_fftn, _centered_ifftn | ||
|
||
|
||
def _compute_bfactor_scaling(b_factor, box_size, voxel_size): | ||
x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) | ||
y = x.clone() | ||
z = x.clone() | ||
x, y, z = torch.meshgrid(x, y, z, indexing="ij") | ||
|
||
s2 = x**2 + y**2 + z**2 | ||
bfactor_scaling_torch = torch.exp(-b_factor * s2 / 4) | ||
|
||
return bfactor_scaling_torch | ||
|
||
|
||
def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=True): | ||
if not in_place: | ||
volumes = volumes.clone() | ||
|
||
b_factor_scaling = _compute_bfactor_scaling(bfactor, volumes.shape[-1], voxel_size) | ||
|
||
if len(volumes.shape) == 3: | ||
volumes = _centered_fftn(volumes, dim=(0, 1, 2)) | ||
volumes = volumes * b_factor_scaling | ||
volumes = _centered_ifftn(volumes, dim=(0, 1, 2)).real | ||
|
||
elif len(volumes.shape) == 4: | ||
volumes = _centered_fftn(volumes, dim=(1, 2, 3)) | ||
volumes = volumes * b_factor_scaling[None, ...] | ||
volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real | ||
|
||
else: | ||
raise ValueError("Input volumes must have 3 or 4 dimensions.") | ||
|
||
return volumes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
from cryo_challenge.power_spectrum_utils import _centered_ifftn, compute_power_spectrum | ||
from cryo_challenge._preprocessing.bfactor_normalize import ( | ||
_compute_bfactor_scaling, | ||
bfactor_normalize_volumes, | ||
) | ||
|
||
|
||
def test_compute_power_spectrum(): | ||
box_size = 224 | ||
volume_shape = (box_size, box_size, box_size) | ||
voxel_size = 1.073 * 2 | ||
|
||
freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) | ||
x = freq.clone() | ||
y = freq.clone() | ||
z = freq.clone() | ||
x, y, z = torch.meshgrid(x, y, z, indexing="ij") | ||
|
||
s2 = x**2 + y**2 + z**2 | ||
|
||
b_factor = 170 | ||
|
||
gaussian_volume = torch.exp(-b_factor / 4 * s2).reshape(volume_shape) | ||
gaussian_volume = _centered_ifftn(gaussian_volume) | ||
|
||
power_spectrum = compute_power_spectrum(gaussian_volume) | ||
power_spectrum_slice = ( | ||
torch.abs(torch.fft.fftn(gaussian_volume)[: box_size // 2, 0, 0]) ** 2 | ||
) | ||
|
||
mean_squared_error = torch.mean((power_spectrum - power_spectrum_slice) ** 2) | ||
|
||
assert mean_squared_error < 1e-3 | ||
|
||
return | ||
|
||
|
||
def test_bfactor_normalize_volumes(): | ||
box_size = 128 | ||
volume_shape = (box_size, box_size, box_size) | ||
voxel_size = 1.5 | ||
|
||
freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) | ||
x = freq.clone() | ||
y = freq.clone() | ||
z = freq.clone() | ||
x, y, z = torch.meshgrid(x, y, z, indexing="ij") | ||
|
||
s2 = x**2 + y**2 + z**2 | ||
|
||
oscillatory_volume = torch.sin(300 * s2).reshape(volume_shape) | ||
oscillatory_volume = _centered_ifftn(oscillatory_volume) | ||
bfactor_scaling_vol = _compute_bfactor_scaling(170, box_size, voxel_size) | ||
|
||
norm_oscillatory_vol = bfactor_normalize_volumes( | ||
oscillatory_volume, 170, voxel_size, in_place=False | ||
) | ||
|
||
ps_osci = torch.fft.fftn(oscillatory_volume, dim=(-3, -2, -1), norm="backward")[ | ||
: box_size // 2, 0, 0 | ||
] | ||
ps_norm_osci = torch.fft.fftn( | ||
norm_oscillatory_vol, dim=(-3, -2, -1), norm="backward" | ||
)[: box_size // 2, 0, 0] | ||
ps_bfactor_scaling = torch.fft.fftshift(bfactor_scaling_vol)[: box_size // 2, 0, 0] | ||
|
||
ps_osci = torch.abs(ps_osci) ** 2 | ||
ps_norm_osci = torch.abs(ps_norm_osci) ** 2 | ||
ps_bfactor_scaling = torch.abs(ps_bfactor_scaling) ** 2 | ||
|
||
assert torch.allclose(ps_norm_osci, ps_osci * ps_bfactor_scaling) |