Skip to content

Commit

Permalink
flags for masking and not masking
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Sep 11, 2024
1 parent de3cddc commit f20344e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
40 changes: 27 additions & 13 deletions src/cryo_challenge/_map_to_map/map_to_map_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,22 @@ class MapToMapDistanceLowMemory(MapToMapDistance):

def __init__(self, config):
super().__init__(config)
self.config = config

def compute_cost(self, map_1, map_2):
raise NotImplementedError()

@override
def get_distance(self, map1, map2, global_store_of_running_results):
map1 = map1.flatten()
map1 -= map1.median()
map1 /= map1.std()
if self.config["analysis"]["normalize"]["do"]:
if self.config["analysis"]["normalize"]["method"] == "median_zscore":
map1 -= map1.median()
map1 /= map1.std()
else:
raise NotImplementedError(
f"Normalization method {self.config['analysis']['normalize']['method']} not implemented."
)
map1 = map1[global_store_of_running_results["mask"]]

return self.compute_cost(map1, map2)
Expand Down Expand Up @@ -308,14 +315,19 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
maps_user_flat = maps2
n_pix = self.config["data"]["n_pix"]
maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3)
mask = (
mrcfile.open(self.config["data"]["mask"]["volume"])
.data.astype(bool)
.flatten()
)
maps_gt_flat_cube[:, mask] = maps_gt_flat
maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3)
maps_user_flat_cube[:, mask] = maps_user_flat

if self.config["data"]["mask"]["do"]:
mask = (
mrcfile.open(self.config["data"]["mask"]["volume"])
.data.astype(bool)
.flatten()
)
maps_gt_flat_cube[:, mask] = maps_gt_flat
maps_user_flat_cube[:, mask] = maps_user_flat
else:
maps_gt_flat_cube = maps_gt_flat
maps_user_flat_cube = maps_user_flat

cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk(
maps_gt_flat_cube, maps_user_flat_cube, n_pix
Expand All @@ -334,18 +346,20 @@ class FSCDistanceLowMemory(MapToMapDistance):
def __init__(self, config):
super().__init__(config)
self.n_pix = self.config["data"]["n_pix"]
self.config = config

def compute_cost(self, map_1, map_2):
raise NotImplementedError()

@override
def get_distance(self, map1, map2, global_store_of_running_results):
map_gt_flat = map1 = map1.flatten()
map1 -= map1.median()
map1 /= map1.std()
map_gt_flat_cube = torch.zeros(self.n_pix**3)
map1 = map1[global_store_of_running_results["mask"]]
map_gt_flat_cube[global_store_of_running_results["mask"]] = map_gt_flat
if self.config["data"]["mask"]["do"]:
map_gt_flat = map_gt_flat[global_store_of_running_results["mask"]]
map_gt_flat_cube[global_store_of_running_results["mask"]] = map_gt_flat
else:
map_gt_flat_cube = map_gt_flat

corr_vector = fourier_shell_correlation(
map_gt_flat_cube.reshape(self.n_pix, self.n_pix, self.n_pix),
Expand Down
4 changes: 4 additions & 0 deletions src/cryo_challenge/_map_to_map/map_to_map_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def run(config):
maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True)
maps_user_flat -= maps_user_flat.median(dim=1, keepdim=True).values
maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True)
else:
raise NotImplementedError(
f"Normalization method {config['analysis']['normalize']['method']} not implemented."
)

computed_assets = {}
results_dict["mask"] = mask
Expand Down

0 comments on commit f20344e

Please sign in to comment.