Skip to content

Commit

Permalink
Merge pull request #283 from mjo22/update-to-assembly-stuff
Browse files Browse the repository at this point in the history
get rid of AbstractStructualEnsembleBatcher
  • Loading branch information
mjo22 authored Dec 4, 2024
2 parents 88ac323 + bc8aa4c commit b4d678d
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 39 deletions.
1 change: 0 additions & 1 deletion src/cryojax/simulator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
from ._structural_ensemble import (
AbstractConformationalVariable as AbstractConformationalVariable,
AbstractStructuralEnsemble as AbstractStructuralEnsemble,
AbstractStructuralEnsembleBatcher as AbstractStructuralEnsembleBatcher,
DiscreteConformationalVariable as DiscreteConformationalVariable,
DiscreteStructuralEnsemble as DiscreteStructuralEnsemble,
SingleStructureEnsemble as SingleStructureEnsemble,
Expand Down
8 changes: 1 addition & 7 deletions src/cryojax/simulator/_assembly/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from abc import abstractmethod
from functools import cached_property
from typing import Optional
from typing_extensions import override

import equinox as eqx
import jax
Expand All @@ -18,11 +17,10 @@
from .._structural_ensemble import (
AbstractConformationalVariable,
AbstractStructuralEnsemble,
AbstractStructuralEnsembleBatcher,
)


class AbstractAssembly(AbstractStructuralEnsembleBatcher, strict=True):
class AbstractAssembly(eqx.Module, strict=True):
"""Abstraction of a biological assembly.
To subclass an `AbstractAssembly`,
Expand Down Expand Up @@ -103,7 +101,3 @@ def subunits(self) -> AbstractStructuralEnsemble:
else:
where = lambda s: s.pose
return eqx.tree_at(where, self.subunit, self.poses)

@override
def get_batched_structural_ensemble(self) -> AbstractStructuralEnsemble:
return self.subunits
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import jax.numpy as jnp
from jaxtyping import Array, Complex, PRNGKeyArray

from .._assembly import AbstractAssembly
from .._instrument_config import InstrumentConfig
from .._pose import AbstractPose
from .._potential_integrator import AbstractPotentialIntegrator
from .._solvent import AbstractIce
from .._structural_ensemble import (
AbstractConformationalVariable,
AbstractStructuralEnsemble,
AbstractStructuralEnsembleBatcher,
)
from .._transfer_theory import ContrastTransferTheory
from .base_scattering_theory import AbstractScatteringTheory
Expand Down Expand Up @@ -123,12 +123,12 @@ def compute_fourier_contrast_at_detector_plane(


class LinearSuperpositionScatteringTheory(AbstractWeakPhaseScatteringTheory, strict=True):
"""Compute the superposition of images of the structural ensemble batch returned by
the `AbstractStructuralEnsembleBatcher`. This must operate in the weak phase
"""Compute the superposition of images over a batch of poses and potentials
parameterized by an `AbstractAssembly`. This must operate in the weak phase
approximation.
"""

structural_ensemble_batcher: AbstractStructuralEnsembleBatcher
assembly: AbstractAssembly
potential_integrator: AbstractPotentialIntegrator
transfer_theory: ContrastTransferTheory
solvent: Optional[AbstractIce] = None
Expand Down Expand Up @@ -163,9 +163,7 @@ def compute_image_superposition(
)

# Get the batch
ensemble_batch = (
self.structural_ensemble_batcher.get_batched_structural_ensemble()
)
ensemble_batch = self.assembly.subunits
# Setup vmap over the pose and conformation
is_mapped = lambda x: isinstance(
x, (AbstractPose, AbstractConformationalVariable)
Expand Down Expand Up @@ -222,9 +220,7 @@ def compute_image_superposition(
)

# Get the batch
ensemble_batch = (
self.structural_ensemble_batcher.get_batched_structural_ensemble()
)
ensemble_batch = self.assembly.subunits
# Setup vmap over the pose and conformation
is_mapped = lambda x: isinstance(
x, (AbstractPose, AbstractConformationalVariable)
Expand Down Expand Up @@ -254,9 +250,9 @@ def compute_image_superposition(

LinearSuperpositionScatteringTheory.__init__.__doc__ = """**Arguments:**
- `structural_ensemble_batcher`: The batcher that computes the states that over which to
compute a superposition of images. Most commonly, this
would be an `AbstractAssembly` concrete class.
- `assembly`: An concrete class of an `AbstractAssembly`. This is used to
output a batch of states over which to
compute a superposition of images.
- `potential_integrator`: The method for integrating the specimen potential.
- `transfer_theory`: The contrast transfer theory.
- `solvent`: The model for the solvent.
Expand Down
3 changes: 0 additions & 3 deletions src/cryojax/simulator/_structural_ensemble/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,3 @@
DiscreteConformationalVariable as DiscreteConformationalVariable,
DiscreteStructuralEnsemble as DiscreteStructuralEnsemble,
)
from .ensemble_batcher import (
AbstractStructuralEnsembleBatcher as AbstractStructuralEnsembleBatcher,
)
13 changes: 0 additions & 13 deletions src/cryojax/simulator/_structural_ensemble/ensemble_batcher.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/test_helix.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_c6_rotation(
@eqx.filter_jit
def compute_rotated_image(pipeline, pose):
pipeline = eqx.tree_at(
lambda m: m.scattering_theory.structural_ensemble_batcher.pose,
lambda m: m.scattering_theory.assembly.pose,
pipeline,
pose,
)
Expand Down Expand Up @@ -154,7 +154,7 @@ def compute_rotated_image_with_helix(
pipeline: cs.ContrastImagingPipeline, pose: cs.AbstractPose
):
pipeline = eqx.tree_at(
lambda m: m.scattering_theory.structural_ensemble_batcher.pose,
lambda m: m.scattering_theory.assembly.pose,
pipeline,
pose,
)
Expand Down

0 comments on commit b4d678d

Please sign in to comment.