Skip to content

Commit

Permalink
Merge pull request #70 from mjo22/helix
Browse files Browse the repository at this point in the history
updates, such helical bug fix, arbitrary image shapes, and reorganizing projection normalization
  • Loading branch information
mjo22 authored Dec 20, 2023
2 parents 822be1b + a922b41 commit c6cfd73
Show file tree
Hide file tree
Showing 12 changed files with 166 additions and 70 deletions.
4 changes: 2 additions & 2 deletions src/cryojax/io/load_voxels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def load_voxel_cloud(filename: str, **kwargs: Any) -> dict[str, Array]:
Read a 3D template on a cartesian grid
to a point cloud.
This is used to instantiate ``cryojax.simulator.ElectronCloud``.
This is used to instantiate ``cryojax.simulator.VoxelCloud``.
Parameters
----------
Expand Down Expand Up @@ -59,7 +59,7 @@ def load_fourier_grid(filename: str, pad_scale: float = 1.0) -> dict[str, Any]:
"""
Read a 3D template in Fourier space on a cartesian grid.
This is used to instantiate ``cryojax.simulator.ElectronGrid``.
This is used to instantiate ``cryojax.simulator.VoxelGrid``.
Parameters
----------
Expand Down
5 changes: 2 additions & 3 deletions src/cryojax/simulator/assembly/_helix.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@

__all__ = ["Helix", "compute_lattice_positions", "compute_lattice_rotations"]

from typing import Union, Optional, Any
from typing import Union, Optional
from jaxtyping import Array, Float
from functools import cached_property

import jax
import jax.numpy as jnp

from ._assembly import Assembly, _Positions, _Rotations
from ..specimen import Specimen

from ...core import field
from ...typing import Real_, RealVector
Expand Down Expand Up @@ -180,7 +179,7 @@ def f(carry, x):
R_n = jnp.array(
((c_n, s_n, 0), (-s_n, c_n, 0), (0, 0, 1)), dtype=float
)
return (R_n @ r.T).T
return (R_n.T @ r.T).T

# The helical coordinates for all sub-helices
positions = jax.vmap(compute_helix_coordinates)(symmetry_angles)
Expand Down
15 changes: 7 additions & 8 deletions src/cryojax/simulator/density/_voxel_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@
from ..pose import Pose
from ...io import load_voxel_cloud, load_fourier_grid
from ...core import field
from ...typing import (
ComplexVolume,
RealCloud,
CloudCoords3D,
)
from ...typing import RealCloud, CloudCoords3D

_VolumeSliceCoords = Float[Array, "N1 N2 1 3"]
_CubicVolume = Float[Array, "N N N"]
_VolumeSliceCoords = Float[Array, "N N 1 3"]


class Voxels(ElectronDensity):
Expand Down Expand Up @@ -71,11 +68,11 @@ class VoxelGrid(Voxels):
----------
weights :
3D electron density grid in Fourier space.
coordinates : shape `(N1, N2, 1, 3)`
coordinates : shape `(N, N, 1, 3)`
Central slice of cartesian coordinate system.
"""

weights: ComplexVolume = field()
weights: _CubicVolume = field()
coordinates: _VolumeSliceCoords = field()

real: bool = field(default=False, static=True)
Expand All @@ -85,6 +82,8 @@ def __check_init__(self):
raise NotImplementedError(
"Real voxel grid densities are not supported."
)
if self.weights.shape != tuple(3 * [self.weights.shape[0]]):
raise ValueError("Only cubic voxel grids are supported.")

def rotate_to(self, pose: Pose) -> "VoxelGrid":
"""
Expand Down
4 changes: 2 additions & 2 deletions src/cryojax/simulator/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .noise import GaussianNoise
from .kernel import Kernel, Constant
from ..utils import scale, irfftn
from ..utils import scale, ifftn
from ..core import field, Module
from ..typing import Real_, RealImage, ImageCoords

Expand Down Expand Up @@ -101,7 +101,7 @@ def sample(
freqs: ImageCoords,
image: Optional[RealImage] = None,
) -> RealImage:
return irfftn(super().sample(key, freqs))
return ifftn(super().sample(key, freqs)).real


@partial(jax.jit, static_argnames=["method", "antialias"])
Expand Down
14 changes: 8 additions & 6 deletions src/cryojax/simulator/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .instrument import Instrument
from .detector import NullDetector
from .ice import Ice, NullIce
from ..utils import fftn, irfftn
from ..utils import fftn, ifftn
from ..core import field, Module
from ..typing import RealImage, Image, Real_

Expand Down Expand Up @@ -134,10 +134,10 @@ def sample(
image = self.render(view=False)
if not isinstance(self.solvent, NullIce):
# The image of the solvent
ice_image = irfftn(
ice_image = ifftn(
self.instrument.optics(freqs)
* self.solvent.sample(key[idx], freqs)
)
).real
image = image + ice_image
idx += 1
if not isinstance(self.instrument.detector, NullDetector):
Expand Down Expand Up @@ -174,7 +174,7 @@ def _view(self, image: Image, is_real: bool = True) -> RealImage:
if self.filter is not None:
if is_real:
image = fftn(image)
image = irfftn(self.filter(image))
image = ifftn(self.filter(image)).real
# Crop the image
image = self.manager.crop(image)
# Mask the image
Expand All @@ -188,8 +188,10 @@ def _render_specimen(self) -> RealImage:
freqs = self.manager.padded_freqs / resolution
# Draw the electron density at a particular conformation and pose
density = self.specimen.realization
# Compute the scattering image
# Compute the scattering image in fourier space
image = self.scattering.scatter(density, resolution=resolution)
# Normalize to cisTEM conventions
image = self.manager.normalize_to_cistem(image, is_real=False)
# Apply translation
image *= self.specimen.pose.shifts(freqs)
# Compute and apply CTF
Expand All @@ -202,7 +204,7 @@ def _render_specimen(self) -> RealImage:
image = scaling * image + offset
# Measure at the detector pixel size
image = self.instrument.detector.pixelize(
irfftn(image), resolution=resolution
ifftn(image).real, resolution=resolution
)

return image
Expand Down
38 changes: 35 additions & 3 deletions src/cryojax/simulator/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

__all__ = ["ImageManager"]

from typing import Any
from typing import Any, Union, Callable

import jax.numpy as jnp

from ..core import field, Buffer
from ..typing import (
Expand All @@ -18,6 +20,7 @@
make_coordinates,
crop,
pad,
crop_or_pad,
resize,
)

Expand All @@ -37,6 +40,9 @@ class ImageManager(Buffer):
when computing it in the object plane. This
should be a floating point number greater than
or equal to 1. By default, it is 1 (no padding).
pad_mode :
The method of image padding. By default, ``"edge"``.
For all options, see ``jax.numpy.pad``.
freqs :
The fourier wavevectors in the imaging plane.
padded_freqs :
Expand All @@ -51,6 +57,7 @@ class ImageManager(Buffer):

shape: tuple[int, int] = field(static=True)
pad_scale: float = field(static=True, default=1.0)
pad_mode: Union[str, Callable] = field(static=True, default="edge")

padded_shape: tuple[int, int] = field(static=True, init=False)

Expand All @@ -75,12 +82,37 @@ def crop(self, image: Image) -> Image:

def pad(self, image: Image, **kwargs: Any) -> Image:
"""Pad an image."""
return pad(image, self.padded_shape, **kwargs)
return pad(image, self.padded_shape, mode=self.pad_mode, **kwargs)

def crop_or_pad(self, image: Image, **kwargs: Any) -> Image:
"""Reshape an image using cropping or padding."""
return crop_or_pad(
image, self.padded_shape, mode=self.pad_mode, **kwargs
)

def downsample(
self, image: Image, method="lanczos5", **kwargs: Any
) -> Image:
"""Downsample an image."""
"""Downsample an image using interpolation."""
return resize(
image, self.shape, antialias=False, method=method, **kwargs
)

def upsample(self, image: Image, method="bicubic", **kwargs: Any) -> Image:
"""Upsample an image using interpolation."""
return resize(image, self.padded_shape, method=method, **kwargs)

def normalize_to_cistem(
self, image: Image, is_real: bool = False
) -> Image:
"""Normalize images on the exit plane according to cisTEM conventions."""
M1, M2 = image.shape
if is_real:
raise NotImplementedError(
"Normalization to cisTEM conventions not supported for real input."
)
else:
# Set zero frequency component to zero
image = image.at[0, 0].set(0.0 + 0.0j)
# cisTEM normalization convention for projections
return image / jnp.sqrt(M1 * M2)
41 changes: 16 additions & 25 deletions src/cryojax/simulator/scattering/_fourier_slice_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
VolumeCoords,
)
from ...utils import (
ifftn,
fftn,
crop,
pad,
map_coordinates,
)

Expand All @@ -46,22 +45,26 @@ def scatter(
rotated fourier transform and interpolating onto
a uniform grid in the object plane.
"""
return extract_slice(
fourier_projection = extract_slice(
density.weights,
density.coordinates,
resolution,
self.manager.padded_shape,
order=self.order,
mode=self.mode,
cval=self.cval,
)
if self.manager.padded_shape != fourier_projection.shape:
fourier_projection = fftn(
self.manager.crop_or_pad(ifftn(fourier_projection).real)
)

return fourier_projection


def extract_slice(
weights: ComplexVolume,
coordinates: VolumeCoords,
resolution: float,
shape: tuple[int, int],
**kwargs: Any,
) -> ComplexImage:
"""
Expand All @@ -76,10 +79,6 @@ def extract_slice(
Frequency central slice coordinate system.
resolution :
The rasterization resolution.
shape :
Shape of the imaging plane in pixels.
``width, height = shape[0], shape[1]``
is the size of the desired imaging plane.
kwargs:
Passed to ``cryojax.utils.interpolate.map_coordinates``.
Expand All @@ -98,20 +97,12 @@ def extract_slice(
# Make coordinates dimensionless
coordinates *= box_size
# Interpolate on the upper half plane get the slice
z = N2 // 2 + 1
projection = map_coordinates(weights, coordinates[:, :z], **kwargs)[..., 0]
# Set zero frequency component to zero
projection = projection.at[0, 0].set(0.0 + 0.0j)
# z = N2 // 2 + 1
# fourier_projection = map_coordinates(
# weights, coordinates[:, :z], **kwargs
# )[..., 0]
# Transform back to real space
projection = jnp.fft.fftshift(jnp.fft.irfftn(projection, s=(N1, N2)))
# Crop or pad to desired image size
M1, M2 = shape
if N1 >= M1 and N2 >= M2:
projection = crop(projection, shape)
elif N1 <= M1 and N2 <= M2:
projection = pad(projection, shape, mode="edge")
else:
raise NotImplementedError(
"Voxel density shape must be larger or smaller than image shape in all dimensions"
)
return fftn(projection) / jnp.sqrt(M1 * M2)
# fourier_projection = irfftn(fourier_projection, s=(N1, N2))
#

return map_coordinates(weights, coordinates, **kwargs)[..., 0]
13 changes: 7 additions & 6 deletions src/cryojax/simulator/scattering/_nufft_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ def scatter(
) -> ComplexImage:
"""Rasterize image with non-uniform FFTs."""
if isinstance(density, VoxelCloud):
return project_with_nufft(
fourier_projection = project_with_nufft(
density.weights,
density.coordinates,
resolution,
self.manager.padded_shape,
eps=self.eps,
)
elif isinstance(density, AtomCloud):
return project_atoms_with_nufft(
fourier_projection = project_atoms_with_nufft(
density.weights,
density.coordinates,
density.variances,
Expand All @@ -62,6 +62,7 @@ def scatter(
raise NotImplementedError(
"Supported density representations are VoxelCloud and AtomCloud"
)
return fourier_projection


def project_atoms_with_nufft(
Expand Down Expand Up @@ -127,11 +128,11 @@ def project_with_nufft(
M1, M2 = shape
image_size = jnp.array(np.array([M1, M2]) * resolution)
coordinates = jnp.flip(coordinates[:, :2], axis=-1)
projection = nufft(weights, coordinates, image_size, shape, **kwargs)
# Set zero frequency component to zero
projection = projection.at[0, 0].set(0.0 + 0.0j)
fourier_projection = nufft(
weights, coordinates, image_size, shape, **kwargs
)

return projection / jnp.sqrt(M1 * M2)
return fourier_projection


"""
Expand Down
2 changes: 2 additions & 0 deletions src/cryojax/simulator/scattering/_scattering_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from abc import abstractmethod

import jax.numpy as jnp

from ..density import ElectronDensity
from ..manager import ImageManager

Expand Down
Loading

0 comments on commit c6cfd73

Please sign in to comment.