From a4e86c04a08e359945f833ab909ce1434d6c32d4 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Tue, 2 Jul 2024 16:41:57 -0400 Subject: [PATCH] remove dependence on ModelContainer, operate on one model at a time --- src/stcal/alignment/resample_utils.py | 170 +++++---- src/stcal/alignment/util.py | 2 - src/stcal/tweakreg/astrometric_utils.py | 44 +-- src/stcal/tweakreg/tweakreg.py | 471 ++++++++++++------------ tests/test_tweakreg.py | 99 ++--- 5 files changed, 360 insertions(+), 426 deletions(-) diff --git a/src/stcal/alignment/resample_utils.py b/src/stcal/alignment/resample_utils.py index 2d1c654a..80ef0cc4 100644 --- a/src/stcal/alignment/resample_utils.py +++ b/src/stcal/alignment/resample_utils.py @@ -42,89 +42,87 @@ def calc_pixmap(in_wcs, out_wcs, shape=None): return np.dstack(transform_function(grid[0], grid[1])) -# is this allowed in stcal, since it operates on a datamodel? -# seems ok. jump step for example does use models -def make_output_wcs(input_models, ref_wcs=None, - pscale_ratio=None, pscale=None, rotation=None, shape=None, - crpix=None, crval=None): - """Generate output WCS here based on footprints of all input WCS objects. - - Parameters - ---------- - input_models : list of `DataModel objects` - Each datamodel must have a ~gwcs.WCS object. - - pscale_ratio : float, optional - Ratio of input to output pixel scale. Ignored when ``pscale`` - is provided. - - pscale : float, None, optional - Absolute pixel scale in degrees. When provided, overrides - ``pscale_ratio``. - - rotation : float, None, optional - Position angle of output image Y-axis relative to North. - A value of 0.0 would orient the final output image to be North up. - The default of `None` specifies that the images will not be rotated, - but will instead be resampled in the default orientation for the camera - with the x and y axes of the resampled image corresponding - approximately to the detector axes. - - shape : tuple of int, None, optional - Shape of the image (data array) using ``numpy.ndarray`` convention - (``ny`` first and ``nx`` second). This value will be assigned to - ``pixel_shape`` and ``array_shape`` properties of the returned - WCS object. - - crpix : tuple of float, None, optional - Position of the reference pixel in the image array. If ``crpix`` is not - specified, it will be set to the center of the bounding box of the - returned WCS object. - - crval : tuple of float, None, optional - Right ascension and declination of the reference pixel. Automatically - computed if not provided. - - Returns - ------- - output_wcs : object - WCS object, with defined domain, covering entire set of input frames - """ - if ref_wcs is None: - wcslist = [i.meta.wcs for i in input_models] - for w, i in zip(wcslist, input_models): - if w.bounding_box is None: - w.bounding_box = util.wcs_bbox_from_shape(i.data.shape) - naxes = wcslist[0].output_frame.naxes - - if naxes != 2: - msg = f"Output WCS needs 2 spatial axes \ - but the supplied WCS has {naxes} axes." - raise RuntimeError(msg) - - output_wcs = util.wcs_from_footprints( - input_models, - pscale_ratio=pscale_ratio, - pscale=pscale, - rotation=rotation, - shape=shape, - crpix=crpix, - crval=crval - ) - - else: - naxes = ref_wcs.output_frame.naxes - if naxes != 2: - msg = f"Output WCS needs 2 spatial axes \ - but the supplied WCS has {naxes} axes." - raise RuntimeError(msg) - output_wcs = deepcopy(ref_wcs) - if shape is not None: - output_wcs.array_shape = shape - - # Check that the output data shape has no zero length dimensions - if not np.prod(output_wcs.array_shape): - msg = f"Invalid output frame shape: {tuple(output_wcs.array_shape)}" - raise ValueError(msg) - - return output_wcs +# def make_output_wcs(input_models, ref_wcs=None, +# pscale_ratio=None, pscale=None, rotation=None, shape=None, +# crpix=None, crval=None): +# """Generate output WCS here based on footprints of all input WCS objects. + +# Parameters +# ---------- +# input_models : list of `DataModel objects` +# Each datamodel must have a ~gwcs.WCS object. + +# pscale_ratio : float, optional +# Ratio of input to output pixel scale. Ignored when ``pscale`` +# is provided. + +# pscale : float, None, optional +# Absolute pixel scale in degrees. When provided, overrides +# ``pscale_ratio``. + +# rotation : float, None, optional +# Position angle of output image Y-axis relative to North. +# A value of 0.0 would orient the final output image to be North up. +# The default of `None` specifies that the images will not be rotated, +# but will instead be resampled in the default orientation for the camera +# with the x and y axes of the resampled image corresponding +# approximately to the detector axes. + +# shape : tuple of int, None, optional +# Shape of the image (data array) using ``numpy.ndarray`` convention +# (``ny`` first and ``nx`` second). This value will be assigned to +# ``pixel_shape`` and ``array_shape`` properties of the returned +# WCS object. + +# crpix : tuple of float, None, optional +# Position of the reference pixel in the image array. If ``crpix`` is not +# specified, it will be set to the center of the bounding box of the +# returned WCS object. + +# crval : tuple of float, None, optional +# Right ascension and declination of the reference pixel. Automatically +# computed if not provided. + +# Returns +# ------- +# output_wcs : object +# WCS object, with defined domain, covering entire set of input frames +# """ +# if ref_wcs is None: +# wcslist = [i.meta.wcs for i in input_models] +# for w, i in zip(wcslist, input_models): +# if w.bounding_box is None: +# w.bounding_box = util.wcs_bbox_from_shape(i.data.shape) +# naxes = wcslist[0].output_frame.naxes + +# if naxes != 2: +# msg = f"Output WCS needs 2 spatial axes \ +# but the supplied WCS has {naxes} axes." +# raise RuntimeError(msg) + +# output_wcs = util.wcs_from_footprints( +# input_models, +# pscale_ratio=pscale_ratio, +# pscale=pscale, +# rotation=rotation, +# shape=shape, +# crpix=crpix, +# crval=crval +# ) + +# else: +# naxes = ref_wcs.output_frame.naxes +# if naxes != 2: +# msg = f"Output WCS needs 2 spatial axes \ +# but the supplied WCS has {naxes} axes." +# raise RuntimeError(msg) +# output_wcs = deepcopy(ref_wcs) +# if shape is not None: +# output_wcs.array_shape = shape + +# # Check that the output data shape has no zero length dimensions +# if not np.prod(output_wcs.array_shape): +# msg = f"Invalid output frame shape: {tuple(output_wcs.array_shape)}" +# raise ValueError(msg) + +# return output_wcs diff --git a/src/stcal/alignment/util.py b/src/stcal/alignment/util.py index 27fbc6ce..00121ec9 100644 --- a/src/stcal/alignment/util.py +++ b/src/stcal/alignment/util.py @@ -584,7 +584,6 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, Parameters ---------- - datamodel : `ImageModel` The input data model for imaging or WFSS mode whose ``meta.wcsinfo`` field should be updated from GWCS. By default, ``datamodel.meta.wcs`` @@ -687,7 +686,6 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, Notes ----- - Use of this requires a judicious choice of required accuracies. Attempts to use higher degrees (~7 or higher) will typically fail due to floating point problems that arise with high powers. diff --git a/src/stcal/tweakreg/astrometric_utils.py b/src/stcal/tweakreg/astrometric_utils.py index eb8d772f..45125060 100644 --- a/src/stcal/tweakreg/astrometric_utils.py +++ b/src/stcal/tweakreg/astrometric_utils.py @@ -5,9 +5,8 @@ from astropy import units as u from astropy.coordinates import SkyCoord from astropy.table import Table -from astropy.time import Time -from stcal.alignment import compute_fiducial, resample_utils +from stcal.alignment import compute_fiducial ASTROMETRIC_CAT_ENVVAR = "ASTROMETRIC_CATALOG_URL" DEF_CAT_URL = "http://gsss.stsci.edu/webservices" @@ -31,20 +30,27 @@ def create_astrometric_catalog( - input_models, + wcs, + epoch, catalog="GAIADR3", output="ref_cat.ecsv", gaia_only=False, table_format="ascii.ecsv", - existing_wcs=None, - num_sources=None, - epoch=None): + num_sources=None): """Create an astrometric catalog that covers the inputs' field-of-view. Parameters ---------- - input_models : list of `~jwst.datamodel.JwstDataModel` - Each datamodel must have a ~gwcs.WCS object. + wcs : `~astropy.wcs.WCS` + WCS object specified by the user as generated by + `resample.resample_utils.make_output_wcs`. This will typically + have the same plate-scale and orientation as the first member in the + list of input images to make_output_wcs. Fortunately, for alignment, + this doesn't matter since no resampling of data will be performed. + + epoch : float + Reference epoch used to update the coordinates for proper motion + (in decimal year). catalog : str, optional Name of catalog to extract astrometric positions for sources in the @@ -58,19 +64,12 @@ def create_astrometric_catalog( gaia_only : bool, optional Specify whether or not to only use sources from GAIA in output catalog - existing_wcs : model - existing WCS object specified by the user as generated by - `resample.resample_utils.make_output_wcs` - num_sources : int Maximum number of brightest/faintest sources to return in catalog. If `num_sources` is negative, return that number of the faintest sources. By default, all sources are returned. - epoch : float, optional - Reference epoch used to update the coordinates for proper motion - (in decimal year). If `None` then the epoch is obtained from - the metadata. + Notes ----- @@ -83,20 +82,9 @@ def create_astrometric_catalog( Astropy Table object of the catalog """ # start by creating a composite field-of-view for all inputs - # This default output WCS will have the same plate-scale and orientation - # as the first member in the list. - # Fortunately, for alignment, this doesn't matter since no resampling of - # data will be performed. - outwcs = existing_wcs if existing_wcs is not None \ - else resample_utils.make_output_wcs(input_models) - radius, fiducial = compute_radius(outwcs) + radius, fiducial = compute_radius(wcs) # perform query for this field-of-view - epoch = ( - epoch - if epoch is not None - else Time(input_models[0].meta.observation.date).decimalyear - ) ref_dict = get_catalog( fiducial[0], fiducial[1], diff --git a/src/stcal/tweakreg/tweakreg.py b/src/stcal/tweakreg/tweakreg.py index adf270d7..f8227f14 100644 --- a/src/stcal/tweakreg/tweakreg.py +++ b/src/stcal/tweakreg/tweakreg.py @@ -8,7 +8,6 @@ from astropy.coordinates import SkyCoord from astropy.table import Table from astropy.time import Time -from jwst.datamodels import ModelContainer from stdatamodels import DataModel from tweakwcs.correctors import JWSTWCSCorrector from tweakwcs.imalign import align_wcs @@ -40,130 +39,80 @@ class TweakregError(BaseException): pass -def tweakreg(images: ModelContainer, catalogs: list[Table], - ref_cat: Table = None, - abs_refcat: str | None = None, - save_abs_catalog: bool = False, - abs_catalog_output_dir: str | None = None, - searchrad: float = 2.0, - abs_searchrad: float = 6.0, - separation: float = 1.0, - abs_separation: float = 1.0, - use2dhist: bool = True, - abs_use2dhist: bool = True, - tolerance: float = 0.7, - abs_tolerance: float = 0.7, - xoffset: float = 0.0, - yoffset: float = 0.0, - enforce_user_order: bool = False, - expand_refcat: bool = False, - minobj: int = 15, - abs_minobj: int = 15, - fitgeometry: str = "rshift", - abs_fitgeometry: str = "rshift", - nclip: int = 3, - abs_nclip: int = 3, - abs_sigma: float = 3.0, - sigma: float = 3.0, - sip_approx: bool = True, - sip_max_pix_error: float = 0.01, - sip_degree: int | None = None, - sip_max_inv_pix_error: float = 0.01, - sip_inv_degree: int | None = None, - sip_npoints: int = 12,) -> ModelContainer: - """ - whatever. +def relative_align(correctors: list[JWSTWCSCorrector], + searchrad: float = 2.0, + separation: float = 1.0, + use2dhist: bool = True, + tolerance: float = 0.7, + xoffset: float = 0.0, + yoffset: float = 0.0, + enforce_user_order: bool = False, + expand_refcat: bool = False, + minobj: int = 15, + fitgeometry: str = "rshift", + nclip: int = 3, + sigma: float = 3.0, + align_to_abs_refcat: bool = False, + ) -> tuple[list[JWSTWCSCorrector], bool]: - Parameters - ---------- - ref_cat: only required when align_to_abs_refcat is True - """ - # perform some input validations if separation <= _SQRT2 * tolerance: msg = "Parameter 'separation' must be larger than 'tolerance' by at \ least a factor of sqrt(2) to avoid source confusion." raise TweakregError(msg) - if abs_separation <= _SQRT2 * abs_tolerance: - msg = "Parameter 'abs_separation' must be larger than 'abs_tolerance' by at \ - least a factor of sqrt(2) to avoid source confusion." - raise TweakregError(msg) + # align images: + xyxymatch = XYXYMatch( + searchrad=searchrad, + separation=separation, + use2dhist=use2dhist, + tolerance=tolerance, + xoffset=xoffset, + yoffset=yoffset + ) - if len(images) == 0: - msg = "Input must contain at least one image model." - raise ValueError(msg) - - if abs_refcat is None: - align_to_abs_refcat = False - - n_groups = len(images.group_names) - - # pre-allocate correctors (same length and order as images) - correctors = [None] * len(images) - for (model_index, image_model) in enumerate(images): - catalog = _filter_catalog_by_bounding_box( - catalogs[model_index], - image_model.meta.wcs.bounding_box - ) - corrector = _construct_wcs_corrector(image_model, catalog) - correctors[model_index] = corrector - - # relative alignment of images to each other (if more than one group) - if n_groups > 1: - - # align images: - xyxymatch = XYXYMatch( - searchrad=searchrad, - separation=separation, - use2dhist=use2dhist, - tolerance=tolerance, - xoffset=xoffset, - yoffset=yoffset + try: + align_wcs( + correctors, + refcat=None, + enforce_user_order=enforce_user_order, + expand_refcat=expand_refcat, + minobj=minobj, + match=xyxymatch, + fitgeom=fitgeometry, + nclip=nclip, + sigma=(sigma, "rmse") ) - - try: - align_wcs( - correctors, - refcat=None, - enforce_user_order=enforce_user_order, - expand_refcat=expand_refcat, - minobj=minobj, - match=xyxymatch, - fitgeom=fitgeometry, - nclip=nclip, - sigma=(sigma, "rmse") - ) - local_align_failed = False - - except ValueError as e: - msg = e.args[0] - if (msg == "Too few input images (or groups of images) with " - "non-empty catalogs."): - local_align_failed = True - if not align_to_abs_refcat: - msg += "At least two exposures are required for image alignment. Nothing to do." - raise TweakregError(msg) from None - else: - raise - - except RuntimeError as e: - msg = e.args[0] - if msg.startswith("Number of output coordinates exceeded allocation"): - # we need at least two exposures to perform image alignment - msg += "Multiple sources within specified tolerance \ - matched to a single reference source. Try to \ - adjust 'tolerance' and/or 'separation' parameters." + local_align_failed = False + + except ValueError as e: + msg = e.args[0] + if (msg == "Too few input images (or groups of images) with " + "non-empty catalogs."): + local_align_failed = True + if not align_to_abs_refcat: + msg += "At least two exposures are required for image alignment. Nothing to do." raise TweakregError(msg) from None + else: raise - with warnings.catch_warnings(record=True) as w: - is_small = _is_wcs_correction_small(correctors, - use2dhist, - searchrad, - tolerance, - xoffset, - yoffset) - warning_msg = "".join([str(mess.message) for mess in w]) + except RuntimeError as e: + msg = e.args[0] + if msg.startswith("Number of output coordinates exceeded allocation"): + # we need at least two exposures to perform image alignment + msg += "Multiple sources within specified tolerance \ + matched to a single reference source. Try to \ + adjust 'tolerance' and/or 'separation' parameters." + raise TweakregError(msg) from None + raise + + with warnings.catch_warnings(record=True) as w: + is_small = _is_wcs_correction_small(correctors, + use2dhist, + searchrad, + tolerance, + xoffset, + yoffset) + warning_msg = "".join([str(mess.message) for mess in w]) if not local_align_failed and not is_small: if align_to_abs_refcat: warning_msg += " Skipping relative alignment (stage 1)..." @@ -171,140 +120,170 @@ def tweakreg(images: ModelContainer, catalogs: list[Table], else: raise TweakregError(warning_msg) - # absolute alignment to the reference catalog - # can (and does) occur after alignment between groups - if align_to_abs_refcat: - - ref_cat = _parse_refcat(abs_refcat, - images, - correctors, - save_abs_catalog=save_abs_catalog, - output_dir=abs_catalog_output_dir) - - # Check that there are enough GAIA sources for a reliable/valid fit - num_ref = len(ref_cat) - if num_ref < abs_minobj: - msg = f"Not enough sources ({num_ref}) in the reference catalog \ - for the single-group alignment step to perform a fit. \ - Skipping alignment to the input reference catalog!" - raise TweakregError(msg) - - # align images: - # Update to separation needed to prevent confusion of sources - # from overlapping images where centering is not consistent or - # for the possibility that errors still exist in relative overlap. - xyxymatch_gaia = XYXYMatch( - searchrad=abs_searchrad, - separation=abs_separation, - use2dhist=abs_use2dhist, - tolerance=abs_tolerance, - xoffset=0.0, - yoffset=0.0 - ) + return correctors, local_align_failed + + +def absolute_align(correctors: list[JWSTWCSCorrector], + abs_refcat: str, + ref_image: DataModel, + save_abs_catalog: bool = False, + abs_catalog_output_dir: str | None = None, + abs_searchrad: float = 6.0, + abs_separation: float = 1.0, + abs_use2dhist: bool = True, + abs_tolerance: float = 0.7, + abs_minobj: int = 15, + abs_fitgeometry: str = "rshift", + abs_nclip: int = 3, + abs_sigma: float = 3.0, + n_groups: int = 1, + local_align_failed: bool = False,) -> list[JWSTWCSCorrector]: - # Set group_id to same value so all get fit as one observation - # The assigned value, 987654, has been hard-coded to make it - # easy to recognize when alignment to GAIA was being performed - # as opposed to the group_id values used for relative alignment - # earlier in this step. - for corrector in correctors: - corrector.meta["group_id"] = 987654 - if ("fit_info" in corrector.meta and - "REFERENCE" in corrector.meta["fit_info"]["status"]): - del corrector.meta["fit_info"] - - # Perform fit - try: - align_wcs( - correctors, - refcat=ref_cat, - enforce_user_order=True, - expand_refcat=False, - minobj=abs_minobj, - match=xyxymatch_gaia, - fitgeom=abs_fitgeometry, - nclip=abs_nclip, - sigma=(abs_sigma, "rmse") - ) - except ValueError as e: - msg = e.args[0] - if (msg == "Too few input images (or groups of images) with " - "non-empty catalogs."): - msg += "At least one exposure is required to align images \ - to an absolute reference catalog. Alignment to an \ - absolute reference catalog will not be performed." - if local_align_failed or n_groups == 1: - msg += " Nothing to do. Skipping 'TweakRegStep'..." - raise TweakregError(msg) from None - warnings.warn(msg) - else: - raise e - - except RuntimeError as e: - msg = e.args[0] - if msg.startswith("Number of output coordinates exceeded allocation"): - # we need at least two exposures to perform image alignment - msg += "Multiple sources within specified tolerance \ - matched to a single reference source. Try to \ - adjust 'tolerance' and/or 'separation' parameters. \ - Alignment to an absolute reference catalog will \ - not be performed." - if local_align_failed or n_groups == 1: - msg += "Skipping 'TweakRegStep'..." - raise TweakregError(msg) from None - else: - warnings.warn(msg) - else: - raise e + if abs_separation <= _SQRT2 * abs_tolerance: + msg = "Parameter 'abs_separation' must be larger than 'abs_tolerance' by at \ + least a factor of sqrt(2) to avoid source confusion." + raise TweakregError(msg) + + ref_cat = _parse_refcat(abs_refcat, + ref_image, + correctors, + save_abs_catalog=save_abs_catalog, + output_dir=abs_catalog_output_dir) + + # Check that there are enough GAIA sources for a reliable/valid fit + num_ref = len(ref_cat) + if num_ref < abs_minobj: + msg = f"Not enough sources ({num_ref}) in the reference catalog \ + for the single-group alignment step to perform a fit. \ + Skipping alignment to the input reference catalog!" + raise TweakregError(msg) - # one final pass through all the models to update them based - # on the results of this step - for (image_model, corrector) in zip(images, correctors): - image_model.meta.cal_step.tweakreg = "COMPLETE" + # align images: + # Update to separation needed to prevent confusion of sources + # from overlapping images where centering is not consistent or + # for the possibility that errors still exist in relative overlap. + xyxymatch_gaia = XYXYMatch( + searchrad=abs_searchrad, + separation=abs_separation, + use2dhist=abs_use2dhist, + tolerance=abs_tolerance, + xoffset=0.0, + yoffset=0.0 + ) - # retrieve fit status and update wcs if fit is successful: + # Set group_id to same value so all get fit as one observation + # The assigned value, 987654, has been hard-coded to make it + # easy to recognize when alignment to GAIA was being performed + # as opposed to the group_id values used for relative alignment + # earlier in this step. + for corrector in correctors: + corrector.meta["group_id"] = 987654 if ("fit_info" in corrector.meta and - "SUCCESS" in corrector.meta["fit_info"]["status"]): + "REFERENCE" in corrector.meta["fit_info"]["status"]): + del corrector.meta["fit_info"] + + # Perform fit + try: + align_wcs( + correctors, + refcat=ref_cat, + enforce_user_order=True, + expand_refcat=False, + minobj=abs_minobj, + match=xyxymatch_gaia, + fitgeom=abs_fitgeometry, + nclip=abs_nclip, + sigma=(abs_sigma, "rmse") + ) + except ValueError as e: + msg = e.args[0] + if (msg == "Too few input images (or groups of images) with " + "non-empty catalogs."): + msg += "At least one exposure is required to align images \ + to an absolute reference catalog. Alignment to an \ + absolute reference catalog will not be performed." + if local_align_failed or n_groups == 1: + msg += " Nothing to do. Skipping 'TweakRegStep'..." + raise TweakregError(msg) from None + warnings.warn(msg) + else: + raise e + + except RuntimeError as e: + msg = e.args[0] + if msg.startswith("Number of output coordinates exceeded allocation"): + # we need at least two exposures to perform image alignment + msg += "Multiple sources within specified tolerance \ + matched to a single reference source. Try to \ + adjust 'tolerance' and/or 'separation' parameters. \ + Alignment to an absolute reference catalog will \ + not be performed." + if local_align_failed or n_groups == 1: + msg += "Skipping 'TweakRegStep'..." + raise TweakregError(msg) from None + else: + warnings.warn(msg) + else: + raise e + + return correctors + + +def apply_tweakreg_solution(image_model: DataModel, + corrector: JWSTWCSCorrector, + abs_refcat: str, + align_to_abs_refcat: bool = False, + sip_approx: bool = True, + sip_max_pix_error: float = 0.01, + sip_degree: int | None = None, + sip_max_inv_pix_error: float = 0.01, + sip_inv_degree: int | None = None, + sip_npoints: int = 12, + ) -> DataModel: + + # retrieve fit status and update wcs if fit is successful: + if ("fit_info" in corrector.meta and + "SUCCESS" in corrector.meta["fit_info"]["status"]): + + # Update/create the WCS .name attribute with information + # on this astrometric fit as the only record that it was + # successful: + if align_to_abs_refcat: + # NOTE: This .name attrib agreed upon by the JWST Cal + # Working Group. + # Current value is merely a place-holder based + # on HST conventions. This value should also be + # translated to the FITS WCSNAME keyword + # IF that is what gets recorded in the archive + # for end-user searches. + corrector.wcs.name = f"FIT-LVL3-{abs_refcat}" + + image_model.meta.wcs = corrector.wcs + update_s_region_imaging(image_model) + + # Also update FITS representation in input exposures for + # subsequent reprocessing by the end-user. + if sip_approx: + try: + update_fits_wcsinfo( + image_model, + max_pix_error=sip_max_pix_error, + degree=sip_degree, + max_inv_pix_error=sip_max_inv_pix_error, + inv_degree=sip_inv_degree, + npoints=sip_npoints, + crpix=None + ) + except (ValueError, RuntimeError) as e: + msg = f"Failed to update 'meta.wcsinfo' with FITS SIP \ + approximation. Reported error is: \n {e.args[0]}" + warnings.warn(msg) - # Update/create the WCS .name attribute with information - # on this astrometric fit as the only record that it was - # successful: - if align_to_abs_refcat: - # NOTE: This .name attrib agreed upon by the JWST Cal - # Working Group. - # Current value is merely a place-holder based - # on HST conventions. This value should also be - # translated to the FITS WCSNAME keyword - # IF that is what gets recorded in the archive - # for end-user searches. - corrector.wcs.name = f"FIT-LVL3-{abs_refcat}" - - image_model.meta.wcs = corrector.wcs - update_s_region_imaging(image_model) - - # Also update FITS representation in input exposures for - # subsequent reprocessing by the end-user. - if sip_approx: - try: - update_fits_wcsinfo( - image_model, - max_pix_error=sip_max_pix_error, - degree=sip_degree, - max_inv_pix_error=sip_max_inv_pix_error, - inv_degree=sip_inv_degree, - npoints=sip_npoints, - crpix=None - ) - except (ValueError, RuntimeError) as e: - msg = f"Failed to update 'meta.wcsinfo' with FITS SIP \ - approximation. Reported error is: \n {e.args[0]}" - warnings.warn(msg) - - return images + return image_model def _parse_refcat(abs_refcat: str, - images: ModelContainer, + ref_model: DataModel, correctors: list, save_abs_catalog: bool = False, output_dir: str | None = None) -> Table: @@ -324,7 +303,6 @@ def _parse_refcat(abs_refcat: str, abs_refcat = abs_refcat.strip() gaia_cat_name = abs_refcat.upper() if gaia_cat_name in SINGLE_GROUP_REFCAT: - ref_model = images[0] epoch = Time(ref_model.meta.observation.date).decimalyear @@ -337,11 +315,9 @@ def _parse_refcat(abs_refcat: str, ) return create_astrometric_catalog( - None, - gaia_cat_name, - existing_wcs=combined_wcs, + combined_wcs, epoch, + catalog=gaia_cat_name, output=output_name, - epoch=epoch, ) if Path.isfile(abs_refcat): @@ -405,6 +381,11 @@ def _construct_wcs_corrector(image_model: DataModel, pre-compute skycoord here so we can later use it to check for a small wcs correction. """ + catalog = _filter_catalog_by_bounding_box( + catalog, + image_model.meta.wcs.bounding_box + ) + wcs = image_model.meta.wcs refang = image_model.meta.wcsinfo.instance return JWSTWCSCorrector( diff --git a/tests/test_tweakreg.py b/tests/test_tweakreg.py index 7d1a1d9c..8747b113 100644 --- a/tests/test_tweakreg.py +++ b/tests/test_tweakreg.py @@ -8,12 +8,12 @@ from astropy.modeling.models import Shift from astropy.wcs import WCS from gwcs.wcstools import grid_from_bounding_box -from jwst.datamodels import ModelContainer from photutils.segmentation import SourceCatalog, SourceFinder from stdatamodels.jwst.datamodels import ImageModel from stcal.tweakreg import astrometric_utils as amutils -from stcal.tweakreg.tweakreg import _is_wcs_correction_small, _parse_refcat, _wcs_to_skycoord, tweakreg +from stcal.tweakreg.tweakreg import _is_wcs_correction_small, _parse_refcat, _wcs_to_skycoord, _construct_wcs_corrector, \ + relative_align, absolute_align, apply_tweakreg_solution from stcal.tweakreg.utils import _wcsinfo_from_wcs_transform # Define input GWCS specification to be used for these tests @@ -26,7 +26,6 @@ # something BKG_LEVEL = 0.001 N_EXAMPLE_SOURCES = 21 -N_CUSTOM_SOURCES = 15 @pytest.fixture(scope="module") @@ -58,11 +57,9 @@ def test_get_catalog(wcsobj): def test_create_catalog(wcsobj): # Create catalog gcat = amutils.create_astrometric_catalog( - None, - existing_wcs=wcsobj, + wcsobj, "2016.0", catalog=TEST_CATALOG, output=None, - epoch="2016.0", ) # check that we got expected number of sources assert len(gcat) == EXPECTED_NUM_SOURCES @@ -77,11 +74,9 @@ def test_create_catalog_graceful_failure(wcsobj): # Create catalog gcat = amutils.create_astrometric_catalog( - None, - existing_wcs=wcsobj, + wcsobj, "2016.0", catalog=TEST_CATALOG, output=None, - epoch="2016.0", ) # check that we got expected number of sources assert len(gcat) == 0 @@ -157,12 +152,8 @@ def example_input(example_wcs): [0.1, 0.6, 0.1], ] m0.meta.observation.date = "2019-01-01T00:00:00" - - m1 = m0.copy() - # give each a unique filename m0.meta.filename = "some_file_0.fits" - m1.meta.filename = "some_file_1.fits" - return ModelContainer([m0, m1]) + return m0 @pytest.mark.usefixtures("_jail") @@ -199,87 +190,65 @@ def make_source_catalog(data): @pytest.mark.parametrize("with_shift", [True, False]) -def test_tweakreg_main(example_input, with_shift): +def test_relative_align(example_input, with_shift): """ A simplified unit test for basic operation of the TweakRegStep when run with or without a small shift in the input image sources """ + shifted = example_input.copy() + shifted.meta.filename = "some_file_1.fits" if with_shift: # shift 9 pixels so that the sources in one of the 2 images # appear at different locations (resulting in a correct wcs update) - example_input[1].data[:-9] = example_input[1].data[9:] - example_input[1].data[-9:] = BKG_LEVEL + shifted.data[:-9] = example_input.data[9:] + shifted.data[-9:] = BKG_LEVEL # assign images to different groups (so they are aligned to each other) - example_input[0].meta.group_id = "a" - example_input[1].meta.group_id = "b" + example_input.meta.group_id = "a" + shifted.meta.group_id = "b" # create source catalogs - source_catalogs = [make_source_catalog(m.data) for m in example_input] + models = [example_input, shifted] + source_catalogs = [make_source_catalog(m.data) for m in models] - # run the step on the example input modified above - result = tweakreg(example_input, source_catalogs) + # construct correctors from the catalogs + correctors = [_construct_wcs_corrector(m, cat) for m, cat in zip(models, source_catalogs)] - # check that step completed - for model in result: - assert model.meta.cal_step.tweakreg == "COMPLETE" + # relative alignment of images to each other (if more than one group) + correctors, local_align_failed = relative_align(correctors) - # and that the wcses differ by a small amount due to the shift above - # by projecting one point through each wcs and comparing the difference - abs_delta = abs(result[1].meta.wcs(0, 0)[0] - result[0].meta.wcs(0, 0)[0]) - if with_shift: - assert abs_delta > 1E-5 - else: - assert abs_delta < 1E-12 + # update the wcs in the models + for (model, corrector) in zip(models, correctors): + apply_tweakreg_solution(model, corrector, TEST_CATALOG, + sip_approx=True, sip_degree=3, sip_max_pix_error=0.1, + sip_max_inv_pix_error=0.1, sip_inv_degree=3, + sip_npoints=12) -@pytest.mark.parametrize("with_shift", [True, False]) -def test_sip_approx(example_input, with_shift): - """ - Test the output FITS WCS. - """ - if with_shift: - # shift 9 pixels so that the sources in one of the 2 images - # appear at different locations (resulting in a correct wcs update) - example_input[1].data[:-9] = example_input[1].data[9:] - example_input[1].data[-9:] = BKG_LEVEL - - # assign images to different groups (so they are aligned to each other) - example_input[0].meta.group_id = "a" - example_input[1].meta.group_id = "b" - - # create source catalogs - source_catalogs = [make_source_catalog(m.data) for m in example_input] - - # run the step on the example input modified above - result = tweakreg(example_input, source_catalogs, - sip_approx=True, sip_degree=3, sip_max_pix_error=0.1, - sip_max_inv_pix_error=0.1, sip_inv_degree=3, - sip_npoints=12) - - # output wcs differs by a small amount due to the shift above: - # project one point through each wcs and compare the difference - abs_delta = abs(result[1].meta.wcs(0, 0)[0] - result[0].meta.wcs(0, 0)[0]) + # and that the wcses differ by a small amount due to the shift above + # by projecting one point through each wcs and comparing the difference + abs_delta = abs(models[1].meta.wcs(0, 0)[0] - models[0].meta.wcs(0, 0)[0]) if with_shift: assert abs_delta > 1E-5 else: assert abs_delta < 1E-12 + # also test SIP approximation keywords # the first wcs is identical to the input and # does not have SIP approximation keywords -- # they are normally set by assign_wcs - assert np.allclose(result[0].meta.wcs(0, 0)[0], example_input[0].meta.wcs(0, 0)[0]) + assert np.allclose(models[0].meta.wcs(0, 0)[0], example_input.meta.wcs(0, 0)[0]) for key in ["ap_order", "bp_order"]: - assert key not in result[0].meta.wcsinfo.instance + assert key not in models[0].meta.wcsinfo.instance # for the second, SIP approximation should be present for key in ["ap_order", "bp_order"]: - assert result[1].meta.wcsinfo.instance[key] == 3 + assert models[1].meta.wcsinfo.instance[key] == 3 # evaluate fits wcs and gwcs for the approximation, make sure they agree - wcs_info = result[1].meta.wcsinfo.instance - grid = grid_from_bounding_box(result[1].meta.wcs.bounding_box) - gwcs_ra, gwcs_dec = result[1].meta.wcs(*grid) + wcs_info = models[1].meta.wcsinfo.instance + grid = grid_from_bounding_box(models[1].meta.wcs.bounding_box) + gwcs_ra, gwcs_dec = models[1].meta.wcs(*grid) fits_wcs = WCS(wcs_info) fitswcs_res = fits_wcs.pixel_to_world(*grid)