Skip to content

Commit

Permalink
remove centering and align only one volume in preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Sep 3, 2024
1 parent 1b594a0 commit 8cf8aea
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 19 deletions.
72 changes: 55 additions & 17 deletions src/cryo_challenge/_preprocessing/align_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,45 @@ def center_submission(volumes: torch.Tensor, pixel_size: float) -> torch.Tensor:
return volumes


# def align_submission(
# volumes: torch.Tensor, ref_volume: torch.Tensor, params: dict
# ) -> torch.Tensor:
# """
# Align submission volumes to ground truth volume

# Parameters:
# -----------
# volumes (torch.Tensor): submission volumes
# shape: (n_volumes, im_x, im_y, im_z)
# ref_volume (torch.Tensor): ground truth volume
# shape: (im_x, im_y, im_z)
# params (dict): dictionary containing alignment parameters

# Returns:
# --------
# volumes (torch.Tensor): aligned submission volumes
# """
# for i in range(len(volumes)):
# obj_vol = volumes[i].numpy().astype(np.float32).copy()

# obj_vol = Volume(obj_vol / obj_vol.sum())
# ref_vol = Volume(ref_volume.copy() / ref_volume.sum())

# _, R_est = align_BO(
# ref_vol,
# obj_vol,
# loss_type=params["BOT_loss"],
# downsampled_size=params["BOT_box_size"],
# max_iters=params["BOT_iter"],
# refine=params["BOT_refine"],
# )
# R_est = Rotation(R_est.astype(np.float32))

# volumes[i] = torch.from_numpy(Volume(volumes[i].numpy()).rotate(R_est)._data)

# return volumes


def align_submission(
volumes: torch.Tensor, ref_volume: torch.Tensor, params: dict
) -> torch.Tensor:
Expand All @@ -134,22 +173,21 @@ def align_submission(
--------
volumes (torch.Tensor): aligned submission volumes
"""
for i in range(len(volumes)):
obj_vol = volumes[i].numpy().astype(np.float32).copy()

obj_vol = Volume(obj_vol / obj_vol.sum())
ref_vol = Volume(ref_volume.copy() / ref_volume.sum())

_, R_est = align_BO(
ref_vol,
obj_vol,
loss_type=params["BOT_loss"],
downsampled_size=params["BOT_box_size"],
max_iters=params["BOT_iter"],
refine=params["BOT_refine"],
)
R_est = Rotation(R_est.astype(np.float32))

volumes[i] = torch.from_numpy(Volume(volumes[i].numpy()).rotate(R_est)._data)
obj_vol = volumes[0].numpy().astype(np.float32)

obj_vol = Volume(obj_vol / obj_vol.sum())
ref_vol = Volume(ref_volume / ref_volume.sum())

_, R_est = align_BO(
ref_vol,
obj_vol,
loss_type=params["BOT_loss"],
downsampled_size=params["BOT_box_size"],
max_iters=params["BOT_iter"],
refine=params["BOT_refine"],
)
R_est = Rotation(R_est.astype(np.float32))

volumes = torch.from_numpy(Volume(volumes.numpy()).rotate(R_est)._data)

return volumes
4 changes: 2 additions & 2 deletions src/cryo_challenge/_preprocessing/preprocessing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import os

from .align_utils import align_submission, center_submission, threshold_submissions
from .align_utils import align_submission, threshold_submissions
from .crop_pad_utils import crop_pad_submission
from .fourier_utils import downsample_submission

Expand Down Expand Up @@ -80,7 +80,7 @@ def preprocess_submissions(submission_dataset, config):

# center submission
print(" Centering submission")
volumes = center_submission(volumes, pixel_size=pixel_size_gt)
# volumes = center_submission(volumes, pixel_size=pixel_size_gt)

# flip handedness
if submission_dataset.submission_config[str(idx)]["flip"] == 1:
Expand Down

0 comments on commit 8cf8aea

Please sign in to comment.