Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get rid of AbstractStructualEnsembleBatcher #283

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/cryojax/simulator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
from ._structural_ensemble import (
AbstractConformationalVariable as AbstractConformationalVariable,
AbstractStructuralEnsemble as AbstractStructuralEnsemble,
AbstractStructuralEnsembleBatcher as AbstractStructuralEnsembleBatcher,
DiscreteConformationalVariable as DiscreteConformationalVariable,
DiscreteStructuralEnsemble as DiscreteStructuralEnsemble,
SingleStructureEnsemble as SingleStructureEnsemble,
Expand Down
8 changes: 1 addition & 7 deletions src/cryojax/simulator/_assembly/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from abc import abstractmethod
from functools import cached_property
from typing import Optional
from typing_extensions import override

import equinox as eqx
import jax
Expand All @@ -18,11 +17,10 @@
from .._structural_ensemble import (
AbstractConformationalVariable,
AbstractStructuralEnsemble,
AbstractStructuralEnsembleBatcher,
)


class AbstractAssembly(AbstractStructuralEnsembleBatcher, strict=True):
class AbstractAssembly(eqx.Module, strict=True):
"""Abstraction of a biological assembly.

To subclass an `AbstractAssembly`,
Expand Down Expand Up @@ -103,7 +101,3 @@ def subunits(self) -> AbstractStructuralEnsemble:
else:
where = lambda s: s.pose
return eqx.tree_at(where, self.subunit, self.poses)

@override
def get_batched_structural_ensemble(self) -> AbstractStructuralEnsemble:
return self.subunits
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import jax.numpy as jnp
from jaxtyping import Array, Complex, PRNGKeyArray

from .._assembly import AbstractAssembly
from .._instrument_config import InstrumentConfig
from .._pose import AbstractPose
from .._potential_integrator import AbstractPotentialIntegrator
from .._solvent import AbstractIce
from .._structural_ensemble import (
AbstractConformationalVariable,
AbstractStructuralEnsemble,
AbstractStructuralEnsembleBatcher,
)
from .._transfer_theory import ContrastTransferTheory
from .base_scattering_theory import AbstractScatteringTheory
Expand Down Expand Up @@ -123,12 +123,12 @@ def compute_fourier_contrast_at_detector_plane(


class LinearSuperpositionScatteringTheory(AbstractWeakPhaseScatteringTheory, strict=True):
"""Compute the superposition of images of the structural ensemble batch returned by
the `AbstractStructuralEnsembleBatcher`. This must operate in the weak phase
"""Compute the superposition of images over a batch of poses and potentials
parameterized by an `AbstractAssembly`. This must operate in the weak phase
approximation.
"""

structural_ensemble_batcher: AbstractStructuralEnsembleBatcher
assembly: AbstractAssembly
potential_integrator: AbstractPotentialIntegrator
transfer_theory: ContrastTransferTheory
solvent: Optional[AbstractIce] = None
Expand Down Expand Up @@ -163,9 +163,7 @@ def compute_image_superposition(
)

# Get the batch
ensemble_batch = (
self.structural_ensemble_batcher.get_batched_structural_ensemble()
)
ensemble_batch = self.assembly.subunits
# Setup vmap over the pose and conformation
is_mapped = lambda x: isinstance(
x, (AbstractPose, AbstractConformationalVariable)
Expand Down Expand Up @@ -222,9 +220,7 @@ def compute_image_superposition(
)

# Get the batch
ensemble_batch = (
self.structural_ensemble_batcher.get_batched_structural_ensemble()
)
ensemble_batch = self.assembly.subunits
# Setup vmap over the pose and conformation
is_mapped = lambda x: isinstance(
x, (AbstractPose, AbstractConformationalVariable)
Expand Down Expand Up @@ -254,9 +250,9 @@ def compute_image_superposition(

LinearSuperpositionScatteringTheory.__init__.__doc__ = """**Arguments:**

- `structural_ensemble_batcher`: The batcher that computes the states that over which to
compute a superposition of images. Most commonly, this
would be an `AbstractAssembly` concrete class.
- `assembly`: An concrete class of an `AbstractAssembly`. This is used to
output a batch of states over which to
compute a superposition of images.
- `potential_integrator`: The method for integrating the specimen potential.
- `transfer_theory`: The contrast transfer theory.
- `solvent`: The model for the solvent.
Expand Down
3 changes: 0 additions & 3 deletions src/cryojax/simulator/_structural_ensemble/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,3 @@
DiscreteConformationalVariable as DiscreteConformationalVariable,
DiscreteStructuralEnsemble as DiscreteStructuralEnsemble,
)
from .ensemble_batcher import (
AbstractStructuralEnsembleBatcher as AbstractStructuralEnsembleBatcher,
)
13 changes: 0 additions & 13 deletions src/cryojax/simulator/_structural_ensemble/ensemble_batcher.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/test_helix.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_c6_rotation(
@eqx.filter_jit
def compute_rotated_image(pipeline, pose):
pipeline = eqx.tree_at(
lambda m: m.scattering_theory.structural_ensemble_batcher.pose,
lambda m: m.scattering_theory.assembly.pose,
pipeline,
pose,
)
Expand Down Expand Up @@ -154,7 +154,7 @@ def compute_rotated_image_with_helix(
pipeline: cs.ContrastImagingPipeline, pose: cs.AbstractPose
):
pipeline = eqx.tree_at(
lambda m: m.scattering_theory.structural_ensemble_batcher.pose,
lambda m: m.scattering_theory.assembly.pose,
pipeline,
pose,
)
Expand Down
Loading