From 362dc95d254299231b315d67ed9aafa8549c6f92 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 15 Jun 2023 14:29:24 +0100 Subject: [PATCH 01/49] Define and document the core CoordinateFrame API This creates a `BaseCoordinateFrame` class, definining the minimal API for a coordinate frame with descriptive docstrings. This is done mainly as an exercise to easily review and document the API. Also add significant docstring to the module describing how and why coordinate frames work. --- gwcs/coordinate_frames.py | 240 +++++++++++++++++++++++++++++++++++++- 1 file changed, 236 insertions(+), 4 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index 5c36c253..2c360259 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -1,7 +1,107 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst """ -Defines coordinate frames and ties them to data axes. +This module defines coordinate frames for describing the inputs and/or outputs of a transform. + +In the following example, we have a two stage transform, with an input frame, an +output frame and an intermediate frame. + +.. code-block:: + + ┌───────────────┐ + │ │ + │ Input │ + │ Frame │ + │ │ + └───────┬───────┘ + │ + ┌─────▼─────┐ + │ Transform │ + └─────┬─────┘ + │ + ┌───────▼───────┐ + │ │ + │ Intermediate │ + │ Frame │ + │ │ + └───────┬───────┘ + │ + ┌─────▼─────┐ + │ Transform │ + └─────┬─────┘ + │ + ┌───────▼───────┐ + │ │ + │ Output │ + │ Frame │ + │ │ + └───────────────┘ + + +Each frame instance is both metadata for the inputs/outputs of a transform and +also a converter between those inputs/outputs and richer coordinate +representations of those inputs/ouputs. + +For example, an output frame of type `~astropy.coordinates.SpectralCoord` +provides metadata to the `.WCS` object such as the ``axes_type`` being +``"SPECTRAL"`` and the unit of the output etc. The output frame also provides a +converter of the numeric output of the transform to a +`~astropy.coordinates.SpectralCoord` object, by combining this metadata with the +numerical values. + +``axes_order`` and conversion between objects and arguments +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +One of the key concepts regarding coordinate frames is the ``axes_order`` argument. +This argument is used to map from the components of the frame to the inputs/outputs of the transform. +To illustrate this consider this situation where you have a forward transform +which outputs three coordinates ``[lat, lambda, lon]``. These would be +represented as a `.SpectralFrame` and a `.CelestialFrame`, however, the axes of +a `.CelestialFrame` are always ``[lon, lat]``, so by specifying two frames as + +.. code-block:: python + + [SpectralCoord(axes_order=(1,)), CelestialCoord(axes_order=(2, 0))] + +we would map the outputs of this transform into the correct positions in the +frames. As shown below, this is also used when constructing the inputs to the inverse transform. + +.. code-block:: + + lat, lambda, lon + │ │ │ + └──────┼─────┼────────┐ + ┌───────────┘ └──┐ │ + │ │ │ + ┌─────────▼────────┐ ┌──────▼─────▼─────┐ + │ │ │ │ + │ SpectralFrame │ │ CelestialFrame │ + │ │ │ │ + │ (1,) │ │ (2, 0) │ + │ │ │ │ + └─────────┬────────┘ └──────────┬────┬──┘ + │ │ │ + │ │ │ + ▼ ▼ ▼ + SpectralCoord(lambda) SkyCoord((lon, lat)) + │ │ │ + └─────┐ ┌────────────┘ │ + │ │ ┌────────────┘ + ▼ ▼ ▼ + [lambda, lon, lat] + │ │ │ + │ │ │ + ┌──────▼─────▼────▼────┐ + │ │ + │ Sort by axes_order │ + │ │ + └────┬──────┬─────┬────┘ + │ │ │ + ▼ ▼ ▼ + lat, lambda, lon + """ + +import abc from collections import defaultdict import logging import numpy as np @@ -10,14 +110,15 @@ from astropy import time from astropy import units as u from astropy import utils as astutil +from astropy.utils.decorators import deprecated from astropy import coordinates as coord from astropy.wcs.wcsapi.low_level_api import (validate_physical_types, VALID_UCDS) from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1 from astropy.coordinates import StokesCoord -__all__ = ['Frame2D', 'CelestialFrame', 'SpectralFrame', 'CompositeFrame', - 'CoordinateFrame', 'TemporalFrame', 'StokesFrame'] +__all__ = ['BaseCoordinateFrame', 'Frame2D', 'CelestialFrame', 'SpectralFrame', 'CompositeFrame', + 'CoordinateFrame', 'TemporalFrame', 'StokesFrame', 'PixelFrame'] def _ucd1_to_ctype_name_mapping(ctype_to_ucd, allowed_ucd_duplicates): @@ -80,7 +181,137 @@ def get_ctype_from_ucd(ucd): return UCD1_TO_CTYPE.get(ucd, "") -class CoordinateFrame: + +class BaseCoordinateFrame(abc.ABC): + """ + API Definition for a Coordinate frame + """ + @property + @abc.abstractmethod + def naxes(self) -> int: + """ + The number of axes described by this frame. + """ + + @property + @abc.abstractmethod + def name(self) -> str: + """ + The name of the coordinate frame. + """ + + # TODO: Why is this not `units`? + @property + @abc.abstractmethod + def unit(self) -> tuple[u.Unit, ...]: + """ + The units of the axes in this frame. + """ + + @property + @abc.abstractmethod + def axes_names(self) -> tuple[str, ...]: + """ + Names describing the axes of the frame. + """ + + @property + @abc.abstractmethod + def axes_order(self) -> tuple[int, ...]: + """ + The position of the axes in the frame in the transform. + """ + + @property + @abc.abstractmethod + def reference_frame(self): + """ + The reference frame of the coordinates described by this frame. + + This is usually an Astropy object such as ``SkyCoord`` or ``Time``. + """ + + @property + @abc.abstractmethod + def axes_type(self): + """ + An upcase string describing the type of the axis. + + Known values are ``"SPATIAL", "TEMPORAL", "STOKES", "SPECTRAL", "PIXEL"``. + """ + + @property + @abc.abstractmethod + def axis_physical_types(self): + """ + The UCD 1+ physical types for the axes, in frame order. + """ + + @property + @abc.abstractmethod + def _world_axis_object_classes(self): + """ + The APE 14 object classes for this frame. + + See Also + -------- + `astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_classes` + """ + + @property + @abc.abstractmethod + def _world_axis_object_components(self): + """ + The APE 14 object components for this frame. + + See Also + -------- + `astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_components` + """ + + # TODO: What order are args in here? + # Are they in transform order so we should use `axes_order` to reorder them? + @abc.abstractmethod + def coordinates(self, *args) -> object: + """ + Construct a rich coordinate object from the output of a transform. + + The object returned by this method must match the structures returned by + ``_world_axis_object_classes`` and ``_world_axis_object_components``. + + Parameters + ---------- + args : `~numbers.Number` or .Quantity` + The numberical objects returned by a transform (or user input). + + Returns + ``````` + coord : `object` + A single "rich" coordinate object such as `~astropy.coordinates.SkyCoord`. + """ + + # TODO: What are inputs, shouldn't this only be one single rich coordinate + # object? It seems the current implementation also handles not-that? i.e. + # CelestialCoord also accepts two quantity objects? + # TODO: Again what are the order of the return values here? + @abc.abstractmethod + def coordinate_to_quantity(self, *coords) -> tuple[u.Quantity, ...]: + """ + Construct `~astropy.units.Quantity` objects from the rich coordinate objects. + + Parameters + ---------- + coord : `object` + The rich coordinate object to convert. + + Returns + ------- + args: iterable of `~astropy.units.Quantity` objects. + The numerical values to pass to the transform. + """ + + +class CoordinateFrame(BaseCoordinateFrame): """ Base class for Coordinate Frames. @@ -225,6 +456,7 @@ def reference_frame(self): """ Reference frame, used to convert to world coordinate objects. """ return self._reference_frame + # TODO: This seems to be spectral specific, should it be moved to SpectralFrame? @property def reference_position(self): """ Reference Position. """ From bff84699b4a754da8cf257d411064a980ffed674 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Mon, 19 Jun 2023 16:41:18 +0100 Subject: [PATCH 02/49] First pass at restructuring the pixel <> world API The goal of this refactoring is to be able to remove `Frame.coordinates` and `Frame.coordinate_to_quantity` and rely on the Astropy WCSAPI machinery to do those conversions. --- docs/index.rst | 5 -- gwcs/api.py | 90 +++-------------------- gwcs/tests/test_api.py | 41 +++++------ gwcs/tests/test_wcs.py | 20 +++-- gwcs/wcs.py | 162 ++++++++++++++++++++--------------------- gwcs/wcstools.py | 2 +- 6 files changed, 121 insertions(+), 199 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 481f2f76..58b94db7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -240,11 +240,6 @@ To convert a pixel (x, y) = (1, 2) to sky coordinates, call the WCS object as a The :meth:`~gwcs.wcs.WCS.invert` method evaluates the :meth:`~gwcs.wcs.WCS.backward_transform` if available, otherwise applies an iterative method to calculate the reverse coordinates. -.. doctest-skip:: - - >>> wcsobj.invert(*sky) - (0.9999999996185807, 1.999999999186798) - GWCS supports the common WCS interface which defines several methods to work with high level Astropy objects: diff --git a/gwcs/api.py b/gwcs/api.py index 2b1ca22e..24363dc9 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -5,7 +5,7 @@ """ -from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS +from astropy.wcs.wcsapi import BaseLowLevelWCS, HighLevelWCSMixin from astropy.modeling import separable import astropy.units as u @@ -15,7 +15,7 @@ __all__ = ["GWCSAPIMixin"] -class GWCSAPIMixin(BaseHighLevelWCS, BaseLowLevelWCS): +class GWCSAPIMixin(BaseLowLevelWCS, HighLevelWCSMixin): """ A mix-in class that is intended to be inherited by the :class:`~gwcs.wcs.WCS` class and provides the low- and high-level @@ -78,19 +78,14 @@ def _remove_quantity_output(self, result, frame): if self.output_frame.naxes == 1: result = [result] - result = tuple(r.to_value(unit) for r, unit in zip(result, frame.unit)) + result = tuple(r.to_value(unit) if isinstance(r, u.Quantity) else r + for r, unit in zip(result, frame.unit)) # If we only have one output axes, we shouldn't return a tuple. if self.output_frame.naxes == 1 and isinstance(result, tuple): return result[0] return result - def _add_units_input(self, arrays, transform, frame): - if transform.uses_quantity: - return tuple(u.Quantity(array, unit) for array, unit in zip(arrays, frame.unit)) - - return arrays - def pixel_to_world_values(self, *pixel_arrays): """ Convert pixel coordinates to world coordinates. @@ -104,8 +99,9 @@ def pixel_to_world_values(self, *pixel_arrays): order, where for an image, ``x`` is the horizontal coordinate and ``y`` is the vertical coordinate. """ - pixel_arrays = self._add_units_input(pixel_arrays, self.forward_transform, self.input_frame) - result = self(*pixel_arrays, with_units=False) + if self.forward_transform.uses_quantity: + pixel_arrays = self._add_units_input(pixel_arrays, self.input_frame) + result = self._call_forward(*pixel_arrays) return self._remove_quantity_output(result, self.output_frame) @@ -132,9 +128,10 @@ def world_to_pixel_values(self, *world_arrays): be returned in the ``(x, y)`` order, where for an image, ``x`` is the horizontal coordinate and ``y`` is the vertical coordinate. """ - world_arrays = self._add_units_input(world_arrays, self.backward_transform, self.output_frame) + if self.backward_transform.uses_quantity: + world_arrays = self._add_units_input(world_arrays, self.output_frame) - result = self.invert(*world_arrays, with_units=False) + result = self._call_backward(*world_arrays) return self._remove_quantity_output(result, self.input_frame) @@ -269,73 +266,6 @@ def world_axis_object_classes(self): def world_axis_object_components(self): return self.output_frame._world_axis_object_components - # High level APE 14 API - - @property - def low_level_wcs(self): - """ - Returns a reference to the underlying low-level WCS object. - """ - return self - - def _sanitize_pixel_inputs(self, *pixel_arrays): - pixels = [] - if self.forward_transform.uses_quantity: - for i, pixel in enumerate(pixel_arrays): - if not isinstance(pixel, u.Quantity): - pixel = u.Quantity(value=pixel, unit=self.input_frame.unit[i]) - pixels.append(pixel) - else: - for i, pixel in enumerate(pixel_arrays): - if isinstance(pixel, u.Quantity): - if pixel.unit != self.input_frame.unit[i]: - raise ValueError('Quantity input does not match the ' - 'input_frame unit.') - pixel = pixel.value - pixels.append(pixel) - - return pixels - - def pixel_to_world(self, *pixel_arrays): - """ - Convert pixel values to world coordinates. - """ - pixels = self._sanitize_pixel_inputs(*pixel_arrays) - return self(*pixels, with_units=True) - - def array_index_to_world(self, *index_arrays): - """ - Convert array indices to world coordinates (represented by Astropy - objects). - """ - pixel_arrays = index_arrays[::-1] - pixels = self._sanitize_pixel_inputs(*pixel_arrays) - return self(*pixels, with_units=True) - - def world_to_pixel(self, *world_objects): - """ - Convert world coordinates to pixel values. - """ - result = self.invert(*world_objects, with_units=True) - - if self.input_frame.naxes > 1: - first_res = result[0] - if not utils.isnumerical(first_res): - result = [i.value for i in result] - else: - if not utils.isnumerical(result): - result = result.value - - return result - - def world_to_array_index(self, *world_objects): - """ - Convert world coordinates (represented by Astropy objects) to array - indices. - """ - result = self.invert(*world_objects, with_units=True)[::-1] - return tuple([utils._toindex(r) for r in result]) - @property def pixel_axis_names(self): """ diff --git a/gwcs/tests/test_api.py b/gwcs/tests/test_api.py index f80c8d8a..aa40a4fe 100644 --- a/gwcs/tests/test_api.py +++ b/gwcs/tests/test_api.py @@ -106,7 +106,7 @@ def test_world_axis_units(wcs_ndim_types_units): @pytest.mark.parametrize(("x", "y"), zip((x, xarr), (y, yarr))) def test_pixel_to_world_values(gwcs_2d_spatial_shift, x, y): wcsobj = gwcs_2d_spatial_shift - assert_allclose(wcsobj.pixel_to_world_values(x, y), wcsobj(x, y, with_units=False)) + assert_allclose(wcsobj.pixel_to_world_values(x, y), wcsobj(x, y)) @pytest.mark.parametrize(("x", "y"), zip((x, xarr), (y, yarr))) @@ -116,7 +116,7 @@ def test_pixel_to_world_values_units_2d(gwcs_2d_shift_scale_quantity, x, y): call_pixel = x*u.pix, y*u.pix api_pixel = x, y - call_world = wcsobj(*call_pixel, with_units=False) + call_world = wcsobj(*call_pixel) api_world = wcsobj.pixel_to_world_values(*api_pixel) # Check that call returns quantities and api dosen't @@ -126,7 +126,7 @@ def test_pixel_to_world_values_units_2d(gwcs_2d_shift_scale_quantity, x, y): # Check that they are the same (and implicitly in the same units) assert_allclose(u.Quantity(call_world).value, api_world) - new_call_pixel = wcsobj.invert(*call_world, with_units=False) + new_call_pixel = wcsobj.invert(*call_world) [assert_allclose(n, p) for n, p in zip(new_call_pixel, call_pixel)] new_api_pixel = wcsobj.world_to_pixel_values(*api_world) @@ -140,7 +140,7 @@ def test_pixel_to_world_values_units_1d(gwcs_1d_freq_quantity, x): call_pixel = x * u.pix api_pixel = x - call_world = wcsobj(call_pixel, with_units=False) + call_world = wcsobj(call_pixel) api_world = wcsobj.pixel_to_world_values(api_pixel) # Check that call returns quantities and api dosen't @@ -150,7 +150,7 @@ def test_pixel_to_world_values_units_1d(gwcs_1d_freq_quantity, x): # Check that they are the same (and implicitly in the same units) assert_allclose(u.Quantity(call_world).value, api_world) - new_call_pixel = wcsobj.invert(call_world, with_units=False) + new_call_pixel = wcsobj.invert(call_world) assert_allclose(new_call_pixel, call_pixel) new_api_pixel = wcsobj.world_to_pixel_values(api_world) @@ -160,7 +160,7 @@ def test_pixel_to_world_values_units_1d(gwcs_1d_freq_quantity, x): @pytest.mark.parametrize(("x", "y"), zip((x, xarr), (y, yarr))) def test_array_index_to_world_values(gwcs_2d_spatial_shift, x, y): wcsobj = gwcs_2d_spatial_shift - assert_allclose(wcsobj.array_index_to_world_values(x, y), wcsobj(y, x, with_units=False)) + assert_allclose(wcsobj.array_index_to_world_values(x, y), wcsobj(y, x)) def test_world_axis_object_components_2d(gwcs_2d_spatial_shift): @@ -273,8 +273,9 @@ def test_high_level_wrapper(wcsobj, request): if wcsobj.forward_transform.uses_quantity: pixel_input *= u.pix + # The wrapper and the raw gwcs class can take different paths wc1 = hlvl.pixel_to_world(*pixel_input) - wc2 = wcsobj(*pixel_input, with_units=True) + wc2 = wcsobj.pixel_to_world(*pixel_input) assert type(wc1) is type(wc2) @@ -368,24 +369,20 @@ def test_low_level_wcs(wcsobj): @wcs_objs def test_pixel_to_world(wcsobj): - comp = wcsobj(x, y, with_units=True) - comp = wcsobj.output_frame.coordinates(comp) + values = wcsobj(x, y) result = wcsobj.pixel_to_world(x, y) - assert isinstance(comp, coord.SkyCoord) assert isinstance(result, coord.SkyCoord) - assert_allclose(comp.data.lon, result.data.lon) - assert_allclose(comp.data.lat, result.data.lat) + assert_allclose(values[0] * u.deg, result.data.lon) + assert_allclose(values[1] * u.deg, result.data.lat) @wcs_objs def test_array_index_to_world(wcsobj): - comp = wcsobj(x, y, with_units=True) - comp = wcsobj.output_frame.coordinates(comp) + values = wcsobj(x, y) result = wcsobj.array_index_to_world(y, x) - assert isinstance(comp, coord.SkyCoord) assert isinstance(result, coord.SkyCoord) - assert_allclose(comp.data.lon, result.data.lon) - assert_allclose(comp.data.lat, result.data.lat) + assert_allclose(values[0] * u.deg, result.data.lon) + assert_allclose(values[1] * u.deg, result.data.lat) def test_pixel_to_world_quantity(gwcs_2d_shift_scale, gwcs_2d_shift_scale_quantity): @@ -466,28 +463,28 @@ def sky_ra_dec(request, gwcs_2d_spatial_shift): def test_world_to_pixel(gwcs_2d_spatial_shift, sky_ra_dec): wcsobj = gwcs_2d_spatial_shift sky, ra, dec = sky_ra_dec - assert_allclose(wcsobj.world_to_pixel(sky), wcsobj.invert(ra, dec, with_units=False)) + assert_allclose(wcsobj.world_to_pixel(sky), wcsobj.invert(ra, dec)) def test_world_to_array_index(gwcs_2d_spatial_shift, sky_ra_dec): wcsobj = gwcs_2d_spatial_shift sky, ra, dec = sky_ra_dec - assert_allclose(wcsobj.world_to_array_index(sky), wcsobj.invert(ra, dec, with_units=False)[::-1]) + assert_allclose(wcsobj.world_to_array_index(sky), wcsobj.invert(ra, dec)[::-1]) def test_world_to_pixel_values(gwcs_2d_spatial_shift, sky_ra_dec): wcsobj = gwcs_2d_spatial_shift sky, ra, dec = sky_ra_dec - assert_allclose(wcsobj.world_to_pixel_values(sky), wcsobj.invert(ra, dec, with_units=False)) + assert_allclose(wcsobj.world_to_pixel_values(ra, dec), wcsobj.invert(ra, dec)) def test_world_to_array_index_values(gwcs_2d_spatial_shift, sky_ra_dec): wcsobj = gwcs_2d_spatial_shift sky, ra, dec = sky_ra_dec - assert_allclose(wcsobj.world_to_array_index_values(sky), - wcsobj.invert(ra, dec, with_units=False)[::-1]) + assert_allclose(wcsobj.world_to_array_index_values(ra, dec), + wcsobj.invert(ra, dec)[::-1]) def test_ndim_str_frames(gwcs_with_frames_strings): diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index 5c9af093..0220bb46 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -208,6 +208,7 @@ def test_backward_transform_has_inverse(): assert_allclose(w.backward_transform.inverse(1, 2), w(1, 2)) +@pytest.mark.skip def test_return_coordinates(): """Test converting to coordinate objects or quantities.""" w = wcs.WCS(pipe[:]) @@ -219,7 +220,7 @@ def test_return_coordinates(): output_quant = w.output_frame.coordinate_to_quantity(num_plus_output) assert_allclose(w(x, y), numerical_result) assert_allclose(utils.get_values(w.unit, *output_quant), numerical_result) - assert_allclose(w.invert(num_plus_output), (x, y)) + assert_allclose(w.invert(num_plus_output, with_units=True), (x, y)) assert isinstance(num_plus_output, coord.SkyCoord) # Spectral frame @@ -269,7 +270,7 @@ def test_from_fiducial_composite(): assert isinstance(w.cube_frame.frames[1].reference_frame, coord.FK5) assert_allclose(w(1, 1, 1), (1.5, 96.52373368309931, -71.37420187296995)) # test returning coordinate objects with composite output_frame - res = w(1, 2, 2, with_units=True) + res = w.pixel_to_world(1, 2, 2) assert_allclose(res[0], u.Quantity(1.5 * u.micron)) assert isinstance(res[1], coord.SkyCoord) assert_allclose(res[1].ra.value, 99.329496642319) @@ -281,7 +282,7 @@ def test_from_fiducial_composite(): assert_allclose(w(1, 1, 1), (11.5, 99.97738475762152, -72.29039139739766)) # test coordinate object output - coord_result = w(1, 1, 1, with_units=True) + coord_result = w.pixel_to_world(1, 1, 1) assert_allclose(coord_result[0], u.Quantity(11.5 * u.micron)) @@ -312,13 +313,16 @@ def test_bounding_box(): with pytest.raises(ValueError): w.bounding_box = ((1, 5), (2, 6)) + +def test_bounding_box_units(): # Test that bounding_box with quantities can be assigned and evaluates bb = ((1 * u.pix, 5 * u.pix), (2 * u.pix, 6 * u.pix)) trans = models.Shift(10 * u .pix) & models.Shift(2 * u.pix) pipeline = [('detector', trans), ('sky', None)] w = wcs.WCS(pipeline) w.bounding_box = bb - assert_allclose(w(-1*u.pix, -1*u.pix), (np.nan, np.nan)) + world = w(-1*u.pix, -1*u.pix) + assert_allclose(world, (np.nan, np.nan)) def test_compound_bounding_box(): @@ -694,11 +698,11 @@ def test_footprint(self): def test_inverse(self): sky_coord = self.wcs(10, 20, with_units=True) - assert np.allclose(self.wcs.invert(sky_coord), (10, 20)) + assert np.allclose(self.wcs.invert(sky_coord, with_units=True), (10, 20)) def test_back_coordinates(self): sky_coord = self.wcs(1, 2, with_units=True) - res = self.wcs.transform('sky', 'focal', sky_coord) + res = self.wcs.transform('sky', 'focal', sky_coord, with_units=True) assert_allclose(res, self.wcs.get_transform('detector', 'focal')(1, 2)) def test_units(self): @@ -818,7 +822,7 @@ def test_to_fits_sip_composite_frame(gwcs_cube_with_separable_spectral): assert fw_hdr['NAXIS2'] == 64 fw = astwcs.WCS(fw_hdr) - gskyval = w(1, 60, 55, with_units=True)[0] + gskyval = w.pixel_to_world(1, 60, 55)[1] fskyval = fw.all_pix2world(1, 60, 0) fskyval = [float(fskyval[ra_axis - 1]), float(fskyval[dec_axis - 1])] assert np.allclose([gskyval.ra.value, gskyval.dec.value], fskyval) @@ -831,7 +835,7 @@ def test_to_fits_sip_composite_frame_galactic(gwcs_3d_galactic_spectral): assert fw_hdr['CTYPE1'] == 'GLAT-TAN' fw = astwcs.WCS(fw_hdr) - gskyval = w(7, 8, 9, with_units=True)[0] + gskyval = w.pixel_to_world(7, 8, 9)[0] assert np.allclose([gskyval.b.value, gskyval.l.value], fw.all_pix2world(7, 9, 0), atol=1e-3) diff --git a/gwcs/wcs.py b/gwcs/wcs.py index 2ae714b2..c2cf9d49 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -15,6 +15,7 @@ Sky2Pix_TAN) from astropy.wcs.utils import celestial_frame_to_wcs, proj_plane_pixel_scales from scipy import linalg, optimize +from astropy.wcs.wcsapi.high_level_api import high_level_objects_to_values, values_to_high_level_objects from . import coordinate_frames as cf from . import utils @@ -134,7 +135,6 @@ class WCS(GWCSAPIMixin): def __init__(self, forward_transform=None, input_frame='detector', output_frame=None, name=""): - #self.low_level_wcs = self self._approx_inverse = None self._available_frames = [] self._pipeline = [] @@ -329,6 +329,19 @@ def _get_frame_name(self, frame): frame_obj = frame return name, frame_obj + def _add_units_input(self, arrays, frame): + if frame is not None: + return tuple(u.Quantity(array, unit) for array, unit in zip(arrays, frame.unit)) + + return arrays + + def _remove_units_input(self, arrays, frame): + if frame is not None: + return tuple(array.to_value(unit) if isinstance(array, u.Quantity) else array + for array, unit in zip(arrays, frame.unit)) + + return arrays + def __call__(self, *args, **kwargs): """ Executes the forward transform. @@ -336,11 +349,6 @@ def __call__(self, *args, **kwargs): args : float or array-like Inputs in the input coordinate system, separate inputs for each dimension. - with_units : bool - If ``True`` returns a `~astropy.coordinates.SkyCoord` or - `~astropy.coordinates.SpectralCoord` object, by using the units of - the output cooridnate frame. - Optional, default=False. with_bounding_box : bool, optional If True(default) values in the result which correspond to any of the inputs being outside the bounding_box are set @@ -348,26 +356,45 @@ def __call__(self, *args, **kwargs): fill_value : float, optional Output value for inputs outside the bounding_box (default is np.nan). + with_units : bool, optional + If ``True`` then high level Astropy objects will be returned. + Optional, default=False. """ - transform = self.forward_transform - if transform is None: - raise NotImplementedError("WCS.forward_transform is not implemented.") - with_units = kwargs.pop("with_units", False) - if 'with_bounding_box' not in kwargs: - kwargs['with_bounding_box'] = True - if 'fill_value' not in kwargs: - kwargs['fill_value'] = np.nan - result = transform(*args, **kwargs) + results = self._call_forward(*args, **kwargs) if with_units: - if self.output_frame.naxes == 1: - result = self.output_frame.coordinates(result) - else: - result = self.output_frame.coordinates(*result) + high_level = values_to_high_level_objects(*results, low_level_wcs=self) + if len(high_level) == 1: + high_level = high_level[0] + return high_level + return results + + def _call_forward(self, *args, from_frame=None, to_frame=None, + with_bounding_box=True, fill_value=np.nan, **kwargs): + """ + Executes the forward transform, but values only. + """ + if from_frame is None and to_frame is None: + transform = self.forward_transform + else: + transform = self.get_transform(from_frame, to_frame) - return result + if transform is None: + raise NotImplementedError("WCS.forward_transform is not implemented.") + + # Validate that the input type matches what the transform expects + input_is_quantity = any((isinstance(a, u.Quantity) for a in args)) + if not input_is_quantity and transform.uses_quantity: + args = self._add_units_input(args, self.input_frame) + if not transform.uses_quantity and input_is_quantity: + args = self._remove_units_input(args, self.input_frame) + + return transform(*args, + with_bounding_box=with_bounding_box, + fill_value=fill_value, + **kwargs) def in_image(self, *args, **kwargs): """ @@ -451,9 +478,8 @@ def invert(self, *args, **kwargs): Output value for inputs outside the bounding_box (default is ``np.nan``). with_units : bool, optional - If ``True`` returns a `~astropy.coordinates.SkyCoord` or - `~astropy.coordinates.SpectralCoord` object, by using the units of - the output cooridnate frame. Default is `False`. + If ``True`` then high level Astropy objects will be accepted. + Optional, default=False. Other Parameters ---------------- @@ -466,40 +492,35 @@ def invert(self, *args, **kwargs): result : tuple or value Returns a tuple of scalar or array values for each axis. Unless ``input_frame.naxes == 1`` when it shall return the value. + The return type will be `~astropy.unit.Quantity` objects if the + transform returns ``Quantity`` objects, else values. """ with_units = kwargs.pop('with_units', False) + if with_units: + args = high_level_objects_to_values(*args, low_level_wcs=self) - if not utils.isnumerical(args[0]): - args = self.output_frame.coordinate_to_quantity(*args) - if self.output_frame.naxes == 1: - args = [args] - try: - if not self.backward_transform.uses_quantity: - args = utils.get_values(self.output_frame.unit, *args) - except (NotImplementedError, KeyError): - args = utils.get_values(self.output_frame.unit, *args) - - if 'with_bounding_box' not in kwargs: - kwargs['with_bounding_box'] = True - - if 'fill_value' not in kwargs: - kwargs['fill_value'] = np.nan + return self._call_backward(*args, **kwargs) + def _call_backward(self, *args, with_bounding_box=True, fill_value=np.nan, **kwargs): try: + transform = self.backward_transform + # Validate that the input type matches what the transform expects + input_is_quantity = any((isinstance(a, u.Quantity) for a in args)) + if not input_is_quantity and transform.uses_quantity: + args = self._add_units_input(args, self.output_frame) + if not transform.uses_quantity and input_is_quantity: + args = self._remove_units_input(args, self.output_frame) + # remove iterative inverse-specific keyword arguments: akwargs = {k: v for k, v in kwargs.items() if k not in _ITER_INV_KWARGS} - result = self.backward_transform(*args, **akwargs) + result = transform(*args, with_bounding_box=with_bounding_box, fill_value=fill_value, **akwargs) except (NotImplementedError, KeyError): - result = self.numerical_inverse(*args, **kwargs, with_units=with_units) + # Always strip units for numerical inverse + args = self._remove_units_input(args, self.output_frame) + result = self.numerical_inverse(*args, with_bounding_box=with_bounding_box, fill_value=fill_value, **kwargs) - if with_units and self.input_frame: - if self.input_frame.naxes == 1: - return self.input_frame.coordinates(result) - else: - return self.input_frame.coordinates(*result) - else: - return result + return result def numerical_inverse(self, *args, tolerance=1e-5, maxiter=50, adaptive=True, detect_divergence=True, quiet=True, with_bounding_box=True, @@ -739,12 +760,6 @@ def numerical_inverse(self, *args, tolerance=1e-5, maxiter=50, adaptive=True, [2.76552923e-05 1.14789013e-05]] """ - if not utils.isnumerical(args[0]): - args = self.output_frame.coordinate_to_quantity(*args) - if self.output_frame.naxes == 1: - args = [args] - args = utils.get_values(self.output_frame.unit, *args) - args_shape = np.shape(args) nargs = args_shape[0] arg_dim = len(args_shape) - 1 @@ -813,13 +828,7 @@ def numerical_inverse(self, *args, tolerance=1e-5, maxiter=50, adaptive=True, result = tuple(np.reshape(result, args_shape)) - if with_units and self.input_frame: - if self.input_frame.naxes == 1: - return self.input_frame.coordinates(result) - else: - return self.input_frame.coordinates(*result) - else: - return result + return result def _vectorized_fixed_point(self, pix0, world, tolerance, maxiter, adaptive, detect_divergence, quiet, @@ -1103,33 +1112,20 @@ def transform(self, from_frame, to_frame, *args, **kwargs): fill_value : float, optional Output value for inputs outside the bounding_box (default is np.nan). """ - transform = self.get_transform(from_frame, to_frame) - if not utils.isnumerical(args[0]): - inp_frame = getattr(self, from_frame) - args = inp_frame.coordinate_to_quantity(*args) - if not transform.uses_quantity: - args = utils.get_values(inp_frame.unit, *args) + # Determine if the transform is actually an inverse + from_ind = self._get_frame_index(from_frame) + to_ind = self._get_frame_index(to_frame) + backward = to_ind < from_ind with_units = kwargs.pop("with_units", False) - if 'with_bounding_box' not in kwargs: - kwargs['with_bounding_box'] = True - if 'fill_value' not in kwargs: - kwargs['fill_value'] = np.nan + if with_units and backward: + args = high_level_objects_to_values(*args, low_level_wcs=self) - result = transform(*args, **kwargs) + results = self._call_forward(*args, from_frame=from_frame, to_frame=to_frame, **kwargs) - if with_units: - to_frame_name, to_frame_obj = self._get_frame_name(to_frame) - if to_frame_obj is not None: - if to_frame_obj.naxes == 1: - result = to_frame_obj.coordinates(result) - else: - result = to_frame_obj.coordinates(*result) - else: - raise TypeError("Coordinate objects could not be created because" - "frame {0} is not defined.".format(to_frame_name)) - - return result + if with_units and not backward: + return values_to_high_level_objects(*results, low_level_wcs=self) + return results @property def available_frames(self): diff --git a/gwcs/wcstools.py b/gwcs/wcstools.py index 179f7773..10716679 100644 --- a/gwcs/wcstools.py +++ b/gwcs/wcstools.py @@ -327,7 +327,7 @@ def wcs_from_points(xy, world_coords, proj_point='center', "Only one of {} is supported.".format(polynomial_type, supported_poly_types.keys())) - skyrot = models.RotateCelestial2Native(crval[0].deg, crval[1].deg, 180) + skyrot = models.RotateCelestial2Native(crval[0].to_value(u.deg), crval[1].to_value(u.deg), 180) trans = (skyrot | projection) projection_x, projection_y = trans(lon, lat) poly = supported_poly_types[polynomial_type](poly_degree) From e5b17461301ce77e061e67aee459fb516e7df1a6 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Tue, 20 Jun 2023 14:30:30 +0100 Subject: [PATCH 03/49] Remove now unused methods coordinates() and coordinate_to_quantity() are replaced by APE 14 methods --- gwcs/coordinate_frames.py | 208 -------------------------------------- 1 file changed, 208 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index 2c360259..14d09df4 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -269,47 +269,6 @@ def _world_axis_object_components(self): `astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_components` """ - # TODO: What order are args in here? - # Are they in transform order so we should use `axes_order` to reorder them? - @abc.abstractmethod - def coordinates(self, *args) -> object: - """ - Construct a rich coordinate object from the output of a transform. - - The object returned by this method must match the structures returned by - ``_world_axis_object_classes`` and ``_world_axis_object_components``. - - Parameters - ---------- - args : `~numbers.Number` or .Quantity` - The numberical objects returned by a transform (or user input). - - Returns - ``````` - coord : `object` - A single "rich" coordinate object such as `~astropy.coordinates.SkyCoord`. - """ - - # TODO: What are inputs, shouldn't this only be one single rich coordinate - # object? It seems the current implementation also handles not-that? i.e. - # CelestialCoord also accepts two quantity objects? - # TODO: Again what are the order of the return values here? - @abc.abstractmethod - def coordinate_to_quantity(self, *coords) -> tuple[u.Quantity, ...]: - """ - Construct `~astropy.units.Quantity` objects from the rich coordinate objects. - - Parameters - ---------- - coord : `object` - The rich coordinate object to convert. - - Returns - ------- - args: iterable of `~astropy.units.Quantity` objects. - The numerical values to pass to the transform. - """ - class CoordinateFrame(BaseCoordinateFrame): """ @@ -467,24 +426,6 @@ def axes_type(self): """ Type of this frame : 'SPATIAL', 'SPECTRAL', 'TIME'. """ return self._axes_type - def coordinates(self, *args): - """ Create world coordinates object""" - coo = tuple([arg * un if not hasattr(arg, "to") else arg.to(un) for arg, un in zip(args, self.unit)]) - if self.naxes == 1: - return coo[0] - return coo - - def coordinate_to_quantity(self, *coords): - """ - Given a rich coordinate object return an astropy quantity object. - """ - # NoOp leaves it to the model to handle - # If coords is a 1-tuple of quantity then return the element of the tuple - # This aligns the behavior with the other implementations - if not hasattr(coords, 'unit') and len(coords) == 1: - return coords[0] - return coords - @property def _default_axis_physical_types(self): """ @@ -590,47 +531,6 @@ def _world_axis_object_components(self): return [('celestial', 0, lambda sc: sc.spherical.lon.to_value(self.unit[0])), ('celestial', 1, lambda sc: sc.spherical.lat.to_value(self.unit[1]))] - def coordinates(self, *args): - """ - Create a SkyCoord object. - - Parameters - ---------- - args : float - inputs to wcs.input_frame - """ - if isinstance(args[0], coord.SkyCoord): - return args[0].transform_to(self.reference_frame) - return coord.SkyCoord(*args, unit=self.unit, frame=self.reference_frame) - - def coordinate_to_quantity(self, *coords): - """ Convert a ``SkyCoord`` object to quantities.""" - if len(coords) == 2: - arg = coords - elif len(coords) == 1: - arg = coords[0] - else: - raise ValueError("Unexpected number of coordinates in " - "input to frame {} : " - "expected 2, got {}".format(self.name, len(coords))) - - if isinstance(arg, coord.SkyCoord): - arg = arg.transform_to(self._reference_frame) - try: - lon = arg.data.lon - lat = arg.data.lat - except AttributeError: - lon = arg.spherical.lon - lat = arg.spherical.lat - - return lon, lat - - elif all(isinstance(a, u.Quantity) for a in arg): - return tuple(arg) - - else: - raise ValueError("Could not convert input {} to lon and lat quantities.".format(arg)) - class SpectralFrame(CoordinateFrame): """ @@ -691,21 +591,6 @@ def _world_axis_object_classes(self): def _world_axis_object_components(self): return [('spectral', 0, lambda sc: sc.to_value(self.unit[0]))] - def coordinates(self, *args): - # using SpectralCoord - if isinstance(args[0], coord.SpectralCoord): - return args[0].to(self.unit[0]) - else: - if hasattr(args[0], 'unit'): - return coord.SpectralCoord(*args).to(self.unit[0]) - else: - return coord.SpectralCoord(*args, self.unit[0]) - - def coordinate_to_quantity(self, *coords): - if hasattr(coords[0], 'unit'): - return coords[0] - return coords[0] * self.unit[0] - class TemporalFrame(CoordinateFrame): """ @@ -767,14 +652,6 @@ def offset_from_time_and_reference(time): return (time - self.reference_frame).sec return [('temporal', 0, offset_from_time_and_reference)] - def coordinates(self, *args): - if np.isscalar(args): - dt = args - else: - dt = args[0] - - return self._convert_to_time(dt, unit=self.unit[0], **self._attrs) - def _convert_to_time(self, dt, *, unit, **kwargs): if (not isinstance(dt, time.TimeDelta) and isinstance(dt, time.Time) or @@ -786,23 +663,6 @@ def _convert_to_time(self, dt, *, unit, **kwargs): return self.reference_frame + dt - def coordinate_to_quantity(self, *coords): - if isinstance(coords[0], time.Time): - ref_value = self.reference_frame.value - if not isinstance(ref_value, np.ndarray): - return (coords[0] - self.reference_frame).to(self.unit[0]) - else: - # If we can't convert to a quantity just drop the object out - # and hope the transform can cope. - return coords[0] - # Is already a quantity - elif hasattr(coords[0], 'unit'): - return coords[0] - if isinstance(coords[0], np.ndarray): - return coords[0] * self.unit[0] - else: - raise ValueError("Can not convert {} to Quantity".format(coords[0])) - class CompositeFrame(CoordinateFrame): """ @@ -852,40 +712,6 @@ def frames(self): def __repr__(self): return repr(self.frames) - def coordinates(self, *args): - coo = [] - if len(args) == len(self.frames): - for frame, arg in zip(self.frames, args): - coo.append(frame.coordinates(arg)) - else: - for frame in self.frames: - fargs = [args[i] for i in frame.axes_order] - coo.append(frame.coordinates(*fargs)) - return coo - - def coordinate_to_quantity(self, *coords): - if len(coords) == len(self.frames): - args = coords - elif len(coords) == self.naxes: - args = [] - for _frame in self.frames: - if _frame.naxes > 1: - # Collect the arguments for this frame based on axes_order - args.append([coords[i] for i in _frame.axes_order]) - else: - args.append(coords[_frame.axes_order[0]]) - else: - raise ValueError("Incorrect number of arguments") - - qs = [] - for _frame, arg in zip(self.frames, args): - ret = _frame.coordinate_to_quantity(arg) - if isinstance(ret, tuple): - qs += list(ret) - else: - qs.append(ret) - return qs - @property def _wao_classes_rename_map(self): mapper = defaultdict(dict) @@ -975,19 +801,6 @@ def _world_axis_object_classes(self): def _world_axis_object_components(self): return [('stokes', 0, 'value')] - def coordinates(self, *args): - if isinstance(args[0], u.Quantity): - arg = args[0].value - else: - arg = args[0] - - return StokesCoord(arg) - - def coordinate_to_quantity(self, *coords): - if isinstance(coords[0], StokesCoord): - return coords[0].value << u.one - return coords[0] - class Frame2D(CoordinateFrame): """ @@ -1020,24 +833,3 @@ def _default_axis_physical_types(self): else: ph_type = self.axes_type return tuple("custom:{}".format(t) for t in ph_type) - - def coordinates(self, *args): - args = [args[i] for i in self.axes_order] - coo = tuple([arg * un for arg, un in zip(args, self.unit)]) - return coo - - def coordinate_to_quantity(self, *coords): - # list or tuple - if len(coords) == 1 and astutil.isiterable(coords[0]): - coords = list(coords[0]) - elif len(coords) == 2: - coords = list(coords) - else: - raise ValueError("Unexpected number of coordinates in " - "input to frame {} : " - "expected 2, got {}".format(self.name, len(coords))) - - for i in range(2): - if not hasattr(coords[i], 'unit'): - coords[i] = coords[i] * self.unit[i] - return tuple(coords) From 3a946d012d9cdaaedbf094718ccfad52834b2804 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Tue, 20 Jun 2023 15:24:42 +0100 Subject: [PATCH 04/49] Rewrite coordinate systems tests for APE 14 This highlights some API changes here. --- gwcs/api.py | 4 +- gwcs/coordinate_frames.py | 60 ++++----- gwcs/tests/test_api.py | 4 +- gwcs/tests/test_coordinate_systems.py | 167 +++++++++++++++----------- gwcs/tests/test_wcs.py | 29 +++++ 5 files changed, 158 insertions(+), 106 deletions(-) diff --git a/gwcs/api.py b/gwcs/api.py index 24363dc9..08253298 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -260,11 +260,11 @@ def serialized_classes(self): @property def world_axis_object_classes(self): - return self.output_frame._world_axis_object_classes + return self.output_frame.world_axis_object_classes @property def world_axis_object_components(self): - return self.output_frame._world_axis_object_components + return self.output_frame.world_axis_object_components @property def pixel_axis_names(self): diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index 14d09df4..a8a1b05b 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -249,7 +249,7 @@ def axis_physical_types(self): @property @abc.abstractmethod - def _world_axis_object_classes(self): + def world_axis_object_classes(self): """ The APE 14 object classes for this frame. @@ -260,7 +260,7 @@ def _world_axis_object_classes(self): @property @abc.abstractmethod - def _world_axis_object_components(self): + def world_axis_object_components(self): """ The APE 14 object components for this frame. @@ -444,14 +444,14 @@ def axis_physical_types(self): return self._axis_physical_types or self._default_axis_physical_types @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return {f"{at}{i}" if i != 0 else at: (u.Quantity, (), {'unit': unit}) for i, (at, unit) in enumerate(zip(self._axes_type, self.unit))} @property - def _world_axis_object_components(self): + def world_axis_object_components(self): return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._axes_type)] @@ -519,7 +519,7 @@ def _default_axis_physical_types(self): return tuple("custom:{}".format(t) for t in self.axes_names) @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return {'celestial': ( coord.SkyCoord, (), @@ -581,7 +581,7 @@ def _default_axis_physical_types(self): return ("custom:{}".format(self.unit[0].physical_type),) @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return {'spectral': ( coord.SpectralCoord, (), @@ -633,8 +633,19 @@ def __init__(self, reference_frame, unit=None, axes_order=(0,), def _default_axis_physical_types(self): return ("time",) + def _convert_to_time(self, dt, *, unit, **kwargs): + if (not isinstance(dt, time.TimeDelta) and + isinstance(dt, time.Time) or + isinstance(self.reference_frame.value, np.ndarray)): + return time.Time(dt, **kwargs) + + if not hasattr(dt, 'unit'): + dt = dt * unit + + return self.reference_frame + dt + @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): comp = ( time.Time, (), @@ -644,7 +655,7 @@ def _world_axis_object_classes(self): return {'temporal': comp} @property - def _world_axis_object_components(self): + def world_axis_object_components(self): if isinstance(self.reference_frame.value, np.ndarray): return [('temporal', 0, 'value')] @@ -652,17 +663,6 @@ def offset_from_time_and_reference(time): return (time - self.reference_frame).sec return [('temporal', 0, offset_from_time_and_reference)] - def _convert_to_time(self, dt, *, unit, **kwargs): - if (not isinstance(dt, time.TimeDelta) and - isinstance(dt, time.Time) or - isinstance(self.reference_frame.value, np.ndarray)): - return time.Time(dt, **kwargs) - - if not hasattr(dt, 'unit'): - dt = dt * unit - - return self.reference_frame + dt - class CompositeFrame(CoordinateFrame): """ @@ -699,10 +699,10 @@ def __init__(self, frames, name=None): "axes_order should contain unique numbers, " "got {}.".format(axes_order)) - super(CompositeFrame, self).__init__(naxes, axes_type=axes_type, - axes_order=axes_order, - unit=unit, axes_names=axes_names, - name=name) + super().__init__(naxes, axes_type=axes_type, + axes_order=axes_order, + unit=unit, axes_names=axes_names, + name=name) self._axis_physical_types = tuple(ph_type) @property @@ -719,7 +719,7 @@ def _wao_classes_rename_map(self): for frame in self.frames: # ensure the frame is in the mapper mapper[frame] - for key in frame._world_axis_object_classes.keys(): + for key in frame.world_axis_object_classes.keys(): if key in seen_names: new_key = f"{key}{seen_names.count(key)}" mapper[frame][key] = new_key @@ -731,7 +731,7 @@ def _wao_renamed_components_iter(self): mapper = self._wao_classes_rename_map for frame in self.frames: renamed_components = [] - for comp in frame._world_axis_object_components: + for comp in frame.world_axis_object_components: comp = list(comp) rename = mapper[frame].get(comp[0]) if rename: @@ -743,14 +743,14 @@ def _wao_renamed_components_iter(self): def _wao_renamed_classes_iter(self): mapper = self._wao_classes_rename_map for frame in self.frames: - for key, value in frame._world_axis_object_classes.items(): + for key, value in frame.world_axis_object_classes.items(): rename = mapper[frame].get(key) if rename: key = rename yield key, value @property - def _world_axis_object_components(self): + def world_axis_object_components(self): """ We need to generate the components respecting the axes_order. """ @@ -764,7 +764,7 @@ def _world_axis_object_components(self): return out @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return dict(self._wao_renamed_classes_iter) @@ -790,7 +790,7 @@ def _default_axis_physical_types(self): return ("phys.polarization.stokes",) @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return {'stokes': ( StokesCoord, (), @@ -798,7 +798,7 @@ def _world_axis_object_classes(self): )} @property - def _world_axis_object_components(self): + def world_axis_object_components(self): return [('stokes', 0, 'value')] diff --git a/gwcs/tests/test_api.py b/gwcs/tests/test_api.py index aa40a4fe..21b573ed 100644 --- a/gwcs/tests/test_api.py +++ b/gwcs/tests/test_api.py @@ -497,12 +497,12 @@ def test_composite_many_base_frame(): q_frame_2 = cf.CoordinateFrame(name='distance', axes_order=(1,), naxes=1, axes_type="SPATIAL", unit=(u.m,)) frame = cf.CompositeFrame([q_frame_1, q_frame_2]) - wao_classes = frame._world_axis_object_classes + wao_classes = frame.world_axis_object_classes assert len(wao_classes) == 2 assert not set(wao_classes.keys()).difference({"SPATIAL", "SPATIAL1"}) - wao_components = frame._world_axis_object_components + wao_components = frame.world_axis_object_components assert len(wao_components) == 2 assert not {c[0] for c in wao_components}.difference({"SPATIAL", "SPATIAL1"}) diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index 967657f8..88035b10 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -10,11 +10,12 @@ from astropy.tests.helper import assert_quantity_allclose from astropy.modeling import models as m from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1 -from astropy.coordinates import StokesCoord +from astropy.coordinates import StokesCoord, SpectralCoord from .. import WCS from .. import coordinate_frames as cf +from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects, high_level_objects_to_values import astropy astropy_version = astropy.__version__ @@ -33,7 +34,7 @@ focal = cf.Frame2D(name='focal', axes_order=(0, 1), unit=(u.m, u.m)) spec1 = cf.SpectralFrame(name='freq', unit=[u.Hz, ], axes_order=(2, )) -spec2 = cf.SpectralFrame(name='wave', unit=[u.m, ], axes_order=(2, ), axes_names=('lambda', )) +spec2 = cf.SpectralFrame(name='wave', unit=[u.m, ], axes_order=(2, ), axes_names=('lambda',)) spec3 = cf.SpectralFrame(name='energy', unit=[u.J, ], axes_order=(2, )) spec4 = cf.SpectralFrame(name='pixel', unit=[u.pix, ], axes_order=(2, )) spec5 = cf.SpectralFrame(name='speed', unit=[u.m/u.s, ], axes_order=(2, )) @@ -55,6 +56,19 @@ inputs3 = [(xscalar, yscalar, xscalar), (xarr, yarr, xarr)] +@pytest.fixture(autouse=True, scope="module") +def serialized_classes(): + """ + In the rest of this test file we are passing the CoordinateFrame object to + astropy helper functions as if they were a low level WCS object. + + This little patch means that this works. + """ + cf.CoordinateFrame.serialized_classes = False + yield + del cf.CoordinateFrame.serialized_classes + + def test_units(): assert(comp1.unit == (u.deg, u.deg, u.Hz)) assert(comp2.unit == (u.m, u.m, u.m)) @@ -64,19 +78,34 @@ def test_units(): assert(comp.unit == (u.deg, u.deg, u.Hz, u.m)) +# These two functions fake the old methods on CoordinateFrame to reduce the +# amount of refactoring that needed doing in these tests. +def coordinates(*inputs, frame): + results = values_to_high_level_objects(*inputs, low_level_wcs=frame) + if isinstance(results, list) and len(results) == 1: + return results[0] + return results + + +def coordinate_to_quantity(*inputs, frame): + results = high_level_objects_to_values(*inputs, low_level_wcs=frame) + results = [r< Date: Mon, 3 Jul 2023 15:12:14 +0100 Subject: [PATCH 05/49] cleanup --- gwcs/coordinate_frames.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index a8a1b05b..3abb2208 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -118,7 +118,7 @@ from astropy.coordinates import StokesCoord __all__ = ['BaseCoordinateFrame', 'Frame2D', 'CelestialFrame', 'SpectralFrame', 'CompositeFrame', - 'CoordinateFrame', 'TemporalFrame', 'StokesFrame', 'PixelFrame'] + 'CoordinateFrame', 'TemporalFrame', 'StokesFrame'] def _ucd1_to_ctype_name_mapping(ctype_to_ucd, allowed_ucd_duplicates): From 76fed910c3a23b3955d04e139f3894d9e656c16a Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 12 Oct 2023 14:31:46 -0400 Subject: [PATCH 06/49] Revert changes to inverse w/r/t with_units --- gwcs/tests/test_wcs.py | 4 ++-- gwcs/wcs.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index 5fea9f96..b1fdd24e 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -220,7 +220,7 @@ def test_return_coordinates(): output_quant = w.output_frame.coordinate_to_quantity(num_plus_output) assert_allclose(w(x, y), numerical_result) assert_allclose(utils.get_values(w.unit, *output_quant), numerical_result) - assert_allclose(w.invert(num_plus_output, with_units=True), (x, y)) + assert_allclose(w.invert(num_plus_output), (x, y)) assert isinstance(num_plus_output, coord.SkyCoord) # Spectral frame @@ -698,7 +698,7 @@ def test_footprint(self): def test_inverse(self): sky_coord = self.wcs(10, 20, with_units=True) - assert np.allclose(self.wcs.invert(sky_coord, with_units=True), (10, 20)) + assert np.allclose(self.wcs.invert(sky_coord), (10, 20)) def test_back_coordinates(self): sky_coord = self.wcs(1, 2, with_units=True) diff --git a/gwcs/wcs.py b/gwcs/wcs.py index c2cf9d49..8f73e248 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -478,7 +478,7 @@ def invert(self, *args, **kwargs): Output value for inputs outside the bounding_box (default is ``np.nan``). with_units : bool, optional - If ``True`` then high level Astropy objects will be accepted. + If ``True`` then high level astropy object (i.e. ``Quantity``) will be returned. Optional, default=False. Other Parameters @@ -496,11 +496,19 @@ def invert(self, *args, **kwargs): transform returns ``Quantity`` objects, else values. """ + if not utils.isnumerical(args[0]): + args = high_level_objects_to_values(*args, low_level_wcs=self) + + results = self._call_backward(*args, **kwargs) + with_units = kwargs.pop('with_units', False) if with_units: - args = high_level_objects_to_values(*args, low_level_wcs=self) + high_level = values_to_high_level_objects(*results, low_level_wcs=self) + if len(high_level) == 1: + high_level = high_level[0] + return high_level - return self._call_backward(*args, **kwargs) + return results def _call_backward(self, *args, with_bounding_box=True, fill_value=np.nan, **kwargs): try: From 102ba588f316677d08983746211d076eaca776af Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 12 Oct 2023 14:53:23 -0400 Subject: [PATCH 07/49] Use is_high_level to detect input to inverse --- gwcs/utils.py | 12 ++++++++++++ gwcs/wcs.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/gwcs/utils.py b/gwcs/utils.py index 104558cf..0d4ded87 100644 --- a/gwcs/utils.py +++ b/gwcs/utils.py @@ -481,3 +481,15 @@ def isnumerical(val): and not np.issubdtype(val.dtype, np.integer)): isnum = False return isnum + + +def is_high_level(*args, low_level_wcs): + """ + Determine if args matches the high level classes as defined by + ``low_level_wcs``. + """ + if len(args) != len(low_level_wcs.world_axis_object_classes): + return False + + return all([type(arg) is waoc[0] + for arg, waoc in zip(args, low_level_wcs.world_axis_object_classes.values())]) diff --git a/gwcs/wcs.py b/gwcs/wcs.py index 8f73e248..fda7aa65 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -496,7 +496,7 @@ def invert(self, *args, **kwargs): transform returns ``Quantity`` objects, else values. """ - if not utils.isnumerical(args[0]): + if utils.is_high_level(*args, low_level_wcs=self): args = high_level_objects_to_values(*args, low_level_wcs=self) results = self._call_backward(*args, **kwargs) From 9f22da6aaad7b38571afe85c2c3d8c074629440d Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 12 Oct 2023 14:54:01 -0400 Subject: [PATCH 08/49] Remove isnumerical --- gwcs/tests/test_utils.py | 15 --------------- gwcs/utils.py | 18 ------------------ 2 files changed, 33 deletions(-) diff --git a/gwcs/tests/test_utils.py b/gwcs/tests/test_utils.py index e69ec536..7f880e65 100644 --- a/gwcs/tests/test_utils.py +++ b/gwcs/tests/test_utils.py @@ -90,21 +90,6 @@ def test_get_axes(): assert not other -def test_isnumerical(): - sky = coord.SkyCoord(1 * u.deg, 2 * u.deg) - assert not gwutils.isnumerical(sky) - - assert not gwutils.isnumerical(2 * u.m) - - assert gwutils.isnumerical(float(0)) - assert gwutils.isnumerical(np.array(0)) - - assert not gwutils.isnumerical(np.array(['s200', '234'])) - - assert gwutils.isnumerical(np.array(0, dtype='>f8')) - assert gwutils.isnumerical(np.array(0, dtype='>i4')) - - def test_get_values(): args = 2 * u.cm units=(u.m, ) diff --git a/gwcs/utils.py b/gwcs/utils.py index 0d4ded87..603c5552 100644 --- a/gwcs/utils.py +++ b/gwcs/utils.py @@ -465,24 +465,6 @@ def create_projection_transform(projcode): return projklass(**projparams) -def isnumerical(val): - """ - Determine if a value is numerical (number or np.array of numbers). - """ - isnum = True - if isinstance(val, coords.SkyCoord): - isnum = False - elif isinstance(val, u.Quantity): - isnum = False - elif isinstance(val, (Time, TimeDelta)): - isnum = False - elif (isinstance(val, np.ndarray) - and not np.issubdtype(val.dtype, np.floating) - and not np.issubdtype(val.dtype, np.integer)): - isnum = False - return isnum - - def is_high_level(*args, low_level_wcs): """ Determine if args matches the high level classes as defined by From e74535218f592aebdec281882dfd3e0b46351aaf Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 12 Oct 2023 15:04:13 -0400 Subject: [PATCH 09/49] Remove old test --- gwcs/tests/test_wcs.py | 43 ------------------------------------------ 1 file changed, 43 deletions(-) diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index b1fdd24e..a9507fa9 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -208,49 +208,6 @@ def test_backward_transform_has_inverse(): assert_allclose(w.backward_transform.inverse(1, 2), w(1, 2)) -@pytest.mark.skip -def test_return_coordinates(): - """Test converting to coordinate objects or quantities.""" - w = wcs.WCS(pipe[:]) - x = 1 - y = 2.3 - numerical_result = (26.8, -0.6) - # Celestial frame - num_plus_output = w(x, y, with_units=True) - output_quant = w.output_frame.coordinate_to_quantity(num_plus_output) - assert_allclose(w(x, y), numerical_result) - assert_allclose(utils.get_values(w.unit, *output_quant), numerical_result) - assert_allclose(w.invert(num_plus_output), (x, y)) - assert isinstance(num_plus_output, coord.SkyCoord) - - # Spectral frame - poly = models.Polynomial1D(1, c0=1, c1=2) - w = wcs.WCS(forward_transform=poly, output_frame=spec) - numerical_result = poly(y) - num_plus_output = w(y, with_units=True) - output_quant = w.output_frame.coordinate_to_quantity(num_plus_output) - assert_allclose(utils.get_values(w.unit, output_quant), numerical_result) - assert isinstance(num_plus_output, u.Quantity) - - # CompositeFrame - [celestial, spectral] - output_frame = cf.CompositeFrame(frames=[icrs, spec]) - transform = m1 & poly - w = wcs.WCS(forward_transform=transform, output_frame=output_frame) - numerical_result = transform(x, y, y) - num_plus_output = w(x, y, y, with_units=True) - output_quant = w.output_frame.coordinate_to_quantity(*num_plus_output) - assert_allclose(utils.get_values(w.unit, *output_quant), numerical_result) - - # CompositeFrame - [celestial, Stokes] - output_frame = cf.CompositeFrame(frames=[icrs, stokes]) - transform = m1 & models.Identity(1) - w = wcs.WCS(forward_transform=transform, output_frame=output_frame) - numerical_result = transform(x, y, y) - num_plus_output = w(x, y, y, with_units=True) - output_quant = w.output_frame.coordinate_to_quantity(*num_plus_output) - assert_allclose(utils.get_values(w.unit, *output_quant), numerical_result) - - def test_from_fiducial_sky(): sky = coord.SkyCoord(1.63 * u.radian, -72.4 * u.deg, frame='fk5') tan = models.Pix2Sky_TAN() From 262a317e7e5981bb95584385c58a673b206d971d Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 12 Oct 2023 15:06:20 -0400 Subject: [PATCH 10/49] Remove unused imports --- gwcs/api.py | 1 - gwcs/coordinate_frames.py | 1 - gwcs/tests/test_wcs.py | 1 - gwcs/utils.py | 1 - 4 files changed, 4 deletions(-) diff --git a/gwcs/api.py b/gwcs/api.py index 08253298..e808c4fd 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -9,7 +9,6 @@ from astropy.modeling import separable import astropy.units as u -from . import utils from . import coordinate_frames as cf __all__ = ["GWCSAPIMixin"] diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index 3abb2208..01bc92c8 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -110,7 +110,6 @@ from astropy import time from astropy import units as u from astropy import utils as astutil -from astropy.utils.decorators import deprecated from astropy import coordinates as coord from astropy.wcs.wcsapi.low_level_api import (validate_physical_types, VALID_UCDS) diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index a9507fa9..308dcc32 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -21,7 +21,6 @@ from .. import wcs from ..wcstools import (wcs_from_fiducial, grid_from_bounding_box, wcs_from_points) from .. import coordinate_frames as cf -from .. import utils from ..utils import CoordinateFrameError from .utils import _gwcs_from_hst_fits_wcs from . import data diff --git a/gwcs/utils.py b/gwcs/utils.py index 603c5552..c04c105d 100644 --- a/gwcs/utils.py +++ b/gwcs/utils.py @@ -11,7 +11,6 @@ from astropy.io import fits from astropy import coordinates as coords from astropy import units as u -from astropy.time import Time, TimeDelta from astropy.wcs import Celprm From 882c6b2ec47961bf7a22f727307844973ce27820 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 12 Oct 2023 15:17:53 -0400 Subject: [PATCH 11/49] Fix doc build --- gwcs/coordinate_frames.py | 4 ++-- gwcs/wcs.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index 01bc92c8..344b4f19 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -254,7 +254,7 @@ def world_axis_object_classes(self): See Also -------- - `astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_classes` + astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_classes """ @property @@ -265,7 +265,7 @@ def world_axis_object_components(self): See Also -------- - `astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_components` + astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_components """ diff --git a/gwcs/wcs.py b/gwcs/wcs.py index fda7aa65..04a440ed 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -492,7 +492,7 @@ def invert(self, *args, **kwargs): result : tuple or value Returns a tuple of scalar or array values for each axis. Unless ``input_frame.naxes == 1`` when it shall return the value. - The return type will be `~astropy.unit.Quantity` objects if the + The return type will be `~astropy.units.Quantity` objects if the transform returns ``Quantity`` objects, else values. """ From 8c48a7ced0b48e7655f4f6fd80e5df2b1b3fe735 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 12 Oct 2023 15:32:26 -0400 Subject: [PATCH 12/49] Delete reference_position Closes #436 --- gwcs/converters/wcs.py | 11 ----------- gwcs/coordinate_frames.py | 25 ++----------------------- 2 files changed, 2 insertions(+), 34 deletions(-) diff --git a/gwcs/converters/wcs.py b/gwcs/converters/wcs.py index 313d6d10..6f2cbd78 100644 --- a/gwcs/converters/wcs.py +++ b/gwcs/converters/wcs.py @@ -147,19 +147,8 @@ def from_yaml_tree(self, node, tag, ctx): from ..coordinate_frames import SpectralFrame node = self._from_yaml_tree(node, tag, ctx) - if 'reference_position' in node: - node['reference_position'] = node['reference_position'].upper() - return SpectralFrame(**node) - def to_yaml_tree(self, frame, tag, ctx): - node = self._to_yaml_tree(frame, tag, ctx) - - if frame.reference_position is not None: - node['reference_position'] = frame.reference_position.lower() - - return node - class CompositeFrameConverter(FrameConverter): tags = ["tag:stsci.edu:gwcs/composite_frame-*"] diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index 344b4f19..ee2ddf90 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -158,10 +158,6 @@ def _ucd1_to_ctype_name_mapping(ctype_to_ucd, allowed_ucd_duplicates): STANDARD_REFERENCE_FRAMES = [frame.upper() for frame in coord.builtin_frames.__all__] -STANDARD_REFERENCE_POSITION = ["GEOCENTER", "BARYCENTER", "HELIOCENTER", - "TOPOCENTER", "LSR", "LSRK", "LSRD", - "GALACTIC_CENTER", "LOCAL_GROUP_CENTER"] - def get_ctype_from_ucd(ucd): """ @@ -199,7 +195,6 @@ def name(self) -> str: The name of the coordinate frame. """ - # TODO: Why is this not `units`? @property @abc.abstractmethod def unit(self) -> tuple[u.Unit, ...]: @@ -283,8 +278,6 @@ class CoordinateFrame(BaseCoordinateFrame): A dimension in the input data that corresponds to this axis. reference_frame : astropy.coordinates.builtin_frames Reference frame (usually used with output_frame to convert to world coordinate objects). - reference_position : str - Reference position - one of ``STANDARD_REFERENCE_POSITION`` unit : list of astropy.units.Unit Unit for each axis. axes_names : list @@ -294,7 +287,7 @@ class CoordinateFrame(BaseCoordinateFrame): """ def __init__(self, naxes, axes_type, axes_order, reference_frame=None, - reference_position=None, unit=None, axes_names=None, + unit=None, axes_names=None, name=None, axis_physical_types=None): self._naxes = naxes self._axes_order = tuple(axes_order) @@ -331,8 +324,6 @@ def __init__(self, naxes, axes_type, axes_order, reference_frame=None, else: self._name = name - self._reference_position = reference_position - if len(self._axes_type) != naxes: raise ValueError("Length of axes_type does not match number of axes.") if len(self._axes_order) != naxes: @@ -367,8 +358,6 @@ def __repr__(self): fmt = '<{0}(name="{1}", unit={2}, axes_names={3}, axes_order={4}'.format( self.__class__.__name__, self.name, self.unit, self.axes_names, self.axes_order) - if self.reference_position is not None: - fmt += ', reference_position="{0}"'.format(self.reference_position) if self.reference_frame is not None: fmt += ", reference_frame={0}".format(self.reference_frame) fmt += ")>" @@ -414,12 +403,6 @@ def reference_frame(self): """ Reference frame, used to convert to world coordinate objects. """ return self._reference_frame - # TODO: This seems to be spectral specific, should it be moved to SpectralFrame? - @property - def reference_position(self): - """ Reference Position. """ - return getattr(self, "_reference_position", None) - @property def axes_type(self): """ Type of this frame : 'SPATIAL', 'SPECTRAL', 'TIME'. """ @@ -547,19 +530,15 @@ class SpectralFrame(CoordinateFrame): Spectral axis name. name : str Name for this frame. - reference_position : str - Reference position - one of ``STANDARD_REFERENCE_POSITION`` """ def __init__(self, axes_order=(0,), reference_frame=None, unit=None, - axes_names=None, name=None, axis_physical_types=None, - reference_position=None): + axes_names=None, name=None, axis_physical_types=None): super(SpectralFrame, self).__init__(naxes=1, axes_type="SPECTRAL", axes_order=axes_order, axes_names=axes_names, reference_frame=reference_frame, unit=unit, name=name, - reference_position=reference_position, axis_physical_types=axis_physical_types) @property From 001eb07b6ae871a922fdc3f13956fd170fcdccca Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 12 Oct 2023 16:47:35 -0400 Subject: [PATCH 13/49] lint --- gwcs/coordinate_frames.py | 42 ++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index ee2ddf90..ec8a1882 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -41,7 +41,7 @@ also a converter between those inputs/outputs and richer coordinate representations of those inputs/ouputs. -For example, an output frame of type `~astropy.coordinates.SpectralCoord` +For example, an output frame of type `~gwcs.coordinate_frames.SpectralFrame` provides metadata to the `.WCS` object such as the ``axes_type`` being ``"SPECTRAL"`` and the unit of the output etc. The output frame also provides a converter of the numeric output of the transform to a @@ -478,20 +478,22 @@ def __init__(self, axes_order=None, reference_frame=None, unit = tuple([u.degree] * naxes) axes_type = ['SPATIAL'] * naxes - super(CelestialFrame, self).__init__(naxes=naxes, axes_type=axes_type, - axes_order=axes_order, - reference_frame=reference_frame, - unit=unit, - axes_names=axes_names, - name=name, axis_physical_types=axis_physical_types) + super().__init__(naxes=naxes, + axes_type=axes_type, + axes_order=axes_order, + reference_frame=reference_frame, + unit=unit, + axes_names=axes_names, + name=name, + axis_physical_types=axis_physical_types) @property def _default_axis_physical_types(self): if isinstance(self.reference_frame, coord.Galactic): return "pos.galactic.lon", "pos.galactic.lat" elif isinstance(self.reference_frame, (coord.GeocentricTrueEcliptic, - coord.GCRS, - coord.PrecessedGeocentric)): + coord.GCRS, + coord.PrecessedGeocentric)): return "pos.bodyrc.lon", "pos.bodyrc.lat" elif isinstance(self.reference_frame, coord.builtin_frames.BaseRADecFrame): return "pos.eq.ra", "pos.eq.dec" @@ -536,10 +538,10 @@ class SpectralFrame(CoordinateFrame): def __init__(self, axes_order=(0,), reference_frame=None, unit=None, axes_names=None, name=None, axis_physical_types=None): - super(SpectralFrame, self).__init__(naxes=1, axes_type="SPECTRAL", axes_order=axes_order, - axes_names=axes_names, reference_frame=reference_frame, - unit=unit, name=name, - axis_physical_types=axis_physical_types) + super().__init__(naxes=1, axes_type="SPECTRAL", axes_order=axes_order, + axes_names=axes_names, reference_frame=reference_frame, + unit=unit, name=name, + axis_physical_types=axis_physical_types) @property def _default_axis_physical_types(self): @@ -759,9 +761,9 @@ class StokesFrame(CoordinateFrame): """ def __init__(self, axes_order=(0,), axes_names=("stokes",), name=None, axis_physical_types=None): - super(StokesFrame, self).__init__(1, ["STOKES"], axes_order, name=name, - axes_names=axes_names, unit=u.one, - axis_physical_types=axis_physical_types) + super().__init__(1, ["STOKES"], axes_order, name=name, + axes_names=axes_names, unit=u.one, + axis_physical_types=axis_physical_types) @property def _default_axis_physical_types(self): @@ -799,10 +801,10 @@ class Frame2D(CoordinateFrame): def __init__(self, axes_order=(0, 1), unit=(u.pix, u.pix), axes_names=('x', 'y'), name=None, axis_physical_types=None): - super(Frame2D, self).__init__(naxes=2, axes_type=["SPATIAL", "SPATIAL"], - axes_order=axes_order, name=name, - axes_names=axes_names, unit=unit, - axis_physical_types=axis_physical_types) + super().__init__(naxes=2, axes_type=["SPATIAL", "SPATIAL"], + axes_order=axes_order, name=name, + axes_names=axes_names, unit=unit, + axis_physical_types=axis_physical_types) @property def _default_axis_physical_types(self): From 0cec3eeb236325f86c50bc72035503d86620d85e Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 12 Oct 2023 17:04:44 -0400 Subject: [PATCH 14/49] Add a test for axes ordering with CelestialFrame fixes #269 --- gwcs/coordinate_frames.py | 15 +++++++++++- gwcs/tests/test_wcs.py | 49 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index ec8a1882..aac4ff13 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -436,6 +436,11 @@ def world_axis_object_classes(self): def world_axis_object_components(self): return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._axes_type)] + @property + def _native_world_axis_object_components(self): + """Defines the target component ordering (i.e. not taking into account axes_order)""" + return self.world_axis_object_components + class CelestialFrame(CoordinateFrame): """ @@ -515,6 +520,14 @@ def _world_axis_object_components(self): return [('celestial', 0, lambda sc: sc.spherical.lon.to_value(self.unit[0])), ('celestial', 1, lambda sc: sc.spherical.lat.to_value(self.unit[1]))] + @property + def world_axis_object_components(self): + # Sort the native waoc by the axes order. The axes order may have jumps + # in it if there are other frames in between the components. + ordered = np.array(self._native_world_axis_object_components, + dtype=object)[np.argsort(self.axes_order)] + return list(map(tuple, ordered)) + class SpectralFrame(CoordinateFrame): """ @@ -711,7 +724,7 @@ def _wao_renamed_components_iter(self): mapper = self._wao_classes_rename_map for frame in self.frames: renamed_components = [] - for comp in frame.world_axis_object_components: + for comp in frame._native_world_axis_object_components: comp = list(comp) rename = mapper[frame].get(comp[0]) if rename: diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index 308dcc32..abd588f9 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -1433,9 +1433,14 @@ def test_no_bounding_box_if_read_from_file(tmp_path): def test_split_frame_wcs(): - # Setup a model where the pixel & world axes are (lat, wave, lon) - spatial = models.Multiply(10*u.arcsec/u.pix) & models.Multiply(15*u.arcsec/u.pix) # pretend this is a spatial model + # Setup a WCS where the pixel & world axes are (lat, wave, lon) + + # We setup a model which is pretending to be a celestial transform. Note + # that we are pretending that this model is ordered lon, lat because that's + # what the projections require in astropy. + spatial = models.Multiply(20*u.arcsec/u.pix) & models.Multiply(15*u.arcsec/u.pix) compound = models.Linear1D(intercept=0*u.nm, slope=10*u.nm/u.pix) & spatial + # This forward transforms uses mappings to be (lat, wave, lon) forward = models.Mapping((1, 2, 0)) | compound | models.Mapping((2, 0, 1)) # Setup the output frame @@ -1448,14 +1453,52 @@ def test_split_frame_wcs(): axes_order=list(range(3)), unit=[u.pix]*3) iwcs = wcs.WCS(forward, input_frame, output_frame) - input_pixel = [0*u.pix, 1*u.pix, 2*u.pix] + input_pixel = [1*u.pix, 2*u.pix, 3*u.pix] output_world = iwcs.pixel_to_world_values(*input_pixel) output_pixel = iwcs.world_to_pixel_values(*output_world) assert_allclose(output_pixel, u.Quantity(input_pixel).to_value(u.pix)) + expected_world = [15*u.arcsec, 20*u.nm, 60*u.arcsec] + for expected, output in zip(expected_world, output_world): + assert_allclose(output, expected.value) + world_obj = iwcs.pixel_to_world(*input_pixel) assert isinstance(world_obj[0], coord.SkyCoord) assert isinstance(world_obj[1], coord.SpectralCoord) + assert u.allclose(world_obj[0].spherical.lat, expected_world[0]) + assert u.allclose(world_obj[0].spherical.lon, expected_world[2]) + assert u.allclose(world_obj[1], expected_world[1]) + obj_pixel = iwcs.world_to_pixel(*world_obj) assert_allclose(obj_pixel, u.Quantity(input_pixel).to_value(u.pix)) + + +def test_reordered_celestial(): + # This is a spatial model which is ordered lat, lon for the purposes of this test. + spatial = models.Multiply(20*u.deg/u.pix) & models.Multiply(15*u.deg/u.pix) + + celestial_frame = cf.CelestialFrame(axes_order=(1, 0), unit=(u.deg, u.deg), + reference_frame=coord.ICRS()) + + input_frame = cf.CoordinateFrame(2, ["PIXEL"]*2, + axes_order=list(range(2)), unit=[u.pix]*2) + + iwcs = wcs.WCS(spatial, input_frame, celestial_frame) + + input_pixel = [1*u.pix, 3*u.pix] + output_world = iwcs.pixel_to_world_values(*input_pixel) + output_pixel = iwcs.world_to_pixel_values(*output_world) + assert_allclose(output_pixel, u.Quantity(input_pixel).to_value(u.pix)) + + expected_world = [20*u.deg, 45*u.deg] + assert_allclose(output_world, [e.value for e in expected_world]) + + world_obj = iwcs.pixel_to_world(*input_pixel) + assert isinstance(world_obj, coord.SkyCoord) + + assert u.allclose(world_obj.spherical.lat, expected_world[0]) + assert u.allclose(world_obj.spherical.lon, expected_world[1]) + + obj_pixel = iwcs.world_to_pixel(world_obj) + assert_allclose(obj_pixel, u.Quantity(input_pixel).to_value(u.pix)) From 45fed688a567363f297c7f1150e475a562ddc9a2 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 12 Oct 2023 17:37:55 -0400 Subject: [PATCH 15/49] =?UTF-8?q?Test=20different=20Units=20said=20Nadia?= =?UTF-8?q?=20=F0=9F=92=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gwcs/tests/test_wcs.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index abd588f9..42e6ef53 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -1438,13 +1438,13 @@ def test_split_frame_wcs(): # We setup a model which is pretending to be a celestial transform. Note # that we are pretending that this model is ordered lon, lat because that's # what the projections require in astropy. - spatial = models.Multiply(20*u.arcsec/u.pix) & models.Multiply(15*u.arcsec/u.pix) + spatial = models.Multiply(20*u.arcsec/u.pix) & models.Multiply(15*u.deg/u.pix) compound = models.Linear1D(intercept=0*u.nm, slope=10*u.nm/u.pix) & spatial # This forward transforms uses mappings to be (lat, wave, lon) forward = models.Mapping((1, 2, 0)) | compound | models.Mapping((2, 0, 1)) # Setup the output frame - celestial_frame = cf.CelestialFrame(axes_order=(2, 0), unit=(u.arcsec, u.arcsec), + celestial_frame = cf.CelestialFrame(axes_order=(2, 0), unit=(u.arcsec, u.deg), reference_frame=coord.ICRS()) spectral_frame = cf.SpectralFrame(axes_order=(1,), unit=u.nm) output_frame = cf.CompositeFrame([spectral_frame, celestial_frame]) @@ -1458,7 +1458,7 @@ def test_split_frame_wcs(): output_pixel = iwcs.world_to_pixel_values(*output_world) assert_allclose(output_pixel, u.Quantity(input_pixel).to_value(u.pix)) - expected_world = [15*u.arcsec, 20*u.nm, 60*u.arcsec] + expected_world = [15*u.deg, 20*u.nm, 60*u.arcsec] for expected, output in zip(expected_world, output_world): assert_allclose(output, expected.value) @@ -1476,9 +1476,9 @@ def test_split_frame_wcs(): def test_reordered_celestial(): # This is a spatial model which is ordered lat, lon for the purposes of this test. - spatial = models.Multiply(20*u.deg/u.pix) & models.Multiply(15*u.deg/u.pix) + spatial = models.Multiply(20*u.arcsec/u.pix) & models.Multiply(15*u.deg/u.pix) - celestial_frame = cf.CelestialFrame(axes_order=(1, 0), unit=(u.deg, u.deg), + celestial_frame = cf.CelestialFrame(axes_order=(1, 0), unit=(u.arcsec, u.deg), reference_frame=coord.ICRS()) input_frame = cf.CoordinateFrame(2, ["PIXEL"]*2, @@ -1491,7 +1491,7 @@ def test_reordered_celestial(): output_pixel = iwcs.world_to_pixel_values(*output_world) assert_allclose(output_pixel, u.Quantity(input_pixel).to_value(u.pix)) - expected_world = [20*u.deg, 45*u.deg] + expected_world = [20*u.arcsec, 45*u.deg] assert_allclose(output_world, [e.value for e in expected_world]) world_obj = iwcs.pixel_to_world(*input_pixel) From 7a8006ae84cfa45c562a33db344b5a1b0b125a85 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Mon, 16 Oct 2023 13:50:46 +0100 Subject: [PATCH 16/49] First attempt at keeping a sorted and unsorted list of frame props --- gwcs/coordinate_frames.py | 153 +++++++++++++++++++++----------------- 1 file changed, 86 insertions(+), 67 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index aac4ff13..0bf2c5c2 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -105,6 +105,7 @@ from collections import defaultdict import logging import numpy as np +from dataclasses import dataclass, InitVar from astropy.utils.misc import isiterable from astropy import time @@ -176,7 +177,6 @@ def get_ctype_from_ucd(ucd): return UCD1_TO_CTYPE.get(ucd, "") - class BaseCoordinateFrame(abc.ABC): """ API Definition for a Coordinate frame @@ -264,6 +264,63 @@ def world_axis_object_components(self): """ +@dataclass +class FrameProperties: + naxes: InitVar[int] + axes_type: tuple[str] + unit: tuple[u.Unit] = None + axes_names: tuple[str] = None + axis_physical_types: list[str] = None + + def __post_init__(self, naxes): + if isinstance(self.axes_type, str): + self.axes_type = (self.axes_type,) + else: + self.axes_type = tuple(self.axes_type) + + if len(self.axes_type) != naxes: + raise ValueError("Length of axes_type does not match number of axes.") + + if self.unit is not None: + if astutil.isiterable(self.unit): + unit = tuple(self.unit) + else: + unit = (self.unit,) + if len(unit) != naxes: + raise ValueError("Number of units does not match number of axes.") + else: + self.unit = tuple(u.Unit(au) for au in unit) + else: + self.unit = tuple(u.dimensionless_unscaled for na in range(naxes)) + + if self.axes_names is not None: + if isinstance(self.axes_names, str): + self.axes_names = (self.axes_names,) + else: + self.axes_names = tuple(self.axes_names) + if len(self.axes_names) != naxes: + raise ValueError("Number of axes names does not match number of axes.") + else: + self.axes_names = tuple([""] * naxes) + + if self.axis_physical_types is not None: + if isinstance(self.axis_physical_types, str): + self.axis_physical_types = (self.axis_physical_types,) + elif not isiterable(self.axis_physical_types): + raise TypeError("axis_physical_types must be of type string or iterable of strings") + if len(self.axis_physical_types) != naxes: + raise ValueError(f'"axis_physical_types" must be of length {naxes}') + ph_type = [] + for axt in self.axis_physical_types: + if axt not in VALID_UCDS and not axt.startswith("custom:"): + ph_type.append("custom:{axt}") + else: + ph_type.append(axt) + + validate_physical_types(ph_type) + self.axes_physical_types = tuple(ph_type) + + class CoordinateFrame(BaseCoordinateFrame): """ Base class for Coordinate Frames. @@ -291,68 +348,28 @@ def __init__(self, naxes, axes_type, axes_order, reference_frame=None, name=None, axis_physical_types=None): self._naxes = naxes self._axes_order = tuple(axes_order) - if isinstance(axes_type, str): - self._axes_type = (axes_type,) - else: - self._axes_type = tuple(axes_type) - self._reference_frame = reference_frame - if unit is not None: - if astutil.isiterable(unit): - unit = tuple(unit) - else: - unit = (unit,) - if len(unit) != naxes: - raise ValueError("Number of units does not match number of axes.") - else: - self._unit = tuple([u.Unit(au) for au in unit]) - else: - self._unit = tuple(u.Unit("") for na in range(naxes)) - if axes_names is not None: - if isinstance(axes_names, str): - axes_names = (axes_names,) - else: - axes_names = tuple(axes_names) - if len(axes_names) != naxes: - raise ValueError("Number of axes names does not match number of axes.") - else: - axes_names = tuple([""] * naxes) - self._axes_names = axes_names if name is None: self._name = self.__class__.__name__ else: self._name = name - if len(self._axes_type) != naxes: - raise ValueError("Length of axes_type does not match number of axes.") if len(self._axes_order) != naxes: raise ValueError("Length of axes_order does not match number of axes.") - super(CoordinateFrame, self).__init__() - # _axis_physical_types holds any user supplied physical types - self._axis_physical_types = self._set_axis_physical_types(axis_physical_types) - - def _set_axis_physical_types(self, pht): - """ - Set the physical type of the coordinate axes using VO UCD1+ v1.23 definitions. - """ - if pht is not None: - if isinstance(pht, str): - pht = (pht,) - elif not isiterable(pht): - raise TypeError("axis_physical_types must be of type string or iterable of strings") - if len(pht) != self.naxes: - raise ValueError('"axis_physical_types" must be of length {}'.format(self.naxes)) - ph_type = [] - for axt in pht: - if axt not in VALID_UCDS and not axt.startswith("custom:"): - ph_type.append("custom:{}".format(axt)) - else: - ph_type.append(axt) + if isinstance(axes_type, str): + axes_type = (axes_type,) + default_apt = tuple([f"custom:{t}" for t in axes_type]) + self._prop = FrameProperties( + naxes, + axes_type, + unit, + axes_names, + axis_physical_types or default_apt, + ) - validate_physical_types(ph_type) - return tuple(ph_type) + super().__init__() def __repr__(self): fmt = '<{0}(name="{1}", unit={2}, axes_names={3}, axes_order={4}'.format( @@ -368,6 +385,10 @@ def __str__(self): return self._name return self.__class__.__name__ + def _sort_property(self, property): + return tuple(dict(sorted(zip(property, self.axes_order), + key=lambda x: x[1])).keys()) + @property def name(self): """ A custom name of this frame.""" @@ -386,12 +407,12 @@ def naxes(self): @property def unit(self): """The unit of this frame.""" - return self._unit + return self._sort_property(self._prop.unit) @property def axes_names(self): """ Names of axes in the frame.""" - return self._axes_names + return self._sort_property(self._prop.axes_names) @property def axes_order(self): @@ -406,15 +427,7 @@ def reference_frame(self): @property def axes_type(self): """ Type of this frame : 'SPATIAL', 'SPECTRAL', 'TIME'. """ - return self._axes_type - - @property - def _default_axis_physical_types(self): - """ - The default physical types to use for this frame if none are specified - by the user. - """ - return tuple("custom:{}".format(t) for t in self.axes_type) + return self._sort_property(self._prop.axes_type) @property def axis_physical_types(self): @@ -423,7 +436,7 @@ def axis_physical_types(self): These physical types are the types in frame order, not transform order. """ - return self._axis_physical_types or self._default_axis_physical_types + return self._sort_property(self._prop.axis_physical_types) @property def world_axis_object_classes(self): @@ -434,7 +447,7 @@ def world_axis_object_classes(self): @property def world_axis_object_components(self): - return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._axes_type)] + return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._prop.axes_type)] @property def _native_world_axis_object_components(self): @@ -681,8 +694,14 @@ def __init__(self, frames, name=None): for frame in frames: axes_order.extend(frame.axes_order) for frame in frames: - for ind, axtype, un, n, pht in zip(frame.axes_order, frame.axes_type, - frame.unit, frame.axes_names, frame.axis_physical_types): + unsorted_prop = zip( + frame.axes_order, + frame._prop.axes_type, + frame._prop.unit, + frame._prop.axes_names, + frame._prop.axis_physical_types + ) + for ind, axtype, un, n, pht in unsorted_prop: axes_type[ind] = axtype axes_names[ind] = n unit[ind] = un From d9cae7d8c89780a857c5bf537a10054ce50bbb65 Mon Sep 17 00:00:00 2001 From: Nadia Dencheva Date: Mon, 20 Nov 2023 09:45:36 -0500 Subject: [PATCH 17/49] make tests pass, ecept slicing --- .github/workflows/ci.yml | 2 +- gwcs/coordinate_frames.py | 142 ++++++++++++++++++-------- gwcs/examples.py | 4 +- gwcs/tests/test_api.py | 9 +- gwcs/tests/test_api_slicing.py | 22 ++-- gwcs/tests/test_coordinate_systems.py | 2 +- gwcs/tests/test_wcs.py | 40 +++++--- gwcs/wcs.py | 1 + 8 files changed, 147 insertions(+), 75 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 766823c9..1392d5ae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,7 +3,7 @@ name: test on: push: branches: - - 'master' + - '*' tags: - '*' pull_request: diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index 0bf2c5c2..ffd4d221 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -60,7 +60,7 @@ .. code-block:: python - [SpectralCoord(axes_order=(1,)), CelestialCoord(axes_order=(2, 0))] + [SpectralFrame(axes_order=(1,)), CelestialFrame(axes_order=(2, 0))] we would map the outputs of this transform into the correct positions in the frames. As shown below, this is also used when constructing the inputs to the inverse transform. @@ -313,12 +313,20 @@ def __post_init__(self, naxes): ph_type = [] for axt in self.axis_physical_types: if axt not in VALID_UCDS and not axt.startswith("custom:"): - ph_type.append("custom:{axt}") + ph_type.append(f"custom:{axt}") else: ph_type.append(axt) validate_physical_types(ph_type) - self.axes_physical_types = tuple(ph_type) + self.axis_physical_types = tuple(ph_type) + + @property + def _default_axis_physical_type(self): + """ + The default physical types to use for this frame if none are specified + by the user. + """ + return tuple("custom:{}".format(t) for t in self.axes_type) class CoordinateFrame(BaseCoordinateFrame): @@ -360,17 +368,24 @@ def __init__(self, naxes, axes_type, axes_order, reference_frame=None, if isinstance(axes_type, str): axes_type = (axes_type,) - default_apt = tuple([f"custom:{t}" for t in axes_type]) + self._prop = FrameProperties( naxes, axes_type, unit, axes_names, - axis_physical_types or default_apt, + axis_physical_types or self._default_axis_physical_type(axes_type) ) super().__init__() + def _default_axis_physical_type(self, axes_type): + """ + The default physical types to use for this frame if none are specified + by the user. + """ + return tuple("custom:{}".format(t) for t in axes_type) + def __repr__(self): fmt = '<{0}(name="{1}", unit={2}, axes_names={3}, axes_order={4}'.format( self.__class__.__name__, self.name, @@ -386,8 +401,11 @@ def __str__(self): return self.__class__.__name__ def _sort_property(self, property): - return tuple(dict(sorted(zip(property, self.axes_order), - key=lambda x: x[1])).keys()) + #return tuple(dict(sorted(zip(property, self.axes_order), + # key=lambda x: x[1])).keys()) + sorted_prop = sorted(zip(property, self.axes_order), + key=lambda x: x[1]) + return tuple([t[0] for t in sorted_prop]) @property def name(self): @@ -408,6 +426,7 @@ def naxes(self): def unit(self): """The unit of this frame.""" return self._sort_property(self._prop.unit) + #return self._prop.unit @property def axes_names(self): @@ -436,14 +455,14 @@ def axis_physical_types(self): These physical types are the types in frame order, not transform order. """ - return self._sort_property(self._prop.axis_physical_types) + return self._prop.axis_physical_types or self._default_axis_physical_types @property def world_axis_object_classes(self): return {f"{at}{i}" if i != 0 else at: (u.Quantity, (), {'unit': unit}) - for i, (at, unit) in enumerate(zip(self._axes_type, self.unit))} + for i, (at, unit) in enumerate(zip(self.axes_type, self.unit))} @property def world_axis_object_components(self): @@ -486,16 +505,15 @@ def __init__(self, axes_order=None, reference_frame=None, if axes_names is None: axes_names = _axes_names naxes = len(_axes_names) - _unit = list(reference_frame.representation_component_units.values()) - if unit is None and _unit: - unit = _unit + self.native_axes_order = tuple(range(naxes)) if axes_order is None: - axes_order = tuple(range(naxes)) + axes_order = self.native_axes_order if unit is None: unit = tuple([u.degree] * naxes) axes_type = ['SPATIAL'] * naxes + pht = axis_physical_types or self._default_axis_physical_types(reference_frame, axes_names) super().__init__(naxes=naxes, axes_type=axes_type, axes_order=axes_order, @@ -503,30 +521,30 @@ def __init__(self, axes_order=None, reference_frame=None, unit=unit, axes_names=axes_names, name=name, - axis_physical_types=axis_physical_types) + axis_physical_types=pht) - @property - def _default_axis_physical_types(self): - if isinstance(self.reference_frame, coord.Galactic): + def _default_axis_physical_types(self, reference_frame, axes_names): + if isinstance(reference_frame, coord.Galactic): return "pos.galactic.lon", "pos.galactic.lat" - elif isinstance(self.reference_frame, (coord.GeocentricTrueEcliptic, - coord.GCRS, - coord.PrecessedGeocentric)): + elif isinstance(reference_frame, (coord.GeocentricTrueEcliptic, + coord.GCRS, + coord.PrecessedGeocentric)): return "pos.bodyrc.lon", "pos.bodyrc.lat" - elif isinstance(self.reference_frame, coord.builtin_frames.BaseRADecFrame): + elif isinstance(reference_frame, coord.builtin_frames.BaseRADecFrame): return "pos.eq.ra", "pos.eq.dec" - elif isinstance(self.reference_frame, coord.builtin_frames.BaseEclipticFrame): + elif isinstance(reference_frame, coord.builtin_frames.BaseEclipticFrame): return "pos.ecliptic.lon", "pos.ecliptic.lat" else: - return tuple("custom:{}".format(t) for t in self.axes_names) + return tuple("custom:{}".format(t) for t in axes_names) @property def world_axis_object_classes(self): + unit = np.array(self.unit)[np.argsort(self.axes_order)] return {'celestial': ( coord.SkyCoord, (), {'frame': self.reference_frame, - 'unit': self.unit})} + 'unit': unit})} @property def _world_axis_object_components(self): @@ -564,27 +582,32 @@ class SpectralFrame(CoordinateFrame): def __init__(self, axes_order=(0,), reference_frame=None, unit=None, axes_names=None, name=None, axis_physical_types=None): + if not isiterable(unit): + unit = (unit,) + + pht = axis_physical_types or self._default_axis_physical_types(unit) + super().__init__(naxes=1, axes_type="SPECTRAL", axes_order=axes_order, axes_names=axes_names, reference_frame=reference_frame, unit=unit, name=name, - axis_physical_types=axis_physical_types) + #axis_physical_types="em.wl") + axis_physical_types=pht) - @property - def _default_axis_physical_types(self): - if self.unit[0].physical_type == "frequency": + def _default_axis_physical_types(self, unit): + if unit[0].physical_type == "frequency": return ("em.freq",) - elif self.unit[0].physical_type == "length": + elif unit[0].physical_type == "length": return ("em.wl",) - elif self.unit[0].physical_type == "energy": + elif unit[0].physical_type == "energy": return ("em.energy",) - elif self.unit[0].physical_type == "speed": + elif unit[0].physical_type == "speed": return ("spect.dopplerVeloc",) logging.warning("Physical type may be ambiguous. Consider " "setting the physical type explicitly as " "either 'spect.dopplerVeloc.optical' or " "'spect.dopplerVeloc.radio'.") else: - return ("custom:{}".format(self.unit[0].physical_type),) + return ("custom:{}".format(unit[0].physical_type),) @property def world_axis_object_classes(self): @@ -625,9 +648,11 @@ def __init__(self, reference_frame, unit=None, axes_order=(0,), reference_frame.scale, reference_frame.location) + pht = axis_physical_types or self._default_axis_physical_types() + super().__init__(naxes=1, axes_type="TIME", axes_order=axes_order, axes_names=axes_names, reference_frame=reference_frame, - unit=unit, name=name, axis_physical_types=axis_physical_types) + unit=unit, name=name, axis_physical_types=pht) self._attrs = {} for a in self.reference_frame.info._represent_as_dict_extra_attrs: try: @@ -635,7 +660,7 @@ def __init__(self, reference_frame, unit=None, axes_order=(0,), except AttributeError: pass - @property + #@property def _default_axis_physical_types(self): return ("time",) @@ -686,13 +711,16 @@ class CompositeFrame(CoordinateFrame): def __init__(self, frames, name=None): self._frames = frames[:] naxes = sum([frame._naxes for frame in self._frames]) + axes_type = list(range(naxes)) unit = list(range(naxes)) axes_names = list(range(naxes)) - axes_order = [] ph_type = list(range(naxes)) + axes_order = [] + for frame in frames: axes_order.extend(frame.axes_order) + for frame in frames: unsorted_prop = zip( frame.axes_order, @@ -706,6 +734,7 @@ def __init__(self, frames, name=None): axes_names[ind] = n unit[ind] = un ph_type[ind] = pht + if len(np.unique(axes_order)) != len(axes_order): raise ValueError("Incorrect numbering of axes, " "axes_order should contain unique numbers, " @@ -714,6 +743,7 @@ def __init__(self, frames, name=None): super().__init__(naxes, axes_type=axes_type, axes_order=axes_order, unit=unit, axes_names=axes_names, + axis_physical_types=tuple(ph_type), name=name) self._axis_physical_types = tuple(ph_type) @@ -721,6 +751,22 @@ def __init__(self, frames, name=None): def frames(self): return self._frames + @property + def unit(self): + return self._prop.unit + + @property + def axes_names(self): + return self._prop.axes_names + + @property + def axes_type(self): + return self._prop.axes_type + + @property + def axis_physical_types(self): + return self._prop.axis_physical_types + def __repr__(self): return repr(self.frames) @@ -767,12 +813,14 @@ def world_axis_object_components(self): We need to generate the components respecting the axes_order. """ out = [None] * self.naxes + for frame, components in self._wao_renamed_components_iter: for i, ao in enumerate(frame.axes_order): out[ao] = components[i] if any([o is None for o in out]): raise ValueError("axes_order leads to incomplete world_axis_object_components") + return out @property @@ -793,11 +841,13 @@ class StokesFrame(CoordinateFrame): """ def __init__(self, axes_order=(0,), axes_names=("stokes",), name=None, axis_physical_types=None): + + pht = axis_physical_types or self._default_axis_physical_types() + super().__init__(1, ["STOKES"], axes_order, name=name, axes_names=axes_names, unit=u.one, - axis_physical_types=axis_physical_types) + axis_physical_types=pht) - @property def _default_axis_physical_types(self): return ("phys.polarization.stokes",) @@ -831,17 +881,19 @@ class Frame2D(CoordinateFrame): """ def __init__(self, axes_order=(0, 1), unit=(u.pix, u.pix), axes_names=('x', 'y'), - name=None, axis_physical_types=None): + name=None, axes_type=["SPATIAL", "SPATIAL"], axis_physical_types=None): + + pht = axis_physical_types or self._default_axis_physical_types(axes_names, axes_type) - super().__init__(naxes=2, axes_type=["SPATIAL", "SPATIAL"], + super().__init__(naxes=2, axes_type=axes_type, axes_order=axes_order, name=name, axes_names=axes_names, unit=unit, - axis_physical_types=axis_physical_types) + axis_physical_types=pht) - @property - def _default_axis_physical_types(self): - if all(self.axes_names): - ph_type = self.axes_names + def _default_axis_physical_types(self, axes_names, axes_type): + if axes_names is not None and all(axes_names): + ph_type = axes_names else: - ph_type = self.axes_type + ph_type = axes_type + return tuple("custom:{}".format(t) for t in ph_type) diff --git a/gwcs/examples.py b/gwcs/examples.py index 61fc9387..510b99db 100644 --- a/gwcs/examples.py +++ b/gwcs/examples.py @@ -45,7 +45,7 @@ def gwcs_2d_spatial_reordered(): A simple one step spatial WCS, in ICRS with a 1 and 2 px shift. """ out_frame = cf.CelestialFrame(reference_frame=coord.ICRS(), - axes_order=(1, 0)) + axes_order=(1, 0)) return wcs.WCS(MODEL_2D_SHIFT | models.Mapping((1, 0)), input_frame=DETECTOR_2D_FRAME, output_frame=out_frame) @@ -243,7 +243,7 @@ def gwcs_3d_galactic_spectral(): shift = models.Shift(-crpix3) & models.Shift(-crpix1) scale = models.Multiply(cdelt3) & models.Multiply(cdelt1) - proj = models.Pix2Sky_CAR() + proj = models.Pix2Sky_TAN() skyrot = models.RotateNative2Celestial(crval3, 90 + crval1, 180) celestial = shift | scale | proj | skyrot diff --git a/gwcs/tests/test_api.py b/gwcs/tests/test_api.py index 21b573ed..f4326abf 100644 --- a/gwcs/tests/test_api.py +++ b/gwcs/tests/test_api.py @@ -79,6 +79,11 @@ def test_names(wcsobj): assert wcsobj.pixel_axis_names == wcsobj.input_frame.axes_names +def test_names_split(gwcs_3d_galactic_spectral): + wcs = gwcs_3d_galactic_spectral + assert wcs.world_axis_names == wcs.output_frame.axes_names == ("Latitude", "Frequency", "Longitude") + + @fixture_wcs_ndim_types_units def test_pixel_n_dim(wcs_ndim_types_units): wcsobj, ndims, *_ = wcs_ndim_types_units @@ -201,7 +206,7 @@ def test_world_axis_object_classes_2d(gwcs_2d_spatial_shift): assert 'frame' in waoc['celestial'][2] assert 'unit' in waoc['celestial'][2] assert isinstance(waoc['celestial'][2]['frame'], coord.ICRS) - assert waoc['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(waoc['celestial'][2]['unit']) == (u.deg, u.deg) def test_world_axis_object_classes_2d_generic(gwcs_2d_quantity_shift): @@ -223,7 +228,7 @@ def test_world_axis_object_classes_4d(gwcs_4d_identity_units): assert 'frame' in waoc['celestial'][2] assert 'unit' in waoc['celestial'][2] assert isinstance(waoc['celestial'][2]['frame'], coord.ICRS) - assert waoc['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(waoc['celestial'][2]['unit']) == (u.deg, u.deg) temporal = waoc['temporal'] assert temporal[0] is time.Time diff --git a/gwcs/tests/test_api_slicing.py b/gwcs/tests/test_api_slicing.py index ba0a9b54..87510ae1 100644 --- a/gwcs/tests/test_api_slicing.py +++ b/gwcs/tests/test_api_slicing.py @@ -1,6 +1,7 @@ import astropy.units as u from astropy.coordinates import Galactic, SkyCoord, SpectralCoord +from astropy.wcs.wcsapi import wcs_info_str from astropy.wcs.wcsapi.wrappers import SlicedLowLevelWCS from numpy.testing import assert_allclose, assert_equal @@ -31,6 +32,11 @@ """ +def test_no_ellipsis(gwcs_3d_galactic_spectral): + expected_repr = EXPECTED_ELLIPSIS_REPR.replace("SlicedLowLevel", "") + assert wcs_info_str(gwcs_3d_galactic_spectral) == expected_repr.strip() + + def test_ellipsis(gwcs_3d_galactic_spectral): wcs = SlicedLowLevelWCS(gwcs_3d_galactic_spectral, Ellipsis) @@ -55,7 +61,7 @@ def test_ellipsis(gwcs_3d_galactic_spectral): assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) - assert wcs.world_axis_object_classes['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) assert wcs.world_axis_object_classes['spectral'][0] is SpectralCoord assert wcs.world_axis_object_classes['spectral'][1] == () @@ -120,7 +126,7 @@ def test_spectral_slice(gwcs_3d_galactic_spectral): assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) - assert wcs.world_axis_object_classes['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) assert_allclose(wcs.pixel_to_world_values(29, 44), (10, 25)) assert_allclose(wcs.array_index_to_world_values(44, 29), (10, 25)) @@ -185,7 +191,7 @@ def test_spectral_range(gwcs_3d_galactic_spectral): assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) - assert wcs.world_axis_object_classes['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) assert wcs.world_axis_object_classes['spectral'][0] is SpectralCoord assert wcs.world_axis_object_classes['spectral'][1] == () @@ -253,7 +259,7 @@ def test_celestial_slice(gwcs_3d_galactic_spectral): assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) - assert wcs.world_axis_object_classes['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) assert wcs.world_axis_object_classes['spectral'][0] is SpectralCoord assert wcs.world_axis_object_classes['spectral'][1] == () @@ -322,7 +328,7 @@ def test_celestial_range(gwcs_3d_galactic_spectral): assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) - assert wcs.world_axis_object_classes['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) assert wcs.world_axis_object_classes['spectral'][0] is SpectralCoord assert wcs.world_axis_object_classes['spectral'][1] == () @@ -394,7 +400,7 @@ def test_no_array_shape(gwcs_3d_galactic_spectral): assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) - assert wcs.world_axis_object_classes['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) assert wcs.world_axis_object_classes['spectral'][0] is SpectralCoord assert wcs.world_axis_object_classes['spectral'][1] == () @@ -441,7 +447,7 @@ def test_no_array_shape(gwcs_3d_galactic_spectral): def test_ellipsis_none_types(gwcs_3d_galactic_spectral): pht = list(gwcs_3d_galactic_spectral.output_frame._axis_physical_types) pht[1] = None - gwcs_3d_galactic_spectral.output_frame._axis_physical_types = tuple(pht) + gwcs_3d_galactic_spectral.output_frame._prop.axis_physical_types = tuple(pht) wcs = SlicedLowLevelWCS(gwcs_3d_galactic_spectral, Ellipsis) @@ -466,7 +472,7 @@ def test_ellipsis_none_types(gwcs_3d_galactic_spectral): assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) - assert wcs.world_axis_object_classes['celestial'][2]['unit'] == (u.deg, u.deg) + assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) assert_allclose(wcs.pixel_to_world_values(29, 39, 44), (10, 20, 25)) assert_allclose(wcs.array_index_to_world_values(44, 39, 29), (10, 20, 25)) diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index 88035b10..6f556c4a 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -375,7 +375,7 @@ def test_axis_physical_type(): assert comp.axis_physical_types == ('pos.eq.ra', 'pos.eq.dec', 'em.freq', 'em.wl') spec6 = cf.SpectralFrame(name='waven', axes_order=(1,), - axis_physical_types='em.wavenumber') + axis_physical_types='em.wavenumber', unit=u.Unit(1)) assert spec6.axis_physical_types == ('em.wavenumber',) t = cf.TemporalFrame(reference_frame=Time("2018-01-01T00:00:00"), unit=u.s) diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index 42e6ef53..5f017946 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -18,13 +18,12 @@ from astropy.utils.introspection import minversion import asdf -from .. import wcs -from ..wcstools import (wcs_from_fiducial, grid_from_bounding_box, wcs_from_points) -from .. import coordinate_frames as cf -from ..utils import CoordinateFrameError -from .utils import _gwcs_from_hst_fits_wcs -from . import data -from ..examples import gwcs_2d_bad_bounding_box_order +from gwcs import wcs +from gwcs.wcstools import (wcs_from_fiducial, grid_from_bounding_box, wcs_from_points) +from gwcs import coordinate_frames as cf +from gwcs.utils import CoordinateFrameError +from . utils import _gwcs_from_hst_fits_wcs +from gwcs.examples import gwcs_2d_bad_bounding_box_order data_path = os.path.split(os.path.abspath(data.__file__))[0] @@ -34,7 +33,7 @@ m2 = models.Scale(2) & models.Scale(-2) m = m1 | m2 -icrs = cf.CelestialFrame(reference_frame=coord.ICRS(), name='icrs') +icrs = cf.CelestialFrame(reference_frame=coord.ICRS(), name='icrs', unit=(u.deg, u.deg)) detector = cf.Frame2D(name='detector', axes_order=(0, 1)) focal = cf.Frame2D(name='focal', axes_order=(0, 1), unit=(u.m, u.m)) spec = cf.SpectralFrame(name='wave', unit=[u.m, ], axes_order=(2, ), axes_names=('lambda', )) @@ -1438,27 +1437,35 @@ def test_split_frame_wcs(): # We setup a model which is pretending to be a celestial transform. Note # that we are pretending that this model is ordered lon, lat because that's # what the projections require in astropy. + + # Input is (lat, wave, lon) + # lat: multuply by 20 arcsec, lon: multiply by 15 deg + # result should be 20 arcsec, 10nm, 45 deg spatial = models.Multiply(20*u.arcsec/u.pix) & models.Multiply(15*u.deg/u.pix) compound = models.Linear1D(intercept=0*u.nm, slope=10*u.nm/u.pix) & spatial # This forward transforms uses mappings to be (lat, wave, lon) - forward = models.Mapping((1, 2, 0)) | compound | models.Mapping((2, 0, 1)) + forward = models.Mapping((1, 0, 2)) | compound | models.Mapping((1, 0, 2)) # Setup the output frame - celestial_frame = cf.CelestialFrame(axes_order=(2, 0), unit=(u.arcsec, u.deg), - reference_frame=coord.ICRS()) - spectral_frame = cf.SpectralFrame(axes_order=(1,), unit=u.nm) + celestial_frame = cf.CelestialFrame(axes_order=(2, 0), unit=(u.deg, u.arcsec), + reference_frame=coord.ICRS(), axes_names=('lon', 'lat')) + #celestial_frame = cf.CelestialFrame(axes_order=(2, 0), unit=(u.arcsec, u.deg), + # reference_frame=coord.ICRS()) + spectral_frame = cf.SpectralFrame(axes_order=(1,), unit=u.nm, axes_names='wave') output_frame = cf.CompositeFrame([spectral_frame, celestial_frame]) + #output_frame = cf.CompositeFrame([celestial_frame, spectral_frame]) input_frame = cf.CoordinateFrame(3, ["PIXEL"]*3, axes_order=list(range(3)), unit=[u.pix]*3) iwcs = wcs.WCS(forward, input_frame, output_frame) - input_pixel = [1*u.pix, 2*u.pix, 3*u.pix] + input_pixel = [1*u.pix, 1*u.pix, 3*u.pix] output_world = iwcs.pixel_to_world_values(*input_pixel) output_pixel = iwcs.world_to_pixel_values(*output_world) assert_allclose(output_pixel, u.Quantity(input_pixel).to_value(u.pix)) - expected_world = [15*u.deg, 20*u.nm, 60*u.arcsec] + expected_world = [20*u.arcsec, 10*u.nm, 45*u.deg] + #expected_world = [15*u.deg, 20*u.nm, 60*u.arcsec] for expected, output in zip(expected_world, output_world): assert_allclose(output, expected.value) @@ -1476,7 +1483,8 @@ def test_split_frame_wcs(): def test_reordered_celestial(): # This is a spatial model which is ordered lat, lon for the purposes of this test. - spatial = models.Multiply(20*u.arcsec/u.pix) & models.Multiply(15*u.deg/u.pix) + # Expected lat=45 deg, lon=20 arcsec + spatial = models.Multiply(20*u.arcsec/u.pix) & models.Multiply(15*u.deg/u.pix) | models.Mapping((1,0)) celestial_frame = cf.CelestialFrame(axes_order=(1, 0), unit=(u.arcsec, u.deg), reference_frame=coord.ICRS()) @@ -1491,7 +1499,7 @@ def test_reordered_celestial(): output_pixel = iwcs.world_to_pixel_values(*output_world) assert_allclose(output_pixel, u.Quantity(input_pixel).to_value(u.pix)) - expected_world = [20*u.arcsec, 45*u.deg] + expected_world = [45*u.deg, 20*u.arcsec]#, 45*u.deg] assert_allclose(output_world, [e.value for e in expected_world]) world_obj = iwcs.pixel_to_world(*input_pixel) diff --git a/gwcs/wcs.py b/gwcs/wcs.py index 04a440ed..c0b22470 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -3,6 +3,7 @@ import itertools import warnings +import astropy.units as u import astropy.io.fits as fits import numpy as np import numpy.linalg as npla From 6618d1d31a63edc79a1c141af80252b30b7c8292 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Wed, 25 Sep 2024 14:26:53 +0100 Subject: [PATCH 18/49] ensure units are units --- gwcs/coordinate_frames.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index ffd4d221..25d59f03 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -584,13 +584,12 @@ def __init__(self, axes_order=(0,), reference_frame=None, unit=None, if not isiterable(unit): unit = (unit,) - + unit = [u.Unit(un) for un in unit] pht = axis_physical_types or self._default_axis_physical_types(unit) super().__init__(naxes=1, axes_type="SPECTRAL", axes_order=axes_order, axes_names=axes_names, reference_frame=reference_frame, unit=unit, name=name, - #axis_physical_types="em.wl") axis_physical_types=pht) def _default_axis_physical_types(self, unit): From 0edcb378f8169e46616b30e22560335f1c1e3212 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 26 Sep 2024 12:24:56 +0100 Subject: [PATCH 19/49] Raise an error if with_units is used in numerical inverse --- gwcs/wcs.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/gwcs/wcs.py b/gwcs/wcs.py index c0b22470..73491fef 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -533,7 +533,7 @@ def _call_backward(self, *args, with_bounding_box=True, fill_value=np.nan, **kwa def numerical_inverse(self, *args, tolerance=1e-5, maxiter=50, adaptive=True, detect_divergence=True, quiet=True, with_bounding_box=True, - fill_value=np.nan, with_units=False, **kwargs): + fill_value=np.nan, **kwargs): """ Invert coordinates from output frame to input frame using numerical inverse. @@ -560,11 +560,6 @@ def numerical_inverse(self, *args, tolerance=1e-5, maxiter=50, adaptive=True, fill_value : float, optional Output value for inputs outside the bounding_box (default is ``np.nan``). - with_units : bool, optional - If ``True`` returns a `~astropy.coordinates.SkyCoord` or - `~astropy.coordinates.SpectralCoord` object, by using the units of - the output cooridnate frame. Default is `False`. - tolerance : float, optional *Absolute tolerance* of solution. Iteration terminates when the iterative solver estimates that the "true solution" is @@ -769,6 +764,9 @@ def numerical_inverse(self, *args, tolerance=1e-5, maxiter=50, adaptive=True, [2.76552923e-05 1.14789013e-05]] """ + if kwargs.pop("with_units", False): + raise ValueError("Support for with_units in numerical_inverse has been removed, use inverse") + args_shape = np.shape(args) nargs = args_shape[0] arg_dim = len(args_shape) - 1 From 065709729e4b023c2515b82160c5153f2c237020 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 26 Sep 2024 12:25:17 +0100 Subject: [PATCH 20/49] Refactor Frames to require _native_world_axis_object_components This means that world_axis_object_components can be automatically sorted for all frames. --- gwcs/coordinate_frames.py | 179 +++++++++++++++++++++----------------- 1 file changed, 98 insertions(+), 81 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index 25d59f03..cfaf9635 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -177,10 +177,81 @@ def get_ctype_from_ucd(ucd): return UCD1_TO_CTYPE.get(ucd, "") +@dataclass +class FrameProperties: + naxes: InitVar[int] + axes_type: tuple[str] + unit: tuple[u.Unit] = None + axes_names: tuple[str] = None + axis_physical_types: list[str] = None + + def __post_init__(self, naxes): + if isinstance(self.axes_type, str): + self.axes_type = (self.axes_type,) + else: + self.axes_type = tuple(self.axes_type) + + if len(self.axes_type) != naxes: + raise ValueError("Length of axes_type does not match number of axes.") + + if self.unit is not None: + if astutil.isiterable(self.unit): + unit = tuple(self.unit) + else: + unit = (self.unit,) + if len(unit) != naxes: + raise ValueError("Number of units does not match number of axes.") + else: + self.unit = tuple(u.Unit(au) for au in unit) + else: + self.unit = tuple(u.dimensionless_unscaled for na in range(naxes)) + + if self.axes_names is not None: + if isinstance(self.axes_names, str): + self.axes_names = (self.axes_names,) + else: + self.axes_names = tuple(self.axes_names) + if len(self.axes_names) != naxes: + raise ValueError("Number of axes names does not match number of axes.") + else: + self.axes_names = tuple([""] * naxes) + + if self.axis_physical_types is not None: + if isinstance(self.axis_physical_types, str): + self.axis_physical_types = (self.axis_physical_types,) + elif not isiterable(self.axis_physical_types): + raise TypeError("axis_physical_types must be of type string or iterable of strings") + if len(self.axis_physical_types) != naxes: + raise ValueError(f'"axis_physical_types" must be of length {naxes}') + ph_type = [] + for axt in self.axis_physical_types: + if axt not in VALID_UCDS and not axt.startswith("custom:"): + ph_type.append(f"custom:{axt}") + else: + ph_type.append(axt) + + validate_physical_types(ph_type) + self.axis_physical_types = tuple(ph_type) + + @property + def _default_axis_physical_type(self): + """ + The default physical types to use for this frame if none are specified + by the user. + """ + return tuple("custom:{}".format(t) for t in self.axes_type) + + class BaseCoordinateFrame(abc.ABC): """ API Definition for a Coordinate frame """ + + _prop: FrameProperties + """ + The FrameProperties object holding properties in native frame order. + """ + @property @abc.abstractmethod def naxes(self) -> int: @@ -253,7 +324,6 @@ def world_axis_object_classes(self): """ @property - @abc.abstractmethod def world_axis_object_components(self): """ The APE 14 object components for this frame. @@ -262,71 +332,30 @@ def world_axis_object_components(self): -------- astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_components """ + if self.naxes == 1: + return self._native_world_axis_object_components - -@dataclass -class FrameProperties: - naxes: InitVar[int] - axes_type: tuple[str] - unit: tuple[u.Unit] = None - axes_names: tuple[str] = None - axis_physical_types: list[str] = None - - def __post_init__(self, naxes): - if isinstance(self.axes_type, str): - self.axes_type = (self.axes_type,) - else: - self.axes_type = tuple(self.axes_type) - - if len(self.axes_type) != naxes: - raise ValueError("Length of axes_type does not match number of axes.") - - if self.unit is not None: - if astutil.isiterable(self.unit): - unit = tuple(self.unit) - else: - unit = (self.unit,) - if len(unit) != naxes: - raise ValueError("Number of units does not match number of axes.") - else: - self.unit = tuple(u.Unit(au) for au in unit) - else: - self.unit = tuple(u.dimensionless_unscaled for na in range(naxes)) - - if self.axes_names is not None: - if isinstance(self.axes_names, str): - self.axes_names = (self.axes_names,) - else: - self.axes_names = tuple(self.axes_names) - if len(self.axes_names) != naxes: - raise ValueError("Number of axes names does not match number of axes.") - else: - self.axes_names = tuple([""] * naxes) - - if self.axis_physical_types is not None: - if isinstance(self.axis_physical_types, str): - self.axis_physical_types = (self.axis_physical_types,) - elif not isiterable(self.axis_physical_types): - raise TypeError("axis_physical_types must be of type string or iterable of strings") - if len(self.axis_physical_types) != naxes: - raise ValueError(f'"axis_physical_types" must be of length {naxes}') - ph_type = [] - for axt in self.axis_physical_types: - if axt not in VALID_UCDS and not axt.startswith("custom:"): - ph_type.append(f"custom:{axt}") - else: - ph_type.append(axt) - - validate_physical_types(ph_type) - self.axis_physical_types = tuple(ph_type) + # If we have more than one axis then we should sort the native + # components by the axes_order. + ordered = np.array(self._native_world_axis_object_components, + dtype=object)[np.argsort(self.axes_order)] + return list(map(tuple, ordered)) @property - def _default_axis_physical_type(self): + @abc.abstractmethod + def _native_world_axis_object_components(self): """ - The default physical types to use for this frame if none are specified - by the user. + This property holds the "native" frame order of the components. + + The native order of the componets is the order the frame assumes the + axes are in when creating the high level objects, for example + ``CelestialFrame`` creates ``SkyCoord`` objects which are in lon, lat + order (in their positional args). + + This property is used both to construct the ordered + ``world_axis_object_components`` property as well as by `CompositeFrame` + to be able to get the components in their native order. """ - return tuple("custom:{}".format(t) for t in self.axes_type) class CoordinateFrame(BaseCoordinateFrame): @@ -401,8 +430,6 @@ def __str__(self): return self.__class__.__name__ def _sort_property(self, property): - #return tuple(dict(sorted(zip(property, self.axes_order), - # key=lambda x: x[1])).keys()) sorted_prop = sorted(zip(property, self.axes_order), key=lambda x: x[1]) return tuple([t[0] for t in sorted_prop]) @@ -426,7 +453,6 @@ def naxes(self): def unit(self): """The unit of this frame.""" return self._sort_property(self._prop.unit) - #return self._prop.unit @property def axes_names(self): @@ -464,19 +490,18 @@ def world_axis_object_classes(self): {'unit': unit}) for i, (at, unit) in enumerate(zip(self.axes_type, self.unit))} - @property - def world_axis_object_components(self): - return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._prop.axes_type)] - @property def _native_world_axis_object_components(self): - """Defines the target component ordering (i.e. not taking into account axes_order)""" - return self.world_axis_object_components + return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._prop.axes_type)] class CelestialFrame(CoordinateFrame): """ - Celestial Frame Representation + Representation of a Celesital coordinate system. + + This class has a native order of longitude then latitude, meaning + ``axes_names``, ``unit`` should be lon, lat ordered. If your transform is + in a different order this should be specified with ``axes_order``. Parameters ---------- @@ -551,14 +576,6 @@ def _world_axis_object_components(self): return [('celestial', 0, lambda sc: sc.spherical.lon.to_value(self.unit[0])), ('celestial', 1, lambda sc: sc.spherical.lat.to_value(self.unit[1]))] - @property - def world_axis_object_components(self): - # Sort the native waoc by the axes order. The axes order may have jumps - # in it if there are other frames in between the components. - ordered = np.array(self._native_world_axis_object_components, - dtype=object)[np.argsort(self.axes_order)] - return list(map(tuple, ordered)) - class SpectralFrame(CoordinateFrame): """ @@ -685,7 +702,7 @@ def world_axis_object_classes(self): return {'temporal': comp} @property - def world_axis_object_components(self): + def _native_world_axis_object_components(self): if isinstance(self.reference_frame.value, np.ndarray): return [('temporal', 0, 'value')] @@ -859,7 +876,7 @@ def world_axis_object_classes(self): )} @property - def world_axis_object_components(self): + def _native_world_axis_object_components(self): return [('stokes', 0, 'value')] From 9eb7df85dee957602220ef9d6c6df90e7c4d51b0 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 26 Sep 2024 12:47:45 +0100 Subject: [PATCH 21/49] Apply suggestions from code review --- gwcs/coordinate_frames.py | 1 - gwcs/tests/test_api.py | 2 +- gwcs/tests/test_coordinate_systems.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index cfaf9635..dae7ffe5 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -676,7 +676,6 @@ def __init__(self, reference_frame, unit=None, axes_order=(0,), except AttributeError: pass - #@property def _default_axis_physical_types(self): return ("time",) diff --git a/gwcs/tests/test_api.py b/gwcs/tests/test_api.py index f4326abf..0734eabb 100644 --- a/gwcs/tests/test_api.py +++ b/gwcs/tests/test_api.py @@ -280,7 +280,7 @@ def test_high_level_wrapper(wcsobj, request): # The wrapper and the raw gwcs class can take different paths wc1 = hlvl.pixel_to_world(*pixel_input) - wc2 = wcsobj.pixel_to_world(*pixel_input) + wc2 = wcsobj(*pixel_input, with_units=True) assert type(wc1) is type(wc2) diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index 6f556c4a..2d31603b 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -89,7 +89,7 @@ def coordinates(*inputs, frame): def coordinate_to_quantity(*inputs, frame): results = high_level_objects_to_values(*inputs, low_level_wcs=frame) - results = [r< Date: Thu, 26 Sep 2024 14:19:39 +0100 Subject: [PATCH 22/49] Ensure we call values_to_high correctly --- gwcs/wcs.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gwcs/wcs.py b/gwcs/wcs.py index 73491fef..5cbd7282 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -7,6 +7,7 @@ import astropy.io.fits as fits import numpy as np import numpy.linalg as npla +from astropy import utils as astutil from astropy.modeling import fix_inputs, projections from astropy.modeling.bounding_box import CompoundBoundingBox from astropy.modeling.bounding_box import ModelBoundingBox as Bbox @@ -366,6 +367,8 @@ def __call__(self, *args, **kwargs): results = self._call_forward(*args, **kwargs) if with_units: + if not astutil.isiterable(results): + results = (results,) high_level = values_to_high_level_objects(*results, low_level_wcs=self) if len(high_level) == 1: high_level = high_level[0] From 3a203d26fd43ca21a7017b45d22e1945f603147f Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 26 Sep 2024 14:20:08 +0100 Subject: [PATCH 23/49] We don't need to unit convert in API It's done by call_forward/backward --- gwcs/api.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/gwcs/api.py b/gwcs/api.py index e808c4fd..a620c591 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -98,8 +98,6 @@ def pixel_to_world_values(self, *pixel_arrays): order, where for an image, ``x`` is the horizontal coordinate and ``y`` is the vertical coordinate. """ - if self.forward_transform.uses_quantity: - pixel_arrays = self._add_units_input(pixel_arrays, self.input_frame) result = self._call_forward(*pixel_arrays) return self._remove_quantity_output(result, self.output_frame) @@ -127,9 +125,6 @@ def world_to_pixel_values(self, *world_arrays): be returned in the ``(x, y)`` order, where for an image, ``x`` is the horizontal coordinate and ``y`` is the vertical coordinate. """ - if self.backward_transform.uses_quantity: - world_arrays = self._add_units_input(world_arrays, self.output_frame) - result = self._call_backward(*world_arrays) return self._remove_quantity_output(result, self.input_frame) From 77c1c604cf2fe27563cc6c197cb77a188d7d1838 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 26 Sep 2024 15:42:07 +0100 Subject: [PATCH 24/49] Fix CelestialFrame units --- gwcs/coordinate_frames.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index dae7ffe5..cc8f414a 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -564,17 +564,16 @@ def _default_axis_physical_types(self, reference_frame, axes_names): @property def world_axis_object_classes(self): - unit = np.array(self.unit)[np.argsort(self.axes_order)] return {'celestial': ( coord.SkyCoord, (), {'frame': self.reference_frame, - 'unit': unit})} + 'unit': self._prop.unit})} @property - def _world_axis_object_components(self): - return [('celestial', 0, lambda sc: sc.spherical.lon.to_value(self.unit[0])), - ('celestial', 1, lambda sc: sc.spherical.lat.to_value(self.unit[1]))] + def _native_world_axis_object_components(self): + return [('celestial', 0, lambda sc: sc.spherical.lon.to_value(self._prop.unit[0])), + ('celestial', 1, lambda sc: sc.spherical.lat.to_value(self._prop.unit[1]))] class SpectralFrame(CoordinateFrame): From 0404517b5745b237fcb0aa912976b57eece21a09 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 26 Sep 2024 16:05:59 +0100 Subject: [PATCH 25/49] More roundtip test fixing Seems this leaves one troublesome test --- gwcs/tests/test_api.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/gwcs/tests/test_api.py b/gwcs/tests/test_api.py index 0734eabb..a3d22bc9 100644 --- a/gwcs/tests/test_api.py +++ b/gwcs/tests/test_api.py @@ -274,11 +274,8 @@ def test_high_level_wrapper(wcsobj, request): pixel_input = [3] * wcsobj.pixel_n_dim - # If the model expects units we have to pass in units - if wcsobj.forward_transform.uses_quantity: - pixel_input *= u.pix - - # The wrapper and the raw gwcs class can take different paths + # Assert that both APE 14 API and GWCS give the same answer The APE 14 API + # uses the mixin class and __call__ calls values_to_high_level_objects wc1 = hlvl.pixel_to_world(*pixel_input) wc2 = wcsobj(*pixel_input, with_units=True) @@ -290,6 +287,22 @@ def test_high_level_wrapper(wcsobj, request): else: _compare_frame_output(wc1, wc2) + # we have just asserted that wc1 and wc2 are equal + if not isinstance(wc1, (list, tuple)): + wc1 = (wc1,) + + pix_out1 = hlvl.world_to_pixel(*wc1) + pix_out2 = wcsobj.invert(*wc1) + + if not isinstance(pix_out2, (list, tuple)): + pix_out2 = (pix_out2,) + + if wcsobj.forward_transform.uses_quantity: + pix_out2 = tuple(p.to_value(unit) for p, unit in zip(pix_out2, wcsobj.input_frame.unit)) + + np.testing.assert_allclose(pix_out1, pixel_input) + np.testing.assert_allclose(pix_out2, pixel_input) + def test_stokes_wrapper(gwcs_stokes_lookup): pytest.importorskip("astropy", minversion="4.0dev0") From 5300a7e4943f83a1b16f9f801ddff1242f4f4dd9 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 26 Sep 2024 20:02:48 +0100 Subject: [PATCH 26/49] Fix roundtrip test by changing projection type --- gwcs/tests/test_api_slicing.py | 56 +++++++++++++++++----------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/gwcs/tests/test_api_slicing.py b/gwcs/tests/test_api_slicing.py index 87510ae1..89197903 100644 --- a/gwcs/tests/test_api_slicing.py +++ b/gwcs/tests/test_api_slicing.py @@ -67,11 +67,11 @@ def test_ellipsis(gwcs_3d_galactic_spectral): assert wcs.world_axis_object_classes['spectral'][1] == () assert wcs.world_axis_object_classes['spectral'][2] == {'unit': 'Hz'} - assert_allclose(wcs.pixel_to_world_values(29, 39, 44), (10, 20, 25)) - assert_allclose(wcs.array_index_to_world_values(44, 39, 29), (10, 20, 25)) + assert_allclose(wcs.pixel_to_world_values(29, 39, 44), (80, 20, 205)) + assert_allclose(wcs.array_index_to_world_values(44, 39, 29), (80, 20, 205)) - assert_allclose(wcs.world_to_pixel_values(10, 20, 25), (29., 39., 44.)) - assert_equal(wcs.world_to_array_index_values(10, 20, 25), (44, 39, 29)) + assert_allclose(wcs.world_to_pixel_values(80, 20, 205), (29., 39., 44.)) + assert_equal(wcs.world_to_array_index_values(80, 20, 205), (44, 39, 29)) assert str(wcs) == EXPECTED_ELLIPSIS_REPR.strip() assert EXPECTED_ELLIPSIS_REPR.strip() in repr(wcs) @@ -128,11 +128,11 @@ def test_spectral_slice(gwcs_3d_galactic_spectral): assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) - assert_allclose(wcs.pixel_to_world_values(29, 44), (10, 25)) - assert_allclose(wcs.array_index_to_world_values(44, 29), (10, 25)) + assert_allclose(wcs.pixel_to_world_values(29, 44), (80, 205)) + assert_allclose(wcs.array_index_to_world_values(44, 29), (80, 205)) - assert_allclose(wcs.world_to_pixel_values(10, 25), (29., 44.)) - assert_equal(wcs.world_to_array_index_values(10, 25), (44, 29)) + assert_allclose(wcs.world_to_pixel_values(80, 205), (29., 44.)) + assert_equal(wcs.world_to_array_index_values(80, 205), (44, 29)) assert_equal(wcs.pixel_bounds, [(-1, 35), (5, 50)]) @@ -197,11 +197,11 @@ def test_spectral_range(gwcs_3d_galactic_spectral): assert wcs.world_axis_object_classes['spectral'][1] == () assert wcs.world_axis_object_classes['spectral'][2] == {'unit': 'Hz'} - assert_allclose(wcs.pixel_to_world_values(29, 35, 44), (10, 20, 25)) - assert_allclose(wcs.array_index_to_world_values(44, 35, 29), (10, 20, 25)) + assert_allclose(wcs.pixel_to_world_values(29, 35, 44), (80, 20, 205)) + assert_allclose(wcs.array_index_to_world_values(44, 35, 29), (80, 20, 205)) - assert_allclose(wcs.world_to_pixel_values(10, 20, 25), (29., 35., 44.)) - assert_equal(wcs.world_to_array_index_values(10, 20, 25), (44, 35, 29)) + assert_allclose(wcs.world_to_pixel_values(80, 20, 205), (29., 35., 44.)) + assert_equal(wcs.world_to_array_index_values(80, 20, 205), (44, 35, 29)) assert_equal(wcs.pixel_bounds, [(-1, 35), (-6, 41), (5, 50)]) @@ -265,11 +265,11 @@ def test_celestial_slice(gwcs_3d_galactic_spectral): assert wcs.world_axis_object_classes['spectral'][1] == () assert wcs.world_axis_object_classes['spectral'][2] == {'unit': 'Hz'} - assert_allclose(wcs.pixel_to_world_values(39, 44), (10.24, 20, 25)) - assert_allclose(wcs.array_index_to_world_values(44, 39), (10.24, 20, 25)) + assert_allclose(wcs.pixel_to_world_values(39, 44), (79.76, 20, 205)) + assert_allclose(wcs.array_index_to_world_values(44, 39), (79.76, 20, 205)) - assert_allclose(wcs.world_to_pixel_values(12.4, 20, 25), (39., 44.)) - assert_equal(wcs.world_to_array_index_values(12.4, 20, 25), (44, 39)) + assert_allclose(wcs.world_to_pixel_values(79.76, 20, 205), (39., 44.)) + assert_equal(wcs.world_to_array_index_values(79.76, 20, 205), (44, 39)) assert_equal(wcs.pixel_bounds, [(-2, 45), (5, 50)]) @@ -334,11 +334,11 @@ def test_celestial_range(gwcs_3d_galactic_spectral): assert wcs.world_axis_object_classes['spectral'][1] == () assert wcs.world_axis_object_classes['spectral'][2] == {'unit': 'Hz'} - assert_allclose(wcs.pixel_to_world_values(24, 39, 44), (10, 20, 25)) - assert_allclose(wcs.array_index_to_world_values(44, 39, 24), (10, 20, 25)) + assert_allclose(wcs.pixel_to_world_values(24, 39, 44), (80, 20, 205)) + assert_allclose(wcs.array_index_to_world_values(44, 39, 24), (80, 20, 205)) - assert_allclose(wcs.world_to_pixel_values(10, 20, 25), (24., 39., 44.)) - assert_equal(wcs.world_to_array_index_values(10, 20, 25), (44, 39, 24)) + assert_allclose(wcs.world_to_pixel_values(80, 20, 205), (24., 39., 44.)) + assert_equal(wcs.world_to_array_index_values(80, 20, 205), (44, 39, 24)) assert_equal(wcs.pixel_bounds, [(-6, 30), (-2, 45), (5, 50)]) @@ -406,11 +406,11 @@ def test_no_array_shape(gwcs_3d_galactic_spectral): assert wcs.world_axis_object_classes['spectral'][1] == () assert wcs.world_axis_object_classes['spectral'][2] == {'unit': 'Hz'} - assert_allclose(wcs.pixel_to_world_values(29, 39, 44), (10, 20, 25)) - assert_allclose(wcs.array_index_to_world_values(44, 39, 29), (10, 20, 25)) + assert_allclose(wcs.pixel_to_world_values(29, 39, 44), (80, 20, 205)) + assert_allclose(wcs.array_index_to_world_values(44, 39, 29), (80, 20, 205)) - assert_allclose(wcs.world_to_pixel_values(10, 20, 25), (29., 39., 44.)) - assert_equal(wcs.world_to_array_index_values(10, 20, 25), (44, 39, 29)) + assert_allclose(wcs.world_to_pixel_values(80, 20, 205), (29., 39., 44.)) + assert_equal(wcs.world_to_array_index_values(80, 20, 205), (44, 39, 29)) assert str(wcs) == EXPECTED_NO_SHAPE_REPR.strip() assert EXPECTED_NO_SHAPE_REPR.strip() in repr(wcs) @@ -474,11 +474,11 @@ def test_ellipsis_none_types(gwcs_3d_galactic_spectral): assert isinstance(wcs.world_axis_object_classes['celestial'][2]['frame'], Galactic) assert tuple(wcs.world_axis_object_classes['celestial'][2]['unit']) == (u.deg, u.deg) - assert_allclose(wcs.pixel_to_world_values(29, 39, 44), (10, 20, 25)) - assert_allclose(wcs.array_index_to_world_values(44, 39, 29), (10, 20, 25)) + assert_allclose(wcs.pixel_to_world_values(29, 39, 44), (80, 20, 205)) + assert_allclose(wcs.array_index_to_world_values(44, 39, 29), (80, 20, 205)) - assert_allclose(wcs.world_to_pixel_values(10, 20, 25), (29., 39., 44.)) - assert_equal(wcs.world_to_array_index_values(10, 20, 25), (44, 39, 29)) + assert_allclose(wcs.world_to_pixel_values(80, 20, 205), (29., 39., 44.)) + assert_equal(wcs.world_to_array_index_values(80, 20, 205), (44, 39, 29)) assert_equal(wcs.pixel_bounds, [(-1, 35), (-2, 45), (5, 50)]) From 669f29294c00df49d539d726e8e8c6e8968d7a68 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Fri, 27 Sep 2024 14:18:25 +0100 Subject: [PATCH 27/49] Test and polish more ordering --- gwcs/api.py | 9 +-------- gwcs/coordinate_frames.py | 7 +++++-- gwcs/tests/test_coordinate_systems.py | 26 ++++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/gwcs/api.py b/gwcs/api.py index a620c591..2839d363 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -51,14 +51,7 @@ def world_axis_physical_types(self): arbitrary string. Alternatively, if the physical type is unknown/undefined, an element can be `None`. """ - # A CompositeFrame orders the output correctly based on axes_order. - if isinstance(self.output_frame, cf.CompositeFrame): - return self.output_frame.axis_physical_types - - # If we don't have a CompositeFrame, where this is taken care of for us, - # we need to make sure we re-order the output to match the transform. - # The underlying frames don't reorder themselves because axes_order is global. - return tuple(self.output_frame.axis_physical_types[i] for i in self.output_frame.axes_order) + return self.output_frame.axis_physical_types @property def world_axis_units(self): diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index cc8f414a..5c5155fe 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -481,7 +481,8 @@ def axis_physical_types(self): These physical types are the types in frame order, not transform order. """ - return self._prop.axis_physical_types or self._default_axis_physical_types + apt = self._prop.axis_physical_types or self._default_axis_physical_types + return self._sort_property(apt) @property def world_axis_object_classes(self): @@ -500,7 +501,7 @@ class CelestialFrame(CoordinateFrame): Representation of a Celesital coordinate system. This class has a native order of longitude then latitude, meaning - ``axes_names``, ``unit`` should be lon, lat ordered. If your transform is + ``axes_names``, ``unit`` and ``axis_physical_types`` should be lon, lat ordered. If your transform is in a different order this should be specified with ``axes_order``. Parameters @@ -515,6 +516,8 @@ class CelestialFrame(CoordinateFrame): Names of the axes in this frame. name : str Name of this frame. + axis_physical_types : list + The UCD 1+ physical types for the axes, in frame order (lon, lat). """ def __init__(self, axes_order=None, reference_frame=None, diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index 2d31603b..a54091e7 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -475,3 +475,29 @@ def test_ucd1_to_ctype(caplog): assert ctype_to_ucd[v] == k assert inv_map['new.repeated.type'] in new_ctype_to_ucd + + +def test_celestial_ordering(): + c1 = cf.CelestialFrame( + reference_frame=coord.ICRS(), + axes_order=(0, 1), + axes_names=("lon", "lat"), + unit=(u.deg, u.arcsec), + axis_physical_types=("custom:lon", "custom:lat"), + ) + c2 = cf.CelestialFrame( + reference_frame=coord.ICRS(), + axes_order=(1, 0), + axes_names=("lon", "lat"), + unit=(u.deg, u.arcsec), + axis_physical_types=("custom:lon", "custom:lat"), + ) + + assert c1.axes_names == ("lon", "lat") + assert c2.axes_names == ("lat", "lon") + + assert c1.unit == (u.deg, u.arcsec) + assert c2.unit == (u.arcsec, u.deg) + + assert c1.axis_physical_types == ("custom:lon", "custom:lat") + assert c2.axis_physical_types == ("custom:lat", "custom:lon") From d32dd24b5d9f445f53c0bddf1f59142ab3bdfa44 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 3 Oct 2024 14:21:06 +0100 Subject: [PATCH 28/49] Make it so CompositeFrame follows the same ordering This means that _prop is in native order and is sorted into axes order --- gwcs/coordinate_frames.py | 52 ++++++--------------- gwcs/tests/test_api_slicing.py | 5 +- gwcs/tests/test_coordinate_systems.py | 67 ++++++++++++++++++--------- 3 files changed, 61 insertions(+), 63 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index 5c5155fe..81bf76cc 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -719,43 +719,35 @@ class CompositeFrame(CoordinateFrame): Parameters ---------- frames : list - List of frames (TemporalFrame, CelestialFrame, SpectralFrame, CoordinateFrame). + List of constituient frames. name : str Name for this frame. - """ def __init__(self, frames, name=None): self._frames = frames[:] naxes = sum([frame._naxes for frame in self._frames]) - axes_type = list(range(naxes)) - unit = list(range(naxes)) - axes_names = list(range(naxes)) - ph_type = list(range(naxes)) axes_order = [] + axes_type = [] + axes_names = [] + unit = [] + ph_type = [] for frame in frames: axes_order.extend(frame.axes_order) + # Stack the raw (not-native) ordered properties for frame in frames: - unsorted_prop = zip( - frame.axes_order, - frame._prop.axes_type, - frame._prop.unit, - frame._prop.axes_names, - frame._prop.axis_physical_types - ) - for ind, axtype, un, n, pht in unsorted_prop: - axes_type[ind] = axtype - axes_names[ind] = n - unit[ind] = un - ph_type[ind] = pht + axes_type += list(frame._prop.axes_type) + axes_names += list(frame._prop.axes_names) + unit += list(frame._prop.unit) + ph_type += list(frame._prop.axis_physical_types) if len(np.unique(axes_order)) != len(axes_order): raise ValueError("Incorrect numbering of axes, " "axes_order should contain unique numbers, " - "got {}.".format(axes_order)) + f"got {axes_order}.") super().__init__(naxes, axes_type=axes_type, axes_order=axes_order, @@ -766,24 +758,11 @@ def __init__(self, frames, name=None): @property def frames(self): + """ + The constituient frames that comprise this `CompositeFrame`. + """ return self._frames - @property - def unit(self): - return self._prop.unit - - @property - def axes_names(self): - return self._prop.axes_names - - @property - def axes_type(self): - return self._prop.axes_type - - @property - def axis_physical_types(self): - return self._prop.axis_physical_types - def __repr__(self): return repr(self.frames) @@ -826,9 +805,6 @@ def _wao_renamed_classes_iter(self): @property def world_axis_object_components(self): - """ - We need to generate the components respecting the axes_order. - """ out = [None] * self.naxes for frame, components in self._wao_renamed_components_iter: diff --git a/gwcs/tests/test_api_slicing.py b/gwcs/tests/test_api_slicing.py index 89197903..3d38a9e9 100644 --- a/gwcs/tests/test_api_slicing.py +++ b/gwcs/tests/test_api_slicing.py @@ -446,7 +446,8 @@ def test_no_array_shape(gwcs_3d_galactic_spectral): def test_ellipsis_none_types(gwcs_3d_galactic_spectral): pht = list(gwcs_3d_galactic_spectral.output_frame._axis_physical_types) - pht[1] = None + # This index is in "axes_order" ordering + pht[2] = None gwcs_3d_galactic_spectral.output_frame._prop.axis_physical_types = tuple(pht) wcs = SlicedLowLevelWCS(gwcs_3d_galactic_spectral, Ellipsis) @@ -467,7 +468,7 @@ def test_ellipsis_none_types(gwcs_3d_galactic_spectral): ('spectral', 0), ('celestial', 0)] - assert all([callable(l) for l in last_one]) + assert all([callable(last) for last in last_one]) assert wcs.world_axis_object_classes['celestial'][0] is SkyCoord assert wcs.world_axis_object_classes['celestial'][1] == () diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index a54091e7..75f2895f 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -478,26 +478,47 @@ def test_ucd1_to_ctype(caplog): def test_celestial_ordering(): - c1 = cf.CelestialFrame( - reference_frame=coord.ICRS(), - axes_order=(0, 1), - axes_names=("lon", "lat"), - unit=(u.deg, u.arcsec), - axis_physical_types=("custom:lon", "custom:lat"), - ) - c2 = cf.CelestialFrame( - reference_frame=coord.ICRS(), - axes_order=(1, 0), - axes_names=("lon", "lat"), - unit=(u.deg, u.arcsec), - axis_physical_types=("custom:lon", "custom:lat"), - ) - - assert c1.axes_names == ("lon", "lat") - assert c2.axes_names == ("lat", "lon") - - assert c1.unit == (u.deg, u.arcsec) - assert c2.unit == (u.arcsec, u.deg) - - assert c1.axis_physical_types == ("custom:lon", "custom:lat") - assert c2.axis_physical_types == ("custom:lat", "custom:lon") + c1 = cf.CelestialFrame( + reference_frame=coord.ICRS(), + axes_order=(0, 1), + axes_names=("lon", "lat"), + unit=(u.deg, u.arcsec), + axis_physical_types=("custom:lon", "custom:lat"), + ) + c2 = cf.CelestialFrame( + reference_frame=coord.ICRS(), + axes_order=(1, 0), + axes_names=("lon", "lat"), + unit=(u.deg, u.arcsec), + axis_physical_types=("custom:lon", "custom:lat"), + ) + + assert c1.axes_names == ("lon", "lat") + assert c2.axes_names == ("lat", "lon") + + assert c1.unit == (u.deg, u.arcsec) + assert c2.unit == (u.arcsec, u.deg) + + assert c1.axis_physical_types == ("custom:lon", "custom:lat") + assert c2.axis_physical_types == ("custom:lat", "custom:lon") + + +def test_composite_ordering(): + print("boo") + c1 = cf.CelestialFrame( + reference_frame=coord.ICRS(), + axes_order=(1, 0), + axes_names=("lon", "lat"), + unit=(u.deg, u.arcsec), + axis_physical_types=("custom:lon", "custom:lat"), + ) + spec = cf.SpectralFrame( + axes_order=(2,), + axes_names=("spectral",), + unit=u.AA, + ) + comp = cf.CompositeFrame([c1, spec]) + assert comp.axes_names == ("lat", "lon", "spectral") + assert comp.axis_physical_types == ("custom:lat", "custom:lon", "em.wl") + assert comp.unit == (u.arcsec, u.deg, u.AA) + assert comp.axes_order == (1, 0, 2) From 33d1f7b5794f2355a7a13130037a3a1cffc3cdd6 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 3 Oct 2024 15:05:52 +0100 Subject: [PATCH 29/49] Test and fix RegionSelector doctest fail --- gwcs/selector.py | 15 ++++++++++++++- gwcs/tests/test_region.py | 9 +++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/gwcs/selector.py b/gwcs/selector.py index 12ac7914..8b4d88b4 100644 --- a/gwcs/selector.py +++ b/gwcs/selector.py @@ -531,7 +531,20 @@ def __init__(self, inputs, outputs, selector, label_mapper, undefined_transform_ raise ValueError('"0" and " " are not allowed as keys.') self._input_units_strict = {key: False for key in self._inputs} self._input_units_allow_dimensionless = {key: False for key in self._inputs} - super(RegionsSelector, self).__init__(n_models=1, name=name, **kwargs) + super().__init__(n_models=1, name=name, **kwargs) + # Validate uses_quantity at init time for nicer error message + self.uses_quantity # noqa + + @property + def uses_quantity(self): + all_uses_quantity = [t.uses_quantity for t in self._selector.values()] + not_all_uses_quantity = [not uq for uq in all_uses_quantity] + if all(all_uses_quantity): + return True + elif not_all_uses_quantity: + return False + else: + raise ValueError("You can not mix models which use quantity and do not use quantity inside a RegionSelector") def set_input(self, rid): """ diff --git a/gwcs/tests/test_region.py b/gwcs/tests/test_region.py index 5fe69f90..6304b786 100644 --- a/gwcs/tests/test_region.py +++ b/gwcs/tests/test_region.py @@ -8,8 +8,9 @@ from numpy.testing import assert_equal, assert_allclose from astropy.modeling import models import pytest -from .. import region, selector -from .. import utils as gwutils +from gwcs import region, selector, WCS +from gwcs import utils as gwutils +from gwcs import coordinate_frames as cf def test_LabelMapperArray_from_vertices_int(): @@ -237,6 +238,10 @@ def test_RegionsSelector(): reg_selector.undefined_transform_value = -100 assert_equal(reg_selector(0, 0), [-100, -100]) + wcs = WCS(forward_transform=reg_selector, output_frame=cf.Frame2D()) + out = wcs(1, 1) + assert out == (-100, -100) + def test_overalpping_ranges(): """ From 3cb00bbfad24b11afbca29c6f4901c3dda291af8 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 3 Oct 2024 15:14:30 +0100 Subject: [PATCH 30/49] lint --- gwcs/api.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/gwcs/api.py b/gwcs/api.py index 2839d363..6fbaba97 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -9,8 +9,6 @@ from astropy.modeling import separable import astropy.units as u -from . import coordinate_frames as cf - __all__ = ["GWCSAPIMixin"] From 12262de7ed83ba485a7292266d4e69de2ce767e1 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 3 Oct 2024 15:35:38 +0100 Subject: [PATCH 31/49] Fix duplicated pass_env / passenv config in tox Also add some useful env vars --- tox.ini | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tox.ini b/tox.ini index 17021f3a..6f9c6531 100644 --- a/tox.ini +++ b/tox.ini @@ -66,6 +66,9 @@ pass_env = CI CODECOV_* DISPLAY + CC + LOCALE_ARCHIVE + LC_ALL jwst,romancal: CRDS_* romanisim,romancal: WEBBPSF_PATH romanisim: GALSIM_CAT_PATH From b2e6d69787c02608586ab2c102fe630e1abf2471 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 3 Oct 2024 15:54:32 +0100 Subject: [PATCH 32/49] Some doc polish --- docs/index.rst | 5 ++--- gwcs/coordinate_frames.py | 22 +++++++++++++++++----- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 58b94db7..d90f59a7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -100,7 +100,7 @@ To install the latest release:: pip install gwcs -The latest release of GWCS is also available as part of `astroconda `__. +The latest release of GWCS is also available as a conda package via `conda-forge `__. .. _getting-started: @@ -240,8 +240,7 @@ To convert a pixel (x, y) = (1, 2) to sky coordinates, call the WCS object as a The :meth:`~gwcs.wcs.WCS.invert` method evaluates the :meth:`~gwcs.wcs.WCS.backward_transform` if available, otherwise applies an iterative method to calculate the reverse coordinates. -GWCS supports the common WCS interface which defines several methods -to work with high level Astropy objects: +GWCS supports the :ref:`wcsapi` which defines several methods to work with high level Astropy objects: .. doctest-skip:: diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index 81bf76cc..dda68ac2 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -1,9 +1,11 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst """ -This module defines coordinate frames for describing the inputs and/or outputs of a transform. +This module defines coordinate frames for describing the inputs and/or outputs +of a transform. -In the following example, we have a two stage transform, with an input frame, an -output frame and an intermediate frame. +In the block diagram, the WCS pipeline has a two stage transformation (two +astropy Model instances), with an input frame, an output frame and an +intermediate frame. .. code-block:: @@ -62,8 +64,11 @@ [SpectralFrame(axes_order=(1,)), CelestialFrame(axes_order=(2, 0))] -we would map the outputs of this transform into the correct positions in the -frames. As shown below, this is also used when constructing the inputs to the inverse transform. +we would map the outputs of this transform into the correct positions in the frames. + As shown below, this is also used when constructing the inputs to the inverse transform. + + +When taking the output from the forward transform the following transformation is performed by the coordinate frames: .. code-block:: @@ -82,6 +87,13 @@ │ │ │ │ │ │ ▼ ▼ ▼ + SpectralCoord(lambda) SkyCoord((lon, lat)) + + +When considering the backward transform the following transformations take place in the coordinate frames before the transform is called: + +.. code-block:: + SpectralCoord(lambda) SkyCoord((lon, lat)) │ │ │ └─────┐ ┌────────────┘ │ From bc2beb9244b7bd436102d162b0d289a2fccd9ac7 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 17 Oct 2024 12:05:10 +0100 Subject: [PATCH 33/49] Fix rebase --- gwcs/coordinate_frames.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index dda68ac2..e549c9e4 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -647,7 +647,7 @@ def world_axis_object_classes(self): {'unit': self.unit[0]})} @property - def _world_axis_object_components(self): + def _native_world_axis_object_components(self): return [('spectral', 0, lambda sc: sc.to_value(self.unit[0]))] From 876b94a24c324408f0dd44a0ef5d049fdc023fd0 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Wed, 20 Nov 2024 13:09:47 +0000 Subject: [PATCH 34/49] Add high level <> values converters to frames This adds back a more sane equivalent of coordinates and coordinates_to_quantity. --- gwcs/coordinate_frames.py | 56 +++++++++++++++++++++++++++ gwcs/tests/test_coordinate_systems.py | 23 ++--------- 2 files changed, 60 insertions(+), 19 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index e549c9e4..dae9b586 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -127,6 +127,7 @@ from astropy.wcs.wcsapi.low_level_api import (validate_physical_types, VALID_UCDS) from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1 +from astropy.wcs.wcsapi.high_level_api import high_level_objects_to_values, values_to_high_level_objects from astropy.coordinates import StokesCoord __all__ = ['BaseCoordinateFrame', 'Frame2D', 'CelestialFrame', 'SpectralFrame', 'CompositeFrame', @@ -507,6 +508,61 @@ def world_axis_object_classes(self): def _native_world_axis_object_components(self): return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._prop.axes_type)] + @property + def serialized_classes(self): + """ + This property is used by the low level WCS API in Astropy. + + By providing it we can duck type as a low level WCS object. + """ + return False + + def to_high_level_coordinates(self, *values): + """ + Convert "values" to high level coordinate objects described by this frame. + + "values" are the coordinates in array or scalar form, and high level + objects are things such as ``SkyCoord`` or ``Quantity``. See + :ref:`wcsapi` for details. + + Parameters + ---------- + values : `numbers.Number` or `numpy.ndarray` + ``naxis`` number of coordinates as scalars or arrays. + + Returns + ------- + high_level_coordinates + One (or more) high level object describing the coordinate. + """ + high_level = values_to_high_level_objects(*values, low_level_wcs=self) + if len(high_level) == 1: + high_level = high_level[0] + return high_level + + def from_high_level_coordinates(self, *high_level_coords): + """ + Convert high level coordinate objects to "values" as described by this frame. + + "values" are the coordinates in array or scalar form, and high level + objects are things such as ``SkyCoord`` or ``Quantity``. See + :ref:`wcsapi` for details. + + Parameters + ---------- + high_level_coordinates + One (or more) high level object describing the coordinate. + + Returns + ------- + values : `numbers.Number` or `numpy.ndarray` + ``naxis`` number of coordinates as scalars or arrays. + """ + values = high_level_objects_to_values(*high_level_coords, low_level_wcs=self) + if len(values) == 1: + values = values[0] + return values + class CelestialFrame(CoordinateFrame): """ diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index 75f2895f..41040906 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -15,7 +15,6 @@ from .. import WCS from .. import coordinate_frames as cf -from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects, high_level_objects_to_values import astropy astropy_version = astropy.__version__ @@ -56,19 +55,6 @@ inputs3 = [(xscalar, yscalar, xscalar), (xarr, yarr, xarr)] -@pytest.fixture(autouse=True, scope="module") -def serialized_classes(): - """ - In the rest of this test file we are passing the CoordinateFrame object to - astropy helper functions as if they were a low level WCS object. - - This little patch means that this works. - """ - cf.CoordinateFrame.serialized_classes = False - yield - del cf.CoordinateFrame.serialized_classes - - def test_units(): assert(comp1.unit == (u.deg, u.deg, u.Hz)) assert(comp2.unit == (u.m, u.m, u.m)) @@ -81,14 +67,13 @@ def test_units(): # These two functions fake the old methods on CoordinateFrame to reduce the # amount of refactoring that needed doing in these tests. def coordinates(*inputs, frame): - results = values_to_high_level_objects(*inputs, low_level_wcs=frame) - if isinstance(results, list) and len(results) == 1: - return results[0] - return results + return frame.to_high_level_coordinates(*inputs) def coordinate_to_quantity(*inputs, frame): - results = high_level_objects_to_values(*inputs, low_level_wcs=frame) + results = frame.from_high_level_coordinates(*inputs) + if not isinstance(results, list): + results = [results] results = [r << unit for r, unit in zip(results, frame.unit)] return results From 991674a8849f4453d10e001c821bbbe3b2e029fb Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Wed, 20 Nov 2024 15:43:12 +0000 Subject: [PATCH 35/49] Fix a apt bug --- gwcs/coordinate_frames.py | 9 ++++----- gwcs/tests/test_coordinate_systems.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index dae9b586..b7581a69 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -247,7 +247,7 @@ def __post_init__(self, naxes): self.axis_physical_types = tuple(ph_type) @property - def _default_axis_physical_type(self): + def _default_axis_physical_types(self): """ The default physical types to use for this frame if none are specified by the user. @@ -416,12 +416,12 @@ def __init__(self, naxes, axes_type, axes_order, reference_frame=None, axes_type, unit, axes_names, - axis_physical_types or self._default_axis_physical_type(axes_type) + axis_physical_types or self._default_axis_physical_types(axes_type) ) super().__init__() - def _default_axis_physical_type(self, axes_type): + def _default_axis_physical_types(self, axes_type): """ The default physical types to use for this frame if none are specified by the user. @@ -494,8 +494,7 @@ def axis_physical_types(self): These physical types are the types in frame order, not transform order. """ - apt = self._prop.axis_physical_types or self._default_axis_physical_types - return self._sort_property(apt) + return self._sort_property(self._prop.axis_physical_types) @property def world_axis_object_classes(self): diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index 41040906..61c55f41 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -347,7 +347,7 @@ def test_coordinate_to_quantity_error(): coordinate_to_quantity(1, frame=frame) -def test_axis_physical_type(): +def test_axis_physical_types(): assert icrs.axis_physical_types == ("pos.eq.ra", "pos.eq.dec") assert spec1.axis_physical_types == ("em.freq",) assert spec2.axis_physical_types == ("em.wl",) From 683fd748c847a8864377e377e6820235abfbd3b1 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Wed, 20 Nov 2024 16:11:02 +0000 Subject: [PATCH 36/49] Fix a naxes bug --- gwcs/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gwcs/api.py b/gwcs/api.py index 6fbaba97..472357dd 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -65,7 +65,7 @@ def world_axis_units(self): def _remove_quantity_output(self, result, frame): if self.forward_transform.uses_quantity: - if self.output_frame.naxes == 1: + if frame.naxes == 1: result = [result] result = tuple(r.to_value(unit) if isinstance(r, u.Quantity) else r From 8ea2ff487399e3f820022bee51f5883ee3af7600 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 21 Nov 2024 11:06:01 +0000 Subject: [PATCH 37/49] Compat with astropy 7.0 for high_level methods --- gwcs/coordinate_frames.py | 3 ++- gwcs/tests/test_coordinate_systems.py | 16 ++-------------- gwcs/wcs.py | 2 ++ 3 files changed, 6 insertions(+), 15 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index b7581a69..c4994757 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -534,6 +534,7 @@ def to_high_level_coordinates(self, *values): high_level_coordinates One (or more) high level object describing the coordinate. """ + values = [v.to_value(unit) if hasattr(v, "to_value") else v for v, unit in zip(values, self.unit)] high_level = values_to_high_level_objects(*values, low_level_wcs=self) if len(high_level) == 1: high_level = high_level[0] @@ -727,7 +728,7 @@ class TemporalFrame(CoordinateFrame): Name for this frame. """ - def __init__(self, reference_frame, unit=None, axes_order=(0,), + def __init__(self, reference_frame, unit=u.s, axes_order=(0,), axes_names=None, name=None, axis_physical_types=None): axes_names = axes_names or "{}({}; {}".format(reference_frame.format, reference_frame.scale, diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index 61c55f41..aa572096 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -193,7 +193,7 @@ def test_temporal_relative(): assert coordinates(10, frame=t) == Time("2018-01-01T00:00:00") + 10 * u.s assert coordinates(10 * u.s, frame=t) == Time("2018-01-01T00:00:00") + 10 * u.s - a = coordinates((10, 20), frame=t) + a = coordinates(np.array((10, 20)), frame=t) assert a[0] == Time("2018-01-01T00:00:00") + 10 * u.s assert a[1] == Time("2018-01-01T00:00:00") + 20 * u.s @@ -201,23 +201,11 @@ def test_temporal_relative(): assert coordinates(10 * u.s, frame=t) == Time("2018-01-01T00:00:00") + 10 * u.s assert coordinates(TimeDelta(10, format='sec'), frame=t) == Time("2018-01-01T00:00:00") + 10 * u.s - a = coordinates((10, 20) * u.s, frame=t) + a = coordinates(np.array((10, 20)) * u.s, frame=t) assert a[0] == Time("2018-01-01T00:00:00") + 10 * u.s assert a[1] == Time("2018-01-01T00:00:00") + 20 * u.s -def test_temporal_absolute(): - t = cf.TemporalFrame(reference_frame=Time([], format='isot')) - assert coordinates("2018-01-01T00:00:00", frame=t) == Time("2018-01-01T00:00:00") - - a = coordinates(("2018-01-01T00:00:00", "2018-01-01T00:10:00"), frame=t) - assert a[0] == Time("2018-01-01T00:00:00") - assert a[1] == Time("2018-01-01T00:10:00") - - t = cf.TemporalFrame(reference_frame=Time([], scale='tai', format='isot')) - assert coordinates("2018-01-01T00:00:00", frame=t) == Time("2018-01-01T00:00:00", scale='tai') - - @pytest.mark.parametrize('inp', [ (coord.SkyCoord(10 * u.deg, 20 * u.deg, frame=coord.ICRS),), # This is the same as 10,20 in ICRS diff --git a/gwcs/wcs.py b/gwcs/wcs.py index 5cbd7282..3c32c88b 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -369,6 +369,8 @@ def __call__(self, *args, **kwargs): if with_units: if not astutil.isiterable(results): results = (results,) + # values are always expected to be arrays or scalars not quantities + results = self._remove_units_input(results, self.output_frame) high_level = values_to_high_level_objects(*results, low_level_wcs=self) if len(high_level) == 1: high_level = high_level[0] From 8c63751377b1edfa4af627be3efe00ea83cc5487 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 21 Nov 2024 11:08:10 +0000 Subject: [PATCH 38/49] Explicitly error if invalid values --- gwcs/coordinate_frames.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index c4994757..3f93c5d0 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -116,6 +116,7 @@ import abc from collections import defaultdict import logging +import numbers import numpy as np from dataclasses import dataclass, InitVar @@ -526,7 +527,7 @@ def to_high_level_coordinates(self, *values): Parameters ---------- - values : `numbers.Number` or `numpy.ndarray` + values : `numbers.Number`, `numpy.ndarray`, or `~astropy.units.Quantity` ``naxis`` number of coordinates as scalars or arrays. Returns @@ -534,7 +535,12 @@ def to_high_level_coordinates(self, *values): high_level_coordinates One (or more) high level object describing the coordinate. """ + # We allow Quantity-like objects here which values_to_high_level_objects does not. values = [v.to_value(unit) if hasattr(v, "to_value") else v for v, unit in zip(values, self.unit)] + + if not all([isinstance(v, numbers.Number) or type(v) is np.ndarray for v in values]): + raise TypeError("All values should be a scalar number or a numpy array.") + high_level = values_to_high_level_objects(*values, low_level_wcs=self) if len(high_level) == 1: high_level = high_level[0] From ad850f7af66851d82c264304edf2ff91a3053652 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 21 Nov 2024 11:26:57 +0000 Subject: [PATCH 39/49] Raise an explicit error if mixed high level types are passed --- gwcs/tests/test_api.py | 11 +++++++++++ gwcs/utils.py | 17 +++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/gwcs/tests/test_api.py b/gwcs/tests/test_api.py index a3d22bc9..959371b3 100644 --- a/gwcs/tests/test_api.py +++ b/gwcs/tests/test_api.py @@ -558,3 +558,14 @@ def test_world_axis_object_components_units(gwcs_3d_identity_units): assert not any([isinstance(o, u.Quantity) for o in values]) np.testing.assert_allclose(values, expected_values) + + +def test_mismatched_high_level_types(gwcs_3d_identity_units): + wcs = gwcs_3d_identity_units + + with pytest.raises(TypeError, match="Invalid types were passed.*(tuple, SpectralCoord).*(SkyCoord, SpectralCoord).*"): + wcs.invert((1*u.deg, 2*u.deg), coord.SpectralCoord(10*u.nm)) + + # Oh astropy why do you make us do this + with pytest.raises(TypeError, match="Invalid types were passed.*got.*Quantity.*expected.*SpectralCoord.*"): + wcs.invert(coord.SkyCoord(1*u.deg, 2*u.deg), 10*u.nm) diff --git a/gwcs/utils.py b/gwcs/utils.py index c04c105d..44d7e4c0 100644 --- a/gwcs/utils.py +++ b/gwcs/utils.py @@ -472,5 +472,18 @@ def is_high_level(*args, low_level_wcs): if len(args) != len(low_level_wcs.world_axis_object_classes): return False - return all([type(arg) is waoc[0] - for arg, waoc in zip(args, low_level_wcs.world_axis_object_classes.values())]) + type_match = [(type(arg), waoc[0]) + for arg, waoc in zip(args, low_level_wcs.world_axis_object_classes.values())] + + types_are_high_level = [argt is t for argt, t in type_match] + + if all(types_are_high_level): + return True + + if any(types_are_high_level): + raise TypeError( + "Invalid types were passed, got " + f"({', '.join(tm[0].__name__ for tm in type_match)}) expected " + f"({', '.join(tm[1].__name__ for tm in type_match)}).") + + return False From 0810abe3a20d12096057cfd7bc678f173f7461a8 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 21 Nov 2024 17:58:42 +0000 Subject: [PATCH 40/49] Test intermediate high_level conversion --- gwcs/tests/test_wcs.py | 74 +++++++++++++++++++++++++++++++++++++++++- gwcs/wcs.py | 38 ++++++++++++++++++---- 2 files changed, 104 insertions(+), 8 deletions(-) diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index 5f017946..146b2889 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -1499,7 +1499,7 @@ def test_reordered_celestial(): output_pixel = iwcs.world_to_pixel_values(*output_world) assert_allclose(output_pixel, u.Quantity(input_pixel).to_value(u.pix)) - expected_world = [45*u.deg, 20*u.arcsec]#, 45*u.deg] + expected_world = [45*u.deg, 20*u.arcsec] assert_allclose(output_world, [e.value for e in expected_world]) world_obj = iwcs.pixel_to_world(*input_pixel) @@ -1510,3 +1510,75 @@ def test_reordered_celestial(): obj_pixel = iwcs.world_to_pixel(world_obj) assert_allclose(obj_pixel, u.Quantity(input_pixel).to_value(u.pix)) + + +@pytest.fixture +def gwcs_with_pipeline_celestial(): + input_frame = cf.CoordinateFrame(2, ["PIXEL"]*2, + axes_order=list(range(2)), + unit=[u.pix]*2, + name="input") + + spatial = models.Multiply(20*u.arcsec/u.pix) & models.Multiply(15*u.deg/u.pix) + + celestial_frame = cf.CelestialFrame(axes_order=(0, 1), unit=(u.arcsec, u.deg), + reference_frame=coord.ICRS(), name="celestial") + + custom = models.Shift(1*u.deg) & models.Shift(2*u.deg) + + output_frame = cf.CoordinateFrame(2, ["CUSTOM"]*2, + axes_order=list(range(2)), unit=[u.arcsec]*2, name="output") + + pipeline = [ + (input_frame, spatial), + (celestial_frame, custom), + (output_frame, None), + ] + + return wcs.WCS(pipeline) + + +def test_high_level_objects_in_pipeline_forward(gwcs_with_pipeline_celestial): + """ + This test checks that high level objects still work with a multi-stage + pipeline when doing forward transforms. + """ + iwcs = gwcs_with_pipeline_celestial + + input_pixel = [1*u.pix, 1*u.pix] + + output_world = iwcs(*input_pixel) + + assert output_world[0].unit == u.deg + assert output_world[1].unit == u.deg + assert u.allclose(output_world[0], 20*u.arcsec + 1*u.deg) + assert u.allclose(output_world[1], 15*u.deg + 2*u.deg) + + # with_units=True puts the result in the frame units rather than in the + # model units. + output_world_with_units = iwcs(*input_pixel, with_units=True) + assert output_world_with_units[0].unit is u.arcsec + assert output_world_with_units[1].unit is u.arcsec + + # This should be in model units of the spatial model + intermediate_world = iwcs.transform( + "input", + "celestial", + *input_pixel, + ) + assert intermediate_world[0].unit == u.arcsec + assert intermediate_world[1].unit == u.deg + assert u.allclose(intermediate_world[0], 20*u.arcsec) + assert u.allclose(intermediate_world[1], 15*u.deg) + + intermediate_world_with_units = iwcs.transform( + "input", + "celestial", + *input_pixel, + with_units=True + ) + assert len(intermediate_world_with_units) == 1 + assert isinstance(intermediate_world_with_units[0], coord.SkyCoord) + sc = intermediate_world_with_units[0] + assert u.allclose(sc.ra, 20*u.arcsec) + assert u.allclose(sc.dec, 15*u.deg) diff --git a/gwcs/wcs.py b/gwcs/wcs.py index 3c32c88b..9fc772f8 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -173,7 +173,6 @@ def _initialize_wcs(self, forward_transform, input_frame, output_frame): else: name, frame_obj = self._get_frame_name(item[0]) super(WCS, self).__setattr__(name, frame_obj) - #self._pipeline.append((name, item[1])) self._pipeline = forward_transform else: raise TypeError("Expected forward_transform to be a model or a " @@ -295,6 +294,17 @@ def backward_transform(self): backward.inverse = self.forward_transform return backward + def _get_frame_by_name(self, frame_name): + """ + Return the frame object by name. + """ + frames = [step.frame for step in self._pipeline if step.frame.name == frame_name] + if len(frames) > 1: + raise ValueError(f"There is more than one frame named {frame_name}") + if len(frames) == 0: + return ValueError(f"No frame found matching {frame_name}") + return frames[0] + def _get_frame_index(self, frame): """ Return the index in the pipeline where this frame is locate. @@ -386,16 +396,21 @@ def _call_forward(self, *args, from_frame=None, to_frame=None, transform = self.forward_transform else: transform = self.get_transform(from_frame, to_frame) + if from_frame is None: + from_frame = self.input_frame + if to_frame is None: + to_frame = self.output_frame if transform is None: raise NotImplementedError("WCS.forward_transform is not implemented.") + breakpoint() # Validate that the input type matches what the transform expects input_is_quantity = any((isinstance(a, u.Quantity) for a in args)) if not input_is_quantity and transform.uses_quantity: - args = self._add_units_input(args, self.input_frame) + args = self._add_units_input(args, from_frame) if not transform.uses_quantity and input_is_quantity: - args = self._remove_units_input(args, self.input_frame) + args = self._remove_units_input(args, from_frame) return transform(*args, with_bounding_box=with_bounding_box, @@ -502,14 +517,18 @@ def invert(self, *args, **kwargs): transform returns ``Quantity`` objects, else values. """ + # must pop before calling the model + with_units = kwargs.pop('with_units', False) + if utils.is_high_level(*args, low_level_wcs=self): args = high_level_objects_to_values(*args, low_level_wcs=self) results = self._call_backward(*args, **kwargs) - with_units = kwargs.pop('with_units', False) if with_units: - high_level = values_to_high_level_objects(*results, low_level_wcs=self) + # values are always expected to be arrays or scalars not quantities + results = self._remove_units_input(results, self.input_frame) + high_level = values_to_high_level_objects(*results, low_level_wcs=self.input_frame) if len(high_level) == 1: high_level = high_level[0] return high_level @@ -1136,7 +1155,13 @@ def transform(self, from_frame, to_frame, *args, **kwargs): results = self._call_forward(*args, from_frame=from_frame, to_frame=to_frame, **kwargs) if with_units and not backward: - return values_to_high_level_objects(*results, low_level_wcs=self) + # TODO: Apparently you can have frames which are just strings + # We need an actual frame object for this to work + if isinstance(to_frame, str): + to_frame = self._get_frame_by_name(to_frame) + # values are always expected to be arrays or scalars not quantities + results = self._remove_units_input(results, to_frame) + return values_to_high_level_objects(*results, low_level_wcs=to_frame) return results @property @@ -1150,7 +1175,6 @@ def available_frames(self): {frame_name: frame_object or None} """ if self._pipeline: - #return [getattr(frame[0], "name", frame[0]) for frame in self._pipeline] return [step.frame if isinstance(step.frame, str) else step.frame.name for step in self._pipeline ] else: return None From 2d90bb0701298b4a9c6f4251e534e70e6009fbc0 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Mon, 25 Nov 2024 09:30:45 +0000 Subject: [PATCH 41/49] Whoops --- gwcs/wcs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gwcs/wcs.py b/gwcs/wcs.py index 9fc772f8..e94498d0 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -404,7 +404,6 @@ def _call_forward(self, *args, from_frame=None, to_frame=None, if transform is None: raise NotImplementedError("WCS.forward_transform is not implemented.") - breakpoint() # Validate that the input type matches what the transform expects input_is_quantity = any((isinstance(a, u.Quantity) for a in args)) if not input_is_quantity and transform.uses_quantity: From 7df9fa8f0ba8da61f65f917fa549276cc8cf267e Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Tue, 3 Dec 2024 16:21:28 +0000 Subject: [PATCH 42/49] Fix high level output of intermediate frames --- gwcs/tests/test_wcs.py | 41 +++++++++++++++++++++++++++++++++++++++++ gwcs/wcs.py | 17 +++++++++++------ 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index 146b2889..e07fdb92 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -1582,3 +1582,44 @@ def test_high_level_objects_in_pipeline_forward(gwcs_with_pipeline_celestial): sc = intermediate_world_with_units[0] assert u.allclose(sc.ra, 20*u.arcsec) assert u.allclose(sc.dec, 15*u.deg) + + +def test_high_level_objects_in_pipeline_backward(gwcs_with_pipeline_celestial): + """ + This test checks that high level objects still work with a multi-stage + pipeline when doing backward transforms. + """ + iwcs = gwcs_with_pipeline_celestial + + input_world = [ + 20*u.arcsec + 1*u.deg, + 15*u.deg + 2*u.deg, + ] + pixel = iwcs.invert(*input_world) + + assert all(isinstance(p, u.Quantity) for p in pixel) + assert u.allclose(pixel, [1, 1]*u.pix) + + pixel = iwcs.invert( + *input_world, + with_units=True, + ) + + assert all(isinstance(p, u.Quantity) for p in pixel) + assert u.allclose(pixel, [1, 1]*u.pix) + + intermediate_world = iwcs.transform( + "output", + "celestial", + *input_world, + ) + assert all(isinstance(p, u.Quantity) for p in intermediate_world) + assert u.allclose(intermediate_world, [20*u.arcsec, 15*u.deg]) + + intermediate_world = iwcs.transform( + "output", + "celestial", + *input_world, + with_units=True, + ) + assert isinstance(intermediate_world, coord.SkyCoord) diff --git a/gwcs/wcs.py b/gwcs/wcs.py index e94498d0..1ed9044c 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -298,6 +298,9 @@ def _get_frame_by_name(self, frame_name): """ Return the frame object by name. """ + if not isinstance(frame_name, str): + return frame_name + frames = [step.frame for step in self._pipeline if step.frame.name == frame_name] if len(frames) > 1: raise ValueError(f"There is more than one frame named {frame_name}") @@ -1146,6 +1149,9 @@ def transform(self, from_frame, to_frame, *args, **kwargs): from_ind = self._get_frame_index(from_frame) to_ind = self._get_frame_index(to_frame) backward = to_ind < from_ind + # Convert from strings to frame objects + from_frame = self._get_frame_by_name(from_frame) + to_frame = self._get_frame_by_name(to_frame) with_units = kwargs.pop("with_units", False) if with_units and backward: @@ -1153,14 +1159,13 @@ def transform(self, from_frame, to_frame, *args, **kwargs): results = self._call_forward(*args, from_frame=from_frame, to_frame=to_frame, **kwargs) - if with_units and not backward: - # TODO: Apparently you can have frames which are just strings - # We need an actual frame object for this to work - if isinstance(to_frame, str): - to_frame = self._get_frame_by_name(to_frame) + if with_units: # values are always expected to be arrays or scalars not quantities results = self._remove_units_input(results, to_frame) - return values_to_high_level_objects(*results, low_level_wcs=to_frame) + high_level = values_to_high_level_objects(*results, low_level_wcs=to_frame) + if len(high_level) == 1: + high_level = high_level[0] + return high_level return results @property From cbdbe8a23b780e5a60f07c1f4b8d0f9647aea8a2 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Wed, 4 Dec 2024 09:42:24 +0000 Subject: [PATCH 43/49] More test fixes --- gwcs/api.py | 2 ++ gwcs/tests/test_wcs.py | 23 +++++++++++++---------- gwcs/wcs.py | 4 ++-- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/gwcs/api.py b/gwcs/api.py index 472357dd..0c32482e 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -9,6 +9,8 @@ from astropy.modeling import separable import astropy.units as u +from gwcs import utils + __all__ = ["GWCSAPIMixin"] diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index e07fdb92..f3144708 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -22,7 +22,8 @@ from gwcs.wcstools import (wcs_from_fiducial, grid_from_bounding_box, wcs_from_points) from gwcs import coordinate_frames as cf from gwcs.utils import CoordinateFrameError -from . utils import _gwcs_from_hst_fits_wcs +from gwcs.tests.utils import _gwcs_from_hst_fits_wcs +from gwcs.tests import data from gwcs.examples import gwcs_2d_bad_bounding_box_order @@ -465,7 +466,8 @@ def test_bounding_box_eval(): Tests evaluation with and without respecting the bounding_box. """ trans3 = models.Shift(10) & models.Scale(2) & models.Shift(-1) - pipeline = [('detector', trans3), ('sky', None)] + pipeline = [(cf.CoordinateFrame(naxes=1, axes_type=("PIXEL",), axes_order=(0,), name='detector'), trans3), + (cf.CoordinateFrame(naxes=1, axes_type=("SPATIAL",), axes_order=(0,), name='sky'), None)] w = wcs.WCS(pipeline) w.bounding_box = ((-1, 10), (6, 15), (4.3, 6.9)) @@ -606,11 +608,13 @@ def setup_class(self): tan = models.Pix2Sky_TAN(name='tangent_projection') sky_cs = cf.CelestialFrame(reference_frame=coord.ICRS(), name='sky') det = cf.Frame2D(name='detector') + focal = cf.Frame2D(name='focal') wcs_forward = wcslin | tan | n2c - pipeline = [wcs.Step('detector', distortion), - wcs.Step('focal', wcs_forward), - wcs.Step(sky_cs, None) - ] + pipeline = [ + wcs.Step(det, distortion), + wcs.Step(focal, wcs_forward), + wcs.Step(sky_cs, None) + ] self.wcs = wcs.WCS(input_frame=det, output_frame=sky_cs, @@ -657,7 +661,7 @@ def test_inverse(self): def test_back_coordinates(self): sky_coord = self.wcs(1, 2, with_units=True) - res = self.wcs.transform('sky', 'focal', sky_coord, with_units=True) + res = self.wcs.transform('sky', 'focal', sky_coord, with_units=False) assert_allclose(res, self.wcs.get_transform('detector', 'focal')(1, 2)) def test_units(self): @@ -1577,9 +1581,8 @@ def test_high_level_objects_in_pipeline_forward(gwcs_with_pipeline_celestial): *input_pixel, with_units=True ) - assert len(intermediate_world_with_units) == 1 - assert isinstance(intermediate_world_with_units[0], coord.SkyCoord) - sc = intermediate_world_with_units[0] + assert isinstance(intermediate_world_with_units, coord.SkyCoord) + sc = intermediate_world_with_units assert u.allclose(sc.ra, 20*u.arcsec) assert u.allclose(sc.dec, 15*u.deg) diff --git a/gwcs/wcs.py b/gwcs/wcs.py index 1ed9044c..c4884ec0 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -1154,8 +1154,8 @@ def transform(self, from_frame, to_frame, *args, **kwargs): to_frame = self._get_frame_by_name(to_frame) with_units = kwargs.pop("with_units", False) - if with_units and backward: - args = high_level_objects_to_values(*args, low_level_wcs=self) + if backward and utils.is_high_level(*args, low_level_wcs=from_frame): + args = high_level_objects_to_values(*args, low_level_wcs=from_frame) results = self._call_forward(*args, from_frame=from_frame, to_frame=to_frame, **kwargs) From 1ac7aea39233c266ef13b9be5311bbd791897719 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Wed, 4 Dec 2024 09:47:40 +0000 Subject: [PATCH 44/49] Move fixture to examples --- gwcs/examples.py | 25 +++++++++++++++++++++++++ gwcs/tests/conftest.py | 5 +++++ gwcs/tests/test_wcs.py | 26 -------------------------- 3 files changed, 30 insertions(+), 26 deletions(-) diff --git a/gwcs/examples.py b/gwcs/examples.py index 510b99db..842629a2 100644 --- a/gwcs/examples.py +++ b/gwcs/examples.py @@ -488,3 +488,28 @@ def gwcs_7d_complex_mapping(): w.pixel_shape = (16, 32, 21, 11, 11, 2) return w + + +def gwcs_with_pipeline_celestial(): + input_frame = cf.CoordinateFrame(2, ["PIXEL"]*2, + axes_order=list(range(2)), + unit=[u.pix]*2, + name="input") + + spatial = models.Multiply(20*u.arcsec/u.pix) & models.Multiply(15*u.deg/u.pix) + + celestial_frame = cf.CelestialFrame(axes_order=(0, 1), unit=(u.arcsec, u.deg), + reference_frame=coord.ICRS(), name="celestial") + + custom = models.Shift(1*u.deg) & models.Shift(2*u.deg) + + output_frame = cf.CoordinateFrame(2, ["CUSTOM"]*2, + axes_order=list(range(2)), unit=[u.arcsec]*2, name="output") + + pipeline = [ + (input_frame, spatial), + (celestial_frame, custom), + (output_frame, None), + ] + + return wcs.WCS(pipeline) diff --git a/gwcs/tests/conftest.py b/gwcs/tests/conftest.py index 83fb216b..f9cf3b7b 100644 --- a/gwcs/tests/conftest.py +++ b/gwcs/tests/conftest.py @@ -141,3 +141,8 @@ def spher_to_cart(): @pytest.fixture def cart_to_spher(): return geometry.CartesianToSpherical() + + +@pytest.fixture +def gwcs_with_pipeline_celestial(): + return examples.gwcs_with_pipeline_celestial() diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index f3144708..ca7ba09a 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -1516,32 +1516,6 @@ def test_reordered_celestial(): assert_allclose(obj_pixel, u.Quantity(input_pixel).to_value(u.pix)) -@pytest.fixture -def gwcs_with_pipeline_celestial(): - input_frame = cf.CoordinateFrame(2, ["PIXEL"]*2, - axes_order=list(range(2)), - unit=[u.pix]*2, - name="input") - - spatial = models.Multiply(20*u.arcsec/u.pix) & models.Multiply(15*u.deg/u.pix) - - celestial_frame = cf.CelestialFrame(axes_order=(0, 1), unit=(u.arcsec, u.deg), - reference_frame=coord.ICRS(), name="celestial") - - custom = models.Shift(1*u.deg) & models.Shift(2*u.deg) - - output_frame = cf.CoordinateFrame(2, ["CUSTOM"]*2, - axes_order=list(range(2)), unit=[u.arcsec]*2, name="output") - - pipeline = [ - (input_frame, spatial), - (celestial_frame, custom), - (output_frame, None), - ] - - return wcs.WCS(pipeline) - - def test_high_level_objects_in_pipeline_forward(gwcs_with_pipeline_celestial): """ This test checks that high level objects still work with a multi-stage From 1b7871bd553aa4941916b6639f7a6c4fe6f48199 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Wed, 4 Dec 2024 09:58:52 +0000 Subject: [PATCH 45/49] Add changelog --- CHANGES.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index cca5803e..e17ffcbf 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,16 @@ 0.22.0 (unreleased) ------------------- +- Coordinate frames now have a "native" order and then are sorted based on ``axes_order``. [#457] + +- ``WCS.numerical_inverse`` no longer accepts high level objects (``with_units=`` is not supported) use ``WCS.inverse``. [#457] + +- ``CoordinateFrame.coordinates`` has been replaced by ``CoordinateFrame.to_high_level_coordinates`` [#457] + +- ``CoordinateFrame.to_quantity`` has been replaced by ``CoordinateFrame.from_high_level_coordinates``. [#457] + +- Inputs to ``CelestialFrame``, such as ``axes_names`` are now explicitly in lon, lat order and will re sorted based on ``axes_order=``. [#457] + - Replace usages of ``copy_arrays`` with ``memmap`` [#503] - Fix an issue with units in ``wcs_from_points``. [#507] From 25131a364bd1f260f24c714ffa3f4a40a196fff6 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Tue, 17 Dec 2024 11:35:18 +0000 Subject: [PATCH 46/49] Fix api test for bbox --- gwcs/tests/test_api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gwcs/tests/test_api.py b/gwcs/tests/test_api.py index eee4385b..94777d17 100644 --- a/gwcs/tests/test_api.py +++ b/gwcs/tests/test_api.py @@ -265,7 +265,11 @@ def test_high_level_wrapper(wcsobj, request): hlvl = HighLevelWCSWrapper(wcsobj) - pixel_input = [6] * wcsobj.pixel_n_dim + pixel_input = [3] * wcsobj.pixel_n_dim + if wcsobj.bounding_box is not None: + for i, interval in wcsobj.bounding_box.intervals.items(): + bbox_min = u.Quantity(interval.lower).value + pixel_input[i] = max(bbox_min + 1, pixel_input[i]) # Assert that both APE 14 API and GWCS give the same answer The APE 14 API # uses the mixin class and __call__ calls values_to_high_level_objects From 62af1950cfe11d6363e53482d084977e18fdce78 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Tue, 17 Dec 2024 11:35:30 +0000 Subject: [PATCH 47/49] lint --- gwcs/tests/test_utils.py | 1 - gwcs/utils.py | 2 -- gwcs/wcs.py | 2 -- 3 files changed, 5 deletions(-) diff --git a/gwcs/tests/test_utils.py b/gwcs/tests/test_utils.py index 73938866..b33d742b 100644 --- a/gwcs/tests/test_utils.py +++ b/gwcs/tests/test_utils.py @@ -6,7 +6,6 @@ from astropy import units as u from astropy import coordinates as coord from astropy.modeling import models -from astropy import table from astropy.tests.helper import assert_quantity_allclose import pytest diff --git a/gwcs/utils.py b/gwcs/utils.py index 2d7bbbea..e50c6486 100644 --- a/gwcs/utils.py +++ b/gwcs/utils.py @@ -11,8 +11,6 @@ from astropy.io import fits from astropy import coordinates as coords import astropy.units as u -from astropy.time import Time, TimeDelta -from astropy import table from astropy.wcs import Celprm diff --git a/gwcs/wcs.py b/gwcs/wcs.py index 9effa949..e5bd8893 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -19,9 +19,7 @@ from astropy.wcs.utils import celestial_frame_to_wcs, proj_plane_pixel_scales from astropy.wcs.wcsapi.high_level_api import high_level_objects_to_values, values_to_high_level_objects -from astropy import units as u from scipy import linalg, optimize -from astropy.wcs.wcsapi.high_level_api import high_level_objects_to_values, values_to_high_level_objects from . import coordinate_frames as cf from . import utils From 8c3142694185f9738bd922d330f5d06f6a03df46 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 19 Dec 2024 14:15:26 +0000 Subject: [PATCH 48/49] Fix tests --- gwcs/tests/test_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gwcs/tests/test_api.py b/gwcs/tests/test_api.py index 6b96dcc1..5d196905 100644 --- a/gwcs/tests/test_api.py +++ b/gwcs/tests/test_api.py @@ -497,7 +497,7 @@ def test_world_to_array_index_values(gwcs_simple_imaging, sky_ra_dec): wcsobj = gwcs_simple_imaging sky, ra, dec = sky_ra_dec - assert_allclose(wcsobj.world_to_array_index_values(sky), + assert_allclose(wcsobj.world_to_array_index_values(ra, dec), wcsobj.invert(ra * u.deg, dec * u.deg, with_units=False)[::-1]) From f97dcd8059a8c1616e1a41d78011ad661d385028 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Thu, 19 Dec 2024 13:30:05 -0500 Subject: [PATCH 49/49] Bugfix uncovered by JWST --- gwcs/wcs.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/gwcs/wcs.py b/gwcs/wcs.py index e44bde42..ab97c996 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -523,8 +523,6 @@ def _call_backward(self, *args, with_bounding_box=True, fill_value=np.nan, **kwa except NotImplementedError: transform = None - with_bounding_box = kwargs.pop('with_bounding_box', True) - fill_value = kwargs.pop('fill_value', np.nan) if with_bounding_box and self.bounding_box is not None: args = self.outside_footprint(args)