Skip to content

Commit

Permalink
implement loading from different mrcfiles in RelionDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Sep 24, 2024
1 parent c087e60 commit cd47668
Showing 1 changed file with 48 additions and 27 deletions.
75 changes: 48 additions & 27 deletions src/cryojax/data/_relion/_starfile_reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ def _get_image_stack(
)
# ... relion convention starts indexing at 1, not 0
particle_index = np.asarray(relion_particle_index, dtype=int) - 1

with mrcfile.mmap(path_to_image_stack, mode="r", permissive=True) as mrc:
image_stack = np.asarray(mrc.data[particle_index]) # type: ignore

elif isinstance(image_stack_index_and_name_series_or_str, pd.Series):
# In this block, the user most likely used fancy indexing, like
# `dataset = RelionDataset(...); particle_stack = dataset[1:10]`
Expand All @@ -398,39 +402,56 @@ def _get_image_stack(
image_stack_index_and_name_dataframe = (
image_stack_index_and_name_series.str.split("@", expand=True)
)
# ... get a pandas.Series for each the index and the filename
relion_particle_index, image_stack_filename = [
image_stack_index_and_name_dataframe[column]
for column in image_stack_index_and_name_dataframe.columns
]
# ... multiple filenames in the same STAR file is not supported with
# fancy indexing
if image_stack_filename.nunique() != 1:
raise ValueError(
"Found multiple image stack filenames when reading "
"STAR file rows. This is most likely because you tried to "
"use fancy indexing with multiple image stack filenames "
"in the same STAR file. If a STAR file refers to multiple image "
"stack filenames, fancy indexing is not supported. For example, "
"this will raise an error: `dataset = RelionDataset(...); "
"particle_stack = dataset[1:10]`."
)
# ... create full path to the image stack

# ... check dtype and shape of images
path_to_image_stack = pathlib.Path(
self.path_to_relion_project,
np.asarray(image_stack_filename, dtype=object)[0],
np.asarray(image_stack_index_and_name_dataframe[1], dtype=object)[0],
)
# ... relion convention starts indexing at 1, not 0
particle_index = np.asarray(relion_particle_index.astype(int), dtype=int) - 1

with mrcfile.mmap(path_to_image_stack, mode="r", permissive=True) as mrc:
tmp_image = np.asarray(mrc.data[0])
dtype = tmp_image.dtype
image_shape = tmp_image.shape

# ... allocate memory for stack
image_stack = np.empty(
(len(image_stack_index_and_name_dataframe), *image_shape), dtype=dtype
)

# ... get unique mrc files
unique_mrc_files = image_stack_index_and_name_dataframe[1].unique()

# ... load images to image_stack
counter = 0 # need to keep count of how many particles have been loaded
for unique_mrc in unique_mrc_files:
# ... get the indices for this particular mrc file
indices_in_mrc = image_stack_index_and_name_dataframe[1] == unique_mrc

# ... relion convention starts indexing at 1, not 0
particle_index = (
image_stack_index_and_name_dataframe[indices_in_mrc][0].values.astype(
int
)
- 1
)

with mrcfile.mmap(
pathlib.Path(self.path_to_relion_project, unique_mrc),
mode="r",
permissive=True,
) as mrc:
image_stack[counter : counter + len(particle_index)] = np.asarray(
mrc.data[particle_index]
)
counter += len(particle_index)

else:
raise IOError(
"Could not read `rlnImageName` in STAR file for `RelionDataset` "
f"index equal to {index}."
)

with mrcfile.mmap(path_to_image_stack, mode="r", permissive=True) as mrc:
image_stack = np.asarray(mrc.data[particle_index]) # type: ignore

return jnp.asarray(image_stack, device=device)


Expand Down Expand Up @@ -521,7 +542,7 @@ def __getitem__(
if not isinstance(filament_index, (int, Int[np.ndarray, ""])): # type: ignore
raise IndexError(
"When indexing a `HelicalRelionDataset`, only "
f"python or numpy-like integer indices are supported, such as "
f"python or numpy-like integer particle_index are supported, such as "
"`helical_particle_stack = helical_dataset[3]`. "
f"Got index {filament_index} of type {type(filament_index)}."
)
Expand All @@ -532,12 +553,12 @@ def __getitem__(
f"bounds! The number of filaments in the dataset is {self.n_filaments}, "
f"but you tried to access the index {filament_index}."
)
# Get the particle stack indices corresponding to this filament
# Get the particle stack particle_index corresponding to this filament
particle_data_blocks_at_filament = self.get_data_blocks_at_filament_index(
filament_index
)
particle_indices = np.asarray(particle_data_blocks_at_filament.index, dtype=int)
# Access the particle stack at these indices
# Access the particle stack at these particle_index
dataset = self.dataset[particle_indices]
return dataset

Expand Down

0 comments on commit cd47668

Please sign in to comment.