Skip to content

Commit

Permalink
Merge pull request #271 from mjo22/idea-for-writer
Browse files Browse the repository at this point in the history
Fix bug in new writer function
  • Loading branch information
mjo22 authored Sep 16, 2024
2 parents 37244dc + fae07b3 commit dc6b075
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 103 deletions.
115 changes: 62 additions & 53 deletions docs/examples/simulate-relion-dataset.ipynb

Large diffs are not rendered by default.

106 changes: 56 additions & 50 deletions src/cryojax/data/_relion/_starfile_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import pathlib
from typing import Any, Callable, cast, Optional

import equinox as eqx
import jax
import numpy as np
import pandas as pd
import starfile
from jaxtyping import Array, Float, PRNGKeyArray

from ... import filter_vmap_with_spec, get_filter_spec
from ..._filter_specs import get_filter_spec
from ...image.operators import Constant, FourierGaussian
from ...io import write_image_stack_to_mrc
from ._starfile_reading import RelionDataset, RelionParticleStack
Expand Down Expand Up @@ -261,52 +262,57 @@ def compute_noisy_image_stack(
# Create the directory for the MRC files if it doesn't exist
if not os.path.exists(dataset.path_to_relion_project):
os.makedirs(dataset.path_to_relion_project)

# Create RNG key, along with a subkey for subsequent use
if seed is not None:
key = jax.random.PRNGKey(seed=seed)
key, subkey = jax.random.split(key)
else:
subkey = cast(PRNGKeyArray, None)

# Create vmapped `compute_image` kernel
test_particle_stack = dataset[0]
filter_spec_for_vmap = _get_particle_stack_filter_spec(test_particle_stack)
compute_image_stack = filter_vmap_with_spec(
compute_image, filter_spec=filter_spec_for_vmap
compute_image_stack = (
eqx.filter_vmap(
lambda vmap, novmap, args: compute_image(eqx.combine(vmap, novmap), args), # type: ignore
in_axes=(0, None, None),
)
if seed is None
else eqx.filter_vmap(
lambda key, vmap, novmap, args: compute_image(
key, eqx.combine(vmap, novmap), args
), # type: ignore
in_axes=(0, 0, None, None),
)
)

# First let's check how many unique MRC files we have in the starfile
# Now, let's preparing the simulation loop. First check how many unique MRC
# files we have in the starfile
particles_fnames = dataset.data_blocks["particles"]["rlnImageName"].str.split(
"@", expand=True
)
mrc_fnames = particles_fnames[1].unique()

# Generate images for each mrcfile
# ... now, generate images for each mrcfile
for mrc_fname in mrc_fnames:
# Check which indices in the starfile correspond to this mrc file
# ... check which indices in the starfile correspond to this mrc file
# and load the particle stack parameters
indices = np.array(
[0, 1]
) # particles_fnames[particles_fnames[1] == mrc_fname].index.to_numpy()
indices = particles_fnames[particles_fnames[1] == mrc_fname].index.to_numpy()
relion_particle_stack = dataset[indices]

# Generate keys for each image in the mrcfile, and a subkey for the next iteration
if seed is not None:
# ... split the particle stack based on parameters to vmap over
vmap, novmap = eqx.partition(relion_particle_stack, filter_spec_for_vmap)
# ... simulate images in the image stack
if seed is None:
image_stack = compute_image_stack(vmap, novmap, args)
else:
# ... generate keys for each image in the mrcfile,
# and a subkey for the next iteration
keys = jax.random.split(subkey, len(indices) + 1)
image_stack = compute_image_stack(
keys[:-1],
vmap,
novmap,
args, # type: ignore
)
subkey = keys[-1]

# Generate the noisy image stack
image_stack = (
compute_image_stack(relion_particle_stack, args)
# if seed is None
# else compute_image_stack(
# keys[:-1],
# relion_particle_stack,
# args,
# )
)

# Write the image stack to an MRC file
# ... write the image stack to an MRC file
filename = os.path.join(dataset.path_to_relion_project, mrc_fname)
write_image_stack_to_mrc(
image_stack,
Expand All @@ -325,30 +331,30 @@ def compute_noisy_image_stack(


def _pointer_to_vmapped_parameters(particle_stack):
if isinstance(particle_stack.envelope, FourierGaussian):
if isinstance(particle_stack.transfer_theory.envelope, FourierGaussian):
output = (
particle_stack.ctf.defocus_in_angstroms,
particle_stack.ctf.astigmatism_in_angstroms,
particle_stack.ctf.astigmatism_angle,
particle_stack.ctf.phase_shift,
particle_stack.envelope.b_factor,
particle_stack.envelope.amplitude,
particle_stack.offset_x_in_angstroms,
particle_stack.offset_y_in_angstroms,
particle_stack.view_phi,
particle_stack.view_theta,
particle_stack.view_psi,
particle_stack.transfer_theory.ctf.defocus_in_angstroms,
particle_stack.transfer_theory.ctf.astigmatism_in_angstroms,
particle_stack.transfer_theory.ctf.astigmatism_angle,
particle_stack.transfer_theory.ctf.phase_shift,
particle_stack.transfer_theory.envelope.b_factor,
particle_stack.transfer_theory.envelope.amplitude,
particle_stack.pose.offset_x_in_angstroms,
particle_stack.pose.offset_y_in_angstroms,
particle_stack.pose.view_phi,
particle_stack.pose.view_theta,
particle_stack.pose.view_psi,
)
else:
output = (
particle_stack.ctf.defocus_in_angstroms,
particle_stack.ctf.astigmatism_in_angstroms,
particle_stack.ctf.astigmatism_angle,
particle_stack.ctf.phase_shift,
particle_stack.offset_x_in_angstroms,
particle_stack.offset_y_in_angstroms,
particle_stack.view_phi,
particle_stack.view_theta,
particle_stack.view_psi,
particle_stack.transfer_theory.ctf.defocus_in_angstroms,
particle_stack.transfer_theory.ctf.astigmatism_in_angstroms,
particle_stack.transfer_theory.ctf.astigmatism_angle,
particle_stack.transfer_theory.ctf.phase_shift,
particle_stack.pose.offset_x_in_angstroms,
particle_stack.pose.offset_y_in_angstroms,
particle_stack.pose.view_phi,
particle_stack.pose.view_theta,
particle_stack.pose.view_psi,
)
return output

0 comments on commit dc6b075

Please sign in to comment.