Skip to content

Commit

Permalink
Some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadair committed Jun 20, 2023
1 parent 3efbc20 commit 3f28a2d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 34 deletions.
13 changes: 4 additions & 9 deletions gwcs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,8 @@ def array_index_to_world_values(self, *index_arrays):
``i`` is the row and ``j`` is the column (i.e. the opposite order to
`~BaseLowLevelWCS.pixel_to_world_values`).
"""
index_arrays = self._add_units_input(index_arrays[::-1], self.forward_transform, self.input_frame)

result = self(*index_arrays, with_units=False)

return self._remove_quantity_output(result, self.output_frame)
pixel_arrays = index_arrays[::-1]
return self.pixel_to_world_values(*pixel_arrays)

def world_to_pixel_values(self, *world_arrays):
"""
Expand Down Expand Up @@ -150,12 +147,10 @@ def world_to_array_index_values(self, *world_arrays):
`~BaseLowLevelWCS.pixel_to_world_values`). The indices should be
returned as rounded integers.
"""
world_arrays = self._add_units_input(world_arrays, self.backward_transform, self.output_frame)
result = self.invert(*world_arrays, with_units=False)
result = self.world_to_pixel_values(*world_arrays)
if self.pixel_n_dim != 1:
result = result[::-1]

return self._remove_quantity_output(result, self.input_frame)
return result

@property
def array_shape(self):
Expand Down
38 changes: 13 additions & 25 deletions gwcs/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from . import utils
from .wcstools import grid_from_bounding_box

# TODO: We now rely on a very new version of astropy all this can go
try:
from astropy.modeling.bounding_box import ModelBoundingBox as Bbox
from astropy.modeling.bounding_box import CompoundBoundingBox
Expand Down Expand Up @@ -209,23 +210,15 @@ def get_transform(self, from_frame, to_frame):
"""
if not self._pipeline:
return None
try:
from_ind = self._get_frame_index(from_frame)
except ValueError:
raise CoordinateFrameError("Frame {0} is not in the available "
"frames".format(from_frame))
try:
to_ind = self._get_frame_index(to_frame)
except ValueError:
raise CoordinateFrameError("Frame {0} is not in the available frames".format(to_frame))

from_ind = self._get_frame_index(from_frame)
to_ind = self._get_frame_index(to_frame)
if to_ind < from_ind:
#transforms = np.array(self._pipeline[to_ind: from_ind], dtype="object")[:, 1].tolist()
transforms = [step.transform for step in self._pipeline[to_ind: from_ind]]
transforms = [tr.inverse for tr in transforms[::-1]]
elif to_ind == from_ind:
return None
else:
#transforms = np.array(self._pipeline[from_ind: to_ind], dtype="object")[:, 1].copy()
transforms = [step.transform for step in self._pipeline[from_ind: to_ind]]
return functools.reduce(lambda x, y: x | y, transforms)

Expand Down Expand Up @@ -304,9 +297,11 @@ def _get_frame_index(self, frame):
"""
if isinstance(frame, cf.CoordinateFrame):
frame = frame.name
#frame_names = [getattr(item[0], "name", item[0]) for item in self._pipeline]
frame_names = [step.frame if isinstance(step.frame, str) else step.frame.name for step in self._pipeline]
return frame_names.index(frame)
try:
return frame_names.index(frame)
except ValueError as e:
raise CoordinateFrameError(f"Frame {frame} is not in the available frames") from e

def _get_frame_name(self, frame):
"""
Expand Down Expand Up @@ -520,7 +515,9 @@ def invert(self, *args, **kwargs):
else:
return result

def numerical_inverse(self, *args, **kwargs):
def numerical_inverse(self, *args, tolerance=1e-5, maxiter=50, adaptive=True,
detect_divergence=True, quiet=True, with_bounding_box=True,
fill_value=np.nan, with_units=False, **kwargs):
"""
Invert coordinates from output frame to input frame using numerical
inverse.
Expand Down Expand Up @@ -757,15 +754,6 @@ def numerical_inverse(self, *args, **kwargs):
[2.76552923e-05 1.14789013e-05]]
"""
tolerance = kwargs.get('tolerance', 1e-5)
maxiter = kwargs.get('maxiter', 50)
adaptive = kwargs.get('adaptive', True)
detect_divergence = kwargs.get('detect_divergence', True)
quiet = kwargs.get('quiet', True)
with_bounding_box = kwargs.get('with_bounding_box', True)
fill_value = kwargs.get('fill_value', np.nan)
with_units = kwargs.pop('with_units', False)

if not utils.isnumerical(args[0]):
args = self.output_frame.coordinate_to_quantity(*args)
if self.output_frame.naxes == 1:
Expand Down Expand Up @@ -1221,14 +1209,14 @@ def insert_frame(self, input_frame, transform, output_frame):
output_name, output_frame_obj = self._get_frame_name(output_frame)
try:
input_index = self._get_frame_index(input_frame)
except ValueError:
except CoordinateFrameError:
input_index = None
if input_frame_obj is None:
raise ValueError(f"New coordinate frame {input_name} must "
"be defined")
try:
output_index = self._get_frame_index(output_frame)
except ValueError:
except CoordinateFrameError:
output_index = None
if output_frame_obj is None:
raise ValueError(f"New coordinate frame {output_name} must "
Expand Down

0 comments on commit 3f28a2d

Please sign in to comment.