From 8cf8aea9388331f0c63a68d9815bfa5fd18b8690 Mon Sep 17 00:00:00 2001 From: DSilva27 Date: Tue, 3 Sep 2024 17:08:47 -0400 Subject: [PATCH] remove centering and align only one volume in preprocessing --- .../_preprocessing/align_utils.py | 72 ++++++++++++++----- .../_preprocessing/preprocessing_pipeline.py | 4 +- 2 files changed, 57 insertions(+), 19 deletions(-) diff --git a/src/cryo_challenge/_preprocessing/align_utils.py b/src/cryo_challenge/_preprocessing/align_utils.py index d2a7784..f0ae4fa 100644 --- a/src/cryo_challenge/_preprocessing/align_utils.py +++ b/src/cryo_challenge/_preprocessing/align_utils.py @@ -116,6 +116,45 @@ def center_submission(volumes: torch.Tensor, pixel_size: float) -> torch.Tensor: return volumes +# def align_submission( +# volumes: torch.Tensor, ref_volume: torch.Tensor, params: dict +# ) -> torch.Tensor: +# """ +# Align submission volumes to ground truth volume + +# Parameters: +# ----------- +# volumes (torch.Tensor): submission volumes +# shape: (n_volumes, im_x, im_y, im_z) +# ref_volume (torch.Tensor): ground truth volume +# shape: (im_x, im_y, im_z) +# params (dict): dictionary containing alignment parameters + +# Returns: +# -------- +# volumes (torch.Tensor): aligned submission volumes +# """ +# for i in range(len(volumes)): +# obj_vol = volumes[i].numpy().astype(np.float32).copy() + +# obj_vol = Volume(obj_vol / obj_vol.sum()) +# ref_vol = Volume(ref_volume.copy() / ref_volume.sum()) + +# _, R_est = align_BO( +# ref_vol, +# obj_vol, +# loss_type=params["BOT_loss"], +# downsampled_size=params["BOT_box_size"], +# max_iters=params["BOT_iter"], +# refine=params["BOT_refine"], +# ) +# R_est = Rotation(R_est.astype(np.float32)) + +# volumes[i] = torch.from_numpy(Volume(volumes[i].numpy()).rotate(R_est)._data) + +# return volumes + + def align_submission( volumes: torch.Tensor, ref_volume: torch.Tensor, params: dict ) -> torch.Tensor: @@ -134,22 +173,21 @@ def align_submission( -------- volumes (torch.Tensor): aligned submission volumes """ - for i in range(len(volumes)): - obj_vol = volumes[i].numpy().astype(np.float32).copy() - - obj_vol = Volume(obj_vol / obj_vol.sum()) - ref_vol = Volume(ref_volume.copy() / ref_volume.sum()) - - _, R_est = align_BO( - ref_vol, - obj_vol, - loss_type=params["BOT_loss"], - downsampled_size=params["BOT_box_size"], - max_iters=params["BOT_iter"], - refine=params["BOT_refine"], - ) - R_est = Rotation(R_est.astype(np.float32)) - - volumes[i] = torch.from_numpy(Volume(volumes[i].numpy()).rotate(R_est)._data) + obj_vol = volumes[0].numpy().astype(np.float32) + + obj_vol = Volume(obj_vol / obj_vol.sum()) + ref_vol = Volume(ref_volume / ref_volume.sum()) + + _, R_est = align_BO( + ref_vol, + obj_vol, + loss_type=params["BOT_loss"], + downsampled_size=params["BOT_box_size"], + max_iters=params["BOT_iter"], + refine=params["BOT_refine"], + ) + R_est = Rotation(R_est.astype(np.float32)) + + volumes = torch.from_numpy(Volume(volumes.numpy()).rotate(R_est)._data) return volumes diff --git a/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py b/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py index 90ccc51..b4c3e61 100644 --- a/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py +++ b/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py @@ -2,7 +2,7 @@ import json import os -from .align_utils import align_submission, center_submission, threshold_submissions +from .align_utils import align_submission, threshold_submissions from .crop_pad_utils import crop_pad_submission from .fourier_utils import downsample_submission @@ -80,7 +80,7 @@ def preprocess_submissions(submission_dataset, config): # center submission print(" Centering submission") - volumes = center_submission(volumes, pixel_size=pixel_size_gt) + # volumes = center_submission(volumes, pixel_size=pixel_size_gt) # flip handedness if submission_dataset.submission_config[str(idx)]["flip"] == 1: