Skip to content

Commit

Permalink
Merge pull request #264 from mjo22/252-add-envelope-function-to-relio…
Browse files Browse the repository at this point in the history
…ndataset

252 add envelope function to reliondataset
  • Loading branch information
mjo22 authored Sep 11, 2024
2 parents 7576872 + 52aac95 commit 295dd4a
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 58 deletions.
36 changes: 22 additions & 14 deletions docs/examples/cross-correlation-search.ipynb

Large diffs are not rendered by default.

66 changes: 37 additions & 29 deletions docs/examples/read-dataset.ipynb

Large diffs are not rendered by default.

53 changes: 40 additions & 13 deletions src/cryojax/data/_relion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,15 @@
import pandas as pd
from jaxtyping import Array, Float, Int

from cryojax.image.operators import FourierGaussian

from ..io import read_and_validate_starfile
from ..simulator import ContrastTransferFunction, EulerAnglePose, InstrumentConfig
from ..simulator import (
ContrastTransferFunction,
ContrastTransferTheory,
EulerAnglePose,
InstrumentConfig,
)
from ._dataset import AbstractDataset
from ._particle_stack import AbstractParticleStack

Expand Down Expand Up @@ -41,14 +48,14 @@ class RelionParticleStack(AbstractParticleStack):

instrument_config: InstrumentConfig
pose: EulerAnglePose
ctf: ContrastTransferFunction
transfer_theory: ContrastTransferTheory
image_stack: Optional[Float[Array, "... y_dim x_dim"]]

def __init__(
self,
instrument_config: InstrumentConfig,
pose: EulerAnglePose,
ctf: ContrastTransferFunction,
transfer_theory: ContrastTransferTheory,
image_stack: Optional[Float[Array, "... y_dim x_dim"]] = None,
):
"""**Arguments:**
Expand All @@ -60,8 +67,8 @@ def __init__(
The pose, represented by euler angles. Any subset of pytree leaves may
have a batch dimension. Upon instantiation, `pose.offset_z_in_angstroms`
is set to zero.
- `ctf`:
The contrast transfer function. Any subset of pytree leaves may
- `transfer_theory`:
The contrast transfer theory. Any subset of pytree leaves may
have a batch dimension. Upon instantiation,
`ctf.defocus_in_angstroms` is set to
`ctf.defocus_in_angstroms + pose.offset_z_in_angstroms`.
Expand All @@ -73,10 +80,10 @@ def __init__(
# Set instrument config as is
self.instrument_config = instrument_config
# Set CTF using the defocus offset in the EulerAnglePose
self.ctf = eqx.tree_at(
lambda tf: tf.defocus_in_angstroms,
ctf,
ctf.defocus_in_angstroms + pose.offset_z_in_angstroms,
self.transfer_theory = eqx.tree_at(
lambda tf: tf.ctf.defocus_in_angstroms,
transfer_theory,
transfer_theory.ctf.defocus_in_angstroms + pose.offset_z_in_angstroms,
)
# Set defocus offset to zero
self.pose = eqx.tree_at(lambda pose: pose.offset_z_in_angstroms, pose, 0.0)
Expand Down Expand Up @@ -191,22 +198,23 @@ def __getitem__(
else None
)
# ... load image parameters into cryoJAX objects
instrument_config, ctf, pose = self._get_starfile_params(
instrument_config, transfer_theory, pose = self._get_starfile_params(
particle_blocks,
optics_group,
device,
)

return RelionParticleStack(instrument_config, pose, ctf, image_stack)
return RelionParticleStack(instrument_config, pose, transfer_theory, image_stack)

@final
def __len__(self) -> int:
return len(self.data_blocks["particles"])

def _get_starfile_params(
self, particle_blocks, optics_group, device
) -> tuple[InstrumentConfig, ContrastTransferFunction, EulerAnglePose]:
) -> tuple[InstrumentConfig, ContrastTransferTheory, EulerAnglePose]:
defocus_in_angstroms = jnp.asarray(particle_blocks["rlnDefocusU"], device=device)

astigmatism_in_angstroms = jnp.asarray(
particle_blocks["rlnDefocusV"], device=device
) - jnp.asarray(particle_blocks["rlnDefocusU"], device=device)
Expand All @@ -222,6 +230,7 @@ def _get_starfile_params(
amplitude_contrast_ratio = jnp.asarray(
optics_group["rlnAmplitudeContrast"], device=device
)

# ... create cryojax objects
instrument_config = self.make_instrument_config_fn(
(int(image_size), int(image_size)),
Expand All @@ -237,6 +246,24 @@ def _get_starfile_params(
amplitude_contrast_ratio=amplitude_contrast_ratio,
phase_shift=phase_shift,
)

if "rlnCtfBfactor" in particle_blocks.keys():
bfactor = jnp.asarray(particle_blocks["rlnCtfBfactor"], device=device)

else:
bfactor = 0.0

if "rlnCtfScalefactor" in particle_blocks.keys():
env_amplitude = jnp.asarray(
particle_blocks["rlnCtfScalefactor"], device=device
)
else:
env_amplitude = 1.0

envelope = FourierGaussian(b_factor=bfactor, amplitude=env_amplitude)

transfer_theory = ContrastTransferTheory(ctf, envelope)

pose = EulerAnglePose()
# ... values for the pose are optional, so look to see if
# each key is present
Expand Down Expand Up @@ -303,7 +330,7 @@ def _get_starfile_params(
tuple([jnp.asarray(value, device=device) for value in pose_parameter_values]),
)

return instrument_config, ctf, pose
return instrument_config, transfer_theory, pose

def _get_image_stack(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/cryojax/image/operators/_fourier_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from equinox import field
from jaxtyping import Array, Float, Inexact

from ..._errors import error_if_not_positive
from ..._errors import error_if_negative, error_if_not_positive
from ._operator import AbstractImageOperator


Expand Down Expand Up @@ -172,7 +172,7 @@ class FourierGaussian(AbstractFourierOperator, strict=True):
"""

amplitude: Float[Array, ""] = field(default=1.0, converter=jnp.asarray)
b_factor: Float[Array, ""] = field(default=1.0, converter=error_if_not_positive)
b_factor: Float[Array, ""] = field(default=1.0, converter=error_if_negative)

@overload
def __call__(
Expand Down

0 comments on commit 295dd4a

Please sign in to comment.