From bc8aa4c92f5259cbb02184d4849ae1a1593b852b Mon Sep 17 00:00:00 2001 From: Michael O'Brien Date: Wed, 4 Dec 2024 17:36:41 -0500 Subject: [PATCH] get rid of AbstractStructualEnsembleBatcher. this is too excessive --- src/cryojax/simulator/__init__.py | 1 - src/cryojax/simulator/_assembly/assembly.py | 8 +------ .../weak_phase_scattering_theory.py | 22 ++++++++----------- .../_structural_ensemble/__init__.py | 3 --- .../_structural_ensemble/ensemble_batcher.py | 13 ----------- tests/test_helix.py | 4 ++-- 6 files changed, 12 insertions(+), 39 deletions(-) delete mode 100644 src/cryojax/simulator/_structural_ensemble/ensemble_batcher.py diff --git a/src/cryojax/simulator/__init__.py b/src/cryojax/simulator/__init__.py index 6a5c5c8c..71920b38 100644 --- a/src/cryojax/simulator/__init__.py +++ b/src/cryojax/simulator/__init__.py @@ -59,7 +59,6 @@ from ._structural_ensemble import ( AbstractConformationalVariable as AbstractConformationalVariable, AbstractStructuralEnsemble as AbstractStructuralEnsemble, - AbstractStructuralEnsembleBatcher as AbstractStructuralEnsembleBatcher, DiscreteConformationalVariable as DiscreteConformationalVariable, DiscreteStructuralEnsemble as DiscreteStructuralEnsemble, SingleStructureEnsemble as SingleStructureEnsemble, diff --git a/src/cryojax/simulator/_assembly/assembly.py b/src/cryojax/simulator/_assembly/assembly.py index f57b98b0..c27ebacc 100644 --- a/src/cryojax/simulator/_assembly/assembly.py +++ b/src/cryojax/simulator/_assembly/assembly.py @@ -6,7 +6,6 @@ from abc import abstractmethod from functools import cached_property from typing import Optional -from typing_extensions import override import equinox as eqx import jax @@ -18,11 +17,10 @@ from .._structural_ensemble import ( AbstractConformationalVariable, AbstractStructuralEnsemble, - AbstractStructuralEnsembleBatcher, ) -class AbstractAssembly(AbstractStructuralEnsembleBatcher, strict=True): +class AbstractAssembly(eqx.Module, strict=True): """Abstraction of a biological assembly. To subclass an `AbstractAssembly`, @@ -103,7 +101,3 @@ def subunits(self) -> AbstractStructuralEnsemble: else: where = lambda s: s.pose return eqx.tree_at(where, self.subunit, self.poses) - - @override - def get_batched_structural_ensemble(self) -> AbstractStructuralEnsemble: - return self.subunits diff --git a/src/cryojax/simulator/_scattering_theory/weak_phase_scattering_theory.py b/src/cryojax/simulator/_scattering_theory/weak_phase_scattering_theory.py index fa74abd5..c1a292f7 100644 --- a/src/cryojax/simulator/_scattering_theory/weak_phase_scattering_theory.py +++ b/src/cryojax/simulator/_scattering_theory/weak_phase_scattering_theory.py @@ -7,6 +7,7 @@ import jax.numpy as jnp from jaxtyping import Array, Complex, PRNGKeyArray +from .._assembly import AbstractAssembly from .._instrument_config import InstrumentConfig from .._pose import AbstractPose from .._potential_integrator import AbstractPotentialIntegrator @@ -14,7 +15,6 @@ from .._structural_ensemble import ( AbstractConformationalVariable, AbstractStructuralEnsemble, - AbstractStructuralEnsembleBatcher, ) from .._transfer_theory import ContrastTransferTheory from .base_scattering_theory import AbstractScatteringTheory @@ -123,12 +123,12 @@ def compute_fourier_contrast_at_detector_plane( class LinearSuperpositionScatteringTheory(AbstractWeakPhaseScatteringTheory, strict=True): - """Compute the superposition of images of the structural ensemble batch returned by - the `AbstractStructuralEnsembleBatcher`. This must operate in the weak phase + """Compute the superposition of images over a batch of poses and potentials + parameterized by an `AbstractAssembly`. This must operate in the weak phase approximation. """ - structural_ensemble_batcher: AbstractStructuralEnsembleBatcher + assembly: AbstractAssembly potential_integrator: AbstractPotentialIntegrator transfer_theory: ContrastTransferTheory solvent: Optional[AbstractIce] = None @@ -163,9 +163,7 @@ def compute_image_superposition( ) # Get the batch - ensemble_batch = ( - self.structural_ensemble_batcher.get_batched_structural_ensemble() - ) + ensemble_batch = self.assembly.subunits # Setup vmap over the pose and conformation is_mapped = lambda x: isinstance( x, (AbstractPose, AbstractConformationalVariable) @@ -222,9 +220,7 @@ def compute_image_superposition( ) # Get the batch - ensemble_batch = ( - self.structural_ensemble_batcher.get_batched_structural_ensemble() - ) + ensemble_batch = self.assembly.subunits # Setup vmap over the pose and conformation is_mapped = lambda x: isinstance( x, (AbstractPose, AbstractConformationalVariable) @@ -254,9 +250,9 @@ def compute_image_superposition( LinearSuperpositionScatteringTheory.__init__.__doc__ = """**Arguments:** -- `structural_ensemble_batcher`: The batcher that computes the states that over which to - compute a superposition of images. Most commonly, this - would be an `AbstractAssembly` concrete class. +- `assembly`: An concrete class of an `AbstractAssembly`. This is used to + output a batch of states over which to + compute a superposition of images. - `potential_integrator`: The method for integrating the specimen potential. - `transfer_theory`: The contrast transfer theory. - `solvent`: The model for the solvent. diff --git a/src/cryojax/simulator/_structural_ensemble/__init__.py b/src/cryojax/simulator/_structural_ensemble/__init__.py index f6af3f15..7867c5cd 100644 --- a/src/cryojax/simulator/_structural_ensemble/__init__.py +++ b/src/cryojax/simulator/_structural_ensemble/__init__.py @@ -9,6 +9,3 @@ DiscreteConformationalVariable as DiscreteConformationalVariable, DiscreteStructuralEnsemble as DiscreteStructuralEnsemble, ) -from .ensemble_batcher import ( - AbstractStructuralEnsembleBatcher as AbstractStructuralEnsembleBatcher, -) diff --git a/src/cryojax/simulator/_structural_ensemble/ensemble_batcher.py b/src/cryojax/simulator/_structural_ensemble/ensemble_batcher.py deleted file mode 100644 index 5a59cadf..00000000 --- a/src/cryojax/simulator/_structural_ensemble/ensemble_batcher.py +++ /dev/null @@ -1,13 +0,0 @@ -from abc import abstractmethod - -import equinox as eqx - -from .base_ensemble import AbstractStructuralEnsemble - - -class AbstractStructuralEnsembleBatcher(eqx.Module, strict=True): - """A batching utility for structural ensembles.""" - - @abstractmethod - def get_batched_structural_ensemble(self) -> AbstractStructuralEnsemble: - raise NotImplementedError diff --git a/tests/test_helix.py b/tests/test_helix.py index 41f68089..86608d24 100644 --- a/tests/test_helix.py +++ b/tests/test_helix.py @@ -108,7 +108,7 @@ def test_c6_rotation( @eqx.filter_jit def compute_rotated_image(pipeline, pose): pipeline = eqx.tree_at( - lambda m: m.scattering_theory.structural_ensemble_batcher.pose, + lambda m: m.scattering_theory.assembly.pose, pipeline, pose, ) @@ -154,7 +154,7 @@ def compute_rotated_image_with_helix( pipeline: cs.ContrastImagingPipeline, pose: cs.AbstractPose ): pipeline = eqx.tree_at( - lambda m: m.scattering_theory.structural_ensemble_batcher.pose, + lambda m: m.scattering_theory.assembly.pose, pipeline, pose, )