Skip to content

Commit

Permalink
further improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
mcara committed Dec 22, 2024
1 parent 456365e commit cd4c0d5
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 105 deletions.
194 changes: 93 additions & 101 deletions src/stcal/alignment/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import functools
import logging
import re
from typing import TYPE_CHECKING
import typing
from typing import TYPE_CHECKING, Tuple
import warnings

if TYPE_CHECKING:
Expand Down Expand Up @@ -143,10 +144,16 @@ def _generate_tranform(
return transform


def _get_axis_min_and_bounding_box(footprints: list[np.ndarray],
ref_wcs: gwcs.wcs.WCS) -> tuple:
def _get_bounding_box_with_offsets(
footprints: list[np.ndarray],
ref_wcs: gwcs.wcs.WCS,
fiducial: Sequence,
crpix: Sequence | None,
shape: Sequence | None
) -> Tuple[tuple, tuple, astmodels.Model]:
"""
Calculates axis minimum values and bounding box.
Calculates the offsets to the transform.
Parameters
----------
Expand All @@ -157,28 +164,84 @@ def _get_axis_min_and_bounding_box(footprints: list[np.ndarray],
ref_wcs : ~gwcs.wcs.WCS
The reference WCS object.
fiducial : tuple
A tuple containing the world coordinates of the fiducial point.
crpix : list or tuple, optional
0-indexed pixel coordinates of the reference pixel.
shape : tuple, optional
Shape (using `numpy.ndarray` convention) of the image array associated
with the ``ref_wcs``.
Returns
-------
tuple
A tuple containing the bounding box region in the format
((x0_lower, x0_upper), (x1_lower, x1_upper), ...).
tuple
A tuple containing two elements:
1 - a :py:class:`np.ndarray` with the minimum value in each axis;
2 - a tuple containing the bounding box region in the format
((x0_lower, x0_upper), (x1_lower, x1_upper)).
Shape of the image. When ``shape`` argument is `None`, shape is
determined from the upper limit of the computed bounding box, otherwise
input value of ``shape`` is returned.
~astropy.modeling.Model
A model with the offsets to be added to the WCS's transform.
"""
domain_bounds = np.hstack([ref_wcs.backward_transform(*f.T) for f in footprints])
axis_min_values = np.min(domain_bounds, axis=1)
domain_bounds = (domain_bounds.T - axis_min_values).T

output_bounding_box = []
for axis in ref_wcs.output_frame.axes_order:
axis_min, axis_max = (
0.0,
domain_bounds[axis].max(),
domain_min = np.min(domain_bounds, axis=1)
domain_max = np.max(domain_bounds, axis=1)

native_crpix = ref_wcs.backward_transform(*fiducial)

if crpix is None:
# shift the coordinates by domain_min so that all input footprints
# will project to positive coordinates in the offsetted ref_wcs
offsets = tuple(ncrp - dmin for ncrp, dmin in zip(native_crpix, domain_min))

elif crpix is None:
raise ValueError(
"If crpix is not provided, fiducial, wcs, and axis_min_values "
"must be provided."
)
# populate output_bounding_box
output_bounding_box.append((axis_min, axis_max))

return (axis_min_values, output_bounding_box)
else:
# assume 0-based CRPIX and require that fiducial would map to the user
# defined crpix value:
offsets = tuple(
crp - ncrp for ncrp, crp in zip(native_crpix, crpix)
)

# Also offset domain limits:
domain_min += offsets
domain_max += offsets

if shape is None:
shape = tuple(int(dmax + 0.5) for dmax in domain_max[::-1])
bounding_box = tuple(
(-0.5, s - 0.5) for s in shape[::-1]
)
# code to reproduce old results (for old unit test test_wcs_from_footprints)
# bounding_box = tuple(
# (0.0, dmax) for dmax in domain_max
# )
# shape = tuple(int(dmax + 0.5) for dmax in domain_max[::-1])

else:
# trim upper bounding box limits
bounding_box = tuple(
(max(0, int(dmin + 0.5)) - 0.5, min(int(dmax + 0.5), sz) - 0.5)
for dmin, dmax, sz in zip(domain_min, domain_max, shape[::-1])
)

model = astmodels.Shift(-offsets[0], name="crpix1")
for k, shift in enumerate(offsets[1:]):
model = model.__and__(astmodels.Shift(-shift, name=f"crpix{k + 2:d}"))

return bounding_box, shape, model


def _calculate_fiducial(footprints: list[np.ndarray],
Expand Down Expand Up @@ -208,53 +271,6 @@ def _calculate_fiducial(footprints: list[np.ndarray],
return _compute_fiducial_from_footprints(footprints)


def _calculate_offsets(fiducial: tuple,
wcs: gwcs.wcs.WCS | None,
axis_min_values: np.ndarray | None,
crpix: Sequence | None) -> astmodels.Model:
"""
Calculates the offsets to the transform.
Parameters
----------
fiducial : tuple
A tuple containing the world coordinates of the fiducial point.
wcs : ~gwcs.wcs.WCS
A WCS object. It will be used to determine the
axis_min_values : np.ndarray
A two-elements array containing the minimum pixel value for each axis.
crpix : list or tuple
0-indexed pixel coordinates of the reference pixel.
Returns
-------
~astropy.modeling.Model
A model with the offsets to be added to the WCS's transform.
Notes
-----
If ``crpix=None``, then ``fiducial``, ``wcs``, and ``axis_min_values`` must be
provided, in which case, the offsets will be calculated using the WCS object to
find the pixel coordinates of the fiducial point and then correct it by the minimum
pixel value for each axis.
"""
if crpix is None and fiducial is not None and wcs is not None and axis_min_values is not None:
offset1, offset2 = wcs.backward_transform(*fiducial)
offset1 -= axis_min_values[0]
offset2 -= axis_min_values[1]
elif crpix is None:
msg = "If crpix is not provided, fiducial, wcs, and axis_min_values must be provided."
raise ValueError(msg)
else:
# assume 0-based CRPIX
offset1, offset2 = crpix

return astmodels.Shift(-offset1, name="crpix1") & astmodels.Shift(-offset2, name="crpix2")


def _calculate_new_wcs(wcs: gwcs.wcs.WCS,
shape: Sequence | None,
footprints: list[np.ndarray],
Expand Down Expand Up @@ -303,45 +319,23 @@ def _calculate_new_wcs(wcs: gwcs.wcs.WCS,
transform=transform,
input_frame=wcs.input_frame,
)
axis_min_values, bbox = _get_axis_min_and_bounding_box(footprints, wcs_new)
offsets = _calculate_offsets(

bounding_box, shape, offsets = _get_bounding_box_with_offsets(
footprints,
ref_wcs=wcs_new,
fiducial=fiducial,
wcs=wcs_new,
axis_min_values=axis_min_values,
crpix=crpix,
shape=shape
)

if crpix is None:
output_bounding_box = bbox
else:
output_bounding_box = []
for axis_range, minval, shift in zip(bbox, axis_min_values, crpix):
output_bounding_box.append(
(
axis_range[0] + shift + minval,
axis_range[1] + shift + minval
)
)
if any(d < 2 for d in shape):
raise ValueError(
"Computed shape for the output image using provided "
"WCS parameters is too small."
)

wcs_new.insert_transform("detector", offsets, after=True)
wcs_new.bounding_box = output_bounding_box

if shape is None:
shape = []
for k, axs in enumerate(output_bounding_box[::-1]):
upper = int(axs[1] + 0.5)
if upper < 1:
log.warning(
"Input images do not overlap with created WCS. "
"Consider adjusting crval and/or crpix values."
)
log.warning(
"Setting minimum array dimension for axis %d to 10."
% (len(output_bounding_box) - k)
)
upper = 10
shape.append(upper)

wcs_new.bounding_box = bounding_box
wcs_new.pixel_shape = shape[::-1]
wcs_new.array_shape = shape
return wcs_new
Expand Down Expand Up @@ -852,9 +846,7 @@ def compute_s_region_imaging(wcs: gwcs.wcs.WCS,
footprint = footprint[:2, :]

# Make sure RA values are all positive
negative_ind = footprint[0] < 0
if negative_ind.any():
footprint[0][negative_ind] = 360 + footprint[0][negative_ind]
np.mod(footprint[0], 360, out=footprint[0])

footprint = footprint.T
return compute_s_region_keyword(footprint)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,10 @@ def test_wcs_from_footprints(s_regions):
# check that all elements of footprint match the *vertices* of the new
# combined WCS
footprnt = wcs.footprint()
assert all(np.isclose(footprnt[0], wcs(0, 0)))
assert all(np.isclose(footprnt[1], wcs(0, 4)))
assert all(np.isclose(footprnt[2], wcs(4, 4)))
assert all(np.isclose(footprnt[3], wcs(4, 0)))
assert all(np.isclose(footprnt[0], wcs(-0.5, -0.5)))
assert all(np.isclose(footprnt[1], wcs(-0.5, 3.5)))
assert all(np.isclose(footprnt[2], wcs(3.5, 3.5)))
assert all(np.isclose(footprnt[3], wcs(3.5, -0.5)))

# check that fiducials match their expected coords in the new combined WCS
assert all(np.isclose(wcs_1(0, 0), wcs(2.5, 1.5)))
Expand Down

0 comments on commit cd4c0d5

Please sign in to comment.