Skip to content

Commit

Permalink
old and new methods numerically agreeing
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Jul 10, 2024
1 parent 67ebf6f commit 648a8c9
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 10 deletions.
50 changes: 40 additions & 10 deletions src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,33 @@ def get_distance_matrix(self, maps1, maps2):

return distance_matrix

class L2Distance(MapToMapDistance):
class L2DistanceNorm(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

def get_distance(self, map1, map2):
return torch.norm(map1 - map2)**2

class L2DistanceSum(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

def get_distance(self, map1, map2):
return compute_cost_l2(map1, map2)

class Correlation(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

def get_distance(self, map1, map2):
return compute_cost_corr(map1, map2)

class BioEM3dDistance(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

def get_distance(self, map1, map2):
return compute_bioem3d_cost(map1, map2)

def run(config):
"""
Expand All @@ -69,8 +89,7 @@ def run(config):
label_key = config["data"]["submission"]["label_key"]
user_submission_label = submission[label_key]

# n_trunc = 10
metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"])#[:n_trunc]
metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"])

results_dict = {}
results_dict["config"] = config
Expand All @@ -80,9 +99,12 @@ def run(config):

cost_funcs_d = {
"fsc": compute_cost_fsc_chunk,
"corr": compute_cost_corr,
"l2": compute_cost_l2,
"bioem": compute_bioem3d_cost,
"corrold": compute_cost_corr,
"corr": Correlation(config).get_distance_matrix,
"l2old": compute_cost_l2,
"l2": L2DistanceSum(config).get_distance_matrix,
"bioemold": compute_bioem3d_cost,
"bioem": BioEM3dDistance(config).get_distance_matrix,
}

maps_user_flat = submission[submission_volume_key].reshape(
Expand Down Expand Up @@ -127,12 +149,20 @@ def run(config):
)
cost_matrix = cost_matrix.numpy()
computed_assets["fsc_matrix"] = fsc_matrix
elif cost_label == "nope":
l2_map2map_distance = L2Distance(config)
cost_matrix = l2_map2map_distance.get_distance_matrix(
elif cost_label == "l2":
cost_matrix = cost_func(
maps_gt_flat, maps_user_flat
).numpy()
elif cost_label == "corr":
corr_map2map_distance = Correlation(config)
cost_matrix = corr_map2map_distance.get_distance_matrix(
maps_gt_flat, maps_user_flat
).numpy()
elif cost_label == "bioem":
bioem_map2map_distance = BioEM3dDistance(config)
cost_matrix = bioem_map2map_distance.get_distance_matrix(
maps_gt_flat, maps_user_flat
).numpy()
print('run new method')

else:
cost_matrix = vmap_distance(
Expand Down
6 changes: 6 additions & 0 deletions src/cryo_challenge/data/_validation/output_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ class MapToMapResultsValidator:
config: dict
user_submitted_populations: torch.Tensor
corr: Optional[dict] = None
corrold: Optional[dict] = None
l2: Optional[dict] = None
l2old: Optional[dict] = None
bioem: Optional[dict] = None
bioemold: Optional[dict] = None
fsc: Optional[dict] = None

def __post_init__(self):
Expand Down Expand Up @@ -142,8 +145,11 @@ class DistributionToDistributionResultsValidator:
id: str
fsc: Optional[dict] = None
bioem: Optional[dict] = None
bioemold: Optional[dict] = None
l2: Optional[dict] = None
l2old: Optional[dict] = None
corr: Optional[dict] = None
corrold: Optional[dict] = None

def __post_init__(self):
validate_input_config_disttodist(self.config)
Expand Down
5 changes: 5 additions & 0 deletions tests/config_files/test_config_map_to_map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ data:
analysis:
metrics:
- l2
- l2old
- corr
- corrold
- bioem
- bioemold
chunk_size_submission: 80
chunk_size_gt: 190
normalize:
Expand Down

0 comments on commit 648a8c9

Please sign in to comment.