Skip to content

Commit

Permalink
remove old reference to l2, corr, and bioem3d
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Jul 10, 2024
1 parent 648a8c9 commit 4438b92
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 48 deletions.
40 changes: 1 addition & 39 deletions src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,6 @@
from ..data._validation.output_validators import MapToMapResultsValidator


def vmap_distance(
maps_gt,
maps_submission,
map_to_map_distance,
chunk_size_gt=None,
chunk_size_submission=None,
):
return torch.vmap(
lambda maps_gt: torch.vmap(
lambda maps_submission: map_to_map_distance(maps_gt, maps_submission),
chunk_size=chunk_size_submission,
)(maps_submission),
chunk_size=chunk_size_gt,
)(maps_gt)


class MapToMapDistance:
def __init__(self, config):
self.config = config
Expand Down Expand Up @@ -99,11 +83,8 @@ def run(config):

cost_funcs_d = {
"fsc": compute_cost_fsc_chunk,
"corrold": compute_cost_corr,
"corr": Correlation(config).get_distance_matrix,
"l2old": compute_cost_l2,
"l2": L2DistanceSum(config).get_distance_matrix,
"bioemold": compute_bioem3d_cost,
"bioem": BioEM3dDistance(config).get_distance_matrix,
}

Expand Down Expand Up @@ -149,29 +130,10 @@ def run(config):
)
cost_matrix = cost_matrix.numpy()
computed_assets["fsc_matrix"] = fsc_matrix
elif cost_label == "l2":
else:
cost_matrix = cost_func(
maps_gt_flat, maps_user_flat
).numpy()
elif cost_label == "corr":
corr_map2map_distance = Correlation(config)
cost_matrix = corr_map2map_distance.get_distance_matrix(
maps_gt_flat, maps_user_flat
).numpy()
elif cost_label == "bioem":
bioem_map2map_distance = BioEM3dDistance(config)
cost_matrix = bioem_map2map_distance.get_distance_matrix(
maps_gt_flat, maps_user_flat
).numpy()

else:
cost_matrix = vmap_distance(
maps_gt_flat,
maps_user_flat,
cost_func,
chunk_size_gt=config["analysis"]["chunk_size_gt"],
chunk_size_submission=config["analysis"]["chunk_size_submission"],
).numpy()

cost_matrix_df = pd.DataFrame(
cost_matrix, columns=None, index=metadata_gt.populations.tolist()
Expand Down
6 changes: 0 additions & 6 deletions src/cryo_challenge/data/_validation/output_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,8 @@ class MapToMapResultsValidator:
config: dict
user_submitted_populations: torch.Tensor
corr: Optional[dict] = None
corrold: Optional[dict] = None
l2: Optional[dict] = None
l2old: Optional[dict] = None
bioem: Optional[dict] = None
bioemold: Optional[dict] = None
fsc: Optional[dict] = None

def __post_init__(self):
Expand Down Expand Up @@ -145,11 +142,8 @@ class DistributionToDistributionResultsValidator:
id: str
fsc: Optional[dict] = None
bioem: Optional[dict] = None
bioemold: Optional[dict] = None
l2: Optional[dict] = None
l2old: Optional[dict] = None
corr: Optional[dict] = None
corrold: Optional[dict] = None

def __post_init__(self):
validate_input_config_disttodist(self.config)
Expand Down
3 changes: 0 additions & 3 deletions tests/config_files/test_config_map_to_map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@ data:
analysis:
metrics:
- l2
- l2old
- corr
- corrold
- bioem
- bioemold
chunk_size_submission: 80
chunk_size_gt: 190
normalize:
Expand Down

0 comments on commit 4438b92

Please sign in to comment.