Skip to content

Commit

Permalink
import, type checking, and docs cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 committed Sep 16, 2024
1 parent df69fd1 commit 09e84e0
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 107 deletions.
6 changes: 2 additions & 4 deletions src/cryojax/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from ._dataset import AbstractDataset as AbstractDataset
from ._generate_relion_datasets import (
generate_starfile as generate_starfile,
write_simulated_image_stack_from_starfile as write_simulated_image_stack_from_starfile, # noqa
)
from ._particle_stack import (
AbstractParticleStack as AbstractParticleStack,
)
from ._relion import (
generate_starfile as generate_starfile,
HelicalRelionDataset as HelicalRelionDataset,
RelionDataset as RelionDataset,
RelionParticleStack as RelionParticleStack,
write_simulated_image_stack_from_starfile as write_simulated_image_stack_from_starfile, # noqa
)
9 changes: 9 additions & 0 deletions src/cryojax/data/_relion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ._starfile_reading import (
HelicalRelionDataset as HelicalRelionDataset,
RelionDataset as RelionDataset,
RelionParticleStack as RelionParticleStack,
)
from ._starfile_writing import (
generate_starfile as generate_starfile,
write_simulated_image_stack_from_starfile as write_simulated_image_stack_from_starfile, # noqa: E501
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,16 @@
import pandas as pd
from jaxtyping import Array, Float, Int

from cryojax.image.operators import FourierGaussian

from ..io import read_and_validate_starfile
from ..simulator import (
from ...image.operators import FourierGaussian
from ...io import read_and_validate_starfile
from ...simulator import (
ContrastTransferFunction,
ContrastTransferTheory,
EulerAnglePose,
InstrumentConfig,
)
from ._dataset import AbstractDataset
from ._particle_stack import AbstractParticleStack
from .._dataset import AbstractDataset
from .._particle_stack import AbstractParticleStack


RELION_REQUIRED_OPTICS_KEYS = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import os
import pathlib
from typing import Any, Callable, Optional
from typing import Any, Callable, cast, Optional

import equinox as eqx

# jax imports
import jax
import jax.tree_util as jtu
import numpy as np
import pandas as pd
import starfile
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
from jaxtyping import Array, Float, PRNGKeyArray, PyTree, Shaped

from ..data._relion import RelionDataset, RelionParticleStack
from ..image.operators import Constant, FourierGaussian
from ..io import write_image_stack_to_mrc
from ..simulator import AbstractPose, ContrastTransferTheory, InstrumentConfig
from ...image.operators import Constant, FourierGaussian
from ...io import write_image_stack_to_mrc
from ...simulator import AbstractPose, ContrastTransferTheory, InstrumentConfig
from ._starfile_reading import RelionDataset, RelionParticleStack


def _get_filename(step, n_char=6):
Expand All @@ -32,25 +30,21 @@ def generate_starfile(
filename: str | pathlib.Path,
mrc_batch_size: Optional[int] = None,
) -> None:
"""
Generate a STAR file from a RelionParticleStack object.
"""Generate a STAR file from a RelionParticleStack object.
This function does not generate particles, it merely populates the starfile.
The starfile is written to disc at the location specified by filename.
Parameters
----------
relion_particle_stack : RelionParticleStack
A RelionParticleStack object.
filename : str
The filename of the STAR file to write.
mrc_batch_size : int, optional
The number of images to write to each MRC file. If None, defaults to n_images.
**Arguments:**
Returns
-------
None
- `relion_particle_stack`:
The `RelionParticleStack` object.
- `filename`:
The filename of the STAR file to write.
- `mrc_batch_size`:
The number of images to write to each MRC file. If `None`, the number of
images in the `RelionParticleStack` is used.
"""

n_images = relion_particle_stack.pose.offset_x_in_angstroms.shape[0]
Expand Down Expand Up @@ -125,7 +119,7 @@ def generate_starfile(
n_batches = n_images // mrc_batch_size
n_remainder = n_images % mrc_batch_size

relative_mrcs_path_prefix = filename.split(".")[0]
relative_mrcs_path_prefix = str(filename).split(".")[0]
image_names = []

for step in range(n_batches):
Expand All @@ -146,72 +140,20 @@ def generate_starfile(

particles_df["rlnImageName"] = image_names
starfile_dict["particles"] = particles_df
starfile.write(starfile_dict, filename)
starfile.write(starfile_dict, pathlib.Path(filename))

return


def _filter_ctf(pytree):
if isinstance(pytree.envelope, FourierGaussian):
output = (
pytree.ctf.defocus_in_angstroms,
pytree.ctf.astigmatism_in_angstroms,
pytree.ctf.astigmatism_angle,
pytree.ctf.phase_shift,
pytree.ctf.spherical_aberration_in_mm,
pytree.ctf.amplitude_contrast_ratio,
pytree.envelope.b_factor,
pytree.envelope.amplitude,
)
else:
output = (
pytree.ctf.defocus_in_angstroms,
pytree.ctf.astigmatism_in_angstroms,
pytree.ctf.astigmatism_angle,
pytree.ctf.phase_shift,
pytree.ctf.spherical_aberration_in_mm,
pytree.ctf.amplitude_contrast_ratio,
)
return output


def _filter_pose(pytree):
output = (
pytree.offset_x_in_angstroms,
pytree.offset_y_in_angstroms,
pytree.view_phi,
pytree.view_theta,
pytree.view_psi,
)
return output


def _is_vmappable(pytree):
if isinstance(pytree, AbstractPose):
false_pytree = jtu.tree_map(lambda _: False, pytree)
return eqx.tree_at(_filter_pose, false_pytree, replace_fn=lambda _: True)
elif isinstance(pytree, ContrastTransferTheory):
false_pytree = jtu.tree_map(lambda _: False, pytree)
return eqx.tree_at(_filter_ctf, false_pytree, replace_fn=lambda _: True)
else:
return jtu.tree_map(lambda _: False, pytree)


def _replace_leaf(leaf, relion_particle_stack):
if isinstance(leaf, ContrastTransferTheory):
return relion_particle_stack.transfer_theory
elif isinstance(leaf, InstrumentConfig):
return relion_particle_stack.instrument_config
elif isinstance(leaf, AbstractPose):
return relion_particle_stack.pose
else:
return leaf


def write_simulated_image_stack_from_starfile(
dataset: RelionDataset,
compute_image_stack: Callable[[PyTree, Any], Float[Array, "batch_dim y_dim x_dim"]]
| Callable[[PRNGKeyArray, PyTree, Any], Float[Array, "batch_dim y_dim x_dim"]],
compute_image_stack: (
Callable[[PyTree, Any], Float[Array, "batch_dim y_dim x_dim"]]
| Callable[
[Shaped[PRNGKeyArray, " batch_dim"], PyTree, Any],
Float[Array, "batch_dim y_dim x_dim"],
]
),
pytree: PyTree,
seed: Optional[int] = None, # seed for the noise
overwrite: bool = False,
Expand All @@ -228,21 +170,6 @@ def write_simulated_image_stack_from_starfile(
the two pytrees resulting from the `eqx.partition` function,
using the `vmap_filter_spec` as the filter.
**Arguments:**
- `dataset` : `RelionDataset` : The dataset initialized from the
STAR file containing the particle stack parameters.
- `compute_image_stack` : `Callable` : A function that computes
the image stack from the parameters contained in the STAR
file.
- `pytree` : `PyTree` : The pytree that is given to `compute_image_stack`
to compute the image stack (before filtering for vmapping).
- `seed` : `Optional[int]` : The seed for the random number generator.
- `overwrite` : `bool` : Whether to overwrite the MRC files if they
already exist.
- `compression` : `Optional[str]` : The compression to use when writing
the MRC files.
```python
# Example 1: Using the function with a `compute_image_stack`
Expand Down Expand Up @@ -316,6 +243,22 @@ def compute_noisy_image_stack(
)
```
**Arguments:**
- `dataset`:
The `RelionDataset` STAR file reader.
- `compute_image_stack`:
A callable that computes the image stack from the parameters contained
in the STAR file.
- `pytree` :
The pytree that is given to `compute_image_stack`
to compute the image stack (before filtering for vmapping).
- `seed`:
The seed for the random number generator.
- `overwrite`:
Whether to overwrite the MRC files if they already exist.
- `compression`:
The compression to use when writing the MRC files.
"""
# Create the directory for the MRC files if it doesn't exist
if not os.path.exists(dataset.path_to_relion_project):
Expand All @@ -324,6 +267,8 @@ def compute_noisy_image_stack(
if seed is not None:
key = jax.random.PRNGKey(seed=seed)
key, subkey = jax.random.split(key)
else:
subkey = cast(PRNGKeyArray, None)

# Function to check if given leaf is an object contained in a
# `RelionParticleStack`
Expand Down Expand Up @@ -380,3 +325,60 @@ def compute_noisy_image_stack(
)

return


def _filter_ctf(pytree):
if isinstance(pytree.envelope, FourierGaussian):
output = (
pytree.ctf.defocus_in_angstroms,
pytree.ctf.astigmatism_in_angstroms,
pytree.ctf.astigmatism_angle,
pytree.ctf.phase_shift,
pytree.ctf.spherical_aberration_in_mm,
pytree.ctf.amplitude_contrast_ratio,
pytree.envelope.b_factor,
pytree.envelope.amplitude,
)
else:
output = (
pytree.ctf.defocus_in_angstroms,
pytree.ctf.astigmatism_in_angstroms,
pytree.ctf.astigmatism_angle,
pytree.ctf.phase_shift,
pytree.ctf.spherical_aberration_in_mm,
pytree.ctf.amplitude_contrast_ratio,
)
return output


def _filter_pose(pytree):
output = (
pytree.offset_x_in_angstroms,
pytree.offset_y_in_angstroms,
pytree.view_phi,
pytree.view_theta,
pytree.view_psi,
)
return output


def _is_vmappable(pytree):
if isinstance(pytree, AbstractPose):
false_pytree = jtu.tree_map(lambda _: False, pytree)
return eqx.tree_at(_filter_pose, false_pytree, replace_fn=lambda _: True)
elif isinstance(pytree, ContrastTransferTheory):
false_pytree = jtu.tree_map(lambda _: False, pytree)
return eqx.tree_at(_filter_ctf, false_pytree, replace_fn=lambda _: True)
else:
return jtu.tree_map(lambda _: False, pytree)


def _replace_leaf(leaf, relion_particle_stack):
if isinstance(leaf, ContrastTransferTheory):
return relion_particle_stack.transfer_theory
elif isinstance(leaf, InstrumentConfig):
return relion_particle_stack.instrument_config
elif isinstance(leaf, AbstractPose):
return relion_particle_stack.pose
else:
return leaf

0 comments on commit 09e84e0

Please sign in to comment.