diff --git a/src/cryojax/data/_relion.py b/src/cryojax/data/_relion.py index b5f6f99c..96ce8c4a 100644 --- a/src/cryojax/data/_relion.py +++ b/src/cryojax/data/_relion.py @@ -5,6 +5,7 @@ from typing import Any, Callable, final, Optional import equinox as eqx +import jax import jax.numpy as jnp import mrcfile import numpy as np @@ -69,11 +70,7 @@ def __init__( is a leading batch dimension followed by the shape of an image in the stack. """ - # Set image stack and config as is - if image_stack is not None: - self.image_stack = jnp.asarray(image_stack) - else: - self.image_stack = None + # Set instrument config as is self.instrument_config = instrument_config # Set CTF using the defocus offset in the EulerAnglePose self.ctf = eqx.tree_at( @@ -83,6 +80,8 @@ def __init__( ) # Set defocus offset to zero self.pose = eqx.tree_at(lambda pose: pose.offset_z_in_angstroms, pose, 0.0) + # Optionally set image stack + self.image_stack = None if image_stack is None else jnp.asarray(image_stack) def _default_make_instrument_config_fn( @@ -106,6 +105,7 @@ class RelionDataset(AbstractDataset): [tuple[int, int], Float[Array, "..."], Float[Array, "..."]], InstrumentConfig ] get_image_stack: bool + get_cpu_arrays: bool @final def __init__( @@ -113,6 +113,7 @@ def __init__( path_to_starfile: str | pathlib.Path, path_to_relion_project: str | pathlib.Path, get_image_stack: bool = True, + get_cpu_arrays: bool = False, make_instrument_config_fn: Callable[ [tuple[int, int], Float[Array, "..."], Float[Array, "..."]], InstrumentConfig, @@ -125,6 +126,9 @@ def __init__( - `get_image_stack`: If `True`, read the stack of images from the STAR file. Otherwise, just read parameters. + - `get_cpu_arrays`: + If `True`, force that JAX arrays are loaded on the CPU. If `False`, + load on the default device. - `make_instrument_config_fn`: A function used for `InstrumentConfig` initialization that returns an `InstrumentConfig`. This is used to customize the metadata of the @@ -138,6 +142,7 @@ def __init__( ) object.__setattr__(self, "make_instrument_config_fn", make_instrument_config_fn) object.__setattr__(self, "get_image_stack", get_image_stack) + object.__setattr__(self, "get_cpu_arrays", get_cpu_arrays) @final def __getitem__( @@ -150,7 +155,7 @@ def __getitem__( f"The number of rows in the dataset is {n_rows}, but you tried to " f"access the index {idx}." ) - # pandas has bad error messages for its indexing + # ... pandas has bad error messages for its indexing if isinstance(index, (int, Int[np.ndarray, ""])): # type: ignore if index > n_rows - 1: raise IndexError(index_error_msg(index)) @@ -158,7 +163,7 @@ def __getitem__( if index.start is not None and index.start > n_rows - 1: raise IndexError(index_error_msg(index.start)) elif isinstance(index, np.ndarray): - pass # catch exceptions later + pass # ... catch exceptions later else: raise IndexError( f"Indexing with the type {type(index)} is not supported by " @@ -176,16 +181,20 @@ def __getitem__( "from the `starfile.read` output." ) optics_group = self.data_blocks["optics"].iloc[0] - # Load the image stack, unless otherwise specified + # Load the image stack and STAR file parameters. First, get the device + # on which to load arrays + device = jax.devices("cpu")[0] if self.get_cpu_arrays else None + # ... load stack of images, unless otherwise specified image_stack = ( - self._get_image_stack(index, particle_blocks) # type: ignore + self._get_image_stack(index, particle_blocks, device) if self.get_image_stack else None ) - # Load image parameters into cryoJAX objects + # ... load image parameters into cryoJAX objects instrument_config, ctf, pose = self._get_starfile_params( - particle_blocks, # type: ignore - optics_group, # type: ignore + particle_blocks, + optics_group, + device, ) return RelionParticleStack(instrument_config, pose, ctf, image_stack) @@ -195,25 +204,29 @@ def __len__(self) -> int: return len(self.data_blocks["particles"]) def _get_starfile_params( - self, particle_blocks: pd.DataFrame, optics_group: pd.DataFrame + self, particle_blocks, optics_group, device ) -> tuple[InstrumentConfig, ContrastTransferFunction, EulerAnglePose]: - defocus_in_angstroms = jnp.asarray(particle_blocks["rlnDefocusU"]) + defocus_in_angstroms = jnp.asarray(particle_blocks["rlnDefocusU"], device=device) astigmatism_in_angstroms = jnp.asarray( - particle_blocks["rlnDefocusV"] - ) - jnp.asarray(particle_blocks["rlnDefocusU"]) - astigmatism_angle = jnp.asarray(particle_blocks["rlnDefocusAngle"]) + particle_blocks["rlnDefocusV"], device=device + ) - jnp.asarray(particle_blocks["rlnDefocusU"], device=device) + astigmatism_angle = jnp.asarray(particle_blocks["rlnDefocusAngle"], device=device) phase_shift = jnp.asarray(particle_blocks["rlnPhaseShift"]) # ... optics group data - image_size = jnp.asarray(optics_group["rlnImageSize"]) - pixel_size = jnp.asarray(optics_group["rlnImagePixelSize"]) + image_size = jnp.asarray(optics_group["rlnImageSize"], device=device) + pixel_size = jnp.asarray(optics_group["rlnImagePixelSize"], device=device) voltage_in_kilovolts = float(optics_group["rlnVoltage"]) # type: ignore - spherical_aberration_in_mm = jnp.asarray(optics_group["rlnSphericalAberration"]) - amplitude_contrast_ratio = jnp.asarray(optics_group["rlnAmplitudeContrast"]) + spherical_aberration_in_mm = jnp.asarray( + optics_group["rlnSphericalAberration"], device=device + ) + amplitude_contrast_ratio = jnp.asarray( + optics_group["rlnAmplitudeContrast"], device=device + ) # ... create cryojax objects instrument_config = self.make_instrument_config_fn( (int(image_size), int(image_size)), pixel_size, - jnp.asarray(voltage_in_kilovolts), + jnp.asarray(voltage_in_kilovolts, device=device), ) ctf = ContrastTransferFunction( defocus_in_angstroms=defocus_in_angstroms, @@ -287,7 +300,7 @@ def _get_starfile_params( pose = eqx.tree_at( lambda p: tuple([getattr(p, name) for name in pose_parameter_names]), pose, - tuple([jnp.asarray(value) for value in pose_parameter_values]), + tuple([jnp.asarray(value, device=device) for value in pose_parameter_values]), ) return instrument_config, ctf, pose @@ -295,7 +308,8 @@ def _get_starfile_params( def _get_image_stack( self, index: int | slice | Int[np.ndarray, ""] | Int[np.ndarray, " N"], - particle_blocks: pd.DataFrame, + particle_blocks, + device, ) -> Float[Array, "... y_dim x_dim"]: # Load particle image stack rlnImageName image_stack_index_and_name_series_or_str = particle_blocks["rlnImageName"] @@ -355,7 +369,7 @@ def _get_image_stack( with mrcfile.mmap(path_to_image_stack, mode="r", permissive=True) as mrc: image_stack = np.asarray(mrc.data[particle_index]) # type: ignore - return jnp.asarray(image_stack) + return jnp.asarray(image_stack, device=device) @dataclasses.dataclass(frozen=True)