diff --git a/src/cryojax/image/operators/__init__.py b/src/cryojax/image/operators/__init__.py index 48123cef..a157fec8 100644 --- a/src/cryojax/image/operators/__init__.py +++ b/src/cryojax/image/operators/__init__.py @@ -1,6 +1,7 @@ from ._filters import ( AbstractFilter as AbstractFilter, CustomFilter as CustomFilter, + FilterLike as FilterLike, HighpassFilter as HighpassFilter, InverseSincFilter as InverseSincFilter, LowpassFilter as LowpassFilter, @@ -19,6 +20,7 @@ CircularCosineMask as CircularCosineMask, CustomMask as CustomMask, Cylindrical2DCosineMask as Cylindrical2DCosineMask, + MaskLike as MaskLike, SphericalCosineMask as SphericalCosineMask, SquareCosineMask as SquareCosineMask, ) diff --git a/src/cryojax/image/operators/_filters.py b/src/cryojax/image/operators/_filters.py index 76d7d4b7..01a34c4c 100644 --- a/src/cryojax/image/operators/_filters.py +++ b/src/cryojax/image/operators/_filters.py @@ -38,6 +38,9 @@ def __call__( return image * jax.lax.stop_gradient(self.array) +FilterLike = AbstractFilter | AbstractImageMultiplier + + class CustomFilter(AbstractFilter, strict=True): """Pass a custom filter as an array.""" diff --git a/src/cryojax/image/operators/_masks.py b/src/cryojax/image/operators/_masks.py index 2ccd6252..0b65c798 100644 --- a/src/cryojax/image/operators/_masks.py +++ b/src/cryojax/image/operators/_masks.py @@ -30,6 +30,9 @@ def __call__( return image * jax.lax.stop_gradient(self.array) +MaskLike = AbstractMask | AbstractImageMultiplier + + class CustomMask(AbstractMask, strict=True): """Pass a custom mask as an array.""" diff --git a/src/cryojax/simulator/_imaging_pipeline.py b/src/cryojax/simulator/_imaging_pipeline.py index 098add37..07460583 100644 --- a/src/cryojax/simulator/_imaging_pipeline.py +++ b/src/cryojax/simulator/_imaging_pipeline.py @@ -11,7 +11,7 @@ from jaxtyping import Array, Complex, Float, PRNGKeyArray from ..image import irfftn, rfftn -from ..image.operators import AbstractFilter, AbstractMask +from ..image.operators import FilterLike, MaskLike from ._detector import AbstractDetector from ._instrument_config import InstrumentConfig from ._scattering_theory import AbstractScatteringTheory @@ -24,8 +24,8 @@ class AbstractImagingPipeline(Module, strict=True): """ instrument_config: AbstractVar[InstrumentConfig] - filter: AbstractVar[Optional[AbstractFilter]] - mask: AbstractVar[Optional[AbstractMask]] + filter: AbstractVar[Optional[FilterLike]] + mask: AbstractVar[Optional[MaskLike]] @abstractmethod def render( @@ -169,16 +169,16 @@ class ContrastImagingPipeline(AbstractImagingPipeline, strict=True): instrument_config: InstrumentConfig scattering_theory: AbstractScatteringTheory - filter: Optional[AbstractFilter] - mask: Optional[AbstractMask] + filter: Optional[FilterLike] + mask: Optional[MaskLike] def __init__( self, instrument_config: InstrumentConfig, scattering_theory: AbstractScatteringTheory, *, - filter: Optional[AbstractFilter] = None, - mask: Optional[AbstractMask] = None, + filter: Optional[FilterLike] = None, + mask: Optional[MaskLike] = None, ): self.instrument_config = instrument_config self.scattering_theory = scattering_theory @@ -236,16 +236,16 @@ class IntensityImagingPipeline(AbstractImagingPipeline, strict=True): instrument_config: InstrumentConfig scattering_theory: AbstractScatteringTheory - filter: Optional[AbstractFilter] - mask: Optional[AbstractMask] + filter: Optional[FilterLike] + mask: Optional[MaskLike] def __init__( self, instrument_config: InstrumentConfig, scattering_theory: AbstractScatteringTheory, *, - filter: Optional[AbstractFilter] = None, - mask: Optional[AbstractMask] = None, + filter: Optional[FilterLike] = None, + mask: Optional[MaskLike] = None, ): self.instrument_config = instrument_config self.scattering_theory = scattering_theory @@ -308,8 +308,8 @@ class ElectronCountingImagingPipeline(AbstractImagingPipeline, strict=True): scattering_theory: AbstractScatteringTheory detector: AbstractDetector - filter: Optional[AbstractFilter] - mask: Optional[AbstractMask] + filter: Optional[FilterLike] + mask: Optional[MaskLike] def __init__( self, @@ -317,8 +317,8 @@ def __init__( scattering_theory: AbstractScatteringTheory, detector: AbstractDetector, *, - filter: Optional[AbstractFilter] = None, - mask: Optional[AbstractMask] = None, + filter: Optional[FilterLike] = None, + mask: Optional[MaskLike] = None, ): self.instrument_config = instrument_config self.scattering_theory = scattering_theory