diff --git a/jwst/resample/gwcs_drizzle.py b/jwst/resample/gwcs_drizzle.py new file mode 100644 index 0000000000..b2d5c25d6d --- /dev/null +++ b/jwst/resample/gwcs_drizzle.py @@ -0,0 +1,406 @@ +import numpy as np + +from drizzle import cdrizzle +from . import resample_utils + +import logging +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +class GWCSDrizzle: + """ + Combine images using the drizzle algorithm + """ + + def __init__(self, product, outwcs=None, wt_scl=None, + pixfrac=1.0, kernel="square", fillval="NAN"): + """ + Create a new Drizzle output object and set the drizzle parameters. + + Parameters + ---------- + + product : DataModel + A data model containing results from a previous run. The three + extensions SCI, WHT, and CTX contain the combined image, total counts + and image id bitmap, respectively. The WCS of the combined image is + also read from the SCI extension. + + outwcs : `gwcs.WCS` + The world coordinate system (WCS) of the resampled image. If not + provided, the WCS is taken from product. + + wt_scl : str, optional + How each input image should be scaled. The choices are `exptime`, + which scales each image by its exposure time, or `expsq`, which scales + each image by the exposure time squared. If not set, then each + input image is scaled by its own weight map. + + pixfrac : float, optional + The fraction of a pixel that the pixel flux is confined to. The + default value of 1 has the pixel flux evenly spread across the image. + A value of 0.5 confines it to half a pixel in the linear dimension, + so the flux is confined to a quarter of the pixel area when the square + kernel is used. + + kernel : str, optional + The name of the kernel used to combine the inputs. The choice of + kernel controls the distribution of flux over the kernel. The kernel + names are: "square", "gaussian", "point", "turbo", "lanczos2", + and "lanczos3". The square kernel is the default. + + fillval : str, optional + The value a pixel is set to in the output if the input image does + not overlap it. The default value of NAN sets NaN values. + """ + + # Initialize the object fields + self._product = product + self.outsci = None + self.outwht = None + self.outcon = None + self.uniqid = 0 + + if wt_scl is None: + self.wt_scl = "" + else: + self.wt_scl = wt_scl + self.kernel = kernel + self.fillval = fillval + self.pixfrac = pixfrac + + self.sciext = "SCI" + self.whtext = "WHT" + self.conext = "CON" + + out_units = "cps" + + self.outexptime = product.meta.exposure.measurement_time or 0.0 + + self.outsci = product.data + if outwcs: + self.outwcs = outwcs + else: + self.outwcs = product.meta.wcs + + self.outwht = product.wht + self.outcon = product.con + + if self.outcon.ndim == 2: + self.outcon = np.reshape(self.outcon, (1, + self.outcon.shape[0], + self.outcon.shape[1])) + elif self.outcon.ndim != 3: + raise ValueError("Drizzle context image has wrong dimensions: \ + {0}".format(product)) + + # Check field values + if not self.outwcs: + raise ValueError("Either an existing file or wcs must be supplied") + + if out_units == "counts": + np.divide(self.outsci, self.outexptime, self.outsci) + elif out_units != "cps": + raise ValueError("Illegal value for out_units: %s" % out_units) + + # Since the context array is dynamic, it must be re-assigned + # back to the product's `con` attribute. + @property + def outcon(self): + """Return the context array""" + return self._product.con + + @outcon.setter + def outcon(self, value): + """Set new context array""" + self._product.con = value + + def add_image(self, insci, inwcs, inwht=None, xmin=0, xmax=0, ymin=0, ymax=0, + expin=1.0, in_units="cps", wt_scl=1.0, iscale=1.0): + """ + Combine an input image with the output drizzled image. + + Instead of reading the parameters from a fits file, you can set + them by calling this lower level method. `Add_fits_file` calls + this method after doing its setup. + + Parameters + ---------- + + insci : array + A 2d numpy array containing the input image to be drizzled. + it is an error to not supply an image. + + inwcs : wcs + The world coordinate system of the input image. This is + used to convert the pixels to the output coordinate system. + + inwht : array, optional + A 2d numpy array containing the pixel by pixel weighting. + Must have the same dimensions as insci. If none is supplied, + the weighting is set to one. + + xmin : int, optional + This and the following three parameters set a bounding rectangle + on the input image. Only pixels on the input image inside this + rectangle will have their flux added to the output image. Xmin + sets the minimum value of the x dimension. The x dimension is the + dimension that varies quickest on the image. All four parameters + are zero based, counting starts at zero. + + xmax : int, optional + Sets the maximum value of the x dimension on the bounding box + of the input image. If ``xmax = 0``, no maximum will + be set in the x dimension (all pixels in a row of the input image + will be resampled). + + ymin : int, optional + Sets the minimum value in the y dimension on the bounding box. The + y dimension varies less rapidly than the x and represents the line + index on the input image. + + ymax : int, optional + Sets the maximum value in the y dimension. If ``ymax = 0``, + no maximum will be set in the y dimension (all pixels in a column + of the input image will be resampled). + + expin : float, optional + The exposure time of the input image, a positive number. The + exposure time is used to scale the image if the units are counts and + to scale the image weighting if the drizzle was initialized with + wt_scl equal to "exptime" or "expsq." + + in_units : str, optional + The units of the input image. The units can either be "counts" + or "cps" (counts per second.) If the value is counts, before using + the input image it is scaled by dividing it by the exposure time. + + wt_scl : float, optional + If drizzle was initialized with wt_scl left blank, this value will + set a scaling factor for the pixel weighting. If drizzle was + initialized with wt_scl set to "exptime" or "expsq", the exposure time + will be used to set the weight scaling and the value of this parameter + will be ignored. + + iscale : float, optional + A scale factor to be applied to pixel intensities of the + input image before resampling. + + """ + if self.wt_scl == "exptime": + wt_scl = expin + elif self.wt_scl == "expsq": + wt_scl = expin * expin + + wt_scl = 1.0 # hard-coded for JWST count-rate data + self.increment_id() + + dodrizzle(insci, inwcs, inwht, self.outwcs, self.outsci, self.outwht, + self.outcon, expin, in_units, wt_scl, uniqid=self.uniqid, + xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, + iscale=iscale, pixfrac=self.pixfrac, kernel=self.kernel, + fillval=self.fillval) + + def increment_id(self): + """ + Increment the id count and add a plane to the context image if needed + + Drizzle tracks which input images contribute to the output image + by setting a bit in the corresponding pixel in the context image. + The uniqid indicates which bit. So it must be incremented each time + a new image is added. Each plane in the context image can hold 32 bits, + so after each 32 images, a new plane is added to the context. + """ + + # Compute what plane of the context image this input would + # correspond to: + planeid = int(self.uniqid / 32) + + # Add a new plane to the context image if planeid overflows + + if self.outcon.shape[0] == planeid: + plane = np.zeros_like(self.outcon[0]) + plane = plane.reshape((1, plane.shape[0], plane.shape[1])) + self.outcon = np.concatenate((self.outcon, plane)) + + # Increment the id + self.uniqid += 1 + + +def dodrizzle(insci, input_wcs, inwht, output_wcs, outsci, outwht, outcon, + expin, in_units, wt_scl, uniqid=1, xmin=0, xmax=0, ymin=0, ymax=0, + iscale=1.0, pixfrac=1.0, kernel='square', fillval="NAN"): + """ + Low level routine for performing 'drizzle' operation on one image. + + Parameters + ---------- + + insci : 2d array + A 2d numpy array containing the input image to be drizzled. + + input_wcs : gwcs.WCS object + The world coordinate system of the input image. + + inwht : 2d array + A 2d numpy array containing the pixel by pixel weighting. + Must have the same dimensions as insci. If none is supplied, + the weighting is set to one. + + output_wcs : gwcs.WCS object + The world coordinate system of the output image. + + outsci : 2d array + A 2d numpy array containing the output image produced by + drizzling. On the first call it should be set to zero. + Subsequent calls it will hold the intermediate results + + outwht : 2d array + A 2d numpy array containing the output counts. On the first + call it should be set to zero. On subsequent calls it will + hold the intermediate results. + + outcon : 2d or 3d array, optional + A 2d or 3d numpy array holding a bitmap of which image was an input + for each output pixel. Should be integer zero on first call. + Subsequent calls hold intermediate results. + + expin : float + The exposure time of the input image, a positive number. The + exposure time is used to scale the image if the units are counts. + + in_units : str + The units of the input image. The units can either be "counts" + or "cps" (counts per second.) + + wt_scl : float + A scaling factor applied to the pixel by pixel weighting. + + uniqid : int, optional + The id number of the input image. Should be one the first time + this function is called and incremented by one on each subsequent + call. + + xmin : int, optional + This and the following three parameters set a bounding rectangle + on the input image. Only pixels on the input image inside this + rectangle will have their flux added to the output image. Xmin + sets the minimum value of the x dimension. The x dimension is the + dimension that varies quickest on the image. All four parameters + are zero based, counting starts at zero. + + xmax : int, optional + Sets the maximum value of the x dimension on the bounding box + of the input image. If ``xmax = 0``, no maximum will + be set in the x dimension (all pixels in a row of the input image + will be resampled). + + ymin : int, optional + Sets the minimum value in the y dimension on the bounding box. The + y dimension varies less rapidly than the x and represents the line + index on the input image. + + ymax : int, optional + Sets the maximum value in the y dimension. If ``ymax = 0``, + no maximum will be set in the y dimension (all pixels in a column + of the input image will be resampled). + + iscale : float, optional + A scale factor to be applied to pixel intensities of the + input image before resampling. + + pixfrac : float, optional + The fraction of a pixel that the pixel flux is confined to. The + default value of 1 has the pixel flux evenly spread across the image. + A value of 0.5 confines it to half a pixel in the linear dimension, + so the flux is confined to a quarter of the pixel area when the square + kernel is used. + + kernel: str, optional + The name of the kernel used to combine the input. The choice of + kernel controls the distribution of flux over the kernel. The kernel + names are: "square", "gaussian", "point", "turbo", "lanczos2", + and "lanczos3". The square kernel is the default. + + fillval: str, optional + The value a pixel is set to in the output if the input image does + not overlap it. The default value of NAN sets NaN values. + + Returns + ------- + A tuple with three values: a version string, the number of pixels + on the input image that do not overlap the output image, and the + number of complete lines on the input image that do not overlap the + output input image. + + """ + # Insure that the fillval parameter gets properly interpreted for use with tdriz + if resample_utils.is_blank(str(fillval)): + fillval = 'NAN' + else: + fillval = str(fillval) + + if in_units == 'cps': + expscale = 1.0 + else: + expscale = expin + + if insci.dtype > np.float32: + insci = insci.astype(np.float32) + + # Add input weight image if it was not passed in + if inwht is None: + inwht = np.ones_like(insci) + + if xmax is None or xmax == xmin: + xmax = insci.shape[1] + if ymax is None or ymax == ymin: + ymax = insci.shape[0] + + # Compute what plane of the context image this input would + # correspond to: + planeid = int((uniqid - 1) / 32) + + # Check if the context image has this many planes + if outcon.ndim == 3: + nplanes = outcon.shape[0] + elif outcon.ndim == 2: + nplanes = 1 + else: + nplanes = 0 + + if nplanes <= planeid: + raise IndexError("Not enough planes in drizzle context image") + + # Alias context image to the requested plane if 3d + if outcon.ndim == 3: + outcon = outcon[planeid] + + # Compute the mapping between the input and output pixel coordinates + # for use in drizzle.cdrizzle.tdriz + pixmap = resample_utils.calc_gwcs_pixmap(input_wcs, output_wcs, insci.shape) + # inwht[np.isnan(pixmap[:,:,0])] = 0. + + log.debug(f"Pixmap shape: {pixmap[:,:,0].shape}") + log.debug(f"Input Sci shape: {insci.shape}") + log.debug(f"Output Sci shape: {outsci.shape}") + + # Call 'drizzle' to perform image combination + log.info(f"Drizzling {insci.shape} --> {outsci.shape}") + + _vers, nmiss, nskip = cdrizzle.tdriz( + insci.astype(np.float32), inwht.astype(np.float32), pixmap, + outsci, outwht, outcon, + uniqid=uniqid, + xmin=xmin, xmax=xmax, + ymin=ymin, ymax=ymax, + scale=iscale, + pixfrac=pixfrac, + kernel=kernel, + in_units=in_units, + expscale=expscale, + wtscale=wt_scl, + fillstr=fillval + ) + return _vers, nmiss, nskip diff --git a/jwst/resample/resample.py b/jwst/resample/resample.py index 72b325cbd3..6b9cd58037 100644 --- a/jwst/resample/resample.py +++ b/jwst/resample/resample.py @@ -1,31 +1,313 @@ import logging -import os import warnings import json +import re +import os +from typing import Any import numpy as np import psutil from drizzle.resample import Drizzle from spherical_geometry.polygon import SphericalPolygon +from astropy.io import fits from stdatamodels.jwst import datamodels -from stdatamodels.jwst.library.basic_utils import bytes2human +from stcal.resample import LibModelAccessBase, Resample, OutputTooLargeError + +from stdatamodels.jwst.datamodels.dqflags import pixel +from stdatamodels.properties import ObjectNode from jwst.datamodels import ModelLibrary from jwst.associations.asn_from_list import asn_from_list from jwst.model_blender.blender import ModelBlender +from jwst.assign_wcs import util from jwst.resample import resample_utils log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) -__all__ = ["OutputTooLargeError", "ResampleData"] + +__all__ = [ + "ResampleData", + "OutputTooLargeError", + "LibModelAccess", + "ResampleImage", + "is_imaging_wcs", +] + + +class LibModelAccess(LibModelAccessBase): + attributes_path = { + "data": "data", + "dq": "dq", + "var_rnoise": "var_rnoise", + "var_poisson": "var_poisson", + "var_flat": "var_flat", + + "filename": "meta.filename", + "group_id": "meta.group_id", + "s_region": "meta.wcsinfo.s_region", + "wcsinfo": "meta.wcsinfo", + "wcs": "meta.wcs", + + "exposure_time": "meta.exposure.exposure_time", + "start_time": "meta.exposure.start_time", + "end_time": "meta.exposure.end_time", + "duration": "meta.exposure.duration", + "measurement_time": "meta.exposure.measurement_time", + "effective_exposure_time": "meta.exposure.effective_exposure_time", + "elapsed_exposure_time": "meta.exposure.elapsed_exposure_time", + + "pixelarea_steradians": "meta.photometry.pixelarea_steradians", + + "level": "meta.background.level", + "subtracted": "meta.background.subtracted", + + "weight_type": "meta.resample.weight_type", + "pointings": "meta.resample.pointings", + "n_coadds": "meta.resample.n_coadds", + + # spectroscopy-specific: + "instrument_name": "meta.instrument.name", + "exposure_type": "meta.exposure.type", + } + + def __new__(cls, *args, **kwargs): + assert set(cls.attributes_path).issuperset(cls.min_supported_attributes) + return super().__new__(cls) + + @classmethod + def get_model_attr_value(cls, model, attr_path): + """ Retrieve a single attribute from the data model. """ + fields = attr_path.strip().split(".") + while fields: + model = getattr(model, fields.pop(0)) + if isinstance(model, ObjectNode): + return model.instance + return model + + @classmethod + def get_model_attributes(cls, model, attributes=None, quiet=False): + """ Retrieve all attributes (data and meta) from the data model. """ + model_attrib = {} + if attributes is None: + attributes = cls.attributes_path + else: + attributes = {a: cls.attributes_path[a] for a in attributes} + + for k, v in attributes.items(): + try: + model_attrib[k] = cls.get_model_attr_value(model, v) + except AttributeError as e: + # TODO: add n_coadds to model's schema + if k == "n_coadds": + model_attrib["n_coadds"] = 0 + continue + if quiet: + continue + else: + raise e + + return model_attrib + + def __init__(self, model_library): + self._mlib = model_library + self.set_active_group(None) + + def iter_model(self, attributes=None): + with self._mlib: + for model in self._mlib: + model_attrib = self.get_model_attributes(model, attributes) + yield model_attrib, model + self._mlib.shelve(model) + + @property + def n_models(self): + return len(self._mlib) + + @property + def n_groups(self): + return len(self._mlib.group_indices) + + @property + def group_indices(self): + return self._mlib.group_indices + + @property + def asn(self): + return self._mlib.asn + + def set_active_group(self, group_id=None): + self._active_group = group_id + + +class ResampleImage(Resample): + dq_flag_name_map = pixel + + def __init__(self, input_models, *args, blendheaders=True, + output_model=None, **kwargs): + if output_model is None: + self.resampled_model = datamodels.ImageModel() + self._update_output_meta_with_first_model = True + else: + self.resampled_model = output_model + self._update_output_meta_with_first_model = False + # convert output_model to dictionary: + attributes = Resample.output_model_attributes( + accumulate=kwargs.get("accumulate", False), + enable_ctx=kwargs.get("enable_ctx", True), + enable_var=kwargs.get("enable_var", True), + ) + output_model = LibModelAccess.get_model_attributes( + output_model, + attributes=attributes, + ) + + if blendheaders: + self.blender = ModelBlender( + blend_ignore_attrs=[ + 'meta.photometry.pixelarea_steradians', + 'meta.photometry.pixelarea_arcsecsq', + 'meta.filename', + ] + ) + + super().__init__( + input_models, + *args, + output_model=output_model, + **kwargs + ) + + # initialize blendheaders if needed + self._blendheaders = blendheaders + + def add_model(self, model_info, image_model): + super().add_model(model_info, image_model) + if self._update_output_meta_with_first_model: + self.resampled_model.update(image_model) + self._update_output_meta_with_first_model = False + + # blend headers if needed: + if self._blendheaders: + self.blender.accumulate(image_model) + + def update_output_model_data(self): + # update data and meta for the output model: + # * arrays: + if self._blendheaders: + self.blender.finalize_model(self.resampled_model) + + self.resampled_model.data = self.output_model["data"] + self.resampled_model.wht = self.output_model["wht"] + + if self._enable_ctx: + self.resampled_model.con = self.output_model["con"] + + if self._enable_var: + self.resampled_model.var_rnoise = self.output_model["var_rnoise"] + self.resampled_model.var_poisson = self.output_model["var_poisson"] + self.resampled_model.var_flat = self.output_model["var_flat"] + self.resampled_model.err = self.output_model["err"] + + # * meta: + self.resampled_model.meta.wcs = self.output_model["wcs"] + self.resampled_model.meta.cal_step.resample = 'COMPLETE' + self.resampled_model.meta.resample.pixel_scale_ratio = self._pixel_scale_ratio + self.resampled_model.meta.resample.pixfrac = self.pixfrac + + if is_imaging_wcs(self.resampled_model.meta.wcs): + # only for an imaging WCS: + self.update_fits_wcsinfo(self.resampled_model) + util.update_s_region_imaging(self.resampled_model) + else: + util.update_s_region_spectral(self.resampled_model) + + self.resampled_model.meta.asn.pool_name = self._input_models.asn.get( + "pool_name", + None + ) + self.resampled_model.meta.asn.table_name = self._input_models.asn.get( + "table_name", + None + ) + + # Update some basic exposure time values based on output_model + self.resampled_model.meta.exposure.exposure_time = self.output_model["exposure_time"] + self.resampled_model.meta.exposure.start_time = self.output_model["start_time"] + self.resampled_model.meta.exposure.end_time = self.output_model["end_time"] + if "measurement_time" in self.output_model: + self.resampled_model.meta.exposure.measurement_time = self.output_model["measurement_time"] + + # Update other exposure time keywords: + # XPOSURE (identical to the total effective exposure time, EFFEXPTM) + xposure = self.output_model["exposure_time"] + self.resampled_model.meta.exposure.effective_exposure_time = xposure + # DURATION (identical to TELAPSE, elapsed time) + self.resampled_model.meta.exposure.duration = self.output_model["duration"] + self.resampled_model.meta.exposure.elapsed_exposure_time = self.output_model["duration"] + + # TODO: finalize blend headers if needed + + def run(self): + super().run() + self.update_output_model_data() + return self.resampled_model + + @staticmethod + def update_fits_wcsinfo(model): + """ Update FITS WCS keywords of the resampled image. """ + # Delete any SIP-related keywords first + pattern = r"^(cd[12]_[12]|[ab]p?_\d_\d|[ab]p?_order)$" + regex = re.compile(pattern) + + keys = list(model.meta.wcsinfo.instance.keys()) + for key in keys: + if regex.match(key): + del model.meta.wcsinfo.instance[key] + + # Write new PC-matrix-based WCS based on GWCS model + transform = model.meta.wcs.forward_transform + model.meta.wcsinfo.crpix1 = -transform[0].offset.value + 1 + model.meta.wcsinfo.crpix2 = -transform[1].offset.value + 1 + model.meta.wcsinfo.cdelt1 = transform[3].factor.value + model.meta.wcsinfo.cdelt2 = transform[4].factor.value + model.meta.wcsinfo.ra_ref = transform[6].lon.value + model.meta.wcsinfo.dec_ref = transform[6].lat.value + model.meta.wcsinfo.crval1 = model.meta.wcsinfo.ra_ref + model.meta.wcsinfo.crval2 = model.meta.wcsinfo.dec_ref + model.meta.wcsinfo.pc1_1 = transform[2].matrix.value[0][0] + model.meta.wcsinfo.pc1_2 = transform[2].matrix.value[0][1] + model.meta.wcsinfo.pc2_1 = transform[2].matrix.value[1][0] + model.meta.wcsinfo.pc2_2 = transform[2].matrix.value[1][1] + model.meta.wcsinfo.ctype1 = "RA---TAN" + model.meta.wcsinfo.ctype2 = "DEC--TAN" + + # Remove no longer relevant WCS keywords + rm_keys = [ + 'v2_ref', + 'v3_ref', + 'ra_ref', + 'dec_ref', + 'roll_ref', + 'v3yangle', + 'vparity', + ] + for key in rm_keys: + if key in model.meta.wcsinfo.instance: + del model.meta.wcsinfo.instance[key] -class OutputTooLargeError(RuntimeError): - """Raised when the output is too large for in-memory instantiation""" +def is_imaging_wcs(wcs): + imaging = all( + ax == 'SPATIAL' for ax in wcs.output_frame.axes_type + ) + return imaging +#################################################### +# Code below was left for spectral data for now # +#################################################### class ResampleData: """ diff --git a/jwst/resample/resample_spec_step.py b/jwst/resample/resample_spec_step.py index 7aa406091b..97adc9a27c 100755 --- a/jwst/resample/resample_spec_step.py +++ b/jwst/resample/resample_spec_step.py @@ -6,11 +6,12 @@ from jwst.datamodels import ModelContainer, ModelLibrary from jwst.lib.pipe_utils import match_nans_and_flags from jwst.lib.wcs_utils import get_wavelengths +from stcal.resample.utils import load_custom_wcs -from . import resample_spec, ResampleStep -from ..exp_to_source import multislit_to_container -from ..assign_wcs.util import update_s_region_spectral -from ..stpipe import Step +from jwst.resample import resample_spec, ResampleStep +from jwst.exp_to_source import multislit_to_container +from jwst.assign_wcs.util import update_s_region_spectral +from jwst.stpipe import Step # Force use of all DQ flagged data except for DO_NOT_USE and NON_SCIENCE @@ -39,9 +40,6 @@ class ResampleSpecStep(Step): pixel_scale_ratio = float(default=1.0) # Ratio of input to output spatial pixel scale pixel_scale = float(default=None) # Spatial pixel scale in arcsec output_wcs = string(default='') # Custom output WCS - single = boolean(default=False) # Resample each input to its own output grid - blendheaders = boolean(default=True) # Blend metadata from inputs into output - in_memory = boolean(default=True) # Keep images in memory """ def process(self, input): @@ -173,9 +171,6 @@ def get_drizpars(self): fillval=self.fillval, wht_type=self.weight_type, good_bits=GOOD_BITS, - single=self.single, - blendheaders=self.blendheaders, - in_memory=self.in_memory ) # Custom output WCS parameters @@ -184,8 +179,10 @@ def get_drizpars(self): 'output_shape', min_vals=[1, 1] ) - kwargs['output_wcs'] = ResampleStep.load_custom_wcs( - self.output_wcs, kwargs['output_shape']) + kwargs['output_wcs'] = load_custom_wcs( + self.output_wcs, + kwargs['output_shape'] + ) kwargs['pscale'] = self.pixel_scale kwargs['pscale_ratio'] = self.pixel_scale_ratio diff --git a/jwst/resample/resample_step.py b/jwst/resample/resample_step.py index 4132850918..a744daf7bf 100755 --- a/jwst/resample/resample_step.py +++ b/jwst/resample/resample_step.py @@ -1,24 +1,32 @@ +import json import logging +import os import re -from copy import deepcopy -import asdf - -from jwst.datamodels import ModelLibrary, ImageModel # type: ignore[attr-defined] -from jwst.lib.pipe_utils import match_nans_and_flags +from jwst.datamodels import ModelLibrary, ImageModel +import gwcs +from stcal.resample import resampled_wcs_from_models, OutputTooLargeError +from stcal.resample.utils import load_custom_wcs from . import resample +from ..associations.asn_from_list import asn_from_list from ..stpipe import Step -from ..assign_wcs import util + log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) -__all__ = ["ResampleStep"] - +__all__ = ["ResampleStep", "MissingFileName"] # Force use of all DQ flagged data except for DO_NOT_USE and NON_SCIENCE GOOD_BITS = '~DO_NOT_USE+NON_SCIENCE' +_OUPUT_EXT = ".fits" + + +class MissingFileName(ValueError): + """ Raised when in_memory is False but no output file name has been + provided. + """ class ResampleStep(Step): @@ -59,71 +67,186 @@ class ResampleStep(Step): reference_file_types: list = [] def process(self, input): - - if isinstance(input, ModelLibrary): - input_models = input - elif isinstance(input, (str, dict, list)): - input_models = ModelLibrary(input, on_disk=not self.in_memory) - elif isinstance(input, ImageModel): - input_models = ModelLibrary([input], on_disk=not self.in_memory) - output = input.meta.filename + output_file_name = None + + if isinstance(input, (str, dict, list)): + input = ModelLibrary(input, on_disk=not self.in_memory) + elif isinstance(input, ImageModel): # TODO: do we need to support this? + input = ModelLibrary([input], on_disk=not self.in_memory) + # TODO: I don't get the purpose of this: + output_file_name = input.meta.filename # <-- ????? self.blendheaders = False + elif not isinstance(input, ModelLibrary): + raise RuntimeError(f"Input {repr(input)} is not a 2D image.") + + input_models = resample.LibModelAccess(input) + + # try to figure output file name. + # TODO: review whether this is the intended action - not sure this + # code reproduces what's currently in the pipeline but also, + # not sure that code makes sense. + if output_file_name is not None: + output_file_name = output_file_name.strip() + + if output_file_name and output_file_name.endswith(_OUPUT_EXT): + self._output_dir = '' + self._output_file_name = output_file_name else: - raise RuntimeError(f"Input {input} is not a 2D image.") + self._output_dir = output_file_name + self._output_file_name = None try: - output = input_models.asn["products"][0]["name"] + output_file_name = input_models.asn["products"][0]["name"] except KeyError: # coron data goes through this path by the time it gets to # resampling. # TODO: figure out why and make sure asn_table is carried along - output = None - - # Check that input models are 2D images - with input_models: - example_model = input_models.borrow(0) - data_shape = example_model.data.shape - input_models.shelve(example_model, 0, modify=False) - if len(data_shape) != 2: - # resample can only handle 2D images, not 3D cubes, etc - raise RuntimeError(f"Input {example_model} is not a 2D image.") - del example_model - - # Make sure all input models have consistent NaN and DO_NOT_USE values - for model in input_models: - match_nans_and_flags(model) - input_models.shelve(model) - del model + pass + + resampled_models = [] # Setup drizzle-related parameters kwargs = self.get_drizpars() - # Call the resampling routine - resamp = resample.ResampleData(input_models, output=output, **kwargs) - result = resamp.do_drizzle(input_models) - - with result: - for model in result: - model.meta.cal_step.resample = 'COMPLETE' - self.update_fits_wcs(model) - util.update_s_region_imaging(model) - - # if pixel_scale exists, it will override pixel_scale_ratio. - # calculate the actual value of pixel_scale_ratio based on pixel_scale - # because source_catalog uses this value from the header. - if self.pixel_scale is None: - model.meta.resample.pixel_scale_ratio = self.pixel_scale_ratio + if self.single: + output_suffix = "outlier_i2d" + + # define output WCS if needed using + if kwargs.pop("output_wcs", None) is None: + output_shape = (None if self.output_shape is None + else self.output_shape[::-1]) + kwargs["output_wcs"], *_ = resampled_wcs_from_models( + input_models, + pixel_scale_ratio=self.pixel_scale_ratio, + pixel_scale=self.pixel_scale, + output_shape=output_shape, + rotation=self.rotation, + crpix=self.crpix, + crval=self.crval, + ) + + group_ids = input_models.group_indices + + # Call the resampling routine for each group of images + for group_id in group_ids: + input_models.set_active_group(group_id) + log.info(f"Resampling images in group {group_id}") + + try: + resampler = resample.ResampleImage( + input_models, + enable_ctx=False, + enable_var=False, + **kwargs, + ) + model = resampler.run() + + except OutputTooLargeError as e: + log.error("Not enough available memory for resample.") + log.error(e.msg) + return input + + except Exception as e: + log.error( + "The following exception occured while resampling." + ) + log.error(e.msg) + return input + + # output file name for the resampled model: + while resampler.input_file_names: + ref_file_name = resampler.input_file_names.pop() + if ref_file_name: + output_file_name = self.resampled_file_name_from_input( + ref_file_name, + suffix=output_suffix, + ) + model.meta.filename = output_file_name + break + else: + ref_file_name = None + if not self.in_memory: + raise MissingFileName( + "Unable to determine output file name which is " + "required when in_memory=False." + ) + + if self.in_memory: + resampled_models.append(model) else: - model.meta.resample.pixel_scale_ratio = resamp.pscale_ratio - model.meta.resample.pixfrac = kwargs['pixfrac'] - result.shelve(model) + # save model to file and append its file name to the output + # list of resampled models: + model.save( + os.path.join(self._output_dir, model.meta.filename) + ) + resampled_models.append(model.meta.filename) + log.info( + f"Resampled image model saved to {model.meta.filename}" + ) + + resampled_models.append(model.meta.filename) + del model - if len(result) == 1: - model = result.borrow(0) - result.shelve(model, 0, modify=False) - return model + else: + if not self.in_memory and output_file_name is None: + raise MissingFileName( + "Unable to determine output file name which is " + "required when in_memory=False." + ) + + if output_file_name and not output_file_name.endswith(_OUPUT_EXT): + sep = '' if output_file_name[-1] == '_' else '_' + output_file_name = output_file_name + f"{sep}i2d{_OUPUT_EXT}" + + try: + resampler = resample.ResampleImage( + input_models, + enable_ctx=True, + enable_var=True, + **kwargs, + ) + model = resampler.run() + except OutputTooLargeError as e: + log.error("Not enough available memory for resample.") + log.error(e.msg) + return input + except Exception as e: + log.error("The following exception occured while resampling.") + log.error(e.msg) + return input + + model.meta.filename = output_file_name + + if self.in_memory: + resampled_models.append(model) + else: + model.save(model.meta.filename) + resampled_models.append(model.meta.filename) + log.info( + f"Resampled image model saved to {model.meta.filename}" + ) - return result + del model + + # make a ModelLibrary obj and save it to asn if requested: + if self.in_memory: + return ModelLibrary(resampled_models, on_disk=False) + else: + # build ModelLibrary as an association from the output file names + asn = asn_from_list(resampled_models, product_name=output_file_name) + # serializes the asn and converts to dict + asn_dict = json.loads(asn.dump()[1]) + return ModelLibrary(asn_dict, on_disk=True) + + def resampled_file_name_from_input(self, input_file_name, suffix): + """ Form output file name from input image name """ + indx = input_file_name.rfind('.') + output_type = input_file_name[indx:] + output_root = '_'.join( + input_file_name.replace(output_type, '').split('_')[:-1] + ) + output_file_name = f'{output_root}_{suffix}{output_type}' + return output_file_name @staticmethod def check_list_pars(vals, name, min_vals=None): @@ -164,63 +287,6 @@ def check_list_pars(vals, name, min_vals=None): else: raise ValueError(f"Both '{name}' values must be either None or not None.") - @staticmethod - def load_custom_wcs(asdf_wcs_file, output_shape=None): - """ - Load a custom output WCS from an ASDF file. - - Parameters - ---------- - asdf_wcs_file : str - Path to an ASDF file containing a GWCS structure. - output_shape : tuple of int, optional - Array shape for the output data. If not provided, - the custom WCS must specify one of: pixel_shape, - array_shape, or bounding_box. - - Returns - ------- - wcs : WCS - The output WCS to resample into. - """ - if not asdf_wcs_file: - return None - - with asdf.open(asdf_wcs_file) as af: - wcs = deepcopy(af.tree["wcs"]) - pixel_area = af.tree.get("pixel_area", None) - pixel_shape = af.tree.get("pixel_shape", None) - array_shape = af.tree.get("array_shape", None) - - if not hasattr(wcs, "pixel_area") or wcs.pixel_area is None: - wcs.pixel_area = pixel_area - if not hasattr(wcs, "pixel_shape") or wcs.pixel_shape is None: - wcs.pixel_shape = pixel_shape - if not hasattr(wcs, "array_shape") or wcs.array_shape is None: - wcs.array_shape = array_shape - - if output_shape is not None: - wcs.array_shape = output_shape[::-1] - wcs.pixel_shape = output_shape - elif wcs.pixel_shape is not None: - wcs.array_shape = wcs.pixel_shape[::-1] - elif wcs.array_shape is not None: - wcs.pixel_shape = wcs.array_shape[::-1] - elif wcs.bounding_box is not None: - wcs.array_shape = tuple( - int(axs[1] + 0.5) - for axs in wcs.bounding_box.bounding_box(order="C") - ) - wcs.pixel_shape = wcs.array_shape[::-1] - else: - raise ValueError( - "Step argument 'output_shape' is required when custom WCS " - "does not have 'array_shape', 'pixel_shape', or " - "'bounding_box' attributes set." - ) - - return wcs - def get_drizpars(self): """ Load all drizzle-related parameter values into kwargs list. @@ -232,27 +298,34 @@ def get_drizpars(self): fillval=self.fillval, wht_type=self.weight_type, good_bits=GOOD_BITS, - single=self.single, blendheaders=self.blendheaders, allowed_memory=self.allowed_memory, - in_memory=self.in_memory ) # Custom output WCS parameters. - kwargs['output_shape'] = self.check_list_pars( - self.output_shape, - 'output_shape', - min_vals=[1, 1] - ) - kwargs['output_wcs'] = self.load_custom_wcs( - self.output_wcs, - kwargs['output_shape'] - ) - kwargs['crpix'] = self.check_list_pars(self.crpix, 'crpix') - kwargs['crval'] = self.check_list_pars(self.crval, 'crval') - kwargs['rotation'] = self.rotation - kwargs['pscale'] = self.pixel_scale - kwargs['pscale_ratio'] = self.pixel_scale_ratio + wcs_pars = { + 'output_shape': self.check_list_pars( + self.output_shape, + 'output_shape', + min_vals=[1, 1] + ), + 'crpix': self.check_list_pars(self.crpix, 'crpix'), + 'crval': self.check_list_pars(self.crval, 'crval'), + 'rotation': self.rotation, + 'pixel_scale': self.pixel_scale, + 'pixel_scale_ratio': self.pixel_scale_ratio, + } + kwargs["wcs_pars"] = wcs_pars + if isinstance(self.output_wcs, str): + kwargs["output_wcs"] = load_custom_wcs( + self.output_wcs, + wcs_pars["output_shape"] + ) + elif isinstance(self.output_wcs, gwcs.WCS): + if self.output_shape is not None: + self.output_wcs.array_shape = self.output_shape[::-1] + self.output_wcs.pixel_shape = self.output_shape + kwargs["output_wcs"] = self.output_wcs # Report values to processing log for k, v in kwargs.items(): @@ -291,8 +364,8 @@ def update_fits_wcs(self, model): model.meta.wcsinfo.ctype2 = "DEC--TAN" # Remove no longer relevant WCS keywords - rm_keys = ['v2_ref', 'v3_ref', 'ra_ref', 'dec_ref', 'roll_ref', - 'v3yangle', 'vparity'] + rm_keys = ["v2_ref", "v3_ref", "ra_ref", "dec_ref", "roll_ref", + "v3yangle", "vparity"] for key in rm_keys: if key in model.meta.wcsinfo.instance: del model.meta.wcsinfo.instance[key] diff --git a/jwst/resample/resample_utils.py b/jwst/resample/resample_utils.py index c527b450f1..2d9117f92f 100644 --- a/jwst/resample/resample_utils.py +++ b/jwst/resample/resample_utils.py @@ -17,7 +17,13 @@ log.setLevel(logging.DEBUG) -__all__ = ['decode_context'] +__all__ = ['decode_context', 'is_blank'] + + +def is_blank(val): + """ Determines whether or not a value is considered 'blank'. + """ + return val in [None, '', ' ', 'None', 'INDEF'] def make_output_wcs(input_models, ref_wcs=None, @@ -67,17 +73,18 @@ def make_output_wcs(input_models, ref_wcs=None, WCS object, with defined domain, covering entire set of input frames """ if ref_wcs is None: + wcslist = [] with input_models: - wcslist = [] - for i, model in enumerate(input_models): + for model in input_models: w = model.meta.wcs if w.bounding_box is None: w.bounding_box = wcs_bbox_from_shape(model.data.shape) + if not wcslist: + ref_wcsinfo = model.meta.wcsinfo.instance + ref_wcs = w + naxes = w.output_frame.naxes wcslist.append(w) - if i == 0: - example_model = model input_models.shelve(model) - naxes = wcslist[0].output_frame.naxes if naxes != 2: msg = ("Output WCS needs 2 spatial axes " @@ -87,7 +94,7 @@ def make_output_wcs(input_models, ref_wcs=None, output_wcs = util.wcs_from_footprints( wcslist, ref_wcs=wcslist[0], - ref_wcsinfo=example_model.meta.wcsinfo.instance, + ref_wcsinfo=ref_wcsinfo, pscale_ratio=pscale_ratio, pscale=pscale, rotation=rotation,