Skip to content

Commit

Permalink
Revert "Implementation of external Zernike3D distance"
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 authored Dec 20, 2024
1 parent db1ead3 commit c913f60
Show file tree
Hide file tree
Showing 10 changed files with 3 additions and 183 deletions.
50 changes: 0 additions & 50 deletions docs/setup_zernike3d_distance.md

This file was deleted.

82 changes: 0 additions & 82 deletions src/cryo_challenge/_map_to_map/map_to_map_distance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
import subprocess
import math
import torch
from typing import Optional, Sequence
Expand Down Expand Up @@ -57,7 +55,6 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
"""Compute the distance matrix between two sets of maps."""
if self.config["data"]["mask"]["do"]:
maps2 = maps2[:, self.mask]

else:
maps2 = maps2.reshape(len(maps2), -1)

Expand Down Expand Up @@ -90,8 +87,6 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):

else:
maps1 = maps1.reshape(len(maps1), -1)
if self.config["data"]["mask"]["do"]:
maps1 = maps1.reshape(len(maps1), -1)[:, self.mask]
maps2 = maps2.reshape(len(maps2), -1)
distance_matrix = torch.vmap(
lambda maps1: torch.vmap(
Expand Down Expand Up @@ -403,80 +398,3 @@ def res_at_fsc_threshold(fscs, threshold=0.5):
res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix)
self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist}
return units_Angstroms[res_fsc_half]


class Zernike3DDistance(MapToMapDistance):
"""Zernike3D based distance.
Zernike3D distance relies on the estimation of the non-linear transformation needed to align two different maps.
The RMSD of the associated non-linear alignment represented as a deformation field is then used as the distance
between two maps
"""

@override
def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
gpuID = self.config["analysis"]["zernike3d_extra_params"]["gpuID"]
outputPath = self.config["analysis"]["zernike3d_extra_params"]["tmpDir"]
thr = self.config["analysis"]["zernike3d_extra_params"]["thr"]
numProjections = self.config["analysis"]["zernike3d_extra_params"][
"numProjections"
]

# Create output directory
if not os.path.isdir(outputPath):
os.mkdir(outputPath)

# Prepare data to call external
targets_paths = os.path.join(outputPath, "target_maps.npy")
references_path = os.path.join(outputPath, "reference_maps.npy")
if not os.path.isfile(targets_paths):
np.save(targets_paths, maps1)
if not os.path.isfile(references_path):
np.save(references_path, maps2)

# Check conda is in PATH (otherwise abort as external software is not installed)
try:
subprocess.check_call("conda", shell=True, stdout=subprocess.PIPE)
except FileNotFoundError:
raise Exception("Conda not found in PATH... Aborting")

# Check if conda env is installed
env_installed = subprocess.run(
r"conda env list | grep 'flexutils-tensorflow '",
shell=True,
check=False,
stdout=subprocess.PIPE,
).stdout
env_installed = bool(
env_installed.decode("utf-8").replace("\n", "").replace("*", "")
)
if not env_installed:
raise Exception("External software not found... Aborting")

# Find conda executable (needed to activate conda envs in a subprocess)
condabin_path = subprocess.run(
r"which conda | sed 's: ::g'",
shell=True,
check=False,
stdout=subprocess.PIPE,
).stdout
condabin_path = condabin_path.decode("utf-8").replace("\n", "").replace("*", "")

# Call external program
subprocess.check_call(
f'eval "$({condabin_path} shell.bash hook)" &&'
f" conda activate flexutils-tensorflow && "
f"compute_distance_matrix_zernike3deep.py --references_file {references_path} "
f"--targets_file {targets_paths} --out_path {outputPath} --gpu {gpuID} --num_projections {numProjections} "
f"--thr {thr}",
shell=True,
)

# Read distance matrix
dists = np.load(os.path.join(outputPath, "dist_mat.npy")).T
self.stored_computed_assets = {"zernike3d": dists}
return dists

@override
def get_computed_assets(self, maps1, maps2, global_store_of_running_results):
return self.stored_computed_assets # must run get_distance_matrix first
3 changes: 0 additions & 3 deletions src/cryo_challenge/_map_to_map/map_to_map_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
L2DistanceNorm,
BioEM3dDistance,
FSCResDistance,
Zernike3DDistance,
)


Expand All @@ -19,7 +18,6 @@
"l2": L2DistanceNorm,
"bioem": BioEM3dDistance,
"res": FSCResDistance,
"zernike3d": Zernike3DDistance,
}


Expand Down Expand Up @@ -53,7 +51,6 @@ def run(config):
maps_user_flat = submission[submission_volume_key].reshape(
len(submission["volumes"]), -1
)

maps_gt_flat = torch.load(
config["data"]["ground_truth"]["volumes"], mmap=do_low_memory_mode
)
Expand Down
2 changes: 0 additions & 2 deletions src/cryo_challenge/data/_validation/output_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class MapToMapResultsValidator:
bioem: Optional[dict] = None
fsc: Optional[dict] = None
res: Optional[dict] = None
zernike3d: Optional[dict] = None

def __post_init__(self):
validate_input_config_mtm(self.config)
Expand Down Expand Up @@ -152,7 +151,6 @@ class DistributionToDistributionResultsValidator:
res: Optional[dict] = None
l2: Optional[dict] = None
corr: Optional[dict] = None
zernike3d: Optional[dict] = None

def __post_init__(self):
validate_input_config_disttodist(self.config)
Expand Down
2 changes: 1 addition & 1 deletion tests/config_files/test_config_map_to_map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ data:
metadata: tests/data/Ground_truth/test_metadata_10.csv
mask:
do: true
volume: tests/data/Ground_truth/test_mask_bool.mrc
volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc
analysis:
metrics:
- l2
Expand Down
31 changes: 0 additions & 31 deletions tests/config_files/test_config_map_to_map_external.yaml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ data:
metadata: tests/data/Ground_truth/test_metadata_10.csv
mask:
do: true
volume: tests/data/Ground_truth/test_mask_bool.mrc
volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc
analysis:
metrics:
- l2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ data:
metadata: tests/data/Ground_truth/test_metadata_10.csv
mask:
do: false
volume: tests/data/Ground_truth/test_mask_bool.mrc
volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc
analysis:
metrics:
- l2
Expand Down
Binary file removed tests/data/Ground_truth/test_mask_bool.mrc
Binary file not shown.
12 changes: 0 additions & 12 deletions tests/test_map_to_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,6 @@


def test_run_map2map_pipeline():
try:
args = OmegaConf.create(
{"config": "tests/config_files/test_config_map_to_map_external.yaml"}
)
results_dict = run_map2map_pipeline.main(args)
assert "zernike3d" in results_dict.keys()
except Exception as e:
print(e)
print(
"External test failed. Skipping test. Fails when running in CI if external dependencies are not installed."
)

for config_fname, config_fname_low_memory in zip(
[
"tests/config_files/test_config_map_to_map.yaml",
Expand Down

0 comments on commit c913f60

Please sign in to comment.