Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/mjo21/cryojax into main
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 committed Aug 14, 2024
2 parents 5f63f6b + 011d54e commit feff0a0
Showing 1 changed file with 153 additions and 103 deletions.
256 changes: 153 additions & 103 deletions src/cryojax/data/_relion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import dataclasses
import pathlib
from typing import Any, Callable, final
from typing import Any, Callable, final, Optional

import equinox as eqx
import jax
import jax.numpy as jnp
import mrcfile
import numpy as np
Expand Down Expand Up @@ -38,20 +39,38 @@ class RelionParticleStack(AbstractParticleStack):
[RELION](https://relion.readthedocs.io/en/release-5.0/).
"""

image_stack: Float[Array, "... y_dim x_dim"]
instrument_config: InstrumentConfig
pose: EulerAnglePose
ctf: ContrastTransferFunction
image_stack: Optional[Float[Array, "... y_dim x_dim"]]

def __init__(
self,
image_stack: Float[Array, "... y_dim x_dim"],
instrument_config: InstrumentConfig,
pose: EulerAnglePose,
ctf: ContrastTransferFunction,
image_stack: Optional[Float[Array, "... y_dim x_dim"]] = None,
):
# Set image stack and config as is
self.image_stack = jnp.asarray(image_stack)
"""**Arguments:**
- `instrument_config`:
The instrument configuration. Any subset of pytree leaves may
have a batch dimension.
- `pose`:
The pose, represented by euler angles. Any subset of pytree leaves may
have a batch dimension. Upon instantiation, `pose.offset_z_in_angstroms`
is set to zero.
- `ctf`:
The contrast transfer function. Any subset of pytree leaves may
have a batch dimension. Upon instantiation,
`ctf.defocus_in_angstroms` is set to
`ctf.defocus_in_angstroms + pose.offset_z_in_angstroms`.
- `image_stack`:
The stack of images. The shape of this array
is a leading batch dimension followed by the shape
of an image in the stack.
"""
# Set instrument config as is
self.instrument_config = instrument_config
# Set CTF using the defocus offset in the EulerAnglePose
self.ctf = eqx.tree_at(
Expand All @@ -60,26 +79,9 @@ def __init__(
ctf.defocus_in_angstroms + pose.offset_z_in_angstroms,
)
# Set defocus offset to zero
self.pose = eqx.tree_at(
lambda pose: pose.offset_z_in_angstroms, pose, jnp.asarray(0.0)
)


RelionParticleStack.__init__.__doc__ = """**Arguments:**
- `image_stack`: The stack of images. The shape of this array
is a leading batch dimension followed by the shape
of an image in the stack.
- `instrument_config`: The instrument configuration. Any subset of pytree leaves may
have a batch dimension.
- `pose`: The pose, represented by euler angles. Any subset of pytree leaves may
have a batch dimension. Upon instantiation, `pose.offset_z_in_angstroms`
is set to zero.
- `ctf`: The contrast transfer function. Any subset of pytree leaves may
have a batch dimension. Upon instantiation,
`ctf.defocus_in_angstroms` is set to
`ctf.defocus_in_angstroms + pose.offset_z_in_angstroms`.
""" # noqa: E501
self.pose = eqx.tree_at(lambda pose: pose.offset_z_in_angstroms, pose, 0.0)
# Optionally set image stack
self.image_stack = None if image_stack is None else jnp.asarray(image_stack)


def _default_make_instrument_config_fn(
Expand All @@ -99,16 +101,19 @@ class RelionDataset(AbstractDataset):

path_to_relion_project: pathlib.Path
data_blocks: dict[str, pd.DataFrame]

make_instrument_config_fn: Callable[
[tuple[int, int], Float[Array, "..."], Float[Array, "..."]], InstrumentConfig
]
get_image_stack: bool
get_cpu_arrays: bool

@final
def __init__(
self,
path_to_starfile: str | pathlib.Path,
path_to_relion_project: str | pathlib.Path,
get_image_stack: bool = True,
get_cpu_arrays: bool = False,
make_instrument_config_fn: Callable[
[tuple[int, int], Float[Array, "..."], Float[Array, "..."]],
InstrumentConfig,
Expand All @@ -118,6 +123,16 @@ def __init__(
- `path_to_starfile`: The path to the Relion STAR file.
- `path_to_relion_project`: The path to the Relion project directory.
- `get_image_stack`:
If `True`, read the stack of images from the STAR file. Otherwise,
just read parameters.
- `get_cpu_arrays`:
If `True`, force that JAX arrays are loaded on the CPU. If `False`,
load on the default device.
- `make_instrument_config_fn`:
A function used for `InstrumentConfig` initialization that returns
an `InstrumentConfig`. This is used to customize the metadata of the
read object.
"""
data_blocks = read_and_validate_starfile(path_to_starfile)
_validate_relion_data_blocks(data_blocks)
Expand All @@ -126,6 +141,8 @@ def __init__(
self, "path_to_relion_project", pathlib.Path(path_to_relion_project)
)
object.__setattr__(self, "make_instrument_config_fn", make_instrument_config_fn)
object.__setattr__(self, "get_image_stack", get_image_stack)
object.__setattr__(self, "get_cpu_arrays", get_cpu_arrays)

@final
def __getitem__(
Expand All @@ -138,15 +155,15 @@ def __getitem__(
f"The number of rows in the dataset is {n_rows}, but you tried to "
f"access the index {idx}."
)
# pandas has bad error messages for its indexing
if isinstance(index, (int, Int[np.ndarray, ""])):
# ... pandas has bad error messages for its indexing
if isinstance(index, (int, Int[np.ndarray, ""])): # type: ignore
if index > n_rows - 1:
raise IndexError(index_error_msg(index))
elif isinstance(index, slice):
if index.start is not None and index.start > n_rows - 1:
raise IndexError(index_error_msg(index.start))
elif isinstance(index, np.ndarray):
pass # catch exceptions later
pass # ... catch exceptions later
else:
raise IndexError(
f"Indexing with the type {type(index)} is not supported by "
Expand All @@ -164,81 +181,52 @@ def __getitem__(
"from the `starfile.read` output."
)
optics_group = self.data_blocks["optics"].iloc[0]
# Load particle image stack rlnImageName
image_stack_index_and_name_series_or_str = particle_blocks["rlnImageName"]
if isinstance(image_stack_index_and_name_series_or_str, str):
# In this block, the user most likely used standard indexing, like
# `dataset = RelionDataset(...); particle_stack = dataset[1]`
image_stack_index_and_name_str = image_stack_index_and_name_series_or_str
# ... split the whole string into its image index and filename
relion_particle_index, image_stack_filename = (
image_stack_index_and_name_str.split("@")
)
# ... create full path to the image stack
path_to_image_stack = pathlib.Path(
self.path_to_relion_project, image_stack_filename
)
# ... relion convention starts indexing at 1, not 0
particle_index = np.asarray(relion_particle_index, dtype=int) - 1
elif isinstance(image_stack_index_and_name_series_or_str, pd.Series):
# In this block, the user most likely used fancy indexing, like
# `dataset = RelionDataset(...); particle_stack = dataset[1:10]`
image_stack_index_and_name_series = image_stack_index_and_name_series_or_str
# ... split the pandas.Series into a pandas.DataFrame with two columns:
# one for the image index and another for the filename
image_stack_index_and_name_dataframe = (
image_stack_index_and_name_series.str.split("@", expand=True)
)
# ... get a pandas.Series for each the index and the filename
relion_particle_index, image_stack_filename = [
image_stack_index_and_name_dataframe[column]
for column in image_stack_index_and_name_dataframe.columns
]
# ... multiple filenames in the same STAR file is not supported with
# fancy indexing
if image_stack_filename.nunique() != 1:
raise ValueError(
"Found multiple image stack filenames when reading "
"STAR file rows. This is most likely because you tried to "
"use fancy indexing with multiple image stack filenames "
"in the same STAR file. If a STAR file refers to multiple image "
"stack filenames, fancy indexing is not supported. For example, "
"this will raise an error: `dataset = RelionDataset(...); "
"particle_stack = dataset[1:10]`."
)
# ... create full path to the image stack
path_to_image_stack = pathlib.Path(
self.path_to_relion_project,
np.asarray(image_stack_filename, dtype=object)[0],
)
# ... relion convention starts indexing at 1, not 0
particle_index = np.asarray(relion_particle_index.astype(int), dtype=int) - 1
else:
raise IOError(
"Could not read `rlnImageName` in STAR file for `RelionDataset` "
f"index equal to {index}."
)
with mrcfile.mmap(path_to_image_stack, mode="r", permissive=True) as mrc:
image_stack = np.asarray(mrc.data[particle_index]) # type: ignore
# Read metadata into a RelionParticleStack
# ... particle data
defocus_in_angstroms = jnp.asarray(particle_blocks["rlnDefocusU"])
# Load the image stack and STAR file parameters. First, get the device
# on which to load arrays
device = jax.devices("cpu")[0] if self.get_cpu_arrays else None
# ... load stack of images, unless otherwise specified
image_stack = (
self._get_image_stack(index, particle_blocks, device)
if self.get_image_stack
else None
)
# ... load image parameters into cryoJAX objects
instrument_config, ctf, pose = self._get_starfile_params(
particle_blocks,
optics_group,
device,
)

return RelionParticleStack(instrument_config, pose, ctf, image_stack)

@final
def __len__(self) -> int:
return len(self.data_blocks["particles"])

def _get_starfile_params(
self, particle_blocks, optics_group, device
) -> tuple[InstrumentConfig, ContrastTransferFunction, EulerAnglePose]:
defocus_in_angstroms = jnp.asarray(particle_blocks["rlnDefocusU"], device=device)
astigmatism_in_angstroms = jnp.asarray(
particle_blocks["rlnDefocusV"]
) - jnp.asarray(particle_blocks["rlnDefocusU"])
astigmatism_angle = jnp.asarray(particle_blocks["rlnDefocusAngle"])
particle_blocks["rlnDefocusV"], device=device
) - jnp.asarray(particle_blocks["rlnDefocusU"], device=device)
astigmatism_angle = jnp.asarray(particle_blocks["rlnDefocusAngle"], device=device)
phase_shift = jnp.asarray(particle_blocks["rlnPhaseShift"])
# ... optics group data
image_size = jnp.asarray(optics_group["rlnImageSize"])
pixel_size = jnp.asarray(optics_group["rlnImagePixelSize"])
voltage_in_kilovolts = float(optics_group["rlnVoltage"])
spherical_aberration_in_mm = jnp.asarray(optics_group["rlnSphericalAberration"])
amplitude_contrast_ratio = jnp.asarray(optics_group["rlnAmplitudeContrast"])
image_size = jnp.asarray(optics_group["rlnImageSize"], device=device)
pixel_size = jnp.asarray(optics_group["rlnImagePixelSize"], device=device)
voltage_in_kilovolts = float(optics_group["rlnVoltage"]) # type: ignore
spherical_aberration_in_mm = jnp.asarray(
optics_group["rlnSphericalAberration"], device=device
)
amplitude_contrast_ratio = jnp.asarray(
optics_group["rlnAmplitudeContrast"], device=device
)
# ... create cryojax objects
instrument_config = self.make_instrument_config_fn(
(int(image_size), int(image_size)),
pixel_size,
jnp.asarray(voltage_in_kilovolts),
jnp.asarray(voltage_in_kilovolts, device=device),
)
ctf = ContrastTransferFunction(
defocus_in_angstroms=defocus_in_angstroms,
Expand Down Expand Up @@ -312,14 +300,76 @@ def __getitem__(
pose = eqx.tree_at(
lambda p: tuple([getattr(p, name) for name in pose_parameter_names]),
pose,
tuple([jnp.asarray(value) for value in pose_parameter_values]),
tuple([jnp.asarray(value, device=device) for value in pose_parameter_values]),
)

return RelionParticleStack(jnp.asarray(image_stack), instrument_config, pose, ctf)
return instrument_config, ctf, pose

@final
def __len__(self) -> int:
return len(self.data_blocks["particles"])
def _get_image_stack(
self,
index: int | slice | Int[np.ndarray, ""] | Int[np.ndarray, " N"],
particle_blocks,
device,
) -> Float[Array, "... y_dim x_dim"]:
# Load particle image stack rlnImageName
image_stack_index_and_name_series_or_str = particle_blocks["rlnImageName"]
if isinstance(image_stack_index_and_name_series_or_str, str):
# In this block, the user most likely used standard indexing, like
# `dataset = RelionDataset(...); particle_stack = dataset[1]`
image_stack_index_and_name_str = image_stack_index_and_name_series_or_str
# ... split the whole string into its image index and filename
relion_particle_index, image_stack_filename = (
image_stack_index_and_name_str.split("@")
)
# ... create full path to the image stack
path_to_image_stack = pathlib.Path(
self.path_to_relion_project, image_stack_filename
)
# ... relion convention starts indexing at 1, not 0
particle_index = np.asarray(relion_particle_index, dtype=int) - 1
elif isinstance(image_stack_index_and_name_series_or_str, pd.Series):
# In this block, the user most likely used fancy indexing, like
# `dataset = RelionDataset(...); particle_stack = dataset[1:10]`
image_stack_index_and_name_series = image_stack_index_and_name_series_or_str
# ... split the pandas.Series into a pandas.DataFrame with two columns:
# one for the image index and another for the filename
image_stack_index_and_name_dataframe = (
image_stack_index_and_name_series.str.split("@", expand=True)
)
# ... get a pandas.Series for each the index and the filename
relion_particle_index, image_stack_filename = [
image_stack_index_and_name_dataframe[column]
for column in image_stack_index_and_name_dataframe.columns
]
# ... multiple filenames in the same STAR file is not supported with
# fancy indexing
if image_stack_filename.nunique() != 1:
raise ValueError(
"Found multiple image stack filenames when reading "
"STAR file rows. This is most likely because you tried to "
"use fancy indexing with multiple image stack filenames "
"in the same STAR file. If a STAR file refers to multiple image "
"stack filenames, fancy indexing is not supported. For example, "
"this will raise an error: `dataset = RelionDataset(...); "
"particle_stack = dataset[1:10]`."
)
# ... create full path to the image stack
path_to_image_stack = pathlib.Path(
self.path_to_relion_project,
np.asarray(image_stack_filename, dtype=object)[0],
)
# ... relion convention starts indexing at 1, not 0
particle_index = np.asarray(relion_particle_index.astype(int), dtype=int) - 1
else:
raise IOError(
"Could not read `rlnImageName` in STAR file for `RelionDataset` "
f"index equal to {index}."
)

with mrcfile.mmap(path_to_image_stack, mode="r", permissive=True) as mrc:
image_stack = np.asarray(mrc.data[particle_index]) # type: ignore

return jnp.asarray(image_stack, device=device)


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -406,7 +456,7 @@ def get_data_blocks_at_filament_index(
def __getitem__(
self, filament_index: int | Int[np.ndarray, ""]
) -> RelionParticleStack:
if not isinstance(filament_index, (int, Int[np.ndarray, ""])):
if not isinstance(filament_index, (int, Int[np.ndarray, ""])): # type: ignore
raise IndexError(
"When indexing a `HelicalRelionDataset`, only "
f"python or numpy-like integer indices are supported, such as "
Expand Down

0 comments on commit feff0a0

Please sign in to comment.