From 4cc0ac166977a063098df47ea9fb35bed2dfbaea Mon Sep 17 00:00:00 2001 From: Zhen-Kai Gao Date: Tue, 20 Feb 2024 23:18:36 +0800 Subject: [PATCH] Add `SlicedLowLevelWCS` support to `reproject` and fix a bug (#8172) Co-authored-by: James Davies Co-authored-by: Howard Bushouse --- CHANGES.rst | 4 + jwst/resample/resample_utils.py | 31 ++------ jwst/resample/tests/test_resample_step.py | 5 +- jwst/resample/tests/test_utils.py | 90 ++++++++++++++++++++++- 4 files changed, 103 insertions(+), 27 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index f48ffac22f..5e3559690d 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -124,6 +124,10 @@ resample - Use the same ``iscale`` value for resampling science data and variance arrays. [#8159] +- Changed to use the high-level APE 14 API (``pixel_to_world_values`` and + ``world_to_pixel_values``) for reproject, which also fixed a bug, and + removed support for astropy model [#8172] + residual_fringe --------------- diff --git a/jwst/resample/resample_utils.py b/jwst/resample/resample_utils.py index 030fdc8c87..68e3287600 100644 --- a/jwst/resample/resample_utils.py +++ b/jwst/resample/resample_utils.py @@ -3,8 +3,6 @@ import warnings import numpy as np -from astropy import wcs as fitswcs -from astropy.modeling import Model from astropy import units as u import gwcs @@ -134,8 +132,9 @@ def reproject(wcs1, wcs2): Parameters ---------- - wcs1, wcs2 : `~astropy.wcs.WCS` or `~gwcs.wcs.WCS` or `~astropy.modeling.Model` - WCS objects. + wcs1, wcs2 : `~astropy.wcs.WCS` or `~gwcs.wcs.WCS` + WCS objects that have `pixel_to_world_values` and `world_to_pixel_values` + methods. Returns ------- @@ -144,25 +143,11 @@ def reproject(wcs1, wcs2): positions in ``wcs1`` and returns x, y positions in ``wcs2``. """ - if isinstance(wcs1, fitswcs.WCS): - forward_transform = wcs1.all_pix2world - elif isinstance(wcs1, gwcs.WCS): - forward_transform = wcs1.forward_transform - elif issubclass(wcs1, Model): - forward_transform = wcs1 - else: - raise TypeError("Expected input to be astropy.wcs.WCS or gwcs.WCS " - "object or astropy.modeling.Model subclass") - - if isinstance(wcs2, fitswcs.WCS): - backward_transform = wcs2.all_world2pix - elif isinstance(wcs2, gwcs.WCS): - backward_transform = wcs2.backward_transform - elif issubclass(wcs2, Model): - backward_transform = wcs2.inverse - else: - raise TypeError("Expected input to be astropy.wcs.WCS or gwcs.WCS " - "object or astropy.modeling.Model subclass") + try: + forward_transform = wcs1.pixel_to_world_values + backward_transform = wcs2.world_to_pixel_values + except AttributeError as err: + raise TypeError("Input should be a WCS") from err def _reproject(x, y): sky = forward_transform(x, y) diff --git a/jwst/resample/tests/test_resample_step.py b/jwst/resample/tests/test_resample_step.py index 3c07ff8b63..ad018e1120 100644 --- a/jwst/resample/tests/test_resample_step.py +++ b/jwst/resample/tests/test_resample_step.py @@ -476,9 +476,10 @@ def test_resample_undefined_variance(nircam_rate, shape): im.var_poisson = np.ones(shape, dtype=im.var_poisson.dtype.type) im.var_flat = np.ones(shape, dtype=im.var_flat.dtype.type) im.meta.filename = "foo.fits" - c = ModelContainer([im]) - ResampleStep.call(c, blendheaders=False) + + with pytest.warns(RuntimeWarning, match="var_rnoise array not available"): + ResampleStep.call(c, blendheaders=False) @pytest.mark.parametrize('ratio', [0.7, 1.2]) diff --git a/jwst/resample/tests/test_utils.py b/jwst/resample/tests/test_utils.py index ed5b846f59..3ef18cbc6b 100644 --- a/jwst/resample/tests/test_utils.py +++ b/jwst/resample/tests/test_utils.py @@ -1,5 +1,10 @@ """Test various utility functions""" -from numpy.testing import assert_array_equal +from astropy import coordinates as coord +from astropy import wcs as fitswcs +from astropy.modeling import models as astmodels +from gwcs import coordinate_frames as cf +from gwcs.wcstools import wcs_from_fiducial +from numpy.testing import assert_allclose, assert_array_equal import numpy as np import pytest @@ -9,7 +14,8 @@ from jwst.resample.resample_utils import ( build_mask, build_driz_weight, - decode_context + decode_context, + reproject ) @@ -25,6 +31,59 @@ JWST_NAMES_INV = '~' + JWST_NAMES +@pytest.fixture(scope='module') +def wcs_gwcs(): + crval = (150.0, 2.0) + crpix = (500.0, 500.0) + shape = (1000, 1000) + pscale = 0.06 / 3600 + + prj = astmodels.Pix2Sky_TAN() + fiducial = np.array(crval) + + pc = np.array([[-1., 0.], [0., 1.]]) + pc_matrix = astmodels.AffineTransformation2D(pc, name='pc_rotation_matrix') + scale = astmodels.Scale(pscale, name='cdelt1') & astmodels.Scale(pscale, name='cdelt2') + transform = pc_matrix | scale + + out_frame = cf.CelestialFrame(name='world', axes_names=('lon', 'lat'), reference_frame=coord.ICRS()) + input_frame = cf.Frame2D(name="detector") + wnew = wcs_from_fiducial(fiducial, coordinate_frame=out_frame, projection=prj, + transform=transform, input_frame=input_frame) + + output_bounding_box = ((0.0, float(shape[1])), (0.0, float(shape[0]))) + offset1, offset2 = crpix + offsets = astmodels.Shift(-offset1, name='crpix1') & astmodels.Shift(-offset2, name='crpix2') + + wnew.insert_transform('detector', offsets, after=True) + wnew.bounding_box = output_bounding_box + + tr = wnew.pipeline[0].transform + pix_area = ( + np.deg2rad(tr['cdelt1'].factor.value) * + np.deg2rad(tr['cdelt2'].factor.value) + ) + + wnew.pixel_area = pix_area + wnew.pixel_shape = shape[::-1] + wnew.array_shape = shape + return wnew + + +@pytest.fixture(scope='module') +def wcs_fitswcs(wcs_gwcs): + fits_wcs = fitswcs.WCS(wcs_gwcs.to_fits_sip()) + return fits_wcs + + +@pytest.fixture(scope='module') +def wcs_slicedwcs(wcs_gwcs): + xmin, xmax = 100, 500 + slices = (slice(xmin, xmax), slice(xmin, xmax)) + sliced_wcs = fitswcs.wcsapi.SlicedLowLevelWCS(wcs_gwcs, slices) + return sliced_wcs + + @pytest.mark.parametrize( 'dq, bitvalues, expected', [ (DQ, 0, np.array([1, 0, 0, 0, 0, 0, 0, 0, 0])), @@ -116,3 +175,30 @@ def test_decode_context(): assert sorted(idx1) == [9, 12, 14, 19, 21, 25, 37, 40, 46, 58, 64, 65, 67, 77] assert sorted(idx2) == [9, 20, 29, 36, 47, 49, 64, 69, 70, 79] + + +@pytest.mark.parametrize( + "wcs1, wcs2, offset", + [ + ("wcs_gwcs", "wcs_fitswcs", 0), + ("wcs_fitswcs", "wcs_gwcs", 0), + ("wcs_gwcs", "wcs_slicedwcs", 100), + ("wcs_slicedwcs", "wcs_gwcs", -100), + ("wcs_fitswcs", "wcs_slicedwcs", 100), + ("wcs_slicedwcs", "wcs_fitswcs", -100), + ] +) +def test_reproject(wcs1, wcs2, offset, request): + wcs1 = request.getfixturevalue(wcs1) + wcs2 = request.getfixturevalue(wcs2) + x = np.arange(150, 200) + + f = reproject(wcs1, wcs2) + res = f(x, x) + assert_allclose(x, res[0] + offset) + assert_allclose(x, res[1] + offset) + + +def test_reproject_with_garbage_input(): + with pytest.raises(TypeError): + reproject("foo", "bar")