Skip to content

Commit

Permalink
type hints for filters and masks have a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 committed Aug 21, 2024
1 parent 820d19b commit 35a666f
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 15 deletions.
2 changes: 2 additions & 0 deletions src/cryojax/image/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._filters import (
AbstractFilter as AbstractFilter,
CustomFilter as CustomFilter,
FilterLike as FilterLike,
HighpassFilter as HighpassFilter,
InverseSincFilter as InverseSincFilter,
LowpassFilter as LowpassFilter,
Expand All @@ -19,6 +20,7 @@
CircularCosineMask as CircularCosineMask,
CustomMask as CustomMask,
Cylindrical2DCosineMask as Cylindrical2DCosineMask,
MaskLike as MaskLike,
SphericalCosineMask as SphericalCosineMask,
SquareCosineMask as SquareCosineMask,
)
Expand Down
3 changes: 3 additions & 0 deletions src/cryojax/image/operators/_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def __call__(
return image * jax.lax.stop_gradient(self.array)


FilterLike = AbstractFilter | AbstractImageMultiplier


class CustomFilter(AbstractFilter, strict=True):
"""Pass a custom filter as an array."""

Expand Down
3 changes: 3 additions & 0 deletions src/cryojax/image/operators/_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def __call__(
return image * jax.lax.stop_gradient(self.array)


MaskLike = AbstractMask | AbstractImageMultiplier


class CustomMask(AbstractMask, strict=True):
"""Pass a custom mask as an array."""

Expand Down
30 changes: 15 additions & 15 deletions src/cryojax/simulator/_imaging_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jaxtyping import Array, Complex, Float, PRNGKeyArray

from ..image import irfftn, rfftn
from ..image.operators import AbstractFilter, AbstractMask
from ..image.operators import FilterLike, MaskLike
from ._detector import AbstractDetector
from ._instrument_config import InstrumentConfig
from ._scattering_theory import AbstractScatteringTheory
Expand All @@ -24,8 +24,8 @@ class AbstractImagingPipeline(Module, strict=True):
"""

instrument_config: AbstractVar[InstrumentConfig]
filter: AbstractVar[Optional[AbstractFilter]]
mask: AbstractVar[Optional[AbstractMask]]
filter: AbstractVar[Optional[FilterLike]]
mask: AbstractVar[Optional[MaskLike]]

@abstractmethod
def render(
Expand Down Expand Up @@ -169,16 +169,16 @@ class ContrastImagingPipeline(AbstractImagingPipeline, strict=True):
instrument_config: InstrumentConfig
scattering_theory: AbstractScatteringTheory

filter: Optional[AbstractFilter]
mask: Optional[AbstractMask]
filter: Optional[FilterLike]
mask: Optional[MaskLike]

def __init__(
self,
instrument_config: InstrumentConfig,
scattering_theory: AbstractScatteringTheory,
*,
filter: Optional[AbstractFilter] = None,
mask: Optional[AbstractMask] = None,
filter: Optional[FilterLike] = None,
mask: Optional[MaskLike] = None,
):
self.instrument_config = instrument_config
self.scattering_theory = scattering_theory
Expand Down Expand Up @@ -236,16 +236,16 @@ class IntensityImagingPipeline(AbstractImagingPipeline, strict=True):
instrument_config: InstrumentConfig
scattering_theory: AbstractScatteringTheory

filter: Optional[AbstractFilter]
mask: Optional[AbstractMask]
filter: Optional[FilterLike]
mask: Optional[MaskLike]

def __init__(
self,
instrument_config: InstrumentConfig,
scattering_theory: AbstractScatteringTheory,
*,
filter: Optional[AbstractFilter] = None,
mask: Optional[AbstractMask] = None,
filter: Optional[FilterLike] = None,
mask: Optional[MaskLike] = None,
):
self.instrument_config = instrument_config
self.scattering_theory = scattering_theory
Expand Down Expand Up @@ -308,17 +308,17 @@ class ElectronCountingImagingPipeline(AbstractImagingPipeline, strict=True):
scattering_theory: AbstractScatteringTheory
detector: AbstractDetector

filter: Optional[AbstractFilter]
mask: Optional[AbstractMask]
filter: Optional[FilterLike]
mask: Optional[MaskLike]

def __init__(
self,
instrument_config: InstrumentConfig,
scattering_theory: AbstractScatteringTheory,
detector: AbstractDetector,
*,
filter: Optional[AbstractFilter] = None,
mask: Optional[AbstractMask] = None,
filter: Optional[FilterLike] = None,
mask: Optional[MaskLike] = None,
):
self.instrument_config = instrument_config
self.scattering_theory = scattering_theory
Expand Down

0 comments on commit 35a666f

Please sign in to comment.