Skip to content

Commit

Permalink
Merge pull request #257 from mjo22/cpu-option-for-relion-dataset
Browse files Browse the repository at this point in the history
Optionally read arrays on the CPU in `RelionDataset`
  • Loading branch information
mjo22 authored Aug 9, 2024
2 parents f76e955 + 62bb5d4 commit 6c13315
Showing 1 changed file with 39 additions and 25 deletions.
64 changes: 39 additions & 25 deletions src/cryojax/data/_relion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, final, Optional

import equinox as eqx
import jax
import jax.numpy as jnp
import mrcfile
import numpy as np
Expand Down Expand Up @@ -69,11 +70,7 @@ def __init__(
is a leading batch dimension followed by the shape
of an image in the stack.
"""
# Set image stack and config as is
if image_stack is not None:
self.image_stack = jnp.asarray(image_stack)
else:
self.image_stack = None
# Set instrument config as is
self.instrument_config = instrument_config
# Set CTF using the defocus offset in the EulerAnglePose
self.ctf = eqx.tree_at(
Expand All @@ -83,6 +80,8 @@ def __init__(
)
# Set defocus offset to zero
self.pose = eqx.tree_at(lambda pose: pose.offset_z_in_angstroms, pose, 0.0)
# Optionally set image stack
self.image_stack = None if image_stack is None else jnp.asarray(image_stack)


def _default_make_instrument_config_fn(
Expand All @@ -106,13 +105,15 @@ class RelionDataset(AbstractDataset):
[tuple[int, int], Float[Array, "..."], Float[Array, "..."]], InstrumentConfig
]
get_image_stack: bool
get_cpu_arrays: bool

@final
def __init__(
self,
path_to_starfile: str | pathlib.Path,
path_to_relion_project: str | pathlib.Path,
get_image_stack: bool = True,
get_cpu_arrays: bool = False,
make_instrument_config_fn: Callable[
[tuple[int, int], Float[Array, "..."], Float[Array, "..."]],
InstrumentConfig,
Expand All @@ -125,6 +126,9 @@ def __init__(
- `get_image_stack`:
If `True`, read the stack of images from the STAR file. Otherwise,
just read parameters.
- `get_cpu_arrays`:
If `True`, force that JAX arrays are loaded on the CPU. If `False`,
load on the default device.
- `make_instrument_config_fn`:
A function used for `InstrumentConfig` initialization that returns
an `InstrumentConfig`. This is used to customize the metadata of the
Expand All @@ -138,6 +142,7 @@ def __init__(
)
object.__setattr__(self, "make_instrument_config_fn", make_instrument_config_fn)
object.__setattr__(self, "get_image_stack", get_image_stack)
object.__setattr__(self, "get_cpu_arrays", get_cpu_arrays)

@final
def __getitem__(
Expand All @@ -150,15 +155,15 @@ def __getitem__(
f"The number of rows in the dataset is {n_rows}, but you tried to "
f"access the index {idx}."
)
# pandas has bad error messages for its indexing
# ... pandas has bad error messages for its indexing
if isinstance(index, (int, Int[np.ndarray, ""])): # type: ignore
if index > n_rows - 1:
raise IndexError(index_error_msg(index))
elif isinstance(index, slice):
if index.start is not None and index.start > n_rows - 1:
raise IndexError(index_error_msg(index.start))
elif isinstance(index, np.ndarray):
pass # catch exceptions later
pass # ... catch exceptions later
else:
raise IndexError(
f"Indexing with the type {type(index)} is not supported by "
Expand All @@ -176,16 +181,20 @@ def __getitem__(
"from the `starfile.read` output."
)
optics_group = self.data_blocks["optics"].iloc[0]
# Load the image stack, unless otherwise specified
# Load the image stack and STAR file parameters. First, get the device
# on which to load arrays
device = jax.devices("cpu")[0] if self.get_cpu_arrays else None
# ... load stack of images, unless otherwise specified
image_stack = (
self._get_image_stack(index, particle_blocks) # type: ignore
self._get_image_stack(index, particle_blocks, device)
if self.get_image_stack
else None
)
# Load image parameters into cryoJAX objects
# ... load image parameters into cryoJAX objects
instrument_config, ctf, pose = self._get_starfile_params(
particle_blocks, # type: ignore
optics_group, # type: ignore
particle_blocks,
optics_group,
device,
)

return RelionParticleStack(instrument_config, pose, ctf, image_stack)
Expand All @@ -195,25 +204,29 @@ def __len__(self) -> int:
return len(self.data_blocks["particles"])

def _get_starfile_params(
self, particle_blocks: pd.DataFrame, optics_group: pd.DataFrame
self, particle_blocks, optics_group, device
) -> tuple[InstrumentConfig, ContrastTransferFunction, EulerAnglePose]:
defocus_in_angstroms = jnp.asarray(particle_blocks["rlnDefocusU"])
defocus_in_angstroms = jnp.asarray(particle_blocks["rlnDefocusU"], device=device)
astigmatism_in_angstroms = jnp.asarray(
particle_blocks["rlnDefocusV"]
) - jnp.asarray(particle_blocks["rlnDefocusU"])
astigmatism_angle = jnp.asarray(particle_blocks["rlnDefocusAngle"])
particle_blocks["rlnDefocusV"], device=device
) - jnp.asarray(particle_blocks["rlnDefocusU"], device=device)
astigmatism_angle = jnp.asarray(particle_blocks["rlnDefocusAngle"], device=device)
phase_shift = jnp.asarray(particle_blocks["rlnPhaseShift"])
# ... optics group data
image_size = jnp.asarray(optics_group["rlnImageSize"])
pixel_size = jnp.asarray(optics_group["rlnImagePixelSize"])
image_size = jnp.asarray(optics_group["rlnImageSize"], device=device)
pixel_size = jnp.asarray(optics_group["rlnImagePixelSize"], device=device)
voltage_in_kilovolts = float(optics_group["rlnVoltage"]) # type: ignore
spherical_aberration_in_mm = jnp.asarray(optics_group["rlnSphericalAberration"])
amplitude_contrast_ratio = jnp.asarray(optics_group["rlnAmplitudeContrast"])
spherical_aberration_in_mm = jnp.asarray(
optics_group["rlnSphericalAberration"], device=device
)
amplitude_contrast_ratio = jnp.asarray(
optics_group["rlnAmplitudeContrast"], device=device
)
# ... create cryojax objects
instrument_config = self.make_instrument_config_fn(
(int(image_size), int(image_size)),
pixel_size,
jnp.asarray(voltage_in_kilovolts),
jnp.asarray(voltage_in_kilovolts, device=device),
)
ctf = ContrastTransferFunction(
defocus_in_angstroms=defocus_in_angstroms,
Expand Down Expand Up @@ -287,15 +300,16 @@ def _get_starfile_params(
pose = eqx.tree_at(
lambda p: tuple([getattr(p, name) for name in pose_parameter_names]),
pose,
tuple([jnp.asarray(value) for value in pose_parameter_values]),
tuple([jnp.asarray(value, device=device) for value in pose_parameter_values]),
)

return instrument_config, ctf, pose

def _get_image_stack(
self,
index: int | slice | Int[np.ndarray, ""] | Int[np.ndarray, " N"],
particle_blocks: pd.DataFrame,
particle_blocks,
device,
) -> Float[Array, "... y_dim x_dim"]:
# Load particle image stack rlnImageName
image_stack_index_and_name_series_or_str = particle_blocks["rlnImageName"]
Expand Down Expand Up @@ -355,7 +369,7 @@ def _get_image_stack(
with mrcfile.mmap(path_to_image_stack, mode="r", permissive=True) as mrc:
image_stack = np.asarray(mrc.data[particle_index]) # type: ignore

return jnp.asarray(image_stack)
return jnp.asarray(image_stack, device=device)


@dataclasses.dataclass(frozen=True)
Expand Down

0 comments on commit 6c13315

Please sign in to comment.