Skip to content

Commit

Permalink
new general map to map distance class
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Jul 10, 2024
1 parent 8d0ae69 commit 67ebf6f
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,34 @@ def vmap_distance(
)(maps_gt)


class MapToMapDistance:
def __init__(self, config):
self.config = config

def get_distance(self, map1, map2):
raise NotImplementedError()

def get_distance_matrix(self, maps1, maps2):
chunk_size_submission = self.config["analysis"]["chunk_size_submission"]
chunk_size_gt = self.config["analysis"]["chunk_size_gt"]
distance_matrix = torch.vmap(
lambda maps1: torch.vmap(
lambda maps2: self.get_distance(maps1, maps2),
chunk_size=chunk_size_submission,
)(maps2),
chunk_size=chunk_size_gt,
)(maps1)

return distance_matrix

class L2Distance(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

def get_distance(self, map1, map2):
return torch.norm(map1 - map2)**2


def run(config):
"""
Compare a submission to ground truth.
Expand Down Expand Up @@ -99,6 +127,13 @@ def run(config):
)
cost_matrix = cost_matrix.numpy()
computed_assets["fsc_matrix"] = fsc_matrix
elif cost_label == "nope":
l2_map2map_distance = L2Distance(config)
cost_matrix = l2_map2map_distance.get_distance_matrix(
maps_gt_flat, maps_user_flat
).numpy()
print('run new method')

else:
cost_matrix = vmap_distance(
maps_gt_flat,
Expand All @@ -111,6 +146,7 @@ def run(config):
cost_matrix_df = pd.DataFrame(
cost_matrix, columns=None, index=metadata_gt.populations.tolist()
)
print(cost_matrix_df)

# output results
single_distance_results_dict = {
Expand Down

0 comments on commit 67ebf6f

Please sign in to comment.