Skip to content

Commit

Permalink
IndependentGaussianPixels distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 committed Dec 9, 2024
1 parent 8de2685 commit 7850bd0
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 53 deletions.
28 changes: 14 additions & 14 deletions docs/examples/simulate-image.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/cryojax/inference/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
)
from ._gaussian_distributions import (
IndependentGaussianFourierModes as IndependentGaussianFourierModes,
IndependentGaussianPixels as IndependentGaussianPixels,
)
209 changes: 170 additions & 39 deletions src/cryojax/inference/distributions/_gaussian_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
Image formation models simulated from gaussian noise distributions.
"""

from abc import abstractmethod
from typing import Optional
from typing_extensions import override

import jax.numpy as jnp
import jax.random as jr
from equinox import field
from equinox import AbstractVar, field
from jaxtyping import Array, Complex, Float, PRNGKeyArray

from ..._errors import error_if_not_positive
Expand All @@ -17,45 +18,38 @@
from ._base_distribution import AbstractDistribution


class IndependentGaussianFourierModes(AbstractDistribution, strict=True):
r"""A gaussian noise model, where each fourier mode is independent.
class AbstractGaussianDistribution(AbstractDistribution, strict=True):
r"""An `AbstractDistribution` where images are formed via additive
gaussian noise.
This computes the likelihood in Fourier space,
so that the variance to be an arbitrary noise power spectrum.
Subclasses may compute the likelihood in real or fourier space and
make different assumptions about the variance / covariance.
"""

imaging_pipeline: AbstractImagingPipeline
variance_function: FourierOperatorLike
signal_scale_factor: Float[Array, ""]
imaging_pipeline: AbstractVar[AbstractImagingPipeline]
signal_scale_factor: AbstractVar[Float[Array, ""]]

is_signal_normalized: bool = field(static=True)
is_signal_normalized: AbstractVar[bool]

def __init__(
self,
imaging_pipeline: AbstractImagingPipeline,
variance_function: Optional[FourierOperatorLike] = None,
signal_scale_factor: Optional[float | Float[Array, ""]] = None,
is_signal_normalized: bool = False,
@override
def sample(
self, rng_key: PRNGKeyArray, *, get_real: bool = True
) -> (
Float[
Array,
"{self.imaging_pipeline.instrument_config.y_dim} "
"{self.imaging_pipeline.instrument_config.x_dim}",
]
| Complex[
Array,
"{self.imaging_pipeline.instrument_config.y_dim} "
"{self.imaging_pipeline.instrument_config.x_dim//2+1}",
]
):
"""**Arguments:**
- `imaging_pipeline`: The image formation model.
- `variance_function`: The variance of each fourier mode. By default,
`cryojax.image.operators.Constant(1.0)`.
- `signal_scale_factor`: A scale factor for the underlying signal simulated from `imaging_pipeline`.
- `is_signal_normalized`:
Whether or not the signal is normalized before applying the `signal_scale_factor`.
If an `AbstractMask` is given to `imaging_pipeline.mask`, the signal is normalized
within the region where the mask is equal to `1`.
""" # noqa: E501
self.imaging_pipeline = imaging_pipeline
self.variance_function = variance_function or Constant(1.0)
if signal_scale_factor is None:
signal_scale_factor = jnp.sqrt(
jnp.asarray(imaging_pipeline.instrument_config.n_pixels, dtype=float)
)
self.signal_scale_factor = error_if_not_positive(jnp.asarray(signal_scale_factor))
self.is_signal_normalized = is_signal_normalized
"""Sample from the gaussian noise model."""
return self.compute_signal(get_real=get_real) + self.compute_noise(
rng_key, get_real=get_real
)

@override
def compute_signal(
Expand Down Expand Up @@ -92,6 +86,64 @@ def compute_signal(
)
return self.signal_scale_factor * real_or_fourier_simulated_image

@abstractmethod
def compute_noise(
self, rng_key: PRNGKeyArray, *, get_real: bool = True
) -> (
Float[
Array,
"{self.imaging_pipeline.instrument_config.y_dim} "
"{self.imaging_pipeline.instrument_config.x_dim}",
]
| Complex[
Array,
"{self.imaging_pipeline.instrument_config.y_dim} "
"{self.imaging_pipeline.instrument_config.x_dim//2+1}",
]
):
"""Draw a realization from the gaussian noise model and return either in
real or fourier space.
"""
raise NotImplementedError


class IndependentGaussianPixels(AbstractGaussianDistribution, strict=True):
r"""A gaussian noise model, where each pixel is independently drawn from
a zero-mean gaussian of fixed variance (white noise).
This computes the likelihood in real space, where the variance is a
constant value across all pixels.
"""

imaging_pipeline: AbstractImagingPipeline
variance: Float[Array, ""]
signal_scale_factor: Float[Array, ""]

is_signal_normalized: bool = field(static=True)

def __init__(
self,
imaging_pipeline: AbstractImagingPipeline,
variance: float | Float[Array, ""] = 1.0,
signal_scale_factor: float | Float[Array, ""] = 1.0,
is_signal_normalized: bool = False,
):
"""**Arguments:**
- `imaging_pipeline`: The image formation model.
- `variance`: The variance of each pixel.
- `signal_scale_factor`: A scale factor for the underlying signal simulated from `imaging_pipeline`.
- `is_signal_normalized`:
Whether or not the signal is normalized before applying the `signal_scale_factor`.
If an `AbstractMask` is given to `imaging_pipeline.mask`, the signal is normalized
within the region where the mask is equal to `1`.
""" # noqa: E501
self.imaging_pipeline = imaging_pipeline
self.variance = error_if_not_positive(variance)
self.signal_scale_factor = error_if_not_positive(signal_scale_factor)
self.is_signal_normalized = is_signal_normalized

@override
def compute_noise(
self, rng_key: PRNGKeyArray, *, get_real: bool = True
) -> (
Expand All @@ -111,7 +163,7 @@ def compute_noise(
freqs = pipeline.instrument_config.padded_frequency_grid_in_angstroms
# Compute the zero mean variance and scale up to be independent of the number of
# pixels
std = jnp.sqrt(n_pixels * self.variance_function(freqs))
std = jnp.sqrt(n_pixels * self.variance)
noise = pipeline.postprocess(
std
* jr.normal(rng_key, shape=freqs.shape[0:-1])
Expand All @@ -124,7 +176,74 @@ def compute_noise(
return noise

@override
def sample(
def log_likelihood(
self,
observed: Float[
Array,
"{self.imaging_pipeline.instrument_config.y_dim} "
"{self.imaging_pipeline.instrument_config.x_dim}",
],
) -> Float[Array, ""]:
"""Evaluate the log-likelihood of the gaussian noise model.
**Arguments:**
- `observed` : The observed data in real space.
"""
variance = self.variance
# Create simulated data
simulated = self.compute_signal(get_real=True)
# Compute residuals
residuals = simulated - observed
# Compute standard normal random variables
squared_standard_normal_per_pixel = jnp.abs(residuals) ** 2 / (2 * variance)
# Compute the log-likelihood for each pixel.
log_likelihood_per_pixel = -1.0 * (
squared_standard_normal_per_pixel - jnp.log(2 * jnp.pi * variance) / 2
)
# Compute log-likelihood, summing over pixels
log_likelihood = jnp.sum(log_likelihood_per_pixel)

return log_likelihood


class IndependentGaussianFourierModes(AbstractGaussianDistribution, strict=True):
r"""A gaussian noise model, where each fourier mode is independent.
This computes the likelihood in Fourier space,
so that the variance to be an arbitrary noise power spectrum.
"""

imaging_pipeline: AbstractImagingPipeline
variance_function: FourierOperatorLike
signal_scale_factor: Float[Array, ""]

is_signal_normalized: bool = field(static=True)

def __init__(
self,
imaging_pipeline: AbstractImagingPipeline,
variance_function: Optional[FourierOperatorLike] = None,
signal_scale_factor: float | Float[Array, ""] = 1.0,
is_signal_normalized: bool = False,
):
"""**Arguments:**
- `imaging_pipeline`: The image formation model.
- `variance_function`: The variance of each fourier mode. By default,
`cryojax.image.operators.Constant(1.0)`.
- `signal_scale_factor`: A scale factor for the underlying signal simulated from `imaging_pipeline`.
- `is_signal_normalized`:
Whether or not the signal is normalized before applying the `signal_scale_factor`.
If an `AbstractMask` is given to `imaging_pipeline.mask`, the signal is normalized
within the region where the mask is equal to `1`.
""" # noqa: E501
self.imaging_pipeline = imaging_pipeline
self.variance_function = variance_function or Constant(1.0)
self.signal_scale_factor = error_if_not_positive(jnp.asarray(signal_scale_factor))
self.is_signal_normalized = is_signal_normalized

def compute_noise(
self, rng_key: PRNGKeyArray, *, get_real: bool = True
) -> (
Float[
Expand All @@ -138,11 +257,23 @@ def sample(
"{self.imaging_pipeline.instrument_config.x_dim//2+1}",
]
):
"""Sample from the gaussian noise model."""
return self.compute_signal(get_real=get_real) + self.compute_noise(
rng_key, get_real=get_real
pipeline = self.imaging_pipeline
n_pixels = pipeline.instrument_config.padded_n_pixels
freqs = pipeline.instrument_config.padded_frequency_grid_in_angstroms
# Compute the zero mean variance and scale up to be independent of the number of
# pixels
std = jnp.sqrt(n_pixels * self.variance_function(freqs))
noise = pipeline.postprocess(
std
* jr.normal(rng_key, shape=freqs.shape[0:-1])
.at[0, 0]
.set(0.0)
.astype(complex),
get_real=get_real,
)

return noise

@override
def log_likelihood(
self,
Expand Down
22 changes: 22 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import pytest

import cryojax.simulator as cxs
from cryojax.inference import distributions as dist


@pytest.mark.parametrize(
"cls, scattering_theory, instrument_config",
[
(dist.IndependentGaussianPixels, "theory", "config"),
(dist.IndependentGaussianFourierModes, "theory", "config"),
],
)
def test_simulate_signal_from_gaussian_distributions(
cls, scattering_theory, instrument_config, request
):
scattering_theory = request.getfixturevalue(scattering_theory)
instrument_config = request.getfixturevalue(instrument_config)
imaging_pipeline = cxs.ContrastImagingPipeline(instrument_config, scattering_theory)
distribution = cls(imaging_pipeline)
np.testing.assert_allclose(imaging_pipeline.render(), distribution.compute_signal())

0 comments on commit 7850bd0

Please sign in to comment.