Skip to content

Commit

Permalink
merge issue14
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Aug 12, 2024
2 parents c8a8e91 + a06b96d commit 072c1ae
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 37 deletions.
36 changes: 36 additions & 0 deletions src/cryo_challenge/_preprocessing/bfactor_normalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from ..power_spectrum_utils import _centered_fftn, _centered_ifftn


def _compute_bfactor_scaling(b_factor, box_size, voxel_size):
x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size))
y = x.clone()
z = x.clone()
x, y, z = torch.meshgrid(x, y, z, indexing="ij")

s2 = x**2 + y**2 + z**2
bfactor_scaling_torch = torch.exp(-b_factor * s2 / 4)

return bfactor_scaling_torch


def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=True):
if not in_place:
volumes = volumes.clone()

b_factor_scaling = _compute_bfactor_scaling(bfactor, volumes.shape[-1], voxel_size)

if len(volumes.shape) == 3:
volumes = _centered_fftn(volumes, dim=(0, 1, 2))
volumes = volumes * b_factor_scaling
volumes = _centered_ifftn(volumes, dim=(0, 1, 2)).real

elif len(volumes.shape) == 4:
volumes = _centered_fftn(volumes, dim=(1, 2, 3))
volumes = volumes * b_factor_scaling[None, ...]
volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real

else:
raise ValueError("Input volumes must have 3 or 4 dimensions.")

return volumes
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
"""
Power spectrum normalization and required utility functions
"""

import torch


Expand Down Expand Up @@ -58,7 +54,7 @@ def _compute_power_spectrum_shell(index, volume, radii, shell_width=0.5):
inner_diameter = shell_width + index
outer_diameter = shell_width + (index + 1)
mask = (radii > inner_diameter) & (radii < outer_diameter)
return torch.norm(mask * volume) ** 2
return torch.sum(mask * volume) / torch.sum(mask)


def compute_power_spectrum(volume, shell_width=0.5):
Expand All @@ -67,36 +63,9 @@ def compute_power_spectrum(volume, shell_width=0.5):
radii = _grid_3d(L, dtype=dtype)["r"]

# Compute centered Fourier transforms.
vol_fft = _centered_fftn(volume)
vol_fft = torch.abs(_centered_fftn(volume)) ** 2

power_spectrum = torch.vmap(
_compute_power_spectrum_shell, in_dims=(0, None, None, None)
)(torch.arange(0, L // 2), vol_fft, radii, shell_width)
return power_spectrum


def normalize_power_spectrum(volumes, power_spectrum_ref, shell_width=0.5):
L = volumes.shape[-1]
dtype = torch.float32
radii = _grid_3d(L, dtype=dtype)["r"]

# Compute centered Fourier transforms.
vols_fft = _centered_fftn(volumes, dim=(1, 2, 3))

inner_diameter = shell_width
for i in range(0, L // 2):
# Compute ring mask
outer_diameter = shell_width + (i + 1)
ring_mask = (radii > inner_diameter) & (radii < outer_diameter)

power_spectrum_sqrt = torch.norm(vols_fft[:, ring_mask], dim=1)
vols_fft[:, ring_mask] = (
vols_fft[:, ring_mask]
/ (power_spectrum_sqrt[:, None] + 1e-12)
* torch.sqrt(power_spectrum_ref[i])
)

# # Update ring
inner_diameter = outer_diameter

return _centered_ifftn(vols_fft, dim=(1, 2, 3)).real
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
},
"0": {
"name": "raw_submission_in_testdata",
"align": 1,
"flavor_name": "test flavor",
"path": "tests/data/unprocessed_dataset_2_submissions/submission_x",
"populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt",
"submission_version": "1.0",
"box_size": 32,
"pixel_size": 15.022,
"path": "tests/data/unprocessed_dataset_2_submissions/submission_x",
"flip": 1,
"populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt",
"submission_version": "1.0"
"align": 1
}
}
72 changes: 72 additions & 0 deletions tests/test_power_spectrum_and_bfactor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
from cryo_challenge.power_spectrum_utils import _centered_ifftn, compute_power_spectrum
from cryo_challenge._preprocessing.bfactor_normalize import (
_compute_bfactor_scaling,
bfactor_normalize_volumes,
)


def test_compute_power_spectrum():
box_size = 224
volume_shape = (box_size, box_size, box_size)
voxel_size = 1.073 * 2

freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size))
x = freq.clone()
y = freq.clone()
z = freq.clone()
x, y, z = torch.meshgrid(x, y, z, indexing="ij")

s2 = x**2 + y**2 + z**2

b_factor = 170

gaussian_volume = torch.exp(-b_factor / 4 * s2).reshape(volume_shape)
gaussian_volume = _centered_ifftn(gaussian_volume)

power_spectrum = compute_power_spectrum(gaussian_volume)
power_spectrum_slice = (
torch.abs(torch.fft.fftn(gaussian_volume)[: box_size // 2, 0, 0]) ** 2
)

mean_squared_error = torch.mean((power_spectrum - power_spectrum_slice) ** 2)

assert mean_squared_error < 1e-3

return


def test_bfactor_normalize_volumes():
box_size = 128
volume_shape = (box_size, box_size, box_size)
voxel_size = 1.5

freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size))
x = freq.clone()
y = freq.clone()
z = freq.clone()
x, y, z = torch.meshgrid(x, y, z, indexing="ij")

s2 = x**2 + y**2 + z**2

oscillatory_volume = torch.sin(300 * s2).reshape(volume_shape)
oscillatory_volume = _centered_ifftn(oscillatory_volume)
bfactor_scaling_vol = _compute_bfactor_scaling(170, box_size, voxel_size)

norm_oscillatory_vol = bfactor_normalize_volumes(
oscillatory_volume, 170, voxel_size, in_place=False
)

ps_osci = torch.fft.fftn(oscillatory_volume, dim=(-3, -2, -1), norm="backward")[
: box_size // 2, 0, 0
]
ps_norm_osci = torch.fft.fftn(
norm_oscillatory_vol, dim=(-3, -2, -1), norm="backward"
)[: box_size // 2, 0, 0]
ps_bfactor_scaling = torch.fft.fftshift(bfactor_scaling_vol)[: box_size // 2, 0, 0]

ps_osci = torch.abs(ps_osci) ** 2
ps_norm_osci = torch.abs(ps_norm_osci) ** 2
ps_bfactor_scaling = torch.abs(ps_bfactor_scaling) ** 2

assert torch.allclose(ps_norm_osci, ps_osci * ps_bfactor_scaling)

0 comments on commit 072c1ae

Please sign in to comment.