diff --git a/src/cryojax/data/_relion/_starfile_reading.py b/src/cryojax/data/_relion/_starfile_reading.py index 85ba217f..451b933a 100644 --- a/src/cryojax/data/_relion/_starfile_reading.py +++ b/src/cryojax/data/_relion/_starfile_reading.py @@ -389,6 +389,10 @@ def _get_image_stack( ) # ... relion convention starts indexing at 1, not 0 particle_index = np.asarray(relion_particle_index, dtype=int) - 1 + + with mrcfile.mmap(path_to_image_stack, mode="r", permissive=True) as mrc: + image_stack = np.asarray(mrc.data[particle_index]) # type: ignore + elif isinstance(image_stack_index_and_name_series_or_str, pd.Series): # In this block, the user most likely used fancy indexing, like # `dataset = RelionDataset(...); particle_stack = dataset[1:10]` @@ -398,39 +402,56 @@ def _get_image_stack( image_stack_index_and_name_dataframe = ( image_stack_index_and_name_series.str.split("@", expand=True) ) - # ... get a pandas.Series for each the index and the filename - relion_particle_index, image_stack_filename = [ - image_stack_index_and_name_dataframe[column] - for column in image_stack_index_and_name_dataframe.columns - ] - # ... multiple filenames in the same STAR file is not supported with - # fancy indexing - if image_stack_filename.nunique() != 1: - raise ValueError( - "Found multiple image stack filenames when reading " - "STAR file rows. This is most likely because you tried to " - "use fancy indexing with multiple image stack filenames " - "in the same STAR file. If a STAR file refers to multiple image " - "stack filenames, fancy indexing is not supported. For example, " - "this will raise an error: `dataset = RelionDataset(...); " - "particle_stack = dataset[1:10]`." - ) - # ... create full path to the image stack + + # ... check dtype and shape of images path_to_image_stack = pathlib.Path( self.path_to_relion_project, - np.asarray(image_stack_filename, dtype=object)[0], + np.asarray(image_stack_index_and_name_dataframe[1], dtype=object)[0], ) - # ... relion convention starts indexing at 1, not 0 - particle_index = np.asarray(relion_particle_index.astype(int), dtype=int) - 1 + + with mrcfile.mmap(path_to_image_stack, mode="r", permissive=True) as mrc: + tmp_image = np.asarray(mrc.data[0]) + dtype = tmp_image.dtype + image_shape = tmp_image.shape + + # ... allocate memory for stack + image_stack = np.empty( + (len(image_stack_index_and_name_dataframe), *image_shape), dtype=dtype + ) + + # ... get unique mrc files + unique_mrc_files = image_stack_index_and_name_dataframe[1].unique() + + # ... load images to image_stack + counter = 0 # need to keep count of how many particles have been loaded + for unique_mrc in unique_mrc_files: + # ... get the indices for this particular mrc file + indices_in_mrc = image_stack_index_and_name_dataframe[1] == unique_mrc + + # ... relion convention starts indexing at 1, not 0 + particle_index = ( + image_stack_index_and_name_dataframe[indices_in_mrc][0].values.astype( + int + ) + - 1 + ) + + with mrcfile.mmap( + pathlib.Path(self.path_to_relion_project, unique_mrc), + mode="r", + permissive=True, + ) as mrc: + image_stack[counter : counter + len(particle_index)] = np.asarray( + mrc.data[particle_index] + ) + counter += len(particle_index) + else: raise IOError( "Could not read `rlnImageName` in STAR file for `RelionDataset` " f"index equal to {index}." ) - with mrcfile.mmap(path_to_image_stack, mode="r", permissive=True) as mrc: - image_stack = np.asarray(mrc.data[particle_index]) # type: ignore - return jnp.asarray(image_stack, device=device) @@ -521,7 +542,7 @@ def __getitem__( if not isinstance(filament_index, (int, Int[np.ndarray, ""])): # type: ignore raise IndexError( "When indexing a `HelicalRelionDataset`, only " - f"python or numpy-like integer indices are supported, such as " + f"python or numpy-like integer particle_index are supported, such as " "`helical_particle_stack = helical_dataset[3]`. " f"Got index {filament_index} of type {type(filament_index)}." ) @@ -532,12 +553,12 @@ def __getitem__( f"bounds! The number of filaments in the dataset is {self.n_filaments}, " f"but you tried to access the index {filament_index}." ) - # Get the particle stack indices corresponding to this filament + # Get the particle stack particle_index corresponding to this filament particle_data_blocks_at_filament = self.get_data_blocks_at_filament_index( filament_index ) particle_indices = np.asarray(particle_data_blocks_at_filament.index, dtype=int) - # Access the particle stack at these indices + # Access the particle stack at these particle_index dataset = self.dataset[particle_indices] return dataset