From e33801a0f0f7f06be09c3387f980fa1efd1bb11b Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Wed, 10 Jul 2024 11:29:58 -0400 Subject: [PATCH] complted. FSCDistance class --- .../_map_to_map/map_to_map_distance_matrix.py | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 46f7e1c..75ce02a 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -66,7 +66,7 @@ def get_distance(self, map1, map2): class FSCDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - + def get_distance_matrix(self, maps1, maps2): # custom method maps_gt_flat = maps1 maps_user_flat = maps2 @@ -78,7 +78,13 @@ def get_distance_matrix(self, maps1, maps2): # custom method maps_gt_flat_cube[:, mask] = maps_gt_flat maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) maps_user_flat_cube[:, mask] = maps_user_flat - return compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix) + + cost_matrix, fsc_matrix = compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix) + self.stored_computed_assets = {'fsc_matrix': fsc_matrix} + return cost_matrix + + def get_computed_assets(self, maps1, maps2): + return self.stored_computed_assets # must run get_distance_matrix first def run(config): """ @@ -137,29 +143,17 @@ def run(config): if cost_label in config["analysis"]["metrics"]: # TODO: can remove print("cost matrix", cost_label) - if ( - cost_label == "fsc" - ): # TODO: make pydantic (include base class). type hint inputs to this (what it needs like gt volumes and populations) # noqa: E501 - - cost_matrix, fsc_matrix = map_to_map_distance.get_distance_matrix( - maps_gt_flat, maps_user_flat - ) - cost_matrix = cost_matrix.numpy() - computed_assets["fsc_matrix"] = fsc_matrix - else: - - cost_matrix = map_to_map_distance.get_distance_matrix( - maps_gt_flat, maps_user_flat - ).numpy() - computed_assets = map_to_map_distance.get_computed_assets( - maps_gt_flat, maps_user_flat - ) - computed_assets.update(computed_assets) + cost_matrix = map_to_map_distance.get_distance_matrix( + maps_gt_flat, maps_user_flat + ).numpy() + computed_assets = map_to_map_distance.get_computed_assets( + maps_gt_flat, maps_user_flat + ) + computed_assets.update(computed_assets) cost_matrix_df = pd.DataFrame( cost_matrix, columns=None, index=metadata_gt.populations.tolist() ) - print(cost_matrix_df) # output results single_distance_results_dict = {