Skip to content

Commit

Permalink
First pass at restructuring the pixel <> world API
Browse files Browse the repository at this point in the history
The goal of this refactoring is to be able to remove `Frame.coordinates`
and `Frame.coordinate_to_quantity` and rely on the Astropy WCSAPI
machinery to do those conversions.
  • Loading branch information
Cadair committed Jun 20, 2023
1 parent 3f28a2d commit 217807e
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 198 deletions.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ To convert a pixel (x, y) = (1, 2) to sky coordinates, call the WCS object as a
The :meth:`~gwcs.wcs.WCS.invert` method evaluates the :meth:`~gwcs.wcs.WCS.backward_transform`
if available, otherwise applies an iterative method to calculate the reverse coordinates.

>>> wcsobj.invert(sky)
>>> wcsobj.invert(sky, with_units=True)
(<Quantity 1. pix>, <Quantity 2. pix>)

.. _save_as_asdf:
Expand Down
90 changes: 10 additions & 80 deletions gwcs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS
from astropy.wcs.wcsapi import BaseLowLevelWCS, HighLevelWCSMixin
from astropy.modeling import separable
import astropy.units as u

Expand All @@ -15,7 +15,7 @@
__all__ = ["GWCSAPIMixin"]


class GWCSAPIMixin(BaseHighLevelWCS, BaseLowLevelWCS):
class GWCSAPIMixin(BaseLowLevelWCS, HighLevelWCSMixin):
"""
A mix-in class that is intended to be inherited by the
:class:`~gwcs.wcs.WCS` class and provides the low- and high-level
Expand Down Expand Up @@ -78,19 +78,14 @@ def _remove_quantity_output(self, result, frame):
if self.output_frame.naxes == 1:
result = [result]

result = tuple(r.to_value(unit) for r, unit in zip(result, frame.unit))
result = tuple(r.to_value(unit) if isinstance(r, u.Quantity) else r
for r, unit in zip(result, frame.unit))

# If we only have one output axes, we shouldn't return a tuple.
if self.output_frame.naxes == 1 and isinstance(result, tuple):
return result[0]
return result

def _add_units_input(self, arrays, transform, frame):
if transform.uses_quantity:
return tuple(u.Quantity(array, unit) for array, unit in zip(arrays, frame.unit))

return arrays

def pixel_to_world_values(self, *pixel_arrays):
"""
Convert pixel coordinates to world coordinates.
Expand All @@ -104,8 +99,9 @@ def pixel_to_world_values(self, *pixel_arrays):
order, where for an image, ``x`` is the horizontal coordinate and ``y``
is the vertical coordinate.
"""
pixel_arrays = self._add_units_input(pixel_arrays, self.forward_transform, self.input_frame)
result = self(*pixel_arrays, with_units=False)
if self.forward_transform.uses_quantity:
pixel_arrays = self._add_units_input(pixel_arrays, self.input_frame)
result = self._call_forward(*pixel_arrays)

return self._remove_quantity_output(result, self.output_frame)

Expand All @@ -132,9 +128,10 @@ def world_to_pixel_values(self, *world_arrays):
be returned in the ``(x, y)`` order, where for an image, ``x`` is the
horizontal coordinate and ``y`` is the vertical coordinate.
"""
world_arrays = self._add_units_input(world_arrays, self.backward_transform, self.output_frame)
if self.backward_transform.uses_quantity:
world_arrays = self._add_units_input(world_arrays, self.output_frame)

result = self.invert(*world_arrays, with_units=False)
result = self._call_backward(*world_arrays)

return self._remove_quantity_output(result, self.input_frame)

Expand Down Expand Up @@ -265,73 +262,6 @@ def world_axis_object_classes(self):
def world_axis_object_components(self):
return self.output_frame._world_axis_object_components

# High level APE 14 API

@property
def low_level_wcs(self):
"""
Returns a reference to the underlying low-level WCS object.
"""
return self

def _sanitize_pixel_inputs(self, *pixel_arrays):
pixels = []
if self.forward_transform.uses_quantity:
for i, pixel in enumerate(pixel_arrays):
if not isinstance(pixel, u.Quantity):
pixel = u.Quantity(value=pixel, unit=self.input_frame.unit[i])
pixels.append(pixel)
else:
for i, pixel in enumerate(pixel_arrays):
if isinstance(pixel, u.Quantity):
if pixel.unit != self.input_frame.unit[i]:
raise ValueError('Quantity input does not match the '
'input_frame unit.')
pixel = pixel.value
pixels.append(pixel)

return pixels

def pixel_to_world(self, *pixel_arrays):
"""
Convert pixel values to world coordinates.
"""
pixels = self._sanitize_pixel_inputs(*pixel_arrays)
return self(*pixels, with_units=True)

def array_index_to_world(self, *index_arrays):
"""
Convert array indices to world coordinates (represented by Astropy
objects).
"""
pixel_arrays = index_arrays[::-1]
pixels = self._sanitize_pixel_inputs(*pixel_arrays)
return self(*pixels, with_units=True)

def world_to_pixel(self, *world_objects):
"""
Convert world coordinates to pixel values.
"""
result = self.invert(*world_objects, with_units=True)

if self.input_frame.naxes > 1:
first_res = result[0]
if not utils.isnumerical(first_res):
result = [i.value for i in result]
else:
if not utils.isnumerical(result):
result = result.value

return result

def world_to_array_index(self, *world_objects):
"""
Convert world coordinates (represented by Astropy objects) to array
indices.
"""
result = self.invert(*world_objects, with_units=True)[::-1]
return tuple([utils._toindex(r) for r in result])

@property
def pixel_axis_names(self):
"""
Expand Down
41 changes: 19 additions & 22 deletions gwcs/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_world_axis_units(wcs_ndim_types_units):
@pytest.mark.parametrize(("x", "y"), zip((x, xarr), (y, yarr)))
def test_pixel_to_world_values(gwcs_2d_spatial_shift, x, y):
wcsobj = gwcs_2d_spatial_shift
assert_allclose(wcsobj.pixel_to_world_values(x, y), wcsobj(x, y, with_units=False))
assert_allclose(wcsobj.pixel_to_world_values(x, y), wcsobj(x, y))


@pytest.mark.parametrize(("x", "y"), zip((x, xarr), (y, yarr)))
Expand All @@ -116,7 +116,7 @@ def test_pixel_to_world_values_units_2d(gwcs_2d_shift_scale_quantity, x, y):
call_pixel = x*u.pix, y*u.pix
api_pixel = x, y

call_world = wcsobj(*call_pixel, with_units=False)
call_world = wcsobj(*call_pixel)
api_world = wcsobj.pixel_to_world_values(*api_pixel)

# Check that call returns quantities and api dosen't
Expand All @@ -126,7 +126,7 @@ def test_pixel_to_world_values_units_2d(gwcs_2d_shift_scale_quantity, x, y):
# Check that they are the same (and implicitly in the same units)
assert_allclose(u.Quantity(call_world).value, api_world)

new_call_pixel = wcsobj.invert(*call_world, with_units=False)
new_call_pixel = wcsobj.invert(*call_world)
[assert_allclose(n, p) for n, p in zip(new_call_pixel, call_pixel)]

new_api_pixel = wcsobj.world_to_pixel_values(*api_world)
Expand All @@ -140,7 +140,7 @@ def test_pixel_to_world_values_units_1d(gwcs_1d_freq_quantity, x):
call_pixel = x * u.pix
api_pixel = x

call_world = wcsobj(call_pixel, with_units=False)
call_world = wcsobj(call_pixel)
api_world = wcsobj.pixel_to_world_values(api_pixel)

# Check that call returns quantities and api dosen't
Expand All @@ -150,7 +150,7 @@ def test_pixel_to_world_values_units_1d(gwcs_1d_freq_quantity, x):
# Check that they are the same (and implicitly in the same units)
assert_allclose(u.Quantity(call_world).value, api_world)

new_call_pixel = wcsobj.invert(call_world, with_units=False)
new_call_pixel = wcsobj.invert(call_world)
assert_allclose(new_call_pixel, call_pixel)

new_api_pixel = wcsobj.world_to_pixel_values(api_world)
Expand All @@ -160,7 +160,7 @@ def test_pixel_to_world_values_units_1d(gwcs_1d_freq_quantity, x):
@pytest.mark.parametrize(("x", "y"), zip((x, xarr), (y, yarr)))
def test_array_index_to_world_values(gwcs_2d_spatial_shift, x, y):
wcsobj = gwcs_2d_spatial_shift
assert_allclose(wcsobj.array_index_to_world_values(x, y), wcsobj(y, x, with_units=False))
assert_allclose(wcsobj.array_index_to_world_values(x, y), wcsobj(y, x))


def test_world_axis_object_components_2d(gwcs_2d_spatial_shift):
Expand Down Expand Up @@ -267,8 +267,9 @@ def test_high_level_wrapper(wcsobj, request):
if wcsobj.forward_transform.uses_quantity:
pixel_input *= u.pix

# The wrapper and the raw gwcs class can take different paths
wc1 = hlvl.pixel_to_world(*pixel_input)
wc2 = wcsobj(*pixel_input, with_units=True)
wc2 = wcsobj.pixel_to_world(*pixel_input)

assert type(wc1) is type(wc2)

Expand Down Expand Up @@ -362,24 +363,20 @@ def test_low_level_wcs(wcsobj):

@wcs_objs
def test_pixel_to_world(wcsobj):
comp = wcsobj(x, y, with_units=True)
comp = wcsobj.output_frame.coordinates(comp)
values = wcsobj(x, y)
result = wcsobj.pixel_to_world(x, y)
assert isinstance(comp, coord.SkyCoord)
assert isinstance(result, coord.SkyCoord)
assert_allclose(comp.data.lon, result.data.lon)
assert_allclose(comp.data.lat, result.data.lat)
assert_allclose(values[0] * u.deg, result.data.lon)
assert_allclose(values[1] * u.deg, result.data.lat)


@wcs_objs
def test_array_index_to_world(wcsobj):
comp = wcsobj(x, y, with_units=True)
comp = wcsobj.output_frame.coordinates(comp)
values = wcsobj(x, y)
result = wcsobj.array_index_to_world(y, x)
assert isinstance(comp, coord.SkyCoord)
assert isinstance(result, coord.SkyCoord)
assert_allclose(comp.data.lon, result.data.lon)
assert_allclose(comp.data.lat, result.data.lat)
assert_allclose(values[0] * u.deg, result.data.lon)
assert_allclose(values[1] * u.deg, result.data.lat)


def test_pixel_to_world_quantity(gwcs_2d_shift_scale, gwcs_2d_shift_scale_quantity):
Expand Down Expand Up @@ -460,28 +457,28 @@ def sky_ra_dec(request, gwcs_2d_spatial_shift):
def test_world_to_pixel(gwcs_2d_spatial_shift, sky_ra_dec):
wcsobj = gwcs_2d_spatial_shift
sky, ra, dec = sky_ra_dec
assert_allclose(wcsobj.world_to_pixel(sky), wcsobj.invert(ra, dec, with_units=False))
assert_allclose(wcsobj.world_to_pixel(sky), wcsobj.invert(ra, dec))


def test_world_to_array_index(gwcs_2d_spatial_shift, sky_ra_dec):
wcsobj = gwcs_2d_spatial_shift
sky, ra, dec = sky_ra_dec
assert_allclose(wcsobj.world_to_array_index(sky), wcsobj.invert(ra, dec, with_units=False)[::-1])
assert_allclose(wcsobj.world_to_array_index(sky), wcsobj.invert(ra, dec)[::-1])


def test_world_to_pixel_values(gwcs_2d_spatial_shift, sky_ra_dec):
wcsobj = gwcs_2d_spatial_shift
sky, ra, dec = sky_ra_dec

assert_allclose(wcsobj.world_to_pixel_values(sky), wcsobj.invert(ra, dec, with_units=False))
assert_allclose(wcsobj.world_to_pixel_values(ra, dec), wcsobj.invert(ra, dec))


def test_world_to_array_index_values(gwcs_2d_spatial_shift, sky_ra_dec):
wcsobj = gwcs_2d_spatial_shift
sky, ra, dec = sky_ra_dec

assert_allclose(wcsobj.world_to_array_index_values(sky),
wcsobj.invert(ra, dec, with_units=False)[::-1])
assert_allclose(wcsobj.world_to_array_index_values(ra, dec),
wcsobj.invert(ra, dec)[::-1])


def test_ndim_str_frames(gwcs_with_frames_strings):
Expand Down
20 changes: 12 additions & 8 deletions gwcs/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def test_backward_transform_has_inverse():
assert_allclose(w.backward_transform.inverse(1, 2), w(1, 2))


@pytest.mark.skip
def test_return_coordinates():
"""Test converting to coordinate objects or quantities."""
w = wcs.WCS(pipe[:])
Expand All @@ -203,7 +204,7 @@ def test_return_coordinates():
output_quant = w.output_frame.coordinate_to_quantity(num_plus_output)
assert_allclose(w(x, y), numerical_result)
assert_allclose(utils.get_values(w.unit, *output_quant), numerical_result)
assert_allclose(w.invert(num_plus_output), (x, y))
assert_allclose(w.invert(num_plus_output, with_units=True), (x, y))
assert isinstance(num_plus_output, coord.SkyCoord)

# Spectral frame
Expand Down Expand Up @@ -253,7 +254,7 @@ def test_from_fiducial_composite():
assert isinstance(w.cube_frame.frames[1].reference_frame, coord.FK5)
assert_allclose(w(1, 1, 1), (1.5, 96.52373368309931, -71.37420187296995))
# test returning coordinate objects with composite output_frame
res = w(1, 2, 2, with_units=True)
res = w.pixel_to_world(1, 2, 2)
assert_allclose(res[0], u.Quantity(1.5 * u.micron))
assert isinstance(res[1], coord.SkyCoord)
assert_allclose(res[1].ra.value, 99.329496642319)
Expand All @@ -265,7 +266,7 @@ def test_from_fiducial_composite():
assert_allclose(w(1, 1, 1), (11.5, 99.97738475762152, -72.29039139739766))
# test coordinate object output

coord_result = w(1, 1, 1, with_units=True)
coord_result = w.pixel_to_world(1, 1, 1)
assert_allclose(coord_result[0], u.Quantity(11.5 * u.micron))


Expand Down Expand Up @@ -299,13 +300,16 @@ def test_bounding_box():
with pytest.raises(ValueError):
w.bounding_box = ((1, 5), (2, 6))


def test_bounding_box_units():
# Test that bounding_box with quantities can be assigned and evaluates
bb = ((1 * u.pix, 5 * u.pix), (2 * u.pix, 6 * u.pix))
trans = models.Shift(10 * u .pix) & models.Shift(2 * u.pix)
pipeline = [('detector', trans), ('sky', None)]
w = wcs.WCS(pipeline)
w.bounding_box = bb
assert_allclose(w(-1*u.pix, -1*u.pix), (np.nan, np.nan))
world = w(-1*u.pix, -1*u.pix)
assert u.allclose(world, (np.nan*u.pix, np.nan*u.pix))


def test_compound_bounding_box():
Expand Down Expand Up @@ -627,11 +631,11 @@ def test_footprint(self):

def test_inverse(self):
sky_coord = self.wcs(10, 20, with_units=True)
assert np.allclose(self.wcs.invert(sky_coord), (10, 20))
assert np.allclose(self.wcs.invert(sky_coord, with_units=True), (10, 20))

def test_back_coordinates(self):
sky_coord = self.wcs(1, 2, with_units=True)
res = self.wcs.transform('sky', 'focal', sky_coord)
res = self.wcs.transform('sky', 'focal', sky_coord, with_units=True)
assert_allclose(res, self.wcs.get_transform('detector', 'focal')(1, 2))

def test_units(self):
Expand Down Expand Up @@ -750,7 +754,7 @@ def test_to_fits_sip_composite_frame(gwcs_cube_with_separable_spectral):
assert fw_hdr['NAXIS2'] == 64

fw = astwcs.WCS(fw_hdr)
gskyval = w(1, 60, 55, with_units=True)[0]
gskyval = w.pixel_to_world(1, 60, 55)[1]
fskyval = fw.all_pix2world(1, 60, 0)
fskyval = [float(fskyval[ra_axis - 1]), float(fskyval[dec_axis - 1])]
assert np.allclose([gskyval.ra.value, gskyval.dec.value], fskyval)
Expand All @@ -763,7 +767,7 @@ def test_to_fits_sip_composite_frame_galactic(gwcs_3d_galactic_spectral):
assert fw_hdr['CTYPE1'] == 'GLAT-TAN'

fw = astwcs.WCS(fw_hdr)
gskyval = w(7, 8, 9, with_units=True)[0]
gskyval = w.pixel_to_world(7, 8, 9)[0]
assert np.allclose([gskyval.b.value, gskyval.l.value],
fw.all_pix2world(7, 9, 0), atol=1e-3)

Expand Down
Loading

0 comments on commit 217807e

Please sign in to comment.