Skip to content

Commit

Permalink
Merge branch 'dev' into 96-alignment-invariant-map-to-map-distance
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Dec 19, 2024
2 parents 59a1ba8 + db1ead3 commit 2f5b9ad
Show file tree
Hide file tree
Showing 29 changed files with 1,921 additions and 669 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .
pip install pytest omegaconf
pip install ".[dev]"
- name: Test with pytest
run: |
Expand Down
35 changes: 21 additions & 14 deletions config_files/config_svd.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
path_to_volumes: /path/to/volumes
box_size_ds: 32
submission_list: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
experiment_mode: "all_vs_ref" # options are "all_vs_all", "all_vs_ref"
# optional unless experiment_mode is "all_vs_ref"
path_to_reference: /path/to/reference/volumes.pt
dtype: "float32" # options are "float32", "float64"
output_options:
# path will be created if it does not exist
output_path: /path/to/output
# whether or not to save the processed volumes (downsampled, normalized, etc.)
save_volumes: True
# whether or not to save the SVD matrices (U, S, V)
save_svd_matrices: True
path_to_submissions: path/to/preprocessed/submissions/ # where all the submission_i.pt files are
#excluded_submissions: # you can exclude some submissions by filename, default = []
# - "submission_0.pt"
# - "submission_1.pt"
voxel_size: 1.0 # voxel size of the input maps (will probably be removed soon)

dtype: float32 # optional, default = float32
svd_max_rank: 5 # optional, default = full rank svd
normalize_params: # optional, if not given there will be no normalization
mask_path: path/to/mask.mrc # default = None, no masking applied
bfactor: 170 # default = None, no bfactor applied
box_size_ds: 16 # default = None, no downsampling applied

gt_params: # optional, if provided there will be extra results
gt_vols_file: path/to/gt_volumes.npy # volumes must be in .npy format (memory stuff)
skip_vols: 1 # default = 1, no volumes skipped. Equivalent to volumes[::skip_vols]

output_params:
output_file: path/to/output_file.pt # where the results will be saved
save_svd_data: True # optional, default = False
generate_plots: True # optional, default = False
50 changes: 50 additions & 0 deletions docs/setup_zernike3d_distance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
<h1 align='center'>How to setup Zernike3D distance?</h1>

<p align="center">

<img alt="Supported Python versions" src="https://img.shields.io/badge/Supported_Python_Versions-3.8_%7C_3.9_%7C_3.10_%7C_3.11_%7C_3.12-blue">
<img alt="GitHub Downloads (all assets, all releases)" src="https://img.shields.io/github/downloads/I2PC/Flexutils-Toolkit/total">
<img alt="GitHub License" src="https://img.shields.io/github/license/I2PC/Flexutils-Toolkit">

</p>

<p align="center">

<img alt="Flexutils" src="https://github.com/scipion-em/scipion-em-flexutils/raw/devel/flexutils/icon.png" width="200" height="200">

</p>



Zernike3D distance relies on the external software **[Flexutils](https://github.com/I2PC/Flexutils-Toolkit)**. The following document includes the installation guide to setup this software in your machine, as well as some guidelines on the parameters and characteristics of the Zernike3D distance.

# Flexutils installation
**Flexutils** can be installed in your system with the following commands:

```bash
git clone https://github.com/I2PC/Flexutils-Toolkit.git
cd Flexutils-Toolkit
bash install.sh
```

Any errors raised during the installation of the software or the computation of the Zernike3D distance can be reported through Flexutils GitHub issue [webpage](https://github.com/I2PC/Flexutils-Toolkit/issues).

# Defining the config file parameters
Zernike3D distance relies on the approximation of a deformation field between two volumes to measure their similarity metric. A detailed explanation on the theory behind the computation of these deformation fields is provided in the following publications: [Zernike3D-IUCRJ](https://journals.iucr.org/m/issues/2021/06/00/eh5012/) and [Zernike3D-NatComm](https://www.nature.com/articles/s41467-023-35791-y).

The software follows a neural network approximation, so the usage of a GPU is strongly recommended.

The Zernike3D distance requires a set of additional execution parameters that need to be supplied through the `config_map_to_map.yaml` file passed to the distance compution step. These additional parameters are presented below:

- **gpuID**: An integer larger than 0 determining the GPU to be used to train the Zernike3Deep neural network.
- **tmpDir**: A path to a folder needed to store the intermediate files generated by the software. This folder is **NOT** emptied once the execution finishes.
- **thr**: An integer larger than 0 determining the number of processes to use during the execution of the software.

```yaml
metrics:
- zernike3d
zernike3d_extra_params:
gpuID: 0
tmpDir: where/to/save/intermediate/files/folder
thr: 20
```
110 changes: 110 additions & 0 deletions figure_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
Some random code that I have found to be useful for plotting figures.
This should become part of the main repo at some point, I will leave it out for now.
- David
"""

from natsort import natsorted

# Here is how I generate the general dictionary parameter for plots
COLORS = {
"Coffee": "#97b4ff",
"Salted Caramel": "#97b4ff",
"Neapolitan": "#648fff",
"Peanut Butter": "#1858ff",
"Cherry": "#b3a4f7",
"Pina Colada": "#8c75f2",
"Chocolate": "#785ef0",
"Cookie Dough": "#512fec",
"Chocolate Chip": "#3d18e9",
"Vanilla": "#e35299",
"Mango": "#dc267f",
"Black Raspberry": "#ff8032",
"Rocky Road": "#fe6100",
"Ground Truth": "#ffb000",
"Mint Chocolate Chip": "#ffb000",
"Bubble Gum": "#ffb000",
}

PLOT_SETUP = {
"Salted Caramel": {"category": "1", "marker": "o"},
"Neapolitan": {"category": "1", "marker": "v"},
"Peanut Butter": {"category": "1", "marker": "^"},
"Coffee": {"category": "1", "marker": "<"},
"Cherry": {"category": "2", "marker": "o"},
"Pina Colada": {"category": "2", "marker": "v"},
"Cookie Dough": {"category": "2", "marker": "^"},
"Chocolate Chip": {"category": "2", "marker": "<"},
"Chocolate": {"category": "2", "marker": ">"},
"Vanilla": {"category": "3", "marker": "o"},
"Mango": {"category": "3", "marker": "v"},
"Rocky Road": {"category": "4", "marker": "o"},
"Black Raspberry": {"category": "4", "marker": "v"},
"Ground Truth": {"category": "5", "marker": "o"},
"Bubble Gum": {"category": "5", "marker": "v"},
"Mint Chocolate Chip": {"category": "5", "marker": "^"},
}

for key in list(PLOT_SETUP.keys()):
# PLOT_SETUP[key]["color"] = COLORS[PLOT_SETUP[key]["category"]]
PLOT_SETUP[key]["color"] = COLORS[key]


# These two functions are useful when setting the order of how to plot figures
def compare_strings(fixed_string, other_string):
return other_string.startswith(fixed_string)


def sort_labels_category(labels, plot_setup):
labels_sorted = []
for i in range(5): # there are 5 categories
for label in labels:
if plot_setup[label]["category"] == str(i + 1):
labels_sorted.append(label)

return labels_sorted


labels = ... # get labels from somwhere (pipeline results for example)

# This is the particular plot_setup for your data
plot_setup = {}
for i, label in enumerate(labels):
for (
possible_label
) in PLOT_SETUP.keys(): # generalized for labels like FLAVOR 1, FLAVOR 2, etc.
# print(label, possible_label)
if compare_strings(possible_label, label):
plot_setup[label] = PLOT_SETUP[possible_label]

for label in labels:
if label not in plot_setup.keys():
raise ValueError(f"Label {label} not found in PLOT_SETUP")

labels = sort_labels_category(natsorted(labels), plot_setup)


# Then I do something like this, which let's me configure how the
# labels will be displayed in the plot
labels_for_plot = {
"Neapolitan": "Neapolitan R1",
"Neapolitan 2": "Neapolitan R2",
"Peanut Butter": "Peanut Butter R1",
"Peanut Butter 2": "Peanut Butter R2",
"Salted Caramel": "Salted Caramel R1",
"Salted Caramel 2": "Salted Caramel R2 1",
"Salted Caramel 3": "Salted Caramel R2 2",
"Chocolate": "Chocolate R1",
"Chocolate 2": "Chocolate R2",
"Chocolate Chip": "Chocolate Chip R1",
"Cookie Dough": "Cookie Dough R1",
"Cookie Dough 2": "Cookie Dough R2",
"Pina Colada 1": "Piña Colada R2",
"Mango": "Mango R1",
"Vanilla": "Vanilla R1",
"Vanilla 2": "Vanilla R2",
"Black Raspberry": "Black Raspberry R1",
"Black Raspberry 2": "Black Raspberry R2",
"Rocky Road": "Rocky Road R1",
}
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ dependencies = [
"osfclient",
"seaborn",
"ipyfilechooser",
"omegaconf"
"omegaconf",
"pydantic",
"ecos"
]

[project.optional-dependencies]
Expand Down
24 changes: 15 additions & 9 deletions src/cryo_challenge/_commands/run_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import os
import yaml

from .._svd.svd_pipeline import run_all_vs_all_pipeline, run_all_vs_ref_pipeline
from ..data._validation.config_validators import validate_config_svd
from .._svd.svd_pipeline import run_svd_noref, run_svd_with_ref
from ..data._validation.config_validators import SVDConfig


def add_args(parser):
Expand Down Expand Up @@ -35,15 +35,21 @@ def main(args):
with open(args.config, "r") as file:
config = yaml.safe_load(file)

validate_config_svd(config)
warnexists(config["output_options"]["output_path"])
mkbasedir(config["output_options"]["output_path"])
config = SVDConfig(**config).model_dump()

if config["experiment_mode"] == "all_vs_all":
run_all_vs_all_pipeline(config)
warnexists(config["output_params"]["output_file"])
mkbasedir(os.path.dirname(config["output_params"]["output_file"]))

elif config["experiment_mode"] == "all_vs_ref":
run_all_vs_ref_pipeline(config)
output_path = os.path.dirname(config["output_params"]["output_file"])

with open(os.path.join(output_path, "config.yaml"), "w") as file:
yaml.dump(config, file)

if config["gt_params"] is None:
run_svd_noref(config)

else:
run_svd_with_ref(config)

return

Expand Down
82 changes: 82 additions & 0 deletions src/cryo_challenge/_map_to_map/map_to_map_distance.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import subprocess
import math
import torch
from typing import Optional, Sequence
Expand Down Expand Up @@ -55,6 +57,7 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
"""Compute the distance matrix between two sets of maps."""
if self.config["data"]["mask"]["do"]:
maps2 = maps2[:, self.mask]

else:
maps2 = maps2.reshape(len(maps2), -1)

Expand Down Expand Up @@ -87,6 +90,8 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):

else:
maps1 = maps1.reshape(len(maps1), -1)
if self.config["data"]["mask"]["do"]:
maps1 = maps1.reshape(len(maps1), -1)[:, self.mask]
maps2 = maps2.reshape(len(maps2), -1)
distance_matrix = torch.vmap(
lambda maps1: torch.vmap(
Expand Down Expand Up @@ -398,3 +403,80 @@ def res_at_fsc_threshold(fscs, threshold=0.5):
res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix)
self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist}
return units_Angstroms[res_fsc_half]


class Zernike3DDistance(MapToMapDistance):
"""Zernike3D based distance.
Zernike3D distance relies on the estimation of the non-linear transformation needed to align two different maps.
The RMSD of the associated non-linear alignment represented as a deformation field is then used as the distance
between two maps
"""

@override
def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
gpuID = self.config["analysis"]["zernike3d_extra_params"]["gpuID"]
outputPath = self.config["analysis"]["zernike3d_extra_params"]["tmpDir"]
thr = self.config["analysis"]["zernike3d_extra_params"]["thr"]
numProjections = self.config["analysis"]["zernike3d_extra_params"][
"numProjections"
]

# Create output directory
if not os.path.isdir(outputPath):
os.mkdir(outputPath)

# Prepare data to call external
targets_paths = os.path.join(outputPath, "target_maps.npy")
references_path = os.path.join(outputPath, "reference_maps.npy")
if not os.path.isfile(targets_paths):
np.save(targets_paths, maps1)
if not os.path.isfile(references_path):
np.save(references_path, maps2)

# Check conda is in PATH (otherwise abort as external software is not installed)
try:
subprocess.check_call("conda", shell=True, stdout=subprocess.PIPE)
except FileNotFoundError:
raise Exception("Conda not found in PATH... Aborting")

# Check if conda env is installed
env_installed = subprocess.run(
r"conda env list | grep 'flexutils-tensorflow '",
shell=True,
check=False,
stdout=subprocess.PIPE,
).stdout
env_installed = bool(
env_installed.decode("utf-8").replace("\n", "").replace("*", "")
)
if not env_installed:
raise Exception("External software not found... Aborting")

# Find conda executable (needed to activate conda envs in a subprocess)
condabin_path = subprocess.run(
r"which conda | sed 's: ::g'",
shell=True,
check=False,
stdout=subprocess.PIPE,
).stdout
condabin_path = condabin_path.decode("utf-8").replace("\n", "").replace("*", "")

# Call external program
subprocess.check_call(
f'eval "$({condabin_path} shell.bash hook)" &&'
f" conda activate flexutils-tensorflow && "
f"compute_distance_matrix_zernike3deep.py --references_file {references_path} "
f"--targets_file {targets_paths} --out_path {outputPath} --gpu {gpuID} --num_projections {numProjections} "
f"--thr {thr}",
shell=True,
)

# Read distance matrix
dists = np.load(os.path.join(outputPath, "dist_mat.npy")).T
self.stored_computed_assets = {"zernike3d": dists}
return dists

@override
def get_computed_assets(self, maps1, maps2, global_store_of_running_results):
return self.stored_computed_assets # must run get_distance_matrix first
3 changes: 3 additions & 0 deletions src/cryo_challenge/_map_to_map/map_to_map_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
L2DistanceNorm,
BioEM3dDistance,
FSCResDistance,
Zernike3DDistance,
)


Expand All @@ -18,6 +19,7 @@
"l2": L2DistanceNorm,
"bioem": BioEM3dDistance,
"res": FSCResDistance,
"zernike3d": Zernike3DDistance,
}


Expand Down Expand Up @@ -51,6 +53,7 @@ def run(config):
maps_user_flat = submission[submission_volume_key].reshape(
len(submission["volumes"]), -1
)

maps_gt_flat = torch.load(
config["data"]["ground_truth"]["volumes"], mmap=do_low_memory_mode
)
Expand Down
Loading

0 comments on commit 2f5b9ad

Please sign in to comment.