Skip to content

Commit

Permalink
remove extra voltage_in_kilovolts
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 committed Dec 16, 2024
1 parent eb38967 commit 76b41f1
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 102 deletions.
20 changes: 9 additions & 11 deletions docs/examples/read-dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -83,29 +83,28 @@
" shape=(100, 100),\n",
" pixel_size=f32[],\n",
" voltage_in_kilovolts=f32[],\n",
" electrons_per_angstrom_squared=f32[],\n",
" electrons_per_angstrom_squared=weak_f32[],\n",
" padded_shape=(100, 100),\n",
" pad_mode='constant'\n",
" ),\n",
" pose=EulerAnglePose(\n",
" offset_x_in_angstroms=f32[],\n",
" offset_y_in_angstroms=f32[],\n",
" offset_x_in_angstroms=weak_f32[],\n",
" offset_y_in_angstroms=weak_f32[],\n",
" offset_z_in_angstroms=0.0,\n",
" view_phi=f32[],\n",
" view_theta=f32[],\n",
" view_psi=f32[]\n",
" view_phi=weak_f32[],\n",
" view_theta=weak_f32[],\n",
" view_psi=weak_f32[]\n",
" ),\n",
" transfer_theory=ContrastTransferTheory(\n",
" ctf=ContrastTransferFunction(\n",
" defocus_in_angstroms=f32[],\n",
" astigmatism_in_angstroms=f32[],\n",
" astigmatism_angle=f32[],\n",
" voltage_in_kilovolts=300.0,\n",
" spherical_aberration_in_mm=f32[],\n",
" amplitude_contrast_ratio=f32[],\n",
" phase_shift=f32[]\n",
" ),\n",
" envelope=Constant(value=f32[])\n",
" envelope=Constant(value=weak_f32[])\n",
" ),\n",
" image_stack=f32[100,100]\n",
")\n"
Expand Down Expand Up @@ -214,10 +213,9 @@
"output_type": "stream",
"text": [
"ContrastTransferFunction(\n",
" defocus_in_angstroms=Array([10050.97, 10050.97, 10050.97], dtype=float32),\n",
" astigmatism_in_angstroms=Array([-50.970703, -50.970703, -50.970703], dtype=float32),\n",
" defocus_in_angstroms=Array([10025.484, 10025.484, 10025.484], dtype=float32),\n",
" astigmatism_in_angstroms=Array([50.970703, 50.970703, 50.970703], dtype=float32),\n",
" astigmatism_angle=Array([-54.58706, -54.58706, -54.58706], dtype=float32),\n",
" voltage_in_kilovolts=300.0,\n",
" spherical_aberration_in_mm=Array(2.7, dtype=float32),\n",
" amplitude_contrast_ratio=Array(0.1, dtype=float32),\n",
" phase_shift=Array([0., 0., 0.], dtype=float32)\n",
Expand Down
69 changes: 34 additions & 35 deletions docs/examples/simulate-relion-dataset.ipynb

Large diffs are not rendered by default.

22 changes: 9 additions & 13 deletions src/cryojax/data/_relion/_starfile_reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _get_starfile_params(
# ... optics group data
image_size = jnp.asarray(optics_group["rlnImageSize"], device=device)
pixel_size = jnp.asarray(optics_group["rlnImagePixelSize"], device=device)
voltage_in_kilovolts = float(optics_group["rlnVoltage"]) # type: ignore
voltage_in_kilovolts = jnp.asarray(optics_group["rlnVoltage"]) # type: ignore
spherical_aberration_in_mm = jnp.asarray(
optics_group["rlnSphericalAberration"], device=device
)
Expand All @@ -246,30 +246,26 @@ def _get_starfile_params(
jnp.asarray(voltage_in_kilovolts, device=device),
)
# ... now the ContrastTransferTheory
make_ctf = (
lambda defocus, astig, angle, voltage, sph, ac, ps: ContrastTransferFunction(
defocus_in_angstroms=defocus,
astigmatism_in_angstroms=astig,
astigmatism_angle=angle,
voltage_in_kilovolts=voltage,
spherical_aberration_in_mm=sph,
amplitude_contrast_ratio=ac,
phase_shift=ps,
)
make_ctf = lambda defocus, astig, angle, sph, ac, ps: ContrastTransferFunction(
defocus_in_angstroms=defocus,
astigmatism_in_angstroms=astig,
astigmatism_angle=angle,
spherical_aberration_in_mm=sph,
amplitude_contrast_ratio=ac,
phase_shift=ps,
)
ctf_params = (
defocus_in_angstroms,
astigmatism_in_angstroms,
astigmatism_angle,
voltage_in_kilovolts,
spherical_aberration_in_mm,
amplitude_contrast_ratio,
phase_shift,
)
ctf = (
eqx.filter_vmap(
make_ctf,
in_axes=(0, 0, 0, None, None, None, 0),
in_axes=(0, 0, 0, None, None, 0),
out_axes=eqxi.if_mapped(0),
)(*ctf_params)
if defocus_in_angstroms.ndim == 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import abstractmethod
from typing import Optional

from equinox import Module
from jaxtyping import Array, Complex, Float
Expand All @@ -18,7 +17,7 @@ def __call__(
self,
frequency_grid_in_angstroms: Float[Array, "y_dim x_dim 2"],
*,
wavelength_in_angstroms: Optional[Float[Array, ""] | float] = None,
voltage_in_kilovolts: Float[Array, ""] | float = 300.0,
defocus_offset: Float[Array, ""] | float = 0.0,
) -> Float[Array, "y_dim x_dim"] | Complex[Array, "y_dim x_dim"]:
raise NotImplementedError
Expand Down
25 changes: 6 additions & 19 deletions src/cryojax/simulator/_transfer_theory/contrast_transfer_theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing_extensions import override

import jax.numpy as jnp
from equinox import field
from jaxtyping import Array, Complex, Float

from ..._errors import error_if_negative, error_if_not_fractional
Expand Down Expand Up @@ -39,7 +38,6 @@ class ContrastTransferFunction(AbstractTransferFunction, strict=True):
defocus_in_angstroms: Float[Array, ""]
astigmatism_in_angstroms: Float[Array, ""]
astigmatism_angle: Float[Array, ""]
voltage_in_kilovolts: Float[Array, ""] | float = field(static=True)
spherical_aberration_in_mm: Float[Array, ""]
amplitude_contrast_ratio: Float[Array, ""]
phase_shift: Float[Array, ""]
Expand All @@ -49,7 +47,6 @@ def __init__(
defocus_in_angstroms: float | Float[Array, ""] = 10000.0,
astigmatism_in_angstroms: float | Float[Array, ""] = 0.0,
astigmatism_angle: float | Float[Array, ""] = 0.0,
voltage_in_kilovolts: float | Float[Array, ""] = 300.0,
spherical_aberration_in_mm: float | Float[Array, ""] = 2.7,
amplitude_contrast_ratio: float | Float[Array, ""] = 0.1,
phase_shift: float | Float[Array, ""] = 0.0,
Expand All @@ -59,19 +56,13 @@ def __init__(
- `defocus_in_angstroms`: The mean defocus in Angstroms.
- `astigmatism_in_angstroms`: The amount of astigmatism in Angstroms.
- `astigmatism_angle`: The defocus angle.
- `voltage_in_kilovolts`:
The accelerating voltage in kV. This field is treated as *static*, i.e.
as part of the pytree. This is because the accelerating voltage is treated
as a traced value in the `InstrumentConfig`, since many modeling components
are interested in the accelerating voltage.
- `spherical_aberration_in_mm`: The spherical aberration coefficient in mm.
- `amplitude_contrast_ratio`: The amplitude contrast ratio.
- `phase_shift`: The additional phase shift.
"""
self.defocus_in_angstroms = error_if_negative(defocus_in_angstroms)
self.astigmatism_in_angstroms = jnp.asarray(astigmatism_in_angstroms)
self.astigmatism_angle = jnp.asarray(astigmatism_angle)
self.voltage_in_kilovolts = voltage_in_kilovolts
self.spherical_aberration_in_mm = error_if_negative(spherical_aberration_in_mm)
self.amplitude_contrast_ratio = error_if_not_fractional(amplitude_contrast_ratio)
self.phase_shift = jnp.asarray(phase_shift)
Expand All @@ -80,22 +71,18 @@ def __call__(
self,
frequency_grid_in_angstroms: Float[Array, "y_dim x_dim 2"],
*,
wavelength_in_angstroms: Optional[Float[Array, ""] | float] = None,
voltage_in_kilovolts: Float[Array, ""] | float = 300.0,
defocus_offset: Float[Array, ""] | float = 0.0,
) -> Float[Array, "y_dim x_dim"]:
# Convert degrees to radians
phase_shift = jnp.deg2rad(self.phase_shift)
astigmatism_angle = jnp.deg2rad(self.astigmatism_angle)
# Convert spherical abberation coefficient to angstroms
spherical_aberration_in_angstroms = self.spherical_aberration_in_mm * 1e7
# Get the wavelength. It can either be passed from upstream or stored in the
# CTF
if wavelength_in_angstroms is None:
wavelength_in_angstroms = convert_keV_to_angstroms(
jnp.asarray(self.voltage_in_kilovolts)
)
else:
wavelength_in_angstroms = jnp.asarray(wavelength_in_angstroms)
# Get the wavelength
wavelength_in_angstroms = convert_keV_to_angstroms(
jnp.asarray(voltage_in_kilovolts)
)
defocus_axis_1_in_angstroms = (
self.defocus_in_angstroms
+ jnp.asarray(defocus_offset)
Expand Down Expand Up @@ -160,7 +147,7 @@ def __call__(
# Compute the CTF
ctf_array = self.envelope(frequency_grid) * self.ctf(
frequency_grid,
wavelength_in_angstroms=instrument_config.wavelength_in_angstroms,
voltage_in_kilovolts=instrument_config.voltage_in_kilovolts,
defocus_offset=defocus_offset,
)
# ... compute the contrast as the CTF multiplied by the exit plane
Expand Down
26 changes: 6 additions & 20 deletions src/cryojax/simulator/_transfer_theory/wave_transfer_theory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Optional
from typing_extensions import override

import jax.numpy as jnp
from equinox import field
from jaxtyping import Array, Complex, Float

from ..._errors import error_if_negative, error_if_not_fractional, error_if_not_positive
Expand All @@ -26,7 +24,6 @@ class WaveTransferFunction(AbstractTransferFunction, strict=True):
defocus_in_angstroms: Float[Array, ""]
astigmatism_in_angstroms: Float[Array, ""]
astigmatism_angle: Float[Array, ""]
voltage_in_kilovolts: Float[Array, ""] | float = field(static=True)
spherical_aberration_in_mm: Float[Array, ""]
amplitude_contrast_ratio: Float[Array, ""]
phase_shift: Float[Array, ""]
Expand All @@ -36,7 +33,6 @@ def __init__(
defocus_in_angstroms: float | Float[Array, ""] = 10000.0,
astigmatism_in_angstroms: float | Float[Array, ""] = 0.0,
astigmatism_angle: float | Float[Array, ""] = 0.0,
voltage_in_kilovolts: float | Float[Array, ""] = 300.0,
spherical_aberration_in_mm: float | Float[Array, ""] = 2.7,
amplitude_contrast_ratio: float | Float[Array, ""] = 0.1,
phase_shift: float | Float[Array, ""] = 0.0,
Expand All @@ -46,19 +42,13 @@ def __init__(
- `defocus_u_in_angstroms`: The major axis defocus in Angstroms.
- `defocus_v_in_angstroms`: The minor axis defocus in Angstroms.
- `astigmatism_angle`: The defocus angle.
- `voltage_in_kilovolts`:
The accelerating voltage in kV. This field is treated as *static*, i.e.
as part of the pytree. This is because the accelerating voltage is treated
as a traced value in the `InstrumentConfig`, since many modeling components
are interested in the accelerating voltage.
- `spherical_aberration_in_mm`: The spherical aberration coefficient in mm.
- `amplitude_contrast_ratio`: The amplitude contrast ratio.
- `phase_shift`: The additional phase shift.
"""
self.defocus_in_angstroms = error_if_not_positive(defocus_in_angstroms)
self.astigmatism_in_angstroms = jnp.asarray(astigmatism_in_angstroms)
self.astigmatism_angle = jnp.asarray(astigmatism_angle)
self.voltage_in_kilovolts = voltage_in_kilovolts
self.spherical_aberration_in_mm = error_if_negative(spherical_aberration_in_mm)
self.amplitude_contrast_ratio = error_if_not_fractional(amplitude_contrast_ratio)
self.phase_shift = jnp.asarray(phase_shift)
Expand All @@ -67,22 +57,18 @@ def __call__(
self,
frequency_grid_in_angstroms: Float[Array, "y_dim x_dim 2"],
*,
wavelength_in_angstroms: Optional[Float[Array, ""] | float] = None,
voltage_in_kilovolts: Float[Array, ""] | float = 300.0,
defocus_offset: Float[Array, ""] | float = 0.0,
) -> Complex[Array, "y_dim x_dim"]:
# Convert degrees to radians
phase_shift = jnp.deg2rad(self.phase_shift)
astigmatism_angle = jnp.deg2rad(self.astigmatism_angle)
# Convert spherical abberation coefficient to angstroms
spherical_aberration_in_angstroms = self.spherical_aberration_in_mm * 1e7
# Get the wavelength. It can either be passed from upstream or stored in the
# CTF
if wavelength_in_angstroms is None:
wavelength_in_angstroms = convert_keV_to_angstroms(
jnp.asarray(self.voltage_in_kilovolts)
)
else:
wavelength_in_angstroms = jnp.asarray(wavelength_in_angstroms)
# Get the wavelength
wavelength_in_angstroms = convert_keV_to_angstroms(
jnp.asarray(voltage_in_kilovolts)
)
defocus_axis_1_in_angstroms = self.defocus_in_angstroms + jnp.asarray(
defocus_offset
)
Expand Down Expand Up @@ -146,7 +132,7 @@ def __call__(
# Compute the wave transfer function
wtf_array = self.wtf(
frequency_grid,
wavelength_in_angstroms=instrument_config.wavelength_in_angstroms,
voltage_in_kilovolts=instrument_config.voltage_in_kilovolts,
defocus_offset=defocus_offset,
)
# ... compute the contrast as the CTF multiplied by the exit plane
Expand Down
3 changes: 1 addition & 2 deletions tests/test_agree_with_cistem.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ def test_ctf_with_cistem(defocus1, defocus2, asti_angle, kV, cs, ac, pixel_size)
defocus_in_angstroms=(defocus1 + defocus2) / 2,
astigmatism_in_angstroms=defocus1 - defocus2,
astigmatism_angle=asti_angle,
voltage_in_kilovolts=kV,
spherical_aberration_in_mm=cs,
amplitude_contrast_ratio=ac,
)
ctf = jnp.array(optics(freqs))
ctf = jnp.array(optics(freqs, voltage_in_kilovolts=kV))
# Compute cisTEM CTF
cisTEM_optics = CTF(
kV=kV,
Expand Down

0 comments on commit 76b41f1

Please sign in to comment.