From f20344e8b6ac36384f34baa8fa96f8fc69d9adad Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 21:10:08 -0400 Subject: [PATCH] flags for masking and not masking --- .../_map_to_map/map_to_map_distance.py | 40 +++++++++++++------ .../_map_to_map/map_to_map_pipeline.py | 4 ++ 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 70df449..4769528 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -60,6 +60,7 @@ class MapToMapDistanceLowMemory(MapToMapDistance): def __init__(self, config): super().__init__(config) + self.config = config def compute_cost(self, map_1, map_2): raise NotImplementedError() @@ -67,8 +68,14 @@ def compute_cost(self, map_1, map_2): @override def get_distance(self, map1, map2, global_store_of_running_results): map1 = map1.flatten() - map1 -= map1.median() - map1 /= map1.std() + if self.config["analysis"]["normalize"]["do"]: + if self.config["analysis"]["normalize"]["method"] == "median_zscore": + map1 -= map1.median() + map1 /= map1.std() + else: + raise NotImplementedError( + f"Normalization method {self.config['analysis']['normalize']['method']} not implemented." + ) map1 = map1[global_store_of_running_results["mask"]] return self.compute_cost(map1, map2) @@ -308,14 +315,19 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): maps_user_flat = maps2 n_pix = self.config["data"]["n_pix"] maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) - mask = ( - mrcfile.open(self.config["data"]["mask"]["volume"]) - .data.astype(bool) - .flatten() - ) - maps_gt_flat_cube[:, mask] = maps_gt_flat maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) - maps_user_flat_cube[:, mask] = maps_user_flat + + if self.config["data"]["mask"]["do"]: + mask = ( + mrcfile.open(self.config["data"]["mask"]["volume"]) + .data.astype(bool) + .flatten() + ) + maps_gt_flat_cube[:, mask] = maps_gt_flat + maps_user_flat_cube[:, mask] = maps_user_flat + else: + maps_gt_flat_cube = maps_gt_flat + maps_user_flat_cube = maps_user_flat cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk( maps_gt_flat_cube, maps_user_flat_cube, n_pix @@ -334,6 +346,7 @@ class FSCDistanceLowMemory(MapToMapDistance): def __init__(self, config): super().__init__(config) self.n_pix = self.config["data"]["n_pix"] + self.config = config def compute_cost(self, map_1, map_2): raise NotImplementedError() @@ -341,11 +354,12 @@ def compute_cost(self, map_1, map_2): @override def get_distance(self, map1, map2, global_store_of_running_results): map_gt_flat = map1 = map1.flatten() - map1 -= map1.median() - map1 /= map1.std() map_gt_flat_cube = torch.zeros(self.n_pix**3) - map1 = map1[global_store_of_running_results["mask"]] - map_gt_flat_cube[global_store_of_running_results["mask"]] = map_gt_flat + if self.config["data"]["mask"]["do"]: + map_gt_flat = map_gt_flat[global_store_of_running_results["mask"]] + map_gt_flat_cube[global_store_of_running_results["mask"]] = map_gt_flat + else: + map_gt_flat_cube = map_gt_flat corr_vector = fourier_shell_correlation( map_gt_flat_cube.reshape(self.n_pix, self.n_pix, self.n_pix), diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 60d3f34..d7db04b 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -108,6 +108,10 @@ def run(config): maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) maps_user_flat -= maps_user_flat.median(dim=1, keepdim=True).values maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True) + else: + raise NotImplementedError( + f"Normalization method {config['analysis']['normalize']['method']} not implemented." + ) computed_assets = {} results_dict["mask"] = mask