Skip to content

Commit

Permalink
overload
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 committed Jan 17, 2025
1 parent f09df33 commit 1249e2c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
13 changes: 0 additions & 13 deletions src/cryojax/simulator/_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from abc import abstractmethod
from functools import cached_property
from typing import overload
from typing_extensions import override, Self

import equinox as eqx
Expand All @@ -26,18 +25,6 @@ class AbstractPose(Module, strict=True):
offset_x_in_angstroms: AbstractVar[Float[Array, ""]]
offset_y_in_angstroms: AbstractVar[Float[Array, ""]]

@overload
def rotate_coordinates(
self,
coordinate_grid_or_list: Float[Array, "z_dim y_dim x_dim 3"],
inverse: bool = False,
) -> Float[Array, "z_dim y_dim x_dim 3"]: ...

@overload
def rotate_coordinates(
self, coordinate_grid_or_list: Float[Array, "size 3"], inverse: bool = False
) -> Float[Array, "size 3"]: ...

def rotate_coordinates(
self,
coordinate_grid_or_list: (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Literal, Optional, overload
from typing_extensions import override

import equinox as eqx
Expand Down Expand Up @@ -123,6 +123,34 @@ def __init__(
self.ctf = ctf
self.envelope = envelope

@overload
def propagate_object_to_detector_plane(
self,
object_spectrum_at_exit_plane: Complex[
Array,
"{instrument_config.padded_y_dim} {instrument_config.padded_x_dim}",
],
instrument_config: InstrumentConfig,
*,
is_projection_approximation: Literal[False],
) -> Complex[
Array, "{instrument_config.padded_y_dim} {instrument_config.padded_x_dim//2+1}"
]: ...

@overload
def propagate_object_to_detector_plane(
self,
object_spectrum_at_exit_plane: Complex[
Array,
"{instrument_config.padded_y_dim} {instrument_config.padded_x_dim//2+1}",
],
instrument_config: InstrumentConfig,
*,
is_projection_approximation: Literal[True],
) -> Complex[
Array, "{instrument_config.padded_y_dim} {instrument_config.padded_x_dim//2+1}"
]: ...

def propagate_object_to_detector_plane(
self,
object_spectrum_at_exit_plane: (
Expand Down

0 comments on commit 1249e2c

Please sign in to comment.