Skip to content

Commit

Permalink
complted. FSCDistance class
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Jul 10, 2024
1 parent 4dddfc0 commit e33801a
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_distance(self, map1, map2):
class FSCDistance(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

def get_distance_matrix(self, maps1, maps2): # custom method
maps_gt_flat = maps1
maps_user_flat = maps2
Expand All @@ -78,7 +78,13 @@ def get_distance_matrix(self, maps1, maps2): # custom method
maps_gt_flat_cube[:, mask] = maps_gt_flat
maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3)
maps_user_flat_cube[:, mask] = maps_user_flat
return compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix)

cost_matrix, fsc_matrix = compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix)
self.stored_computed_assets = {'fsc_matrix': fsc_matrix}
return cost_matrix

def get_computed_assets(self, maps1, maps2):
return self.stored_computed_assets # must run get_distance_matrix first

def run(config):
"""
Expand Down Expand Up @@ -137,29 +143,17 @@ def run(config):
if cost_label in config["analysis"]["metrics"]: # TODO: can remove
print("cost matrix", cost_label)

if (
cost_label == "fsc"
): # TODO: make pydantic (include base class). type hint inputs to this (what it needs like gt volumes and populations) # noqa: E501

cost_matrix, fsc_matrix = map_to_map_distance.get_distance_matrix(
maps_gt_flat, maps_user_flat
)
cost_matrix = cost_matrix.numpy()
computed_assets["fsc_matrix"] = fsc_matrix
else:

cost_matrix = map_to_map_distance.get_distance_matrix(
maps_gt_flat, maps_user_flat
).numpy()
computed_assets = map_to_map_distance.get_computed_assets(
maps_gt_flat, maps_user_flat
)
computed_assets.update(computed_assets)
cost_matrix = map_to_map_distance.get_distance_matrix(
maps_gt_flat, maps_user_flat
).numpy()
computed_assets = map_to_map_distance.get_computed_assets(
maps_gt_flat, maps_user_flat
)
computed_assets.update(computed_assets)

cost_matrix_df = pd.DataFrame(
cost_matrix, columns=None, index=metadata_gt.populations.tolist()
)
print(cost_matrix_df)

# output results
single_distance_results_dict = {
Expand Down

0 comments on commit e33801a

Please sign in to comment.