From f082418275d44f0d9187f2838e17abaac692c842 Mon Sep 17 00:00:00 2001 From: Mihai Cara Date: Thu, 8 Aug 2024 09:28:24 -0400 Subject: [PATCH 01/10] Move common resample code to stcal --- src/stcal/resample/__init__.py | 9 + src/stcal/resample/resample.py | 1122 ++++++++++++++++++++++++++++++++ src/stcal/resample/utils.py | 36 + 3 files changed, 1167 insertions(+) create mode 100644 src/stcal/resample/__init__.py create mode 100644 src/stcal/resample/resample.py create mode 100644 src/stcal/resample/utils.py diff --git a/src/stcal/resample/__init__.py b/src/stcal/resample/__init__.py new file mode 100644 index 000000000..1ae898af0 --- /dev/null +++ b/src/stcal/resample/__init__.py @@ -0,0 +1,9 @@ +from .resample import * + +__all__ = [ + "OutputTooLargeError", + "ResampleModelIO", + "ResampleBase", + "ResampleCoAdd", + "ResampleSingle" +] diff --git a/src/stcal/resample/resample.py b/src/stcal/resample/resample.py new file mode 100644 index 000000000..af055ccaf --- /dev/null +++ b/src/stcal/resample/resample.py @@ -0,0 +1,1122 @@ +import logging +import os +import warnings +from copy import deepcopy +import sys +import abc +from pathlib import Path, PurePath + +import numpy as np +from scipy.ndimage import median_filter + +from drizzle.resample import Drizzle +from drizzle.utils import calc_pixmap + +import psutil +from spherical_geometry.polygon import SphericalPolygon + +from stdatamodels.dqflags import interpret_bit_flags +from stdatamodels.jwst.datamodels.dqflags import pixel + +from stdatamodels.jwst import datamodels +from stdatamodels.jwst.library.basic_utils import bytes2human + +from .utils import get_tmeasure, build_mask + + +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + +__all__ = [ + "OutputTooLargeError", + "ResampleModelIO", + "ResampleBase", + "ResampleCoAdd", + "ResampleSingle" +] + + +_SUPPORTED_CUSTOM_WCS_PARS = [ + 'pixel_scale_ratio', + 'pixel_scale', + 'output_shape', + 'crpix', + 'crval', + 'rotation', +] + + +def _resample_range(data_shape, bbox=None): + # Find range of input pixels to resample: + if bbox is None: + xmin = ymin = 0 + xmax = data_shape[1] - 1 + ymax = data_shape[0] - 1 + else: + ((x1, x2), (y1, y2)) = bbox + xmin = max(0, int(x1 + 0.5)) + ymin = max(0, int(y1 + 0.5)) + xmax = min(data_shape[1] - 1, int(x2 + 0.5)) + ymax = min(data_shape[0] - 1, int(y2 + 0.5)) + + return xmin, xmax, ymin, ymax + + +class ResampleModelIO(abc.ABC): + @abc.abstractmethod + def open_model(self, file_name): + ... + + @abc.abstractmethod + def get_model_meta(self, file_name, fields): + ... + + @abc.abstractmethod + def close_model(self, model): + ... + + @abc.abstractmethod + def save_model(self, model): + ... + + @abc.abstractmethod + def write_model(self, model, file_name): + ... + + @abc.abstractmethod + def new_model(self, image_shape=None, file_name=None): + """ Return a new model for the resampled output """ + ... + + +class OutputTooLargeError(RuntimeError): + """Raised when the output is too large for in-memory instantiation""" + + +def output_wcs_from_input_wcs(input_wcs_list, pixel_scale_ratio=1.0, + pixel_scale=None, output_shape=None, + crpix=None, crval=None, rotation=None): + # TODO: should be replaced with a version that lives in stcal and + # uses s_region + w = deepcopy(input_wcs_list[0]) # this is bad + return { + 'output_wcs': w, + 'pscale': np.rad2deg(np.sqrt(_compute_image_pixel_area(w))), + 'pscale_ratio': 1.0, + 'crpix': None + } + + +class ResampleBase(abc.ABC): + """ + This is the controlling routine for the resampling process. + + Notes + ----- + This routine performs the following operations:: + + 1. Extracts parameter settings from input model, such as pixfrac, + weight type, exposure time (if relevant), and kernel, and merges + them with any user-provided values. + 2. Creates output WCS based on input images and define mapping function + between all input arrays and the output array. + 3. Updates output data model with output arrays from drizzle, including + a record of metadata from all input models. + """ + resample_suffix = 'i2d' + resample_file_ext = '.fits' + + def __init__(self, input_models, + pixfrac=1.0, kernel="square", fillval=0.0, wht_type="ivm", + good_bits=0, output_wcs=None, wcs_pars=None, + in_memory=True, allowed_memory=None, **kwargs): + """ + Parameters + ---------- + input_models : list of objects + list of data models, one for each input image + + output : str + filename for output + + kwargs : dict + Other parameters. + + .. note:: + ``output_shape`` is in the ``x, y`` order. + + .. note:: + ``in_memory`` controls whether or not the resampled + array from ``resample_many_to_many()`` + should be kept in memory or written out to disk and + deleted from memory. Default value is `True` to keep + all products in memory. + """ + self._output_model = None + self._output_filename = None + self._output_wcs = None + self._output_array_shape = None + self._close_output = False + self._output_pixel_scale = None + self._template_output_model = None + + # input models + self._input_models = input_models + # a lightweight data model with meta from first input model but no data. + # it will be updated by 'load_input_meta()' below + self._first_model_meta = None + + # resample parameters + self.pixfrac = pixfrac + self.kernel = kernel + self.fillval = fillval + self.weight_type = wht_type + self.good_bits = good_bits + self.in_memory = in_memory + + self._user_output_wcs = output_wcs + + # check wcs_pars has supported keywords: + if wcs_pars is None: + wcs_pars = {} + elif wcs_pars: + unsup = [] + unsup = set(wcs_pars.keys()).difference(_SUPPORTED_CUSTOM_WCS_PARS) + if unsup: + raise KeyError( + "Unsupported custom WCS parameters: " + f"{','.join(map(repr, unsup))}." + ) + # WCS parameters (should be deleted once not needed; + # once an output WCS was created) + self._wcs_pars = wcs_pars + + # process additional kwags specific to subclasses and store + # unprocessed/unrecognized kwargs in ukwargs and warn about these + # unrecognized kwargs + ukwargs = self.process_kwargs(kwargs) + self._warn_extra_kwargs(ukwargs) + + # load meta necessary for output WCS (and other) computations: + self.load_input_meta( + all=self._output_model is None and output_wcs is None + ) + + # computed average pixel scale of the first input image: + input_pscale0 = np.rad2deg( + np.sqrt(_compute_image_pixel_area(self._input_wcs_list[0])) + ) + + # compute output pixel scale, WCS, set-up output model + if self._output_model: + self._output_wcs = deepcopy(self._output_model.meta.wcs) + self._output_array_shape = self._output_model.data.shape + # TODO: extract any useful info from the output image before we close it: + # if meta has pixel scale, populate it from there, if not: + self._output_pixel_scale = np.rad2deg( + np.sqrt(_compute_image_pixel_area(self._output_wcs)) + ) + self._pixel_scale_ratio = self._output_pixel_scale / input_pscale0 + log.info(f'Computed output pixel scale: {self._output_pixel_scale} arcsec.') + + self._create_output_template_model() # create template before possibly closing output + if self._close_output and not self.in_memory: + self.close_model(self._output_model) + self._output_model = None + + elif output_wcs: + naxes = output_wcs.output_frame.naxes + if naxes != 2: + raise RuntimeError( + "Output WCS needs 2 spatial axes but the " + f"supplied WCS has {naxes} axes." + ) + self._output_wcs = deepcopy(output_wcs) + if wcs_pars and "output_shape" in wcs_pars: + self._output_array_shape = wcs_pars["output_shape"] + else: + self._output_array_shape = self._output_wcs.array_shape + if not self._output_array_shape and output_wcs.bounding_box: + halfpix = 0.5 + sys.float_info.epsilon + self._output_array_shape = ( + int(output_wcs.bounding_box[1][1] + halfpix), + int(output_wcs.bounding_box[0][1] + halfpix), + ) + else: + raise ValueError( + "Unable to infer output image size from provided inputs." + ) + self._output_wcs.array_shape = self._output_array_shape + + self._output_pixel_scale = np.rad2deg( + np.sqrt(_compute_image_pixel_area(self._output_wcs)) + ) + self._pixel_scale_ratio = self._output_pixel_scale / input_pscale0 + log.info(f'Computed output pixel scale: {self._output_pixel_scale} arcsec.') + self._create_output_template_model() + + else: + # build output WCS and calculate ouput image shape + if "pixel_scale" in wcs_pars and wcs_pars['pixel_scale'] is not None: + self._pixel_scale_ratio = wcs_pars["pixel_scale"] / input_pscale0 + log.info(f'Output pixel scale: {wcs_pars["pixel_scale"]} arcsec.') + log.info(f'Computed output pixel scale ratio: {self._pixel_scale_ratio}.') + else: + self._pixel_scale_ratio = wcs_pars.get("pixel_scale_ratio", 1.0) + log.info(f'Output pixel scale ratio: {self._pixel_scale_ratio}') + self._output_pixel_scale = input_pscale0 * self._pixel_scale_ratio + wcs_pars = wcs_pars.copy() + wcs_pars["pixel_scale"] = self._output_pixel_scale + log.info(f'Computed output pixel scale: {self._output_pixel_scale} arcsec.') + + w, ps = self._compute_output_wcs(**wcs_pars) + self._output_wcs = w + self._output_pixel_scale = ps + self._output_array_shape = self._output_wcs.array_shape + self._create_output_template_model() + + # Check that the output data shape has no zero length dimensions + npix = np.prod(self._output_array_shape) + if not npix: + raise ValueError( + f"Invalid output frame shape: {tuple(self._output_array_shape)}" + ) + + assert self._pixel_scale_ratio + log.info(f"Driz parameter kernel: {self.kernel}") + log.info(f"Driz parameter pixfrac: {self.pixfrac}") + log.info(f"Driz parameter fillval: {self.fillval}") + log.info(f"Driz parameter weight_type: {self.weight_type}") + + self.check_memory_requirements(allowed_memory) + + log.debug('Output mosaic size: {}'.format(self._output_wcs.pixel_shape)) + + @property + def output_model(self): + return self._output_model + + def process_kwargs(self, kwargs): + """ A method called by ``__init__`` to process input kwargs before + output WCS is created and before output model template is created. + + Returns + ------- + kwargs : dict + Unrecognized/not processed ``kwargs``. + + """ + return {k : v for k, v in kwargs.items()} + + def _warn_extra_kwargs(self, kwargs): + for k in kwargs: + log.warning(f"Unrecognized argument '{k}' will be ignored.") + + def check_memory_requirements(self, allowed_memory): + """ Called just before '_pre_run_callback()' is called to verify + that there is enough memory to hold the output. """ + if allowed_memory is None and "DMODEL_ALLOWED_MEMORY" not in os.environ: + return + + allowed_memory = float(allowed_memory) + # make a small image model to get the dtype + dtype = datamodels.ImageModel((1, 1)).data.dtype + + # get the available memory + available_memory = psutil.virtual_memory().available + psutil.swap_memory().total + + # compute the output array size + npix = npix = np.prod(self._output_array_shape) + nmodels = len(self._input_models) + nconpl = nmodels // 32 + (1 if nmodels % 32 else 0) + required_memory = npix * (3 * dtype.itemsize + nconpl * 4) + + # compare used to available + used_fraction = required_memory / available_memory + if used_fraction > allowed_memory: + raise OutputTooLargeError( + f'Combined ImageModel size {self._output_wcs.array_shape} ' + f'requires {bytes2human(required_memory)}. ' + f'Model cannot be instantiated.' + ) + + def _compute_output_wcs(self, **wcs_pars): + """ returns a diustortion-free WCS object and its pixel scale """ + owcs = output_wcs_from_input_wcs(self._input_wcs_list, **wcs_pars) + return owcs['output_wcs'], owcs['pscale'] + + def load_input_meta(self, all=True): + # if 'all=False', load meta from the first image only + + # set-up list for WCS + self._input_wcs_list = [] + self._input_s_region = [] + self._input_file_names = [] + self._close_input_models = [] + + for k, model in enumerate(self._input_models): + close_model = isinstance(model, str) + self._close_input_models.append(close_model) + if close_model: + self._input_file_names.append(model) + model = self.open_model(model) + if self.in_memory: + self._input_models[k] = model + else: + self._input_file_names.append(model.meta.filename) + # extract all info needed from *this* model: + w = deepcopy(model.meta.wcs) + w.array_shape = model.data.shape + self._input_wcs_list.append(w) + + # extract other useful data + # - S_REGION: + self._input_s_region.append(model.meta.wcsinfo.s_region) + + # store first model's entire meta (except for WCS and data): + if self._first_model_meta is None: + self._first_model_meta = self.new_model() + self._first_model_meta.update(model) + + if close_model and not self.in_memory: + self.close_model(model) + + if not all: + break + + def blend_output_metadata(self, output_model): + # TODO: Not sure about funct. signature and also I don't like it needs + # to open input files again + pass + + def build_driz_weight(self, model, weight_type=None, good_bits=None): + """Create a weight map for use by drizzle + """ + dqmask = build_mask(model.dq, good_bits) + + if weight_type and weight_type.startswith('ivm'): + weight_type = weight_type.strip() + selective_median = weight_type.startswith('ivm-smed') + + bitvalue = interpret_bit_flags(good_bits, mnemonic_map=pixel) + if bitvalue is None: + bitvalue = 0 + saturation = pixel['SATURATED'] + + if selective_median and not (bitvalue & saturation): + selective_median = False + weight_type = 'ivm' + + if (model.hasattr("var_rnoise") and model.var_rnoise is not None and + model.var_rnoise.shape == model.data.shape): + with np.errstate(divide="ignore", invalid="ignore"): + inv_variance = model.var_rnoise**-1 + + inv_variance[~np.isfinite(inv_variance)] = 1 + + if weight_type != 'ivm': + ny, nx = inv_variance.shape + + # apply a median filter to smooth the weight at saturated + # (or high read-out noise) single pixels. keep kernel size + # small to still give lower weight to extended CRs, etc. + ksz = weight_type[8 if selective_median else 7 :] + if ksz: + kernel_size = int(ksz) + if not (kernel_size % 2): + raise ValueError( + 'Kernel size of the median filter in IVM weighting' + ' must be an odd integer.' + ) + else: + kernel_size = 3 + + ivm_copy = inv_variance.copy() + + if selective_median: + # apply median filter selectively only at + # points of partially saturated sources: + jumps = np.where( + np.logical_and(model.dq & saturation, dqmask) + ) + w2 = kernel_size // 2 + for r, c in zip(*jumps): + x1 = max(0, c - w2) + x2 = min(nx, c + w2 + 1) + y1 = max(0, r - w2) + y2 = min(ny, r + w2 + 1) + data = ivm_copy[y1:y2, x1:x2][dqmask[y1:y2, x1:x2]] + if data.size: + inv_variance[r, c] = np.median(data) + # else: leave it as is + + else: + # apply median to the entire inv-var array: + inv_variance = median_filter( + inv_variance, + size=kernel_size + ) + bad_dqmask = np.logical_not(dqmask) + inv_variance[bad_dqmask] = ivm_copy[bad_dqmask] + + else: + warnings.warn( + "var_rnoise array not available. " + "Setting drizzle weight map to 1", + RuntimeWarning + ) + inv_variance = 1.0 + + result = inv_variance * dqmask + + elif weight_type == 'exptime': + exptime = model.meta.exposure.exposure_time + result = exptime * dqmask + + else: + result = np.ones(model.data.shape, dtype=model.data.dtype) * dqmask + + return result.astype(np.float32) + + @abc.abstractmethod + def run(self): + ... + + def _create_output_template_model(self): + pass + + def update_exposure_times(self): + """Modify exposure time metadata in-place""" + total_exposure_time = 0. + exptime_start = [] + exptime_end = [] + duration = 0.0 + total_exptime = 0.0 + measurement_time_success = [] + for exposure in self._input_models.models_grouped: + total_exposure_time += exposure[0].meta.exposure.exposure_time + t, success = get_tmeasure(exposure[0]) + measurement_time_success.append(success) + total_exptime += t + exptime_start.append(exposure[0].meta.exposure.start_time) + exptime_end.append(exposure[0].meta.exposure.end_time) + duration += exposure[0].meta.exposure.duration + + # Update some basic exposure time values based on output_model + self._output_model.meta.exposure.exposure_time = total_exposure_time + if not all(measurement_time_success): + self._output_model.meta.exposure.measurement_time = total_exptime + self._output_model.meta.exposure.start_time = min(exptime_start) + self._output_model.meta.exposure.end_time = max(exptime_end) + + # Update other exposure time keywords: + # XPOSURE (identical to the total effective exposure time, EFFEXPTM) + self._output_model.meta.exposure.effective_exposure_time = total_exptime + # DURATION (identical to TELAPSE, elapsed time) + self._output_model.meta.exposure.duration = duration + self._output_model.meta.exposure.elapsed_exposure_time = duration + + +class ResampleCoAdd(ResampleBase): + """ + This is the controlling routine for the resampling process. + + Notes + ----- + This routine performs the following operations:: + + 1. Extracts parameter settings from input model, such as pixfrac, + weight type, exposure time (if relevant), and kernel, and merges + them with any user-provided values. + 2. Creates output WCS based on input images and define mapping function + between all input arrays and the output array. + 3. Updates output data model with output arrays from drizzle, including + a record of metadata from all input models. + """ + + def __init__(self, input_models, output, accum=False, + pixfrac=1.0, kernel="square", fillval=0.0, wht_type="ivm", + good_bits=0, output_wcs=None, wcs_pars=None, + in_memory=True, allowed_memory=None): + """ + Parameters + ---------- + input_models : list of objects + list of data models, one for each input image + + output : DataModel, str + filename for output + + kwargs : dict + Other parameters. + + .. note:: + ``output_shape`` is in the ``x, y`` order. + + .. note:: + ``in_memory`` controls whether or not the resampled + array from ``resample_many_to_many()`` + should be kept in memory or written out to disk and + deleted from memory. Default value is `True` to keep + all products in memory. + """ + self._accum = accum + + super().__init__( + input_models, + pixfrac, kernel, fillval, wht_type, + good_bits, output_wcs, wcs_pars, + in_memory, allowed_memory, output=output + ) + + def process_kwargs(self, kwargs): + """ A method called by ``__init__`` to process input kwargs before + output WCS is created and before output model template is created. + """ + kwargs = super().process_kwargs(kwargs) + output = kwargs.pop("output", None) + accum = kwargs.pop("accum", False) + + # Load the model if accum is True + if isinstance(output, str): + self._output_filename = output + if accum: + try: + self._output_model = self.open_model(output) + self._close_output = True + log.info( + "Output model has been loaded and it will be used to " + "accumulate new data." + ) + if self._user_output_wcs: + log.info( + "'output_wcs' will be ignored when 'output' is " + "provided and accum=True" + ) + if self._wcs_pars: + log.info( + "'wcs_pars' will be ignored when 'output' is " + "provided and accum=True" + ) + except FileNotFoundError: + pass + + elif output is not None: + self._output_filename = output.meta.filename + self._output_model = output + self._close_output = False + + return kwargs + + def blend_output_metadata(self, output_model): + pass + + def extra_pre_resample_setup(self): + pass + + def post_process_resample_model(self, data_model, driz_init_kwargs, add_image_kwargs): + pass + + def finalize_resample(self): + pass + + def _create_new_output_model(self): + # this probably needs to be an abstract class. + # also this is mostly needed for "single" drizzle. + output_model = self.new_model(None) + + # update meta data and wcs + + # TODO: don't like this as it means reloading first image (again) + output_model.update(self._first_model_meta) + output_model.meta.wcs = deepcopy(self._output_wcs) + + pix_area = self._output_pixel_scale**2 + output_model.meta.photometry.pixelarea_steradians = pix_area + output_model.meta.photometry.pixelarea_arcsecsq = ( + pix_area * np.rad2deg(3600)**2 + ) + + return output_model + + def build_output_model_name(self): + fnames = {f for f in self._input_file_names if f is not None} + + if not fnames: + return "resampled_data_{resample_suffix}{resample_file_ext}" + + # TODO: maybe remove ending suffix for single file names? + prefix = os.path.commonprefix( + [PurePath(f).stem.strip('_- ') for f in fnames] + ) + + return prefix + "{resample_suffix}{resample_file_ext}" + + def create_output_model(self, resample_results): + # this probably needs to be an abstract class (different telescopes + # may want to save different arrays and ignore others). + + if not self._output_model and self._output_filename: + if self._accum and Path(self._output_filename).is_file(): + self._output_model = self.open_model(self._output_filename) + else: + self._output_model = self._create_new_output_model() + self._close_output = not self.in_memory + + if self._output_filename is None: + self._output_filename = self.build_output_model_name() + + self._output_model.data = resample_results.out_img + + self.update_exposure_times() + self.finalize_resample() + + self._output_model.meta.resample.weight_type = self.weight_type + self._output_model.meta.resample.pointings = len(self._input_models.group_names) + # TODO: also store the number of images added in total: ncoadds? + + self.blend_output_metadata(self._output_model) + + self._output_model.write(self._output_filename, overwrite=True) + + if self._close_output and not self.in_memory: + self.close_model(self._output_model) + self._output_model = None + + def run(self): + """Resample and coadd many inputs to a single output. + + Used for stage 3 resampling + """ + + # TODO: repetiveness of code below should be compactified via using + # getattr as in orig code and maybe making an alternative method to + # the original resample_variance_array + ninputs = len(self._input_models) + + do_accum = ( + self._accum and + ( + self._output_model or + (self._output_filename and Path(self._output_filename).is_file()) + ) + ) + + if do_accum and self._output_model is None: + self._output_model = self.open_model(self._output_filename) + + # get old data: + data = self._output_model.data # use .copy()? + wht = self._output_model.wht # use .copy()? + ctx = self._output_model.con # use .copy()? + t_exptime = self._output_model.meta.exptime + # TODO: we need something to store total number of images that + # have been used to create the resampled output, something + # similar to output_model.meta.resample.pointings + ncoadds = self._output_model.meta.resample.ncoadds # ???? (not sure about name) + self.accum_output_arrays = True + + else: + ncoadds = 0 + data = None + wht = None + ctx = None + t_exptime = 0.0 + self.accum_output_arrays = False + + driz_data = Drizzle( + kernel=self.kernel, + fillval=self.fillval, + out_shape=self._output_array_shape, + out_img=data, + out_wht=wht, + out_ctx=ctx, + exptime=t_exptime, + begin_ctx_id=ncoadds, + max_ctx_id=ncoadds + ninputs, + ) + + self.extra_pre_resample_setup() + + log.info("Resampling science data") + for img in self._input_models: + input_pixflux_area = img.meta.photometry.pixelarea_steradians + if (input_pixflux_area and + 'SPECTRAL' not in img.meta.wcs.output_frame.axes_type): + img.meta.wcs.array_shape = img.data.shape + input_pixel_area = _compute_image_pixel_area(img.meta.wcs) + if input_pixel_area is None: + raise ValueError( + "Unable to compute input pixel area from WCS of input " + f"image {repr(img.meta.filename)}." + ) + iscale = np.sqrt(input_pixflux_area / input_pixel_area) + else: + iscale = 1.0 + + img.meta.iscale = iscale + + # TODO: should weight_type=None here? + in_wht = self.build_driz_weight( + img, + weight_type=self.weight_type, + good_bits=self.good_bits + ) + + # apply sky subtraction + blevel = img.meta.background.level + if not img.meta.background.subtracted and blevel is not None: + in_data = img.data - blevel + else: + in_data = img.data + + xmin, xmax, ymin, ymax = _resample_range( + in_data.shape, + img.meta.wcs.bounding_box + ) + + pixmap = calc_pixmap(wcs_from=img.meta.wcs, wcs_to=self._output_wcs) + + add_image_kwargs = { + 'exptime': img.meta.exposure.exposure_time, + 'pixmap': pixmap, + 'scale': iscale, + 'weight_map': in_wht, + 'wht_scale': 1.0, + 'pixfrac': self.pixfrac, + 'in_units': 'cps', # TODO: get units from data model + 'xmin': xmin, + 'xmax': xmax, + 'ymin': ymin, + 'ymax': ymax, + } + + driz_data.add_image(in_data, **add_image_kwargs) + + self.post_process_resample_model(img, None, add_image_kwargs) + + # TODO: see what to do about original update_exposure_times() + + return self.create_output_model(driz_data) + + +class ResampleSingle(ResampleBase): + """ + This is the controlling routine for the resampling process. + + Notes + ----- + This routine performs the following operations:: + + 1. Extracts parameter settings from input model, such as pixfrac, + weight type, exposure time (if relevant), and kernel, and merges + them with any user-provided values. + 2. Creates output WCS based on input images and define mapping function + between all input arrays and the output array. + 3. Updates output data model with output arrays from drizzle, including + a record of metadata from all input models. + """ + + def __init__(self, input_models, + pixfrac=1.0, kernel="square", fillval=0.0, wht_type="ivm", + good_bits=0, output_wcs=None, wcs_pars=None, + in_memory=True, allowed_memory=None): + """ + Parameters + ---------- + input_models : list of objects + list of data models, one for each input image + + output : DataModel, str + filename for output + + kwargs : dict + Other parameters. + + .. note:: + ``output_shape`` is in the ``x, y`` order. + + .. note:: + ``in_memory`` controls whether or not the resampled + array from ``resample_many_to_many()`` + should be kept in memory or written out to disk and + deleted from memory. Default value is `True` to keep + all products in memory. + + """ + super().__init__( + input_models, + pixfrac=pixfrac, + kernel=kernel, + fillval=fillval, + wht_type=wht_type, + good_bits=good_bits, + output_wcs=output_wcs, + wcs_pars=wcs_pars, + in_memory=in_memory, + allowed_memory=allowed_memory, + ) + + def build_output_name_from_input_name(self, input_file_name): + """ Form output file name from input image name """ + indx = input_file_name.rfind('.') + output_type = input_file_name[indx:] + output_root = '_'.join( + input_file_name.replace(output_type, '').split('_')[:-1] + ) + output_file_name = f'{output_root}_outlier_i2d{output_type}' + return output_file_name + + def _create_output_template_model(self): + # this probably needs to be an abstract class. + # also this is mostly needed for "single" drizzle. + self._template_output_model = self.new_model() + self._template_output_model.update(self._first_model_meta) + self._template_output_model.meta.wcs = deepcopy(self._output_wcs) + + pix_area = self._output_pixel_scale**2 + self._template_output_model.meta.photometry.pixelarea_steradians = pix_area + self._template_output_model.meta.photometry.pixelarea_arcsecsq = ( + pix_area * np.rad2deg(3600)**2 + ) + + def create_output_model_single(self, file_name, resample_results): + # this probably needs to be an abstract class + output_model = self._template_output_model.copy() + output_model.data = resample_results.out_img + if self.in_memory: + return output_model + else: + output_model.write(file_name, overwrite=True) + self.close_model(output_model.close) + log.info(f"Saved resampled model to {file_name}") + return file_name + + def resample(self): + """Resample many inputs to many outputs where outputs have a common frame. + + Coadd only different detectors of the same exposure, i.e. map NRCA5 and + NRCB5 onto the same output image, as they image different areas of the + sky. + + Used for outlier detection + """ + output_models = [] # ModelContainer() + + for exposure in self._input_models.models_grouped: + driz = Drizzle( + kernel=self.kernel, + fillval=self.fillval, + out_shape=self._output_array_shape, + max_ctx_id=0 + ) + + # Determine output file type from input exposure filenames + # Use this for defining the output filename + output_filename = self.resampled_output_name_from_input_name( + exposure[0].meta.filename + ) + + log.info(f"{len(exposure)} exposures to drizzle together") + + exptime = None + + for img in exposure: + img = self.open_model(img) + if exptime is None: + exptime = exposure[0].meta.exposure.exposure_time + + # compute image intensity correction due to the difference + # between where in the input image + # img.meta.photometry.pixelarea_steradians was computed and + # the average input pixel area. + + input_pixflux_area = img.meta.photometry.pixelarea_steradians + if (input_pixflux_area and + 'SPECTRAL' not in img.meta.wcs.output_frame.axes_type): + img.meta.wcs.array_shape = img.data.shape + input_pixel_area = _compute_image_pixel_area(img.meta.wcs) + if input_pixel_area is None: + raise ValueError( + "Unable to compute input pixel area from WCS of input " + f"image {repr(img.meta.filename)}." + ) + iscale = np.sqrt(input_pixflux_area / input_pixel_area) + else: + iscale = 1.0 + + # TODO: should weight_type=None here? + in_wht = self.build_driz_weight( + img, + weight_type=self.weight_type, + good_bits=self.good_bits + ) + + # apply sky subtraction + blevel = img.meta.background.level + if not img.meta.background.subtracted and blevel is not None: + data = img.data - blevel + else: + data = img.data + + xmin, xmax, ymin, ymax = _resample_range( + data.shape, + img.meta.wcs.bounding_box + ) + + pixmap = calc_pixmap(wcs_from=img.meta.wcs, wcs_to=self._output_wcs) + + driz.add_image( + data, + exptime=exposure[0].meta.exposure.exposure_time, + pixmap=pixmap, + scale=iscale, + weight_map=in_wht, + wht_scale=1.0, + pixfrac=self.pixfrac, + in_units='cps', # TODO: get units from data model + xmin=xmin, + xmax=xmax, + ymin=ymin, + ymax=ymax, + ) + + del data + self.close_model(img) + + output_models.append( + self.create_output_model_single( + output_filename, + driz + ) + ) + + return output_models # or maybe just a plain list - not ModelContainer? + + +def _get_boundary_points(xmin, xmax, ymin, ymax, dx=None, dy=None, shrink=0): + """ + xmin, xmax, ymin, ymax - integer coordinates of pixel boundaries + step - distance between points along an edge + shrink - number of pixels by which to reduce `shape` + + Returns a list of points and the area of the rectangle + """ + nx = xmax - xmin + 1 + ny = ymax - ymin + 1 + + if dx is None: + dx = nx + if dy is None: + dy = ny + + if nx - 2 * shrink < 1 or ny - 2 * shrink < 1: + raise ValueError("Image size is too small.") + + sx = max(1, int(np.ceil(nx / dx))) + sy = max(1, int(np.ceil(ny / dy))) + + xmin += shrink + xmax -= shrink + ymin += shrink + ymax -= shrink + + size = 2 * sx + 2 * sy + x = np.empty(size) + y = np.empty(size) + + b = np.s_[0:sx] # bottom edge + r = np.s_[sx:sx + sy] # right edge + t = np.s_[sx + sy:2 * sx + sy] # top edge + l = np.s_[2 * sx + sy:2 * sx + 2 * sy] # left + + x[b] = np.linspace(xmin, xmax, sx, False) + y[b] = ymin + x[r] = xmax + y[r] = np.linspace(ymin, ymax, sy, False) + x[t] = np.linspace(xmax, xmin, sx, False) + y[t] = ymax + x[l] = xmin + y[l] = np.linspace(ymax, ymin, sy, False) + + area = (xmax - xmin) * (ymax - ymin) + center = (0.5 * (xmin + xmax), 0.5 * (ymin + ymax)) + + return x, y, area, center, b, r, t, l + + +def _compute_image_pixel_area(wcs): + """ Computes pixel area in steradians. + """ + if wcs.array_shape is None: + raise ValueError("WCS must have array_shape attribute set.") + + valid_polygon = False + spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == 'SPATIAL')[0] + + ny, nx = wcs.array_shape + + ((xmin, xmax), (ymin, ymax)) = wcs.bounding_box + + xmin = max(0, int(xmin + 0.5)) + xmax = min(nx - 1, int(xmax - 0.5)) + ymin = max(0, int(ymin + 0.5)) + ymax = min(ny - 1, int(ymax - 0.5)) + if xmin > xmax: + (xmin, xmax) = (xmax, xmin) + if ymin > ymax: + (ymin, ymax) = (ymax, ymin) + + k = 0 + dxy = [1, -1, -1, 1] + + while xmin < xmax and ymin < ymax: + try: + x, y, image_area, center, b, r, t, l = _get_boundary_points( + xmin=xmin, + xmax=xmax, + ymin=ymin, + ymax=ymax, + dx=min((xmax - xmin) // 4, 15), + dy=min((ymax - ymin) // 4, 15) + ) + except ValueError: + return None + + world = wcs(x, y) + ra = world[spatial_idx[0]] + dec = world[spatial_idx[1]] + + limits = [ymin, xmax, ymax, xmin] + + for j in range(4): + sl = [b, r, t, l][k] + if not (np.all(np.isfinite(ra[sl])) and + np.all(np.isfinite(dec[sl]))): + limits[k] += dxy[k] + ymin, xmax, ymax, xmin = limits + k = (k + 1) % 4 + break + k = (k + 1) % 4 + else: + valid_polygon = True + break + + ymin, xmax, ymax, xmin = limits + + if not valid_polygon: + return None + + world = wcs(*center) + wcenter = (world[spatial_idx[0]], world[spatial_idx[1]]) + + sky_area = SphericalPolygon.from_radec(ra, dec, center=wcenter).area() + if sky_area > 2 * np.pi: + log.warning( + "Unexpectedly large computed sky area for an image. " + "Setting area to: 4*Pi - area" + ) + sky_area = 4 * np.pi - sky_area + pix_area = sky_area / image_area + + return pix_area diff --git a/src/stcal/resample/utils.py b/src/stcal/resample/utils.py new file mode 100644 index 000000000..b48d19b0f --- /dev/null +++ b/src/stcal/resample/utils.py @@ -0,0 +1,36 @@ +import numpy as np +from stdatamodels.dqflags import interpret_bit_flags +from stdatamodels.jwst.datamodels.dqflags import pixel + +__all__ = [ + "build_mask", "get_tmeasure", +] + + +def build_mask(dqarr, bitvalue): + """Build a bit mask from an input DQ array and a bitvalue flag + + In the returned bit mask, 1 is good, 0 is bad + """ + bitvalue = interpret_bit_flags(bitvalue, mnemonic_map=pixel) + + if bitvalue is None: + return np.ones(dqarr.shape, dtype=np.uint8) + return np.logical_not(np.bitwise_and(dqarr, ~bitvalue)).astype(np.uint8) + + +def get_tmeasure(model): + """ + Check if the measurement_time keyword is present in the datamodel + for use in exptime weighting. If not, revert to using exposure_time. + + Returns a tuple of (exptime, is_measurement_time) + """ + try: + tmeasure = model.meta.exposure.measurement_time + except AttributeError: + return model.meta.exposure.exposure_time, False + if tmeasure is None: + return model.meta.exposure.exposure_time, False + else: + return tmeasure, True From 45a49cc34af438e001b68a20903af409af34e9ad Mon Sep 17 00:00:00 2001 From: Mihai Cara Date: Thu, 8 Aug 2024 10:26:27 -0400 Subject: [PATCH 02/10] fix method names --- src/stcal/resample/resample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stcal/resample/resample.py b/src/stcal/resample/resample.py index af055ccaf..14d2b0fb1 100644 --- a/src/stcal/resample/resample.py +++ b/src/stcal/resample/resample.py @@ -892,7 +892,7 @@ def create_output_model_single(self, file_name, resample_results): log.info(f"Saved resampled model to {file_name}") return file_name - def resample(self): + def run(self): """Resample many inputs to many outputs where outputs have a common frame. Coadd only different detectors of the same exposure, i.e. map NRCA5 and @@ -913,7 +913,7 @@ def resample(self): # Determine output file type from input exposure filenames # Use this for defining the output filename - output_filename = self.resampled_output_name_from_input_name( + output_filename = self.build_output_name_from_input_name( exposure[0].meta.filename ) From a3a586709780850521429854aadc2cc8f210ced1 Mon Sep 17 00:00:00 2001 From: Mihai Cara Date: Wed, 14 Aug 2024 10:02:02 -0400 Subject: [PATCH 03/10] Remove dependencies on stdatamodels, etc. --- src/stcal/resample/resample.py | 261 ++++++++++++++++++++++++++++----- 1 file changed, 225 insertions(+), 36 deletions(-) diff --git a/src/stcal/resample/resample.py b/src/stcal/resample/resample.py index 14d2b0fb1..489bcd5b9 100644 --- a/src/stcal/resample/resample.py +++ b/src/stcal/resample/resample.py @@ -15,11 +15,7 @@ import psutil from spherical_geometry.polygon import SphericalPolygon -from stdatamodels.dqflags import interpret_bit_flags -from stdatamodels.jwst.datamodels.dqflags import pixel - -from stdatamodels.jwst import datamodels -from stdatamodels.jwst.library.basic_utils import bytes2human +from astropy.nddata.bitmask import interpret_bit_flags from .utils import get_tmeasure, build_mask @@ -46,6 +42,43 @@ ] +# FIXME: temporarily copied here to avoid this import: +# from stdatamodels.jwst.library.basic_utils import bytes2human +def bytes2human(n): + """Convert bytes to human-readable format + + Taken from the `psutil` library which references + http://code.activestate.com/recipes/578019 + + Parameters + ---------- + n : int + Number to convert + + Returns + ------- + readable : str + A string with units attached. + + Examples + -------- + >>> bytes2human(10000) + '9.8K' + + >>> bytes2human(100001221) + '95.4M' + """ + symbols = ('K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y') + prefix = {} + for i, s in enumerate(symbols): + prefix[s] = 1 << (i + 1) * 10 + for s in reversed(symbols): + if n >= prefix[s]: + value = float(n) / prefix[s] + return '%.1f%s' % (value, s) + return "%sB" % n + + def _resample_range(data_shape, bbox=None): # Find range of input pixels to resample: if bbox is None: @@ -125,10 +158,13 @@ class ResampleBase(abc.ABC): """ resample_suffix = 'i2d' resample_file_ext = '.fits' + n_arrays_per_output = 2 # #flt-point arrays in the output (data, weight, var, err, etc.) + dq_flag_name_map = {} def __init__(self, input_models, pixfrac=1.0, kernel="square", fillval=0.0, wht_type="ivm", good_bits=0, output_wcs=None, wcs_pars=None, + data_type=np.float64, enable_ctx=True, in_memory=True, allowed_memory=None, **kwargs): """ Parameters @@ -152,6 +188,9 @@ def __init__(self, input_models, deleted from memory. Default value is `True` to keep all products in memory. """ + self._data_type = data_type + self._enable_ctx = enable_ctx + self._output_model = None self._output_filename = None self._output_wcs = None @@ -319,8 +358,6 @@ def check_memory_requirements(self, allowed_memory): return allowed_memory = float(allowed_memory) - # make a small image model to get the dtype - dtype = datamodels.ImageModel((1, 1)).data.dtype # get the available memory available_memory = psutil.virtual_memory().available + psutil.swap_memory().total @@ -329,7 +366,8 @@ def check_memory_requirements(self, allowed_memory): npix = npix = np.prod(self._output_array_shape) nmodels = len(self._input_models) nconpl = nmodels // 32 + (1 if nmodels % 32 else 0) - required_memory = npix * (3 * dtype.itemsize + nconpl * 4) + n_arr = self.n_arrays_per_output + 2 # 2 comes from pixmap + required_memory = npix * (n_arr * self._data_type.itemsize + nconpl * 4) # compare used to available used_fraction = required_memory / available_memory @@ -384,11 +422,6 @@ def load_input_meta(self, all=True): if not all: break - def blend_output_metadata(self, output_model): - # TODO: Not sure about funct. signature and also I don't like it needs - # to open input files again - pass - def build_driz_weight(self, model, weight_type=None, good_bits=None): """Create a weight map for use by drizzle """ @@ -397,15 +430,22 @@ def build_driz_weight(self, model, weight_type=None, good_bits=None): if weight_type and weight_type.startswith('ivm'): weight_type = weight_type.strip() selective_median = weight_type.startswith('ivm-smed') - - bitvalue = interpret_bit_flags(good_bits, mnemonic_map=pixel) + bitvalue = interpret_bit_flags( + good_bits, + flag_name_map=self.dq_flag_name_map + ) if bitvalue is None: bitvalue = 0 - saturation = pixel['SATURATED'] - if selective_median and not (bitvalue & saturation): - selective_median = False - weight_type = 'ivm' + # disable selective median if SATURATED flag is included + # in good_bits: + try: + saturation = self.dq_flag_name_map['SATURATED'] + if selective_median and not (bitvalue & saturation): + selective_median = False + weight_type = 'ivm' + except AttributeError: + pass if (model.hasattr("var_rnoise") and model.var_rnoise is not None and model.var_rnoise.shape == model.data.shape): @@ -608,18 +648,6 @@ def process_kwargs(self, kwargs): return kwargs - def blend_output_metadata(self, output_model): - pass - - def extra_pre_resample_setup(self): - pass - - def post_process_resample_model(self, data_model, driz_init_kwargs, add_image_kwargs): - pass - - def finalize_resample(self): - pass - def _create_new_output_model(self): # this probably needs to be an abstract class. # also this is mostly needed for "single" drizzle. @@ -669,13 +697,13 @@ def create_output_model(self, resample_results): self._output_model.data = resample_results.out_img self.update_exposure_times() - self.finalize_resample() + self._finish_variance_processing() self._output_model.meta.resample.weight_type = self.weight_type self._output_model.meta.resample.pointings = len(self._input_models.group_names) # TODO: also store the number of images added in total: ncoadds? - self.blend_output_metadata(self._output_model) + self.final_post_processing() self._output_model.write(self._output_filename, overwrite=True) @@ -683,6 +711,167 @@ def create_output_model(self, resample_results): self.close_model(self._output_model) self._output_model = None + def _setup_variance_data(self): + self._var_rnoise_sum = np.full(self._output_array_shape, np.nan) + self._var_poisson_sum = np.full(self._output_array_shape, np.nan) + self._var_flat_sum = np.full(self._output_array_shape, np.nan) + # self._total_weight_var_rnoise = np.zeros(self._output_array_shape) + self._total_weight_var_poisson = np.zeros(self._output_array_shape) + self._total_weight_var_flat = np.zeros(self._output_array_shape) + + def _resample_variance_data(self, data_model, driz_init_kwargs, add_image_kwargs): + log.info("Resampling variance components") + + # create resample objects for the three variance arrays: + driz_init_kwargs = { + 'kernel': self.kernel, + 'fillval': np.nan, + 'out_shape': self._output_array_shape, + # 'exptime': 1.0, + 'no_ctx': True, + } + driz_rnoise = Drizzle(**driz_init_kwargs) + driz_poisson = Drizzle(**driz_init_kwargs) + driz_flat = Drizzle(**driz_init_kwargs) + + # Resample read-out noise and compute weight map for variance arrays + if self._check_var_array(data_model, 'var_rnoise'): + data = np.sqrt(data_model.var_rnoise) + driz_rnoise.add_image(data, **add_image_kwargs) + var = driz_rnoise.out_img + np.square(var, out=var) + + weight_mask = var > 0 + + # Set the weight for the image from the weight type + if self.weight_type == "ivm": + weight_mask = var > 0 + weight = np.ones(self._output_array_shape) + weight[weight_mask] = np.reciprocal(var[weight_mask]) + weight_mask &= (weight > 0.0) + # Add the inverse of the resampled variance to a running sum. + # Update only pixels (in the running sum) with valid new values: + self._var_rnoise_sum[weight_mask] = np.nansum( + [ + self._var_rnoise_sum[weight_mask], + weight[weight_mask] + ], + axis=0 + ) + elif self.weight_type == "exptime": + weight = np.full( + self._output_array_shape, + get_tmeasure(data_model)[0], + ) + weight_mask = np.ones(self._output_array_shape, dtype=bool) + self._var_rnoise_sum = np.nansum( + [self._var_rnoise_sum, weight], + axis=0 + ) + else: + weight = np.ones(self._output_array_shape) + weight_mask = np.ones(self._output_array_shape, dtype=bool) + self._var_rnoise_sum = np.nansum( + [self._var_rnoise_sum, weight], + axis=0 + ) + else: + weight = np.ones(self._output_array_shape) + weight_mask = np.ones(self._output_array_shape, dtype=bool) + + if self._check_var_array(data_model, 'var_poisson'): + data = np.sqrt(data_model.var_poisson) + driz_poisson.add_image(data, **add_image_kwargs) + var = driz_poisson.out_img + np.square(var, out=var) + + mask = (var > 0) & weight_mask + + # Add the inverse of the resampled variance to a running sum. + # Update only pixels (in the running sum) with valid new values: + self._var_poisson_sum[mask] = np.nansum( + [ + self._var_poisson_sum[mask], + var[mask] * weight[mask] * weight[mask] + ], + axis=0 + ) + self._total_weight_var_poisson[mask] += weight[mask] + + if self._check_var_array(data_model, 'var_flat'): + data = np.sqrt(data_model.var_flat) + driz_flat.add_image(data, **add_image_kwargs) + var = driz_flat.out_img + np.square(var, out=var) + + mask = (var > 0) & weight_mask + + # Add the inverse of the resampled variance to a running sum. + # Update only pixels (in the running sum) with valid new values: + self._var_flat_sum[mask] = np.nansum( + [ + self._var_flat_sum[mask], + var[mask] * weight[mask] * weight[mask] + ], + axis=0 + ) + self._total_weight_var_flat[mask] += weight[mask] + + def final_post_processing(self): + pass + + def _finish_variance_processing(self): + # We now have a sum of the weighted resampled variances. + # Divide by the total weights, squared, and set in the output model. + # Zero weight and missing values are NaN in the output. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "invalid value*", RuntimeWarning) + warnings.filterwarnings("ignore", "divide by zero*", RuntimeWarning) + + odt = self._output_model.data.dtype + + # readout noise + np.reciprocal(self._var_rnoise_sum, out=self._var_rnoise_sum) + self._output_model.var_rnoise = self._var_rnoise_sum.astype(dtype=odt) + + # Poisson noise + for _ in range(2): + np.divide( + self._var_poisson_sum, + self._total_weight_var_poisson, + out=self._var_poisson_sum + ) + self._output_model.var_poisson = self._var_poisson_sum.astype(dtype=odt) + + # flat's noise + for _ in range(2): + np.divide( + self._var_flat_sum, + self._total_weight_var_flat, + out=self._var_flat_sum + ) + self._output_model.var_flat = self._var_flat_sum.astype(dtype=odt) + + # compute total error: + vars = np.array( + [ + self._var_rnoise_sum, + self._var_poisson_sum, + self._var_flat_sum, + ] + ) + all_nan_mask = np.any(np.isnan(vars), axis=0) + self._output_model.err = np.sqrt(np.nansum(vars, axis=0)).astype(dtype=odt) + self._output_model.err[all_nan_mask] = np.nan + + del vars + del self._var_rnoise_sum + del self._var_poisson_sum + del self._var_flat_sum + # del self._total_weight_var_rnoise + del self._total_weight_var_poisson + del self._total_weight_var_flat + def run(self): """Resample and coadd many inputs to a single output. @@ -736,7 +925,7 @@ def run(self): max_ctx_id=ncoadds + ninputs, ) - self.extra_pre_resample_setup() + self._setup_variance_data() log.info("Resampling science data") for img in self._input_models: @@ -793,7 +982,7 @@ def run(self): driz_data.add_image(in_data, **add_image_kwargs) - self.post_process_resample_model(img, None, add_image_kwargs) + self._resample_variance_data(img, None, add_image_kwargs) # TODO: see what to do about original update_exposure_times() @@ -888,7 +1077,7 @@ def create_output_model_single(self, file_name, resample_results): return output_model else: output_model.write(file_name, overwrite=True) - self.close_model(output_model.close) + self.close_model(output_model) log.info(f"Saved resampled model to {file_name}") return file_name From e1d4b8c93cbb2b886bcbd67357d502848f2fc32d Mon Sep 17 00:00:00 2001 From: Mihai Cara Date: Thu, 22 Aug 2024 09:34:49 -0400 Subject: [PATCH 04/10] do not access model attributes directly --- src/stcal/resample/resample.py | 658 +++++++++++++++++++++------------ 1 file changed, 425 insertions(+), 233 deletions(-) diff --git a/src/stcal/resample/resample.py b/src/stcal/resample/resample.py index 489bcd5b9..20030d84f 100644 --- a/src/stcal/resample/resample.py +++ b/src/stcal/resample/resample.py @@ -101,7 +101,27 @@ def open_model(self, file_name): ... @abc.abstractmethod - def get_model_meta(self, file_name, fields): + def get_model_attr_value(self, model, attribute_name): + ... + + @abc.abstractmethod + def set_model_attr_value(self, model, attribute_name, value): + ... + + @abc.abstractmethod + def get_model_meta(self, model, attributes): + ... + + @abc.abstractmethod + def set_model_meta(self, model, attributes): + ... + + @abc.abstractmethod + def get_model_array(self, model, array_name): + ... + + @abc.abstractmethod + def set_model_array(self, model, array_name, data): ... @abc.abstractmethod @@ -113,11 +133,11 @@ def save_model(self, model): ... @abc.abstractmethod - def write_model(self, model, file_name): + def write_model(self, model, file_name, **kwargs): ... @abc.abstractmethod - def new_model(self, image_shape=None, file_name=None): + def new_model(self, image_shape=None, file_name=None, copy_meta_from=None): """ Return a new model for the resampled output """ ... @@ -164,7 +184,7 @@ class ResampleBase(abc.ABC): def __init__(self, input_models, pixfrac=1.0, kernel="square", fillval=0.0, wht_type="ivm", good_bits=0, output_wcs=None, wcs_pars=None, - data_type=np.float64, enable_ctx=True, + enable_ctx=True, in_memory=True, allowed_memory=None, **kwargs): """ Parameters @@ -188,7 +208,6 @@ def __init__(self, input_models, deleted from memory. Default value is `True` to keep all products in memory. """ - self._data_type = data_type self._enable_ctx = enable_ctx self._output_model = None @@ -202,7 +221,7 @@ def __init__(self, input_models, # input models self._input_models = input_models # a lightweight data model with meta from first input model but no data. - # it will be updated by 'load_input_meta()' below + # it will be updated by 'prload_input_meta()' below self._first_model_meta = None # resample parameters @@ -237,19 +256,26 @@ def __init__(self, input_models, self._warn_extra_kwargs(ukwargs) # load meta necessary for output WCS (and other) computations: - self.load_input_meta( - all=self._output_model is None and output_wcs is None + self.preload_input_meta( + wcs1=True, + filename=self._output_model is None, + s_region=output_wcs is None, ) # computed average pixel scale of the first input image: input_pscale0 = np.rad2deg( - np.sqrt(_compute_image_pixel_area(self._input_wcs_list[0])) + np.sqrt(_compute_image_pixel_area(self._input_img1_wcs)) ) # compute output pixel scale, WCS, set-up output model if self._output_model: - self._output_wcs = deepcopy(self._output_model.meta.wcs) - self._output_array_shape = self._output_model.data.shape + self._output_wcs = deepcopy( + self.get_model_attr_value(self._output_model, "wcs") + ) + self._output_array_shape = self.get_model_array( + self._output_model, + "data" + ).shape # TODO: extract any useful info from the output image before we close it: # if meta has pixel scale, populate it from there, if not: self._output_pixel_scale = np.rad2deg( @@ -362,12 +388,18 @@ def check_memory_requirements(self, allowed_memory): # get the available memory available_memory = psutil.virtual_memory().available + psutil.swap_memory().total + # determine data type of the output model: + out_model = self.new_model((2, 2)) + data = self.get_model_array(out_model) + data_type = data.dtype + del data, out_model + # compute the output array size npix = npix = np.prod(self._output_array_shape) nmodels = len(self._input_models) nconpl = nmodels // 32 + (1 if nmodels % 32 else 0) n_arr = self.n_arrays_per_output + 2 # 2 comes from pixmap - required_memory = npix * (n_arr * self._data_type.itemsize + nconpl * 4) + required_memory = npix * (n_arr * data_type.itemsize + nconpl * 4) # compare used to available used_fraction = required_memory / available_memory @@ -380,52 +412,62 @@ def check_memory_requirements(self, allowed_memory): def _compute_output_wcs(self, **wcs_pars): """ returns a diustortion-free WCS object and its pixel scale """ - owcs = output_wcs_from_input_wcs(self._input_wcs_list, **wcs_pars) +# owcs = output_wcs_from_input_wcs(self._input_s_region_list, **wcs_pars) + owcs = output_wcs_from_input_wcs([self._input_img1_wcs], **wcs_pars) return owcs['output_wcs'], owcs['pscale'] - def load_input_meta(self, all=True): - # if 'all=False', load meta from the first image only - - # set-up list for WCS - self._input_wcs_list = [] - self._input_s_region = [] - self._input_file_names = [] - self._close_input_models = [] - - for k, model in enumerate(self._input_models): - close_model = isinstance(model, str) - self._close_input_models.append(close_model) - if close_model: - self._input_file_names.append(model) - model = self.open_model(model) - if self.in_memory: - self._input_models[k] = model - else: - self._input_file_names.append(model.meta.filename) - # extract all info needed from *this* model: - w = deepcopy(model.meta.wcs) - w.array_shape = model.data.shape - self._input_wcs_list.append(w) + def preload_input_meta(self, wcs1, filename, s_region): + # set-up lists for WCS and file names + self._input_img1_wcs = None + self._input_s_region_list = [] + self._input_filename_list = [] - # extract other useful data - # - S_REGION: - self._input_s_region.append(model.meta.wcsinfo.s_region) + # loop over only science exposures in the ModelLibrary + # sci_indices = self._input_models.ind_asn_type("science") + with self._input_models: + for model in self._input_models: + # model = self._input_models.borrow(idx) - # store first model's entire meta (except for WCS and data): - if self._first_model_meta is None: - self._first_model_meta = self.new_model() - self._first_model_meta.update(model) + try: + if self.get_model_attr_value(model, "exptype").upper() != "SCIENCE": + self._input_models.shelve(model, modify=False) + continue + except AttributeError: + pass - if close_model and not self.in_memory: - self.close_model(model) + if self._input_img1_wcs is None and wcs1: + # extract all info needed from *this* model: + self._input_img1_wcs = deepcopy( + self.get_model_attr_value(model, "wcs") + ) + self._input_img1_wcs.array_shape = self.get_model_array( + model, + "data" + ).shape + + if filename: + self._input_filename_list.append( + self.get_model_attr_value(model, "filename") + ) - if not all: - break + if s_region: + self._input_s_region_list.append( + self.get_model_attr_value(model, "s_region") + ) + + self._input_models.shelve(model, modify=False) + + # store first model's entire meta (except for WCS and data): + if self._first_model_meta is None: + self._first_model_meta = self.new_model(copy_meta_from=model) def build_driz_weight(self, model, weight_type=None, good_bits=None): """Create a weight map for use by drizzle """ - dqmask = build_mask(model.dq, good_bits) + data = self.get_model_array(model, "data") + dq = self.get_model_array(model, "dq") + + dqmask = build_mask(dq, good_bits) if weight_type and weight_type.startswith('ivm'): weight_type = weight_type.strip() @@ -447,15 +489,15 @@ def build_driz_weight(self, model, weight_type=None, good_bits=None): except AttributeError: pass - if (model.hasattr("var_rnoise") and model.var_rnoise is not None and - model.var_rnoise.shape == model.data.shape): + var_rnoise = self.get_model_array(model, "var_rnoise", default=None) + if (var_rnoise is not None and var_rnoise.shape == data.shape): with np.errstate(divide="ignore", invalid="ignore"): - inv_variance = model.var_rnoise**-1 + inv_variance = var_rnoise**-1 inv_variance[~np.isfinite(inv_variance)] = 1 if weight_type != 'ivm': - ny, nx = inv_variance.shape + ny, nx = data.shape # apply a median filter to smooth the weight at saturated # (or high read-out noise) single pixels. keep kernel size @@ -465,8 +507,8 @@ def build_driz_weight(self, model, weight_type=None, good_bits=None): kernel_size = int(ksz) if not (kernel_size % 2): raise ValueError( - 'Kernel size of the median filter in IVM weighting' - ' must be an odd integer.' + 'Kernel size of the median filter in IVM ' + 'weighting must be an odd integer.' ) else: kernel_size = 3 @@ -477,7 +519,7 @@ def build_driz_weight(self, model, weight_type=None, good_bits=None): # apply median filter selectively only at # points of partially saturated sources: jumps = np.where( - np.logical_and(model.dq & saturation, dqmask) + np.logical_and(dq & saturation, dqmask) ) w2 = kernel_size // 2 for r, c in zip(*jumps): @@ -510,11 +552,11 @@ def build_driz_weight(self, model, weight_type=None, good_bits=None): result = inv_variance * dqmask elif weight_type == 'exptime': - exptime = model.meta.exposure.exposure_time + exptime = self.get_model_attr_value(model, "exposure_time") result = exptime * dqmask else: - result = np.ones(model.data.shape, dtype=model.data.dtype) * dqmask + result = np.ones(data.shape, dtype=data.dtype) * dqmask return result.astype(np.float32) @@ -533,28 +575,42 @@ def update_exposure_times(self): duration = 0.0 total_exptime = 0.0 measurement_time_success = [] - for exposure in self._input_models.models_grouped: - total_exposure_time += exposure[0].meta.exposure.exposure_time - t, success = get_tmeasure(exposure[0]) + + for exposure in self._input_models.group_indices.values(): + with self._input_models: + model0 = self._input_models.borrow(exposure[0]) + attrs = self.get_model_meta( + model0, + ["exposure_time", "start_time", "end_time", "duration"] + ) + + t, success = get_tmeasure(model0) + self._input_models.shelve(model0, exposure[0]) + + total_exposure_time += attrs["exposure_time"] measurement_time_success.append(success) total_exptime += t - exptime_start.append(exposure[0].meta.exposure.start_time) - exptime_end.append(exposure[0].meta.exposure.end_time) - duration += exposure[0].meta.exposure.duration + exptime_start.append(attrs["start_time"]) + exptime_end.append(attrs["end_time"]) + duration += attrs["duration"] + + attrs = { + # basic exposure time attributes: + "exposure_time": total_exposure_time, + "start_time": min(exptime_start), + "end_time": max(exptime_end), + # Update other exposure time keywords: + # XPOSURE (identical to the total effective exposure time, EFFEXPTM) + "effective_exposure_time": total_exptime, + # DURATION (identical to TELAPSE, elapsed time) + "duration": duration, + "elapsed_exposure_time": duration, + } - # Update some basic exposure time values based on output_model - self._output_model.meta.exposure.exposure_time = total_exposure_time if not all(measurement_time_success): - self._output_model.meta.exposure.measurement_time = total_exptime - self._output_model.meta.exposure.start_time = min(exptime_start) - self._output_model.meta.exposure.end_time = max(exptime_end) + attrs["measurement_time"] = total_exptime - # Update other exposure time keywords: - # XPOSURE (identical to the total effective exposure time, EFFEXPTM) - self._output_model.meta.exposure.effective_exposure_time = total_exptime - # DURATION (identical to TELAPSE, elapsed time) - self._output_model.meta.exposure.duration = duration - self._output_model.meta.exposure.elapsed_exposure_time = duration + self.set_model_meta(self._output_model, attrs) class ResampleCoAdd(ResampleBase): @@ -577,6 +633,7 @@ class ResampleCoAdd(ResampleBase): def __init__(self, input_models, output, accum=False, pixfrac=1.0, kernel="square", fillval=0.0, wht_type="ivm", good_bits=0, output_wcs=None, wcs_pars=None, + enable_ctx=True, enable_err=True, in_memory=True, allowed_memory=None): """ Parameters @@ -603,11 +660,20 @@ def __init__(self, input_models, output, accum=False, self._accum = accum super().__init__( - input_models, - pixfrac, kernel, fillval, wht_type, - good_bits, output_wcs, wcs_pars, - in_memory, allowed_memory, output=output + input_models=input_models, + pixfrac=pixfrac, + kernel=kernel, + fillval=fillval, + wht_type=wht_type, + good_bits=good_bits, + output_wcs=output_wcs, + wcs_pars=wcs_pars, + enable_ctx=enable_ctx, + in_memory=in_memory, + allowed_memory=allowed_memory, + output=output, ) + self._enable_err = enable_err def process_kwargs(self, kwargs): """ A method called by ``__init__`` to process input kwargs before @@ -642,7 +708,10 @@ def process_kwargs(self, kwargs): pass elif output is not None: - self._output_filename = output.meta.filename + self._output_filename = self.get_model_attr_value( + output, + "filename" + ) self._output_model = output self._close_output = False @@ -651,24 +720,26 @@ def process_kwargs(self, kwargs): def _create_new_output_model(self): # this probably needs to be an abstract class. # also this is mostly needed for "single" drizzle. - output_model = self.new_model(None) + output_model = self.new_model( + None, + copy_meta_from=self._first_model_meta + ) # update meta data and wcs - - # TODO: don't like this as it means reloading first image (again) - output_model.update(self._first_model_meta) - output_model.meta.wcs = deepcopy(self._output_wcs) - pix_area = self._output_pixel_scale**2 - output_model.meta.photometry.pixelarea_steradians = pix_area - output_model.meta.photometry.pixelarea_arcsecsq = ( - pix_area * np.rad2deg(3600)**2 + self.set_model_meta( + output_model, + { + "wcs": deepcopy(self._output_wcs), + "pixelarea_steradians": pix_area, + "pixelarea_arcsecsq": pix_area * np.rad2deg(3600)**2, + } ) return output_model def build_output_model_name(self): - fnames = {f for f in self._input_file_names if f is not None} + fnames = {f for f in self._input_filename_list if f is not None} if not fnames: return "resampled_data_{resample_suffix}{resample_file_ext}" @@ -694,13 +765,19 @@ def create_output_model(self, resample_results): if self._output_filename is None: self._output_filename = self.build_output_model_name() - self._output_model.data = resample_results.out_img + self.set_model_array(self._output_model, "data", resample_results.out_img) self.update_exposure_times() - self._finish_variance_processing() - - self._output_model.meta.resample.weight_type = self.weight_type - self._output_model.meta.resample.pointings = len(self._input_models.group_names) + if self._enable_err: + self._finish_variance_processing() + + self.set_model_meta( + self._output_model, + { + "weight_type": self.weight_type, + "pointings": len(self._input_models.group_names), + } + ) # TODO: also store the number of images added in total: ncoadds? self.final_post_processing() @@ -710,6 +787,9 @@ def create_output_model(self, resample_results): if self._close_output and not self.in_memory: self.close_model(self._output_model) self._output_model = None + return self._output_filename + + return self._output_model def _setup_variance_data(self): self._var_rnoise_sum = np.full(self._output_array_shape, np.nan) @@ -719,6 +799,27 @@ def _setup_variance_data(self): self._total_weight_var_poisson = np.zeros(self._output_array_shape) self._total_weight_var_flat = np.zeros(self._output_array_shape) + def _check_var_array(self, data_model, array_name): + array_data = self.get_model_array(data_model, array_name, default=None) + sci_data = self.get_model_array(data_model, "data", default=None) + filename = self.get_model_meta(data_model, "filename") + + if array_data is None or array_data.size == 0: + log.debug( + f"No data for '{array_name}' for model " + f"{repr(filename)}. Skipping ..." + ) + return False + + elif array_data.shape != sci_data.shape: + log.warning( + f"Data shape mismatch for '{array_name}' for model " + f"{repr(filename)}. Skipping ..." + ) + return False + + return True + def _resample_variance_data(self, data_model, driz_init_kwargs, add_image_kwargs): log.info("Resampling variance components") @@ -828,11 +929,15 @@ def _finish_variance_processing(self): warnings.filterwarnings("ignore", "invalid value*", RuntimeWarning) warnings.filterwarnings("ignore", "divide by zero*", RuntimeWarning) - odt = self._output_model.data.dtype + odt = self.get_model_array(self._output_model, "data").dtype # readout noise np.reciprocal(self._var_rnoise_sum, out=self._var_rnoise_sum) - self._output_model.var_rnoise = self._var_rnoise_sum.astype(dtype=odt) + self.set_model_array( + self._output_model, + "var_rnoise", + self._var_rnoise_sum.astype(dtype=odt) + ) # Poisson noise for _ in range(2): @@ -841,7 +946,11 @@ def _finish_variance_processing(self): self._total_weight_var_poisson, out=self._var_poisson_sum ) - self._output_model.var_poisson = self._var_poisson_sum.astype(dtype=odt) + self.set_model_array( + self._output_model, + "var_poisson", + self._var_poisson_sum.astype(dtype=odt) + ) # flat's noise for _ in range(2): @@ -850,7 +959,11 @@ def _finish_variance_processing(self): self._total_weight_var_flat, out=self._var_flat_sum ) - self._output_model.var_flat = self._var_flat_sum.astype(dtype=odt) + self.set_model_array( + self._output_model, + "var_flat", + self._var_flat_sum.astype(dtype=odt) + ) # compute total error: vars = np.array( @@ -861,8 +974,13 @@ def _finish_variance_processing(self): ] ) all_nan_mask = np.any(np.isnan(vars), axis=0) - self._output_model.err = np.sqrt(np.nansum(vars, axis=0)).astype(dtype=odt) - self._output_model.err[all_nan_mask] = np.nan + + err = np.sqrt(np.nansum(vars, axis=0)).astype(dtype=odt) + err[all_nan_mask] = np.nan + self.set_model_array(self._output_model, "err", err) + self.set_model_array(self._output_model, "var_rnoise", self._var_rnoise_sum) + self.set_model_array(self._output_model, "var_poisson", self._var_poisson_sum) + self.set_model_array(self._output_model, "var_flat", self._var_flat_sum) del vars del self._var_rnoise_sum @@ -895,14 +1013,25 @@ def run(self): self._output_model = self.open_model(self._output_filename) # get old data: - data = self._output_model.data # use .copy()? - wht = self._output_model.wht # use .copy()? - ctx = self._output_model.con # use .copy()? - t_exptime = self._output_model.meta.exptime + data = self.get_model_array(self._output_model, "data") + wht = self.get_model_array(self._output_model, "wht") + if self._enable_ctx: + ctx = self.get_model_array(self._output_model, "con") + else: + ctx = None + + t_exptime = self.get_model_attr_value( + self._output_model, + "exptime" + ) # TODO: we need something to store total number of images that # have been used to create the resampled output, something - # similar to output_model.meta.resample.pointings - ncoadds = self._output_model.meta.resample.ncoadds # ???? (not sure about name) + # similar to output_model.meta.resample.pointings. + # For now I will call it "ncoadds" + ncoadds = self.get_model_attr_value( + self._output_model, + "ncoadds" + ) self.accum_output_arrays = True else: @@ -925,64 +1054,98 @@ def run(self): max_ctx_id=ncoadds + ninputs, ) - self._setup_variance_data() + if self._enable_err: + self._setup_variance_data() log.info("Resampling science data") - for img in self._input_models: - input_pixflux_area = img.meta.photometry.pixelarea_steradians - if (input_pixflux_area and - 'SPECTRAL' not in img.meta.wcs.output_frame.axes_type): - img.meta.wcs.array_shape = img.data.shape - input_pixel_area = _compute_image_pixel_area(img.meta.wcs) - if input_pixel_area is None: - raise ValueError( - "Unable to compute input pixel area from WCS of input " - f"image {repr(img.meta.filename)}." + + # loop over only science exposures in the ModelLibrary + # sci_indices = self._input_models.ind_asn_type("science") + with self._input_models: + for model in self._input_models: + # model = self._input_models.borrow(idx) + + try: + if self.get_model_attr_value(model, "exptype").upper() != "SCIENCE": + self._input_models.shelve(model, modify=False) + continue + except AttributeError: + pass + + in_data = self.get_model_array(model, "data") + + attrs = self.get_model_meta( + model, + [ + "wcs", + "pixelarea_steradians", + "filename", + "level", + "subtracted", + "exposure_time", + ] + ) + + # Check that input models are 2D images + if in_data.ndim != 2: + raise RuntimeError( + f"Input {attrs['filename']} is not a 2D image." ) - iscale = np.sqrt(input_pixflux_area / input_pixel_area) - else: - iscale = 1.0 - img.meta.iscale = iscale + input_pixflux_area = attrs["pixelarea_steradians"] + imwcs = attrs["wcs"] + if (input_pixflux_area and + 'SPECTRAL' not in imwcs.output_frame.axes_type): + imwcs.array_shape = in_data.shape + input_pixel_area = _compute_image_pixel_area(imwcs) + if input_pixel_area is None: + raise ValueError( + "Unable to compute input pixel area from WCS of input " + f"image {repr(attrs['filename'])}." + ) + iscale = np.sqrt(input_pixflux_area / input_pixel_area) + else: + iscale = 1.0 - # TODO: should weight_type=None here? - in_wht = self.build_driz_weight( - img, - weight_type=self.weight_type, - good_bits=self.good_bits - ) + # TODO: should weight_type=None here? + in_wht = self.build_driz_weight( + model, + weight_type=self.weight_type, + good_bits=self.good_bits + ) - # apply sky subtraction - blevel = img.meta.background.level - if not img.meta.background.subtracted and blevel is not None: - in_data = img.data - blevel - else: - in_data = img.data + # apply sky subtraction + blevel = attrs["level"] + if not attrs["subtracted"] and blevel is not None: + in_data = in_data - blevel - xmin, xmax, ymin, ymax = _resample_range( - in_data.shape, - img.meta.wcs.bounding_box - ) + xmin, xmax, ymin, ymax = _resample_range( + in_data.shape, + imwcs.bounding_box + ) - pixmap = calc_pixmap(wcs_from=img.meta.wcs, wcs_to=self._output_wcs) - - add_image_kwargs = { - 'exptime': img.meta.exposure.exposure_time, - 'pixmap': pixmap, - 'scale': iscale, - 'weight_map': in_wht, - 'wht_scale': 1.0, - 'pixfrac': self.pixfrac, - 'in_units': 'cps', # TODO: get units from data model - 'xmin': xmin, - 'xmax': xmax, - 'ymin': ymin, - 'ymax': ymax, - } + pixmap = calc_pixmap(wcs_from=imwcs, wcs_to=self._output_wcs) - driz_data.add_image(in_data, **add_image_kwargs) + add_image_kwargs = { + 'exptime': attrs["exposure_time"], + 'pixmap': pixmap, + 'scale': iscale, + 'weight_map': in_wht, + 'wht_scale': 1.0, + 'pixfrac': self.pixfrac, + 'in_units': 'cps', # TODO: get units from data model + 'xmin': xmin, + 'xmax': xmax, + 'ymin': ymin, + 'ymax': ymax, + } - self._resample_variance_data(img, None, add_image_kwargs) + driz_data.add_image(in_data, **add_image_kwargs) + + if self._enable_err: + self._resample_variance_data(model, None, add_image_kwargs) + + self._input_models.shelve(model, modify=False) # TODO: see what to do about original update_exposure_times() @@ -1059,24 +1222,27 @@ def build_output_name_from_input_name(self, input_file_name): def _create_output_template_model(self): # this probably needs to be an abstract class. # also this is mostly needed for "single" drizzle. - self._template_output_model = self.new_model() - self._template_output_model.update(self._first_model_meta) - self._template_output_model.meta.wcs = deepcopy(self._output_wcs) - + self._template_output_model = self.new_model( + copy_meta_from=self._first_model_meta, + ) pix_area = self._output_pixel_scale**2 - self._template_output_model.meta.photometry.pixelarea_steradians = pix_area - self._template_output_model.meta.photometry.pixelarea_arcsecsq = ( - pix_area * np.rad2deg(3600)**2 + self.set_model_meta( + self._template_output_model, + { + "wcs": deepcopy(self._output_wcs), + "pixelarea_steradians": pix_area, + "pixelarea_arcsecsq": pix_area * np.rad2deg(3600)**2, + } ) def create_output_model_single(self, file_name, resample_results): # this probably needs to be an abstract class - output_model = self._template_output_model.copy() - output_model.data = resample_results.out_img + output_model = deepcopy(self._template_output_model) + self.set_model_array(output_model, "data", resample_results.out_img) if self.in_memory: return output_model else: - output_model.write(file_name, overwrite=True) + self.write_model(output_model, file_name, overwrite=True) self.close_model(output_model) log.info(f"Saved resampled model to {file_name}") return file_name @@ -1090,9 +1256,10 @@ def run(self): Used for outlier detection """ - output_models = [] # ModelContainer() + output_models = [] + + for exposure_indices in self._input_models.group_indices.values(): - for exposure in self._input_models.models_grouped: driz = Drizzle( kernel=self.kernel, fillval=self.fillval, @@ -1100,87 +1267,112 @@ def run(self): max_ctx_id=0 ) - # Determine output file type from input exposure filenames - # Use this for defining the output filename - output_filename = self.build_output_name_from_input_name( - exposure[0].meta.filename - ) - - log.info(f"{len(exposure)} exposures to drizzle together") + log.info(f"{len(exposure_indices)} exposures to drizzle together") exptime = None - for img in exposure: - img = self.open_model(img) - if exptime is None: - exptime = exposure[0].meta.exposure.exposure_time + meta_fields = [ + "wcs", + "pixelarea_steradians", + "filename", + "level", + "subtracted", + ] - # compute image intensity correction due to the difference - # between where in the input image - # img.meta.photometry.pixelarea_steradians was computed and - # the average input pixel area. + for idx in exposure_indices: - input_pixflux_area = img.meta.photometry.pixelarea_steradians - if (input_pixflux_area and - 'SPECTRAL' not in img.meta.wcs.output_frame.axes_type): - img.meta.wcs.array_shape = img.data.shape - input_pixel_area = _compute_image_pixel_area(img.meta.wcs) - if input_pixel_area is None: - raise ValueError( - "Unable to compute input pixel area from WCS of input " - f"image {repr(img.meta.filename)}." + with self._input_models: + model = self._input_models.borrow(idx) + + in_data = self.get_model_array(model, "data") + + if exptime is None: + attrs = self.get_model_meta( + model, + meta_fields + ["exposure_time", "filename"] ) - iscale = np.sqrt(input_pixflux_area / input_pixel_area) - else: - iscale = 1.0 + else: + attrs = self.get_model_meta(model, meta_fields) - # TODO: should weight_type=None here? - in_wht = self.build_driz_weight( - img, - weight_type=self.weight_type, - good_bits=self.good_bits - ) + # Check that input models are 2D images + if in_data.ndim != 2: + raise RuntimeError( + f"Input {attrs['filename']} is not a 2D image." + ) - # apply sky subtraction - blevel = img.meta.background.level - if not img.meta.background.subtracted and blevel is not None: - data = img.data - blevel - else: - data = img.data + input_pixflux_area = attrs["pixelarea_steradians"] + imwcs = attrs["wcs"] - xmin, xmax, ymin, ymax = _resample_range( - data.shape, - img.meta.wcs.bounding_box - ) + if exptime is None: + exptime = attrs["exposure_time"] + # Determine output file type from input exposure filenames + # Use this for defining the output filename + output_filename = self.build_output_name_from_input_name( + attrs["filename"] + ) - pixmap = calc_pixmap(wcs_from=img.meta.wcs, wcs_to=self._output_wcs) - - driz.add_image( - data, - exptime=exposure[0].meta.exposure.exposure_time, - pixmap=pixmap, - scale=iscale, - weight_map=in_wht, - wht_scale=1.0, - pixfrac=self.pixfrac, - in_units='cps', # TODO: get units from data model - xmin=xmin, - xmax=xmax, - ymin=ymin, - ymax=ymax, - ) + # compute image intensity correction due to the difference + # between where in the input image + # img.meta.photometry.pixelarea_steradians was computed and + # the average input pixel area. + if (input_pixflux_area and + 'SPECTRAL' not in imwcs.output_frame.axes_type): + imwcs.array_shape = in_data.shape + input_pixel_area = _compute_image_pixel_area(imwcs) + if input_pixel_area is None: + raise ValueError( + "Unable to compute input pixel area from WCS " + f"of input image {repr(attrs['filename'])}." + ) + iscale = np.sqrt(input_pixflux_area / input_pixel_area) + else: + iscale = 1.0 - del data - self.close_model(img) + # TODO: should weight_type=None here? + in_wht = self.build_driz_weight( + model, + weight_type=self.weight_type, + good_bits=self.good_bits + ) - output_models.append( - self.create_output_model_single( - output_filename, - driz + # apply sky subtraction + blevel = attrs["level"] + if not attrs["subtracted"] and blevel is not None: + in_data = in_data - blevel + + xmin, xmax, ymin, ymax = _resample_range( + in_data.shape, + imwcs.bounding_box + ) + + pixmap = calc_pixmap(wcs_from=imwcs, wcs_to=self._output_wcs) + + driz.add_image( + in_data, + exptime=exptime, + pixmap=pixmap, + scale=iscale, + weight_map=in_wht, + wht_scale=1.0, + pixfrac=self.pixfrac, + in_units='cps', # TODO: get units from data model + xmin=xmin, + xmax=xmax, + ymin=ymin, + ymax=ymax, + ) + + self._input_models.shelve(model, idx, modify=False) + del in_data + + output_models.append( + self.create_output_model_single( + output_filename, + driz + ) ) - ) - return output_models # or maybe just a plain list - not ModelContainer? + return output_models def _get_boundary_points(xmin, xmax, ymin, ymax, dx=None, dy=None, shrink=0): From b8b41322eae31693e994f836dec7e91b6e24767f Mon Sep 17 00:00:00 2001 From: Mihai Cara Date: Thu, 22 Aug 2024 09:54:53 -0400 Subject: [PATCH 05/10] fix use of model attributes --- src/stcal/resample/resample.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/stcal/resample/resample.py b/src/stcal/resample/resample.py index 20030d84f..90fbd137f 100644 --- a/src/stcal/resample/resample.py +++ b/src/stcal/resample/resample.py @@ -802,7 +802,7 @@ def _setup_variance_data(self): def _check_var_array(self, data_model, array_name): array_data = self.get_model_array(data_model, array_name, default=None) sci_data = self.get_model_array(data_model, "data", default=None) - filename = self.get_model_meta(data_model, "filename") + filename = self.get_model_attr_value(data_model, "filename") if array_data is None or array_data.size == 0: log.debug( @@ -837,7 +837,8 @@ def _resample_variance_data(self, data_model, driz_init_kwargs, add_image_kwargs # Resample read-out noise and compute weight map for variance arrays if self._check_var_array(data_model, 'var_rnoise'): - data = np.sqrt(data_model.var_rnoise) + data = self.get_model_array(data_model, "var_rnoise") + data = np.sqrt(data) driz_rnoise.add_image(data, **add_image_kwargs) var = driz_rnoise.out_img np.square(var, out=var) @@ -881,7 +882,8 @@ def _resample_variance_data(self, data_model, driz_init_kwargs, add_image_kwargs weight_mask = np.ones(self._output_array_shape, dtype=bool) if self._check_var_array(data_model, 'var_poisson'): - data = np.sqrt(data_model.var_poisson) + data = self.get_model_array(data_model, "var_poisson") + data = np.sqrt(data) driz_poisson.add_image(data, **add_image_kwargs) var = driz_poisson.out_img np.square(var, out=var) @@ -900,7 +902,8 @@ def _resample_variance_data(self, data_model, driz_init_kwargs, add_image_kwargs self._total_weight_var_poisson[mask] += weight[mask] if self._check_var_array(data_model, 'var_flat'): - data = np.sqrt(data_model.var_flat) + data = self.get_model_array(data_model, "var_flat") + data = np.sqrt(data) driz_flat.add_image(data, **add_image_kwargs) var = driz_flat.out_img np.square(var, out=var) From 4b99b4a5af32537204fa7fafa2687653d2b76381 Mon Sep 17 00:00:00 2001 From: Mihai Cara Date: Mon, 26 Aug 2024 02:01:32 -0400 Subject: [PATCH 06/10] Address reviewer comments --- src/stcal/resample/resample.py | 32 ++++++++++---------------------- src/stcal/resample/utils.py | 7 +++---- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/src/stcal/resample/resample.py b/src/stcal/resample/resample.py index 90fbd137f..d5e10eb91 100644 --- a/src/stcal/resample/resample.py +++ b/src/stcal/resample/resample.py @@ -146,20 +146,6 @@ class OutputTooLargeError(RuntimeError): """Raised when the output is too large for in-memory instantiation""" -def output_wcs_from_input_wcs(input_wcs_list, pixel_scale_ratio=1.0, - pixel_scale=None, output_shape=None, - crpix=None, crval=None, rotation=None): - # TODO: should be replaced with a version that lives in stcal and - # uses s_region - w = deepcopy(input_wcs_list[0]) # this is bad - return { - 'output_wcs': w, - 'pscale': np.rad2deg(np.sqrt(_compute_image_pixel_area(w))), - 'pscale_ratio': 1.0, - 'crpix': None - } - - class ResampleBase(abc.ABC): """ This is the controlling routine for the resampling process. @@ -334,7 +320,7 @@ def __init__(self, input_models, wcs_pars["pixel_scale"] = self._output_pixel_scale log.info(f'Computed output pixel scale: {self._output_pixel_scale} arcsec.') - w, ps = self._compute_output_wcs(**wcs_pars) + w, ps = self.compute_output_wcs(**wcs_pars) self._output_wcs = w self._output_pixel_scale = ps self._output_array_shape = self._output_wcs.array_shape @@ -410,11 +396,9 @@ def check_memory_requirements(self, allowed_memory): f'Model cannot be instantiated.' ) - def _compute_output_wcs(self, **wcs_pars): - """ returns a diustortion-free WCS object and its pixel scale """ -# owcs = output_wcs_from_input_wcs(self._input_s_region_list, **wcs_pars) - owcs = output_wcs_from_input_wcs([self._input_img1_wcs], **wcs_pars) - return owcs['output_wcs'], owcs['pscale'] + def compute_output_wcs(self, **wcs_pars): + """ returns a tuple of distortion-free WCS object and its pixel scale """ + ... def preload_input_meta(self, wcs1, filename, s_region): # set-up lists for WCS and file names @@ -467,7 +451,7 @@ def build_driz_weight(self, model, weight_type=None, good_bits=None): data = self.get_model_array(model, "data") dq = self.get_model_array(model, "dq") - dqmask = build_mask(dq, good_bits) + dqmask = build_mask(dq, good_bits, flag_name_map=self.dq_flag_name_map) if weight_type and weight_type.startswith('ivm'): weight_type = weight_type.strip() @@ -782,7 +766,11 @@ def create_output_model(self, resample_results): self.final_post_processing() - self._output_model.write(self._output_filename, overwrite=True) + self.write_model( + self._output_model, + self._output_filename, + overwrite=True + ) if self._close_output and not self.in_memory: self.close_model(self._output_model) diff --git a/src/stcal/resample/utils.py b/src/stcal/resample/utils.py index b48d19b0f..c9623b1b0 100644 --- a/src/stcal/resample/utils.py +++ b/src/stcal/resample/utils.py @@ -1,18 +1,17 @@ import numpy as np -from stdatamodels.dqflags import interpret_bit_flags -from stdatamodels.jwst.datamodels.dqflags import pixel +from astropy.nddata.bitmask import interpret_bit_flags __all__ = [ "build_mask", "get_tmeasure", ] -def build_mask(dqarr, bitvalue): +def build_mask(dqarr, bitvalue, flag_name_map=None): """Build a bit mask from an input DQ array and a bitvalue flag In the returned bit mask, 1 is good, 0 is bad """ - bitvalue = interpret_bit_flags(bitvalue, mnemonic_map=pixel) + bitvalue = interpret_bit_flags(bitvalue, flag_name_map=flag_name_map) if bitvalue is None: return np.ones(dqarr.shape, dtype=np.uint8) From 4864e126f37774442648b1b101dec9933b2806e9 Mon Sep 17 00:00:00 2001 From: Mihai Cara Date: Tue, 1 Oct 2024 03:30:05 -0400 Subject: [PATCH 07/10] Refactor previous code to work with arrays only --- src/stcal/resample/__init__.py | 7 +- src/stcal/resample/resample.py | 1731 ++++++++++++++------------------ src/stcal/resample/utils.py | 64 +- 3 files changed, 832 insertions(+), 970 deletions(-) diff --git a/src/stcal/resample/__init__.py b/src/stcal/resample/__init__.py index 1ae898af0..2baa9834a 100644 --- a/src/stcal/resample/__init__.py +++ b/src/stcal/resample/__init__.py @@ -1,9 +1,8 @@ from .resample import * __all__ = [ + "LibModelAccess", "OutputTooLargeError", - "ResampleModelIO", - "ResampleBase", - "ResampleCoAdd", - "ResampleSingle" + "Resample", + "resampled_wcs_from_models", ] diff --git a/src/stcal/resample/resample.py b/src/stcal/resample/resample.py index d5e10eb91..bffb13020 100644 --- a/src/stcal/resample/resample.py +++ b/src/stcal/resample/resample.py @@ -1,10 +1,9 @@ +import abc +from copy import deepcopy import logging import os -import warnings -from copy import deepcopy import sys -import abc -from pathlib import Path, PurePath +import warnings import numpy as np from scipy.ndimage import median_filter @@ -14,24 +13,29 @@ import psutil from spherical_geometry.polygon import SphericalPolygon +from astropy.nddata.bitmask import ( + bitfield_to_boolean_mask, + interpret_bit_flags, +) -from astropy.nddata.bitmask import interpret_bit_flags - -from .utils import get_tmeasure, build_mask +from .utils import bytes2human, get_tmeasure +from ..alignment.util import ( + compute_scale, + wcs_bbox_from_shape, + wcs_from_footprints, +) log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) __all__ = [ + "LibModelAccess", "OutputTooLargeError", - "ResampleModelIO", - "ResampleBase", - "ResampleCoAdd", - "ResampleSingle" + "Resample", + "resampled_wcs_from_models", ] - _SUPPORTED_CUSTOM_WCS_PARS = [ 'pixel_scale_ratio', 'pixel_scale', @@ -42,43 +46,6 @@ ] -# FIXME: temporarily copied here to avoid this import: -# from stdatamodels.jwst.library.basic_utils import bytes2human -def bytes2human(n): - """Convert bytes to human-readable format - - Taken from the `psutil` library which references - http://code.activestate.com/recipes/578019 - - Parameters - ---------- - n : int - Number to convert - - Returns - ------- - readable : str - A string with units attached. - - Examples - -------- - >>> bytes2human(10000) - '9.8K' - - >>> bytes2human(100001221) - '95.4M' - """ - symbols = ('K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y') - prefix = {} - for i, s in enumerate(symbols): - prefix[s] = 1 << (i + 1) * 10 - for s in reversed(symbols): - if n >= prefix[s]: - value = float(n) / prefix[s] - return '%.1f%s' % (value, s) - return "%sB" % n - - def _resample_range(data_shape, bbox=None): # Find range of input pixels to resample: if bbox is None: @@ -95,58 +62,186 @@ def _resample_range(data_shape, bbox=None): return xmin, xmax, ymin, ymax -class ResampleModelIO(abc.ABC): - @abc.abstractmethod - def open_model(self, file_name): - ... +class LibModelAccess(abc.ABC): + # list of model attributes needed by this module. While this is not + # required, it is helpful for subclasses to check they know how to + # access these attributes. + min_supported_attributes = [ + # arrays: + "data", + "dq", + "var_rnoise", + "var_poisson", + "var_flat", + + # meta: + "filename", + "group_id", + "s_region", + "wcsinfo", + "wcs", + + "exposure_time", + "start_time", + "end_time", + "duration", + "measurement_time", + "effective_exposure_time", + "elapsed_exposure_time", + + "pixelarea_steradians", +# "pixelarea_arcsecsq", + + "level", # sky level + "subtracted", + + "weight_type", + "pointings", + "n_coadds", + ] @abc.abstractmethod - def get_model_attr_value(self, model, attribute_name): + def iter_model(self, attributes=None): ... + @property @abc.abstractmethod - def set_model_attr_value(self, model, attribute_name, value): + def n_models(self): ... + @property @abc.abstractmethod - def get_model_meta(self, model, attributes): + def n_groups(self): ... - @abc.abstractmethod - def set_model_meta(self, model, attributes): - ... - @abc.abstractmethod - def get_model_array(self, model, array_name): - ... +def resampled_wcs_from_models( + input_models, + pixel_scale_ratio=1.0, + pixel_scale=None, + output_shape=None, + rotation=None, + crpix=None, + crval=None, +): + """ + Computes the WCS of the resampled image from input models and + specified WCS parameters. - @abc.abstractmethod - def set_model_array(self, model, array_name, data): - ... + Parameters + ---------- - @abc.abstractmethod - def close_model(self, model): - ... + input_models : LibModelAccess + An object of `LibModelAccess`-derived type. + + pixel_scale_ratio : float, optional + Desired pixel scale ratio defined as the ratio of the first model's + pixel scale computed from this model's WCS at the fiducial point + (taken as the ``ref_ra`` and ``ref_dec`` from the ``wcsinfo`` meta + attribute of the first input image) to the desired output pixel + scale. Ignored when ``pixel_scale`` is specified. + + pixel_scale : float, None, optional + Desired pixel scale (in degrees) of the output WCS. When provided, + overrides ``pixel_scale_ratio``. + + output_shape : tuple of two integers (int, int), None, optional + Shape of the image (data array) using ``np.ndarray`` convention + (``ny`` first and ``nx`` second). This value will be assigned to + ``pixel_shape`` and ``array_shape`` properties of the returned + WCS object. + + rotation : float, None, optional + Position angle of output image's Y-axis relative to North. + A value of 0.0 would orient the final output image to be North up. + The default of `None` specifies that the images will not be rotated, + but will instead be resampled in the default orientation for the + camera with the x and y axes of the resampled image corresponding + approximately to the detector axes. Ignored when ``transform`` is + provided. + + crpix : tuple of float, None, optional + Position of the reference pixel in the resampled image array. + If ``crpix`` is not specified, it will be set to the center of the + bounding box of the returned WCS object. + + crval : tuple of float, None, optional + Right ascension and declination of the reference pixel. + Automatically computed if not provided. - @abc.abstractmethod - def save_model(self, model): - ... + Returns + ------- + wcs : ~gwcs.wcs.WCS + The WCS object corresponding to the combined input footprints. - @abc.abstractmethod - def write_model(self, model, file_name, **kwargs): - ... + pscale_in : float + Computed pixel scale (in degrees) of the first input image. - @abc.abstractmethod - def new_model(self, image_shape=None, file_name=None, copy_meta_from=None): - """ Return a new model for the resampled output """ - ... + pscale_out : float + Computed pixel scale (in degrees) of the output image. + + """ + # build a list of WCS of all input models: + wcs_list = [] + ref_wcsinfo = None + for model_info, _ in input_models.iter_model( + attributes=["data", "wcs", "wcsinfo"] + ): + # TODO: is deepcopy necessary? Is ModelLibrary read-only by default? + w = deepcopy(model_info["wcs"]) + if ref_wcsinfo is None: + ref_wcsinfo = model_info["wcsinfo"] + # make sure all WCS objects have the bounding_box defined: + if w.bounding_box is None: + bbox = wcs_bbox_from_shape(model_info["data"].shape) + w.bounding_box = bbox + wcs_list.append(w) + + if output_shape is None: + bounding_box = None + else: + bounding_box = wcs_bbox_from_shape(output_shape) + + pscale_in0 = compute_scale( + wcs_list[0], + fiducial=np.array([ref_wcsinfo["ra_ref"], ref_wcsinfo["dec_ref"]]) + ) + + if pixel_scale is None: + pixel_scale = pscale_in0 / pixel_scale_ratio + log.info( + f"Pixel scale ratio (pscale_in / pscale_out): {pixel_scale_ratio}" + ) + log.info(f"Computed output pixel scale: {3600 * pixel_scale} arcsec.") + else: + pixel_scale_ratio = pscale_in0 / pixel_scale + log.info(f"Output pixel scale: {3600 * pixel_scale} arcsec.") + log.info( + "Computed pixel scale ratio (pscale_in / pscale_out): " + f"{pixel_scale_ratio}." + ) + + wcs = wcs_from_footprints( + wcs_list=wcs_list, + ref_wcs=wcs_list[0], + ref_wcsinfo=ref_wcsinfo, + pscale_ratio=pixel_scale_ratio, + pscale=pixel_scale, + rotation=rotation, + bounding_box=bounding_box, + shape=output_shape, + crpix=crpix, + crval=crval, + ) + + return wcs, pscale_in0, pixel_scale, pixel_scale_ratio class OutputTooLargeError(RuntimeError): """Raised when the output is too large for in-memory instantiation""" -class ResampleBase(abc.ABC): +class Resample: """ This is the controlling routine for the resampling process. @@ -165,21 +260,31 @@ class ResampleBase(abc.ABC): resample_suffix = 'i2d' resample_file_ext = '.fits' n_arrays_per_output = 2 # #flt-point arrays in the output (data, weight, var, err, etc.) + + # supported output arrays (subclasses can add more): + output_array_types = { + "data": np.float32, + "wht": np.float32, + "con": np.int32, + "var_rnoise": np.float32, + "var_flat": np.float32, + "var_poisson": np.float32, + "err": np.float32, + } + dq_flag_name_map = {} - def __init__(self, input_models, - pixfrac=1.0, kernel="square", fillval=0.0, wht_type="ivm", - good_bits=0, output_wcs=None, wcs_pars=None, - enable_ctx=True, - in_memory=True, allowed_memory=None, **kwargs): + def __init__(self, input_models, pixfrac=1.0, kernel="square", + fillval=0.0, wht_type="ivm", good_bits=0, + output_wcs=None, wcs_pars=None, output_model=None, + accumulate=False, enable_ctx=True, enable_var=True, + allowed_memory=None): """ Parameters ---------- - input_models : list of objects - list of data models, one for each input image - - output : str - filename for output + input_models : LibModelAccess + A `LibModelAccess` object allowing iterating over all contained + models of interest. kwargs : dict Other parameters. @@ -187,28 +292,21 @@ def __init__(self, input_models, .. note:: ``output_shape`` is in the ``x, y`` order. - .. note:: - ``in_memory`` controls whether or not the resampled - array from ``resample_many_to_many()`` - should be kept in memory or written out to disk and - deleted from memory. Default value is `True` to keep - all products in memory. """ - self._enable_ctx = enable_ctx + # input models + self._input_models = input_models self._output_model = None - self._output_filename = None self._output_wcs = None - self._output_array_shape = None - self._close_output = False - self._output_pixel_scale = None - self._template_output_model = None + self._enable_ctx = enable_ctx + self._enable_var = enable_var + self._accumulate = accumulate - # input models - self._input_models = input_models - # a lightweight data model with meta from first input model but no data. - # it will be updated by 'prload_input_meta()' below - self._first_model_meta = None + # these are attributes that are used only for information purpose + # and are added to created the output_model only if they are not already + # present there: + self._pixel_scale_ratio = None + self._output_pixel_scale = None # resample parameters self.pixfrac = pixfrac @@ -216,9 +314,10 @@ def __init__(self, input_models, self.fillval = fillval self.weight_type = wht_type self.good_bits = good_bits - self.in_memory = in_memory - self._user_output_wcs = output_wcs + self._output_wcs = output_wcs + + self.input_file_names = [] # check wcs_pars has supported keywords: if wcs_pars is None: @@ -231,161 +330,308 @@ def __init__(self, input_models, "Unsupported custom WCS parameters: " f"{','.join(map(repr, unsup))}." ) - # WCS parameters (should be deleted once not needed; - # once an output WCS was created) - self._wcs_pars = wcs_pars - - # process additional kwags specific to subclasses and store - # unprocessed/unrecognized kwargs in ukwargs and warn about these - # unrecognized kwargs - ukwargs = self.process_kwargs(kwargs) - self._warn_extra_kwargs(ukwargs) - - # load meta necessary for output WCS (and other) computations: - self.preload_input_meta( - wcs1=True, - filename=self._output_model is None, - s_region=output_wcs is None, - ) - # computed average pixel scale of the first input image: - input_pscale0 = np.rad2deg( - np.sqrt(_compute_image_pixel_area(self._input_img1_wcs)) - ) + # determine output WCS and set-up output model if needed: + if output_model is None: + if output_wcs is None: + output_wcs, _, ps, ps_ratio = resampled_wcs_from_models( + input_models, + pixel_scale_ratio=wcs_pars.get("pixel_scale_ratio", 1.0), + pixel_scale=wcs_pars.get("pixel_scale"), + output_shape=wcs_pars.get("output_shape"), + rotation=wcs_pars.get("rotation"), + crpix=wcs_pars.get("crpix"), + crval=wcs_pars.get("crval"), + ) + self._output_pixel_scale = ps # degrees + self._pixel_scale_ratio = ps_ratio + else: + self.check_output_wcs(output_wcs, wcs_pars) + self._output_pixel_scale = np.rad2deg( + np.sqrt(_compute_image_pixel_area(output_wcs)) + ) + log.info( + "Computed output pixel scale: " + f"{3600 * self._output_pixel_scale} arcsec." + ) + + self._output_wcs = output_wcs - # compute output pixel scale, WCS, set-up output model - if self._output_model: - self._output_wcs = deepcopy( - self.get_model_attr_value(self._output_model, "wcs") + else: + self.validate_output_model( + output_model=output_model, + output_wcs=output_wcs, + accumulate=accumulate, + enable_ctx=enable_ctx, + enable_var=enable_var, ) - self._output_array_shape = self.get_model_array( - self._output_model, - "data" - ).shape - # TODO: extract any useful info from the output image before we close it: - # if meta has pixel scale, populate it from there, if not: + self._output_model = output_model + self._output_wcs = output_model["wcs"] + if output_wcs: + log.warning( + "'output_wcs' will be ignored. Using the 'wcs' supplied " + "by the 'output_model' instead." + ) self._output_pixel_scale = np.rad2deg( - np.sqrt(_compute_image_pixel_area(self._output_wcs)) + np.sqrt(_compute_image_pixel_area(output_wcs)) ) - self._pixel_scale_ratio = self._output_pixel_scale / input_pscale0 - log.info(f'Computed output pixel scale: {self._output_pixel_scale} arcsec.') - - self._create_output_template_model() # create template before possibly closing output - if self._close_output and not self.in_memory: - self.close_model(self._output_model) - self._output_model = None - - elif output_wcs: - naxes = output_wcs.output_frame.naxes - if naxes != 2: - raise RuntimeError( - "Output WCS needs 2 spatial axes but the " - f"supplied WCS has {naxes} axes." - ) - self._output_wcs = deepcopy(output_wcs) + self._pixel_scale_ratio = output_model.get("wcs", None) + log.info( + "Computed output pixel scale: " + f"{3600 * self._output_pixel_scale} arcsec." + ) + + self._output_array_shape = self._output_wcs.array_shape + + # Check that the output data shape has no zero length dimensions + npix = np.prod(self._output_array_shape) + if not npix: + raise ValueError( + f"Invalid output frame shape: {tuple(self._output_array_shape)}" + ) + + # set up output model (arrays, etc.) + if self._output_model is None: + self._output_model = self.create_output_model( + allowed_memory=allowed_memory + ) + + self._group_ids = [] + + log.info(f"Driz parameter kernel: {self.kernel}") + log.info(f"Driz parameter pixfrac: {self.pixfrac}") + log.info(f"Driz parameter fillval: {self.fillval}") + log.info(f"Driz parameter weight_type: {self.weight_type}") + + log.debug(f"Output mosaic size: {self._output_wcs.pixel_shape}") + + def check_output_wcs(self, output_wcs, wcs_pars, + estimate_output_shape=True): + """ + Check that provided WCS has expected properties and that its + ``array_shape`` property is defined. + + """ + naxes = output_wcs.output_frame.naxes + if naxes != 2: + raise RuntimeError( + "Output WCS needs 2 spatial axes but the " + f"supplied WCS has {naxes} axes." + ) + + # make sure array_shape and pixel_shape are set: + if output_wcs.array_shape is None and estimate_output_shape: if wcs_pars and "output_shape" in wcs_pars: - self._output_array_shape = wcs_pars["output_shape"] + output_wcs.array_shape = wcs_pars["output_shape"] else: - self._output_array_shape = self._output_wcs.array_shape - if not self._output_array_shape and output_wcs.bounding_box: + if output_wcs.bounding_box: halfpix = 0.5 + sys.float_info.epsilon - self._output_array_shape = ( + output_wcs.array_shape = ( int(output_wcs.bounding_box[1][1] + halfpix), int(output_wcs.bounding_box[0][1] + halfpix), ) else: + # TODO: In principle, we could compute footprints of all + # input models, convert them to image coordinates using + # `output_wcs`, and then take max(x_i), max(y_i) as + # output image size. raise ValueError( - "Unable to infer output image size from provided inputs." + "Unable to infer output image size from provided " + "inputs." ) - self._output_wcs.array_shape = self._output_array_shape - self._output_pixel_scale = np.rad2deg( - np.sqrt(_compute_image_pixel_area(self._output_wcs)) + @classmethod + def output_model_attributes(cls, accumulate, enable_ctx, enable_var): + """ + Returns a set of string keywords that must be present in an + 'output_model' that is provided as input at the class initialization. + + """ + # always required: + attributes = { + "data", + "wcs", + "wht", + } + + if enable_ctx: + attributes.add("con") + if enable_var: + attributes.update( + ["var_rnoise", "var_poisson", "var_flat", "err"] + ) + if accumulate: + if enable_ctx: + attributes.add("n_coadds") + + # additional attributes required for input parameter 'output_model' + # when data and weight arrays are not None: + attributes.update( + { + "pixfrac", + "kernel", + "fillval", + "weight_type", + "pointings", + "exposure_time", + "measurement_time", + "start_time", + "end_time", + "duration", + } ) - self._pixel_scale_ratio = self._output_pixel_scale / input_pscale0 - log.info(f'Computed output pixel scale: {self._output_pixel_scale} arcsec.') - self._create_output_template_model() - else: - # build output WCS and calculate ouput image shape - if "pixel_scale" in wcs_pars and wcs_pars['pixel_scale'] is not None: - self._pixel_scale_ratio = wcs_pars["pixel_scale"] / input_pscale0 - log.info(f'Output pixel scale: {wcs_pars["pixel_scale"]} arcsec.') - log.info(f'Computed output pixel scale ratio: {self._pixel_scale_ratio}.') - else: - self._pixel_scale_ratio = wcs_pars.get("pixel_scale_ratio", 1.0) - log.info(f'Output pixel scale ratio: {self._pixel_scale_ratio}') - self._output_pixel_scale = input_pscale0 * self._pixel_scale_ratio - wcs_pars = wcs_pars.copy() - wcs_pars["pixel_scale"] = self._output_pixel_scale - log.info(f'Computed output pixel scale: {self._output_pixel_scale} arcsec.') - - w, ps = self.compute_output_wcs(**wcs_pars) - self._output_wcs = w - self._output_pixel_scale = ps - self._output_array_shape = self._output_wcs.array_shape - self._create_output_template_model() + return attributes - # Check that the output data shape has no zero length dimensions - npix = np.prod(self._output_array_shape) - if not npix: + def validate_output_model(self, output_model, accumulate, + enable_ctx, enable_var): + if output_model is None: + if accumulate: + raise ValueError( + "'output_model' must be defined when 'accumulate' is True." + ) + return + + required_attributes = self.output_model_attributes( + accumulate=accumulate, + enable_ctx=enable_ctx, + enable_var=enable_var, + ) + + for attr in required_attributes: + if attr not in output_model: + raise ValueError( + f"'output_model' dictionary must have '{attr}' set." + ) + + model_wcs = output_model["wcs"] + self.check_output_wcs(model_wcs, estimate_output_shape=False) + wcs_shape = model_wcs.array_shape + ref_shape = output_model["data"].shape + if accumulate and wcs_shape is None: raise ValueError( - f"Invalid output frame shape: {tuple(self._output_array_shape)}" + "Output model's 'wcs' must have 'array_shape' attribute " + "set when 'accumulate' parameter is True." ) - assert self._pixel_scale_ratio - log.info(f"Driz parameter kernel: {self.kernel}") - log.info(f"Driz parameter pixfrac: {self.pixfrac}") - log.info(f"Driz parameter fillval: {self.fillval}") - log.info(f"Driz parameter weight_type: {self.weight_type}") + if not np.array_equiv(wcs_shape, ref_shape): + raise ValueError( + "Output model's 'wcs.array_shape' value is not consistent " + "with the shape of the data array." + ) + + for attr in required_attributes.difference(["data", "wcs"]): + if (isinstance(output_model[attr], np.ndarray) and + not np.array_equiv(output_model[attr].shape, ref_shape)): + raise ValueError( + "'output_wcs.array_shape' value is not consistent " + f"with the shape of the '{attr}' array." + ) + + # TODO: also check "pixfrac", "kernel", "fillval", "weight_type" + # with initializer parameters. log a warning if different. - self.check_memory_requirements(allowed_memory) + def create_output_model(self, allowed_memory): + """ Create a new "output model": a dictionary of data and meta fields. + Check that there is enough memory to hold all arrays. + """ + assert self._output_wcs is not None + assert np.array_equiv( + self._output_wcs.array_shape, + self._output_array_shape + ) + assert self._output_pixel_scale - log.debug('Output mosaic size: {}'.format(self._output_wcs.pixel_shape)) + pix_area = self._output_pixel_scale**2 + + output_model = { + # WCS: + "wcs": deepcopy(self._output_wcs), + + # main arrays: + "data": None, + "wht": None, + "con": None, + + # resample parameters: + "pixfrac": self.pixfrac, + "kernel": self.kernel, + "fillval": self.fillval, + "weight_type": self.weight_type, + + # accumulate-specific: + "n_coadds": 0, + + # pixel scale: + "pixelarea_steradians": pix_area, + "pixelarea_arcsecsq": pix_area * np.rad2deg(3600)**2, + "pixel_scale_ratio": self._pixel_scale_ratio, + + # drizzle info: + "pointings": 0, + + # exposure time: + "exposure_time": 0.0, + "measurement_time": None, + "start_time": None, + "end_time": None, + "duration": 0.0, + } + + if self._enable_var: + output_model.update( + { + "var_rnoise": None, + "var_flat": None, + "var_poisson": None, + "err": None, + } + ) + + if allowed_memory: + self.check_memory_requirements(list(output_model), allowed_memory) + + return output_model @property def output_model(self): return self._output_model - def process_kwargs(self, kwargs): - """ A method called by ``__init__`` to process input kwargs before - output WCS is created and before output model template is created. - - Returns - ------- - kwargs : dict - Unrecognized/not processed ``kwargs``. + @property + def output_array_shape(self): + return self._output_array_shape - """ - return {k : v for k, v in kwargs.items()} + @property + def group_ids(self): + return self._group_ids - def _warn_extra_kwargs(self, kwargs): - for k in kwargs: - log.warning(f"Unrecognized argument '{k}' will be ignored.") + def check_memory_requirements(self, arrays, allowed_memory): + """ Called just before `create_output_model` returns to verify + that there is enough memory to hold the output. - def check_memory_requirements(self, allowed_memory): - """ Called just before '_pre_run_callback()' is called to verify - that there is enough memory to hold the output. """ + """ if allowed_memory is None and "DMODEL_ALLOWED_MEMORY" not in os.environ: return allowed_memory = float(allowed_memory) # get the available memory - available_memory = psutil.virtual_memory().available + psutil.swap_memory().total - - # determine data type of the output model: - out_model = self.new_model((2, 2)) - data = self.get_model_array(out_model) - data_type = data.dtype - del data, out_model + available_memory = ( + psutil.virtual_memory().available + psutil.swap_memory().total + ) # compute the output array size npix = npix = np.prod(self._output_array_shape) nmodels = len(self._input_models) - nconpl = nmodels // 32 + (1 if nmodels % 32 else 0) - n_arr = self.n_arrays_per_output + 2 # 2 comes from pixmap - required_memory = npix * (n_arr * data_type.itemsize + nconpl * 4) + nconpl = nmodels // 32 + (1 if nmodels % 32 else 0) # #context planes + required_memory = 0 + for arr in arrays: + if arr in self.output_array_types: + f = nconpl if arr == "con" else 1 + required_memory += f * self.output_array_types[arr].itemsize + # add pixmap itemsize: + required_memory += 2 * np.dtype(float).itemsize + required_memory *= npix # compare used to available used_fraction = required_memory / available_memory @@ -396,62 +642,18 @@ def check_memory_requirements(self, allowed_memory): f'Model cannot be instantiated.' ) - def compute_output_wcs(self, **wcs_pars): - """ returns a tuple of distortion-free WCS object and its pixel scale """ - ... - - def preload_input_meta(self, wcs1, filename, s_region): - # set-up lists for WCS and file names - self._input_img1_wcs = None - self._input_s_region_list = [] - self._input_filename_list = [] - - # loop over only science exposures in the ModelLibrary - # sci_indices = self._input_models.ind_asn_type("science") - with self._input_models: - for model in self._input_models: - # model = self._input_models.borrow(idx) - - try: - if self.get_model_attr_value(model, "exptype").upper() != "SCIENCE": - self._input_models.shelve(model, modify=False) - continue - except AttributeError: - pass - - if self._input_img1_wcs is None and wcs1: - # extract all info needed from *this* model: - self._input_img1_wcs = deepcopy( - self.get_model_attr_value(model, "wcs") - ) - self._input_img1_wcs.array_shape = self.get_model_array( - model, - "data" - ).shape - - if filename: - self._input_filename_list.append( - self.get_model_attr_value(model, "filename") - ) - - if s_region: - self._input_s_region_list.append( - self.get_model_attr_value(model, "s_region") - ) - - self._input_models.shelve(model, modify=False) - - # store first model's entire meta (except for WCS and data): - if self._first_model_meta is None: - self._first_model_meta = self.new_model(copy_meta_from=model) - - def build_driz_weight(self, model, weight_type=None, good_bits=None): - """Create a weight map for use by drizzle - """ - data = self.get_model_array(model, "data") - dq = self.get_model_array(model, "dq") - - dqmask = build_mask(dq, good_bits, flag_name_map=self.dq_flag_name_map) + def build_driz_weight(self, model_info, weight_type=None, good_bits=None): + """Create a weight map for use by drizzle. """ + data = model_info["data"] + dq = model_info["dq"] + + dqmask = bitfield_to_boolean_mask( + dq, + good_bits, + good_mask_value=1, + dtype=np.uint8, + flag_name_map=self.dq_flag_name_map, + ) if weight_type and weight_type.startswith('ivm'): weight_type = weight_type.strip() @@ -466,14 +668,14 @@ def build_driz_weight(self, model, weight_type=None, good_bits=None): # disable selective median if SATURATED flag is included # in good_bits: try: - saturation = self.dq_flag_name_map['SATURATED'] + saturation = self.dq_flag_name_map["SATURATED"] if selective_median and not (bitvalue & saturation): selective_median = False weight_type = 'ivm' except AttributeError: pass - var_rnoise = self.get_model_array(model, "var_rnoise", default=None) + var_rnoise = model_info["var_rnoise"] if (var_rnoise is not None and var_rnoise.shape == data.shape): with np.errstate(divide="ignore", invalid="ignore"): inv_variance = var_rnoise**-1 @@ -535,8 +737,8 @@ def build_driz_weight(self, model, weight_type=None, good_bits=None): result = inv_variance * dqmask - elif weight_type == 'exptime': - exptime = self.get_model_attr_value(model, "exposure_time") + elif weight_type == "exptime": + exptime = model_info["exposure_time"] result = exptime * dqmask else: @@ -544,291 +746,276 @@ def build_driz_weight(self, model, weight_type=None, good_bits=None): return result.astype(np.float32) - @abc.abstractmethod - def run(self): - ... + def init_time_info(self): + """ Initialize variables/arrays needed to process exposure time. """ + self._t_used_group_id = [] - def _create_output_template_model(self): - pass - - def update_exposure_times(self): - """Modify exposure time metadata in-place""" - total_exposure_time = 0. - exptime_start = [] - exptime_end = [] - duration = 0.0 - total_exptime = 0.0 - measurement_time_success = [] - - for exposure in self._input_models.group_indices.values(): - with self._input_models: - model0 = self._input_models.borrow(exposure[0]) - attrs = self.get_model_meta( - model0, - ["exposure_time", "start_time", "end_time", "duration"] - ) + self._total_exposure_time = self.output_model["exposure_time"] + self._duration = self.output_model["duration"] + self._total_measurement_time = self.output_model["measurement_time"] + if self._total_measurement_time is None: + self._total_measurement_time = 0.0 - t, success = get_tmeasure(model0) - self._input_models.shelve(model0, exposure[0]) + if (start_time := self.output_model.get("start_time", None)) is None: + self._exptime_start = [] + else: + self._exptime_start[start_time] - total_exposure_time += attrs["exposure_time"] - measurement_time_success.append(success) - total_exptime += t - exptime_start.append(attrs["start_time"]) - exptime_end.append(attrs["end_time"]) - duration += attrs["duration"] + if (end_time := self.output_model.get("end_time", None)) is None: + self._exptime_end = [] + else: + self._exptime_end[end_time] - attrs = { - # basic exposure time attributes: - "exposure_time": total_exposure_time, - "start_time": min(exptime_start), - "end_time": max(exptime_end), - # Update other exposure time keywords: - # XPOSURE (identical to the total effective exposure time, EFFEXPTM) - "effective_exposure_time": total_exptime, - # DURATION (identical to TELAPSE, elapsed time) - "duration": duration, - "elapsed_exposure_time": duration, - } + self._measurement_time_success = [] - if not all(measurement_time_success): - attrs["measurement_time"] = total_exptime + def update_total_time(self, model_info): + """ A method called by the `~ResampleBase.run` method to process each + image's time attributes. - self.set_model_meta(self._output_model, attrs) + """ + if (group_id := model_info["group_id"]) in self._t_used_group_id: + return + self._t_used_group_id.append(group_id) + self._exptime_start.append(model_info["start_time"]) + self._exptime_end.append(model_info["end_time"]) -class ResampleCoAdd(ResampleBase): - """ - This is the controlling routine for the resampling process. + t, success = get_tmeasure(model_info) + self._total_exposure_time += model_info["exposure_time"] + self._measurement_time_success.append(success) + self._total_measurement_time += t - Notes - ----- - This routine performs the following operations:: + self._duration += model_info["duration"] - 1. Extracts parameter settings from input model, such as pixfrac, - weight type, exposure time (if relevant), and kernel, and merges - them with any user-provided values. - 2. Creates output WCS based on input images and define mapping function - between all input arrays and the output array. - 3. Updates output data model with output arrays from drizzle, including - a record of metadata from all input models. - """ + def finalize_time_info(self): + """ Perform final computations for the total time and update relevant + fileds of the output model. - def __init__(self, input_models, output, accum=False, - pixfrac=1.0, kernel="square", fillval=0.0, wht_type="ivm", - good_bits=0, output_wcs=None, wcs_pars=None, - enable_ctx=True, enable_err=True, - in_memory=True, allowed_memory=None): """ - Parameters - ---------- - input_models : list of objects - list of data models, one for each input image + attrs = { + # basic exposure time attributes: + "exposure_time": self._total_exposure_time, + "start_time": min(self._exptime_start), + "end_time": max(self._exptime_end), + # Update other exposure time keywords: + # XPOSURE (identical to the total effective exposure time, EFFEXPTM) + "effective_exposure_time": self._total_exposure_time, + # DURATION (identical to TELAPSE, elapsed time) + "duration": self._duration, + "elapsed_exposure_time": self._duration, + } - output : DataModel, str - filename for output + if all(self._measurement_time_success): + attrs["measurement_time"] = self._total_measurement_time - kwargs : dict - Other parameters. + self._output_model.update(attrs) - .. note:: - ``output_shape`` is in the ``x, y`` order. + def init_resample_data(self): + """ Create a `Drizzle` object to process image data. """ + om = self._output_model - .. note:: - ``in_memory`` controls whether or not the resampled - array from ``resample_many_to_many()`` - should be kept in memory or written out to disk and - deleted from memory. Default value is `True` to keep - all products in memory. - """ - self._accum = accum - - super().__init__( - input_models=input_models, - pixfrac=pixfrac, - kernel=kernel, - fillval=fillval, - wht_type=wht_type, - good_bits=good_bits, - output_wcs=output_wcs, - wcs_pars=wcs_pars, - enable_ctx=enable_ctx, - in_memory=in_memory, - allowed_memory=allowed_memory, - output=output, + self.driz_data = Drizzle( + kernel=self.kernel, + fillval=self.fillval, + out_shape=self._output_array_shape, + out_img=om["data"], + out_wht=om["wht"], + out_ctx=om["con"], + exptime=om["exposure_time"], + begin_ctx_id=om["n_coadds"], + max_ctx_id=om["n_coadds"] + self._input_models.n_models, ) - self._enable_err = enable_err - def process_kwargs(self, kwargs): - """ A method called by ``__init__`` to process input kwargs before - output WCS is created and before output model template is created. + def init_resample_variance(self): + """ Create a `Drizzle` objects to process image variance arrays. """ + self._var_rnoise_sum = np.full(self._output_array_shape, np.nan) + self._var_poisson_sum = np.full(self._output_array_shape, np.nan) + self._var_flat_sum = np.full(self._output_array_shape, np.nan) + # self._total_weight_var_rnoise = np.zeros(self._output_array_shape) + self._total_weight_var_poisson = np.zeros(self._output_array_shape) + self._total_weight_var_flat = np.zeros(self._output_array_shape) + + # create resample objects for the three variance arrays: + driz_init_kwargs = { + 'kernel': self.kernel, + 'fillval': np.nan, + 'out_shape': self._output_array_shape, + # 'exptime': 1.0, + 'no_ctx': True, + } + self.driz_rnoise = Drizzle(**driz_init_kwargs) + self.driz_poisson = Drizzle(**driz_init_kwargs) + self.driz_flat = Drizzle(**driz_init_kwargs) + + def _check_var_array(self, model_info, array_name): + """ Check that a variance array has the same shape as the model's + data array. + """ - kwargs = super().process_kwargs(kwargs) - output = kwargs.pop("output", None) - accum = kwargs.pop("accum", False) - - # Load the model if accum is True - if isinstance(output, str): - self._output_filename = output - if accum: - try: - self._output_model = self.open_model(output) - self._close_output = True - log.info( - "Output model has been loaded and it will be used to " - "accumulate new data." - ) - if self._user_output_wcs: - log.info( - "'output_wcs' will be ignored when 'output' is " - "provided and accum=True" - ) - if self._wcs_pars: - log.info( - "'wcs_pars' will be ignored when 'output' is " - "provided and accum=True" - ) - except FileNotFoundError: - pass + array_data = model_info.get(array_name, None) + sci_data = model_info["data"] + model_name = _get_model_name(model_info) - elif output is not None: - self._output_filename = self.get_model_attr_value( - output, - "filename" + if array_data is None or array_data.size == 0: + log.debug( + f"No data for '{array_name}' for model " + f"{repr(model_name)}. Skipping ..." ) - self._output_model = output - self._close_output = False + return False - return kwargs + elif array_data.shape != sci_data.shape: + log.warning( + f"Data shape mismatch for '{array_name}' for model " + f"{repr(model_name)}. Skipping ..." + ) + return False - def _create_new_output_model(self): - # this probably needs to be an abstract class. - # also this is mostly needed for "single" drizzle. - output_model = self.new_model( - None, - copy_meta_from=self._first_model_meta - ) + return True - # update meta data and wcs - pix_area = self._output_pixel_scale**2 - self.set_model_meta( - output_model, - { - "wcs": deepcopy(self._output_wcs), - "pixelarea_steradians": pix_area, - "pixelarea_arcsecsq": pix_area * np.rad2deg(3600)**2, - } - ) + def add_model(self, model_info, image_model): + """ Resample and add data (variance, etc.) arrays to the output arrays. - return output_model + Parameters + ---------- - def build_output_model_name(self): - fnames = {f for f in self._input_filename_list if f is not None} + model_info : dict + A dictionary with data extracted from an image model needed for + `Resample` to successfully process this model. - if not fnames: - return "resampled_data_{resample_suffix}{resample_file_ext}" + image_model : object + The original data model from which ``model`` data was extracted. + It is not used by this method in this class but can be used + by pipeline-specific subclasses to perform additional processing + such as blend headers. - # TODO: maybe remove ending suffix for single file names? - prefix = os.path.commonprefix( - [PurePath(f).stem.strip('_- ') for f in fnames] - ) + """ + in_data = model_info["data"] - return prefix + "{resample_suffix}{resample_file_ext}" + if (group_id := model_info["group_id"]) not in self.group_ids: + self.group_ids.append(group_id) + self.output_model["pointings"] += 1 - def create_output_model(self, resample_results): - # this probably needs to be an abstract class (different telescopes - # may want to save different arrays and ignore others). + self.input_file_names.append(model_info["filename"]) - if not self._output_model and self._output_filename: - if self._accum and Path(self._output_filename).is_file(): - self._output_model = self.open_model(self._output_filename) - else: - self._output_model = self._create_new_output_model() - self._close_output = not self.in_memory + # Check that input models are 2D images + if in_data.ndim != 2: + raise RuntimeError( + f"Input model {_get_model_name(model_info)} " + "is not a 2D image." + ) - if self._output_filename is None: - self._output_filename = self.build_output_model_name() + input_pixflux_area = model_info["pixelarea_steradians"] + imwcs = model_info["wcs"] + if (input_pixflux_area and + 'SPECTRAL' not in imwcs.output_frame.axes_type): + if not np.array_equiv(imwcs.array_shape, in_data.shape): + imwcs.array_shape = in_data.shape + input_pixel_area = _compute_image_pixel_area(imwcs) + if input_pixel_area is None: + model_name = model_info["filename"] + if not model_name: + model_name = "Unknown" + raise ValueError( + "Unable to compute input pixel area from WCS of input " + f"image {repr(model_name)}." + ) + iscale = np.sqrt(input_pixflux_area / input_pixel_area) + else: + iscale = 1.0 - self.set_model_array(self._output_model, "data", resample_results.out_img) + # TODO: should weight_type=None here? + in_wht = self.build_driz_weight( + model_info, + weight_type=self.weight_type, + good_bits=self.good_bits + ) - self.update_exposure_times() - if self._enable_err: - self._finish_variance_processing() + # apply sky subtraction + blevel = model_info["level"] + if not model_info["subtracted"] and blevel is not None: + in_data = in_data - blevel - self.set_model_meta( - self._output_model, - { - "weight_type": self.weight_type, - "pointings": len(self._input_models.group_names), - } + xmin, xmax, ymin, ymax = _resample_range( + in_data.shape, + imwcs.bounding_box ) - # TODO: also store the number of images added in total: ncoadds? - self.final_post_processing() + pixmap = calc_pixmap(wcs_from=imwcs, wcs_to=self._output_wcs) + + add_image_kwargs = { + 'exptime': model_info["exposure_time"], + 'pixmap': pixmap, + 'scale': iscale, + 'weight_map': in_wht, + 'wht_scale': 1.0, + 'pixfrac': self.pixfrac, + 'in_units': 'cps', # TODO: get units from data model + 'xmin': xmin, + 'xmax': xmax, + 'ymin': ymin, + 'ymax': ymax, + } - self.write_model( - self._output_model, - self._output_filename, - overwrite=True - ) + self.driz_data.add_image(in_data, **add_image_kwargs) - if self._close_output and not self.in_memory: - self.close_model(self._output_model) - self._output_model = None - return self._output_filename + if self._enable_var: + self.resample_variance_data(model_info, add_image_kwargs) - return self._output_model + def run(self): + """ Resample and coadd many inputs to a single output. - def _setup_variance_data(self): - self._var_rnoise_sum = np.full(self._output_array_shape, np.nan) - self._var_poisson_sum = np.full(self._output_array_shape, np.nan) - self._var_flat_sum = np.full(self._output_array_shape, np.nan) - # self._total_weight_var_rnoise = np.zeros(self._output_array_shape) - self._total_weight_var_poisson = np.zeros(self._output_array_shape) - self._total_weight_var_flat = np.zeros(self._output_array_shape) + 1. Call methods that initialize data, variance, and time computations. + 2. Add input images (data, variances, etc) to output arrays. + 3. Perform final computations to compute variance and error + arrays and total expose time information for the resampled image. - def _check_var_array(self, data_model, array_name): - array_data = self.get_model_array(data_model, array_name, default=None) - sci_data = self.get_model_array(data_model, "data", default=None) - filename = self.get_model_attr_value(data_model, "filename") + """ + self.init_time_info() + self.init_resample_data() + if self._enable_var: + self.init_resample_variance() + + for model_info, image_model in self._input_models.iter_model(): + self.add_model(model_info, image_model) + self.update_total_time(model_info) + + # assign resampled arrays to the output model dictionary: + self._output_model["data"] = self.driz_data.out_img.astype( + dtype=self.output_array_types["data"] + ) + self._output_model["wht"] = self.driz_data.out_wht.astype( + dtype=self.output_array_types["wht"] + ) - if array_data is None or array_data.size == 0: - log.debug( - f"No data for '{array_name}' for model " - f"{repr(filename)}. Skipping ..." + if self._enable_ctx: + self._output_model["con"] = self.driz_data.out_ctx.astype( + dtype=self.output_array_types["con"] ) - return False - elif array_data.shape != sci_data.shape: - log.warning( - f"Data shape mismatch for '{array_name}' for model " - f"{repr(filename)}. Skipping ..." - ) - return False + if self._enable_var: + self.finalize_variance_processing() + self.compute_errors() - return True + self.finalize_time_info() - def _resample_variance_data(self, data_model, driz_init_kwargs, add_image_kwargs): - log.info("Resampling variance components") + def resample_variance_data(self, data_model, add_image_kwargs): + """ Resample and add input model's variance arrays to the output + vararrays. - # create resample objects for the three variance arrays: - driz_init_kwargs = { - 'kernel': self.kernel, - 'fillval': np.nan, - 'out_shape': self._output_array_shape, - # 'exptime': 1.0, - 'no_ctx': True, - } - driz_rnoise = Drizzle(**driz_init_kwargs) - driz_poisson = Drizzle(**driz_init_kwargs) - driz_flat = Drizzle(**driz_init_kwargs) + """ + log.info("Resampling variance components") # Resample read-out noise and compute weight map for variance arrays if self._check_var_array(data_model, 'var_rnoise'): - data = self.get_model_array(data_model, "var_rnoise") + data = data_model["var_rnoise"] data = np.sqrt(data) - driz_rnoise.add_image(data, **add_image_kwargs) - var = driz_rnoise.out_img + + # reset driz output arrays: + self.driz_rnoise.out_img[:, :] = self.driz_rnoise.fillval + self.driz_rnoise.out_wht[:, :] = 0.0 + + self.driz_rnoise.add_image(data, **add_image_kwargs) + var = self.driz_rnoise.out_img np.square(var, out=var) weight_mask = var > 0 @@ -869,50 +1056,38 @@ def _resample_variance_data(self, data_model, driz_init_kwargs, add_image_kwargs weight = np.ones(self._output_array_shape) weight_mask = np.ones(self._output_array_shape, dtype=bool) - if self._check_var_array(data_model, 'var_poisson'): - data = self.get_model_array(data_model, "var_poisson") + for var_name in ["var_flat", "var_poisson"]: + if not self._check_var_array(data_model, var_name): + continue + data = data_model[var_name] data = np.sqrt(data) - driz_poisson.add_image(data, **add_image_kwargs) - var = driz_poisson.out_img - np.square(var, out=var) - mask = (var > 0) & weight_mask + driz = getattr(self, var_name.replace("var", "driz")) + var_sum = getattr(self, f"_{var_name}_sum") + t_var_weight = getattr(self, f"_total_weight_{var_name}") - # Add the inverse of the resampled variance to a running sum. - # Update only pixels (in the running sum) with valid new values: - self._var_poisson_sum[mask] = np.nansum( - [ - self._var_poisson_sum[mask], - var[mask] * weight[mask] * weight[mask] - ], - axis=0 - ) - self._total_weight_var_poisson[mask] += weight[mask] + # reset driz output arrays: + driz.out_img[:, :] = driz.fillval + driz.out_wht[:, :] = 0.0 - if self._check_var_array(data_model, 'var_flat'): - data = self.get_model_array(data_model, "var_flat") - data = np.sqrt(data) - driz_flat.add_image(data, **add_image_kwargs) - var = driz_flat.out_img + driz.add_image(data, **add_image_kwargs) + var = driz.out_img np.square(var, out=var) mask = (var > 0) & weight_mask # Add the inverse of the resampled variance to a running sum. # Update only pixels (in the running sum) with valid new values: - self._var_flat_sum[mask] = np.nansum( + var_sum[mask] = np.nansum( [ - self._var_flat_sum[mask], + var_sum[mask], var[mask] * weight[mask] * weight[mask] ], axis=0 ) - self._total_weight_var_flat[mask] += weight[mask] - - def final_post_processing(self): - pass + t_var_weight[mask] += weight[mask] - def _finish_variance_processing(self): + def finalize_variance_processing(self): # We now have a sum of the weighted resampled variances. # Divide by the total weights, squared, and set in the output model. # Zero weight and missing values are NaN in the output. @@ -920,15 +1095,16 @@ def _finish_variance_processing(self): warnings.filterwarnings("ignore", "invalid value*", RuntimeWarning) warnings.filterwarnings("ignore", "divide by zero*", RuntimeWarning) - odt = self.get_model_array(self._output_model, "data").dtype - # readout noise np.reciprocal(self._var_rnoise_sum, out=self._var_rnoise_sum) - self.set_model_array( - self._output_model, - "var_rnoise", - self._var_rnoise_sum.astype(dtype=odt) - ) + if self._accumulate and self._output_model["var_rnoise"]: + self._output_model["var_rnoise"] += self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_rnoise"] + ) + else: + self._output_model["var_rnoise"] = self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_rnoise"] + ) # Poisson noise for _ in range(2): @@ -937,11 +1113,15 @@ def _finish_variance_processing(self): self._total_weight_var_poisson, out=self._var_poisson_sum ) - self.set_model_array( - self._output_model, - "var_poisson", - self._var_poisson_sum.astype(dtype=odt) - ) + + if self._accumulate and self._output_model["var_poisson"]: + self._output_model["var_poisson"] += self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_poisson"] + ) + else: + self._output_model["var_poisson"] = self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_poisson"] + ) # flat's noise for _ in range(2): @@ -950,420 +1130,42 @@ def _finish_variance_processing(self): self._total_weight_var_flat, out=self._var_flat_sum ) - self.set_model_array( - self._output_model, - "var_flat", - self._var_flat_sum.astype(dtype=odt) - ) - - # compute total error: - vars = np.array( - [ - self._var_rnoise_sum, - self._var_poisson_sum, - self._var_flat_sum, - ] - ) - all_nan_mask = np.any(np.isnan(vars), axis=0) - - err = np.sqrt(np.nansum(vars, axis=0)).astype(dtype=odt) - err[all_nan_mask] = np.nan - self.set_model_array(self._output_model, "err", err) - self.set_model_array(self._output_model, "var_rnoise", self._var_rnoise_sum) - self.set_model_array(self._output_model, "var_poisson", self._var_poisson_sum) - self.set_model_array(self._output_model, "var_flat", self._var_flat_sum) - - del vars - del self._var_rnoise_sum - del self._var_poisson_sum - del self._var_flat_sum - # del self._total_weight_var_rnoise - del self._total_weight_var_poisson - del self._total_weight_var_flat - - def run(self): - """Resample and coadd many inputs to a single output. - - Used for stage 3 resampling - """ - - # TODO: repetiveness of code below should be compactified via using - # getattr as in orig code and maybe making an alternative method to - # the original resample_variance_array - ninputs = len(self._input_models) - - do_accum = ( - self._accum and - ( - self._output_model or - (self._output_filename and Path(self._output_filename).is_file()) - ) - ) - - if do_accum and self._output_model is None: - self._output_model = self.open_model(self._output_filename) - - # get old data: - data = self.get_model_array(self._output_model, "data") - wht = self.get_model_array(self._output_model, "wht") - if self._enable_ctx: - ctx = self.get_model_array(self._output_model, "con") - else: - ctx = None - - t_exptime = self.get_model_attr_value( - self._output_model, - "exptime" - ) - # TODO: we need something to store total number of images that - # have been used to create the resampled output, something - # similar to output_model.meta.resample.pointings. - # For now I will call it "ncoadds" - ncoadds = self.get_model_attr_value( - self._output_model, - "ncoadds" - ) - self.accum_output_arrays = True - - else: - ncoadds = 0 - data = None - wht = None - ctx = None - t_exptime = 0.0 - self.accum_output_arrays = False - - driz_data = Drizzle( - kernel=self.kernel, - fillval=self.fillval, - out_shape=self._output_array_shape, - out_img=data, - out_wht=wht, - out_ctx=ctx, - exptime=t_exptime, - begin_ctx_id=ncoadds, - max_ctx_id=ncoadds + ninputs, - ) - - if self._enable_err: - self._setup_variance_data() - - log.info("Resampling science data") - - # loop over only science exposures in the ModelLibrary - # sci_indices = self._input_models.ind_asn_type("science") - with self._input_models: - for model in self._input_models: - # model = self._input_models.borrow(idx) - - try: - if self.get_model_attr_value(model, "exptype").upper() != "SCIENCE": - self._input_models.shelve(model, modify=False) - continue - except AttributeError: - pass - - in_data = self.get_model_array(model, "data") - - attrs = self.get_model_meta( - model, - [ - "wcs", - "pixelarea_steradians", - "filename", - "level", - "subtracted", - "exposure_time", - ] - ) - - # Check that input models are 2D images - if in_data.ndim != 2: - raise RuntimeError( - f"Input {attrs['filename']} is not a 2D image." - ) - - input_pixflux_area = attrs["pixelarea_steradians"] - imwcs = attrs["wcs"] - if (input_pixflux_area and - 'SPECTRAL' not in imwcs.output_frame.axes_type): - imwcs.array_shape = in_data.shape - input_pixel_area = _compute_image_pixel_area(imwcs) - if input_pixel_area is None: - raise ValueError( - "Unable to compute input pixel area from WCS of input " - f"image {repr(attrs['filename'])}." - ) - iscale = np.sqrt(input_pixflux_area / input_pixel_area) - else: - iscale = 1.0 - # TODO: should weight_type=None here? - in_wht = self.build_driz_weight( - model, - weight_type=self.weight_type, - good_bits=self.good_bits + if self._accumulate and self._output_model["var_flat"]: + self._output_model["var_flat"] += self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_flat"] ) - - # apply sky subtraction - blevel = attrs["level"] - if not attrs["subtracted"] and blevel is not None: - in_data = in_data - blevel - - xmin, xmax, ymin, ymax = _resample_range( - in_data.shape, - imwcs.bounding_box + else: + self._output_model["var_flat"] = self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_flat"] ) - pixmap = calc_pixmap(wcs_from=imwcs, wcs_to=self._output_wcs) - - add_image_kwargs = { - 'exptime': attrs["exposure_time"], - 'pixmap': pixmap, - 'scale': iscale, - 'weight_map': in_wht, - 'wht_scale': 1.0, - 'pixfrac': self.pixfrac, - 'in_units': 'cps', # TODO: get units from data model - 'xmin': xmin, - 'xmax': xmax, - 'ymin': ymin, - 'ymax': ymax, - } - - driz_data.add_image(in_data, **add_image_kwargs) - - if self._enable_err: - self._resample_variance_data(model, None, add_image_kwargs) - - self._input_models.shelve(model, modify=False) - - # TODO: see what to do about original update_exposure_times() - - return self.create_output_model(driz_data) - - -class ResampleSingle(ResampleBase): - """ - This is the controlling routine for the resampling process. - - Notes - ----- - This routine performs the following operations:: - - 1. Extracts parameter settings from input model, such as pixfrac, - weight type, exposure time (if relevant), and kernel, and merges - them with any user-provided values. - 2. Creates output WCS based on input images and define mapping function - between all input arrays and the output array. - 3. Updates output data model with output arrays from drizzle, including - a record of metadata from all input models. - """ - - def __init__(self, input_models, - pixfrac=1.0, kernel="square", fillval=0.0, wht_type="ivm", - good_bits=0, output_wcs=None, wcs_pars=None, - in_memory=True, allowed_memory=None): - """ - Parameters - ---------- - input_models : list of objects - list of data models, one for each input image - - output : DataModel, str - filename for output - - kwargs : dict - Other parameters. - - .. note:: - ``output_shape`` is in the ``x, y`` order. - - .. note:: - ``in_memory`` controls whether or not the resampled - array from ``resample_many_to_many()`` - should be kept in memory or written out to disk and - deleted from memory. Default value is `True` to keep - all products in memory. - - """ - super().__init__( - input_models, - pixfrac=pixfrac, - kernel=kernel, - fillval=fillval, - wht_type=wht_type, - good_bits=good_bits, - output_wcs=output_wcs, - wcs_pars=wcs_pars, - in_memory=in_memory, - allowed_memory=allowed_memory, + # free arrays: + del self._var_rnoise_sum + del self._var_poisson_sum + del self._var_flat_sum + # del self._total_weight_var_rnoise + del self._total_weight_var_poisson + del self._total_weight_var_flat + + def compute_errors(self): + """ Computes total error of the resampled image. """ + vars = np.array( + [ + self._output_model["var_rnoise"], + self._output_model["var_poisson"], + self._output_model["var_flat"], + ] ) + all_nan_mask = np.any(np.isnan(vars), axis=0) - def build_output_name_from_input_name(self, input_file_name): - """ Form output file name from input image name """ - indx = input_file_name.rfind('.') - output_type = input_file_name[indx:] - output_root = '_'.join( - input_file_name.replace(output_type, '').split('_')[:-1] - ) - output_file_name = f'{output_root}_outlier_i2d{output_type}' - return output_file_name - - def _create_output_template_model(self): - # this probably needs to be an abstract class. - # also this is mostly needed for "single" drizzle. - self._template_output_model = self.new_model( - copy_meta_from=self._first_model_meta, - ) - pix_area = self._output_pixel_scale**2 - self.set_model_meta( - self._template_output_model, - { - "wcs": deepcopy(self._output_wcs), - "pixelarea_steradians": pix_area, - "pixelarea_arcsecsq": pix_area * np.rad2deg(3600)**2, - } + err = np.sqrt(np.nansum(vars, axis=0)).astype( + dtype=self.output_array_types["err"] ) + del vars + err[all_nan_mask] = np.nan - def create_output_model_single(self, file_name, resample_results): - # this probably needs to be an abstract class - output_model = deepcopy(self._template_output_model) - self.set_model_array(output_model, "data", resample_results.out_img) - if self.in_memory: - return output_model - else: - self.write_model(output_model, file_name, overwrite=True) - self.close_model(output_model) - log.info(f"Saved resampled model to {file_name}") - return file_name - - def run(self): - """Resample many inputs to many outputs where outputs have a common frame. - - Coadd only different detectors of the same exposure, i.e. map NRCA5 and - NRCB5 onto the same output image, as they image different areas of the - sky. - - Used for outlier detection - """ - output_models = [] - - for exposure_indices in self._input_models.group_indices.values(): - - driz = Drizzle( - kernel=self.kernel, - fillval=self.fillval, - out_shape=self._output_array_shape, - max_ctx_id=0 - ) - - log.info(f"{len(exposure_indices)} exposures to drizzle together") - - exptime = None - - meta_fields = [ - "wcs", - "pixelarea_steradians", - "filename", - "level", - "subtracted", - ] - - for idx in exposure_indices: - - with self._input_models: - model = self._input_models.borrow(idx) - - in_data = self.get_model_array(model, "data") - - if exptime is None: - attrs = self.get_model_meta( - model, - meta_fields + ["exposure_time", "filename"] - ) - else: - attrs = self.get_model_meta(model, meta_fields) - - # Check that input models are 2D images - if in_data.ndim != 2: - raise RuntimeError( - f"Input {attrs['filename']} is not a 2D image." - ) - - input_pixflux_area = attrs["pixelarea_steradians"] - imwcs = attrs["wcs"] - - if exptime is None: - exptime = attrs["exposure_time"] - # Determine output file type from input exposure filenames - # Use this for defining the output filename - output_filename = self.build_output_name_from_input_name( - attrs["filename"] - ) - - # compute image intensity correction due to the difference - # between where in the input image - # img.meta.photometry.pixelarea_steradians was computed and - # the average input pixel area. - if (input_pixflux_area and - 'SPECTRAL' not in imwcs.output_frame.axes_type): - imwcs.array_shape = in_data.shape - input_pixel_area = _compute_image_pixel_area(imwcs) - if input_pixel_area is None: - raise ValueError( - "Unable to compute input pixel area from WCS " - f"of input image {repr(attrs['filename'])}." - ) - iscale = np.sqrt(input_pixflux_area / input_pixel_area) - else: - iscale = 1.0 - - # TODO: should weight_type=None here? - in_wht = self.build_driz_weight( - model, - weight_type=self.weight_type, - good_bits=self.good_bits - ) - - # apply sky subtraction - blevel = attrs["level"] - if not attrs["subtracted"] and blevel is not None: - in_data = in_data - blevel - - xmin, xmax, ymin, ymax = _resample_range( - in_data.shape, - imwcs.bounding_box - ) - - pixmap = calc_pixmap(wcs_from=imwcs, wcs_to=self._output_wcs) - - driz.add_image( - in_data, - exptime=exptime, - pixmap=pixmap, - scale=iscale, - weight_map=in_wht, - wht_scale=1.0, - pixfrac=self.pixfrac, - in_units='cps', # TODO: get units from data model - xmin=xmin, - xmax=xmax, - ymin=ymin, - ymax=ymax, - ) - - self._input_models.shelve(model, idx, modify=False) - del in_data - - output_models.append( - self.create_output_model_single( - output_filename, - driz - ) - ) - - return output_models + self._output_model["err"] = err def _get_boundary_points(xmin, xmax, ymin, ymax, dx=None, dy=None, shrink=0): @@ -1492,3 +1294,10 @@ def _compute_image_pixel_area(wcs): pix_area = sky_area / image_area return pix_area + + +def _get_model_name(model_info): + model_name = model_info["filename"] + if model_name is None or not model_name.strip(): + model_name = "Unknown" + return model_name diff --git a/src/stcal/resample/utils.py b/src/stcal/resample/utils.py index c9623b1b0..71d32b123 100644 --- a/src/stcal/resample/utils.py +++ b/src/stcal/resample/utils.py @@ -1,11 +1,28 @@ +import os +from pathlib import Path, PurePath + import numpy as np from astropy.nddata.bitmask import interpret_bit_flags __all__ = [ - "build_mask", "get_tmeasure", + "build_mask", "build_output_model_name", "get_tmeasure", "bytes2human" ] +def build_output_model_name(input_filename_list): + fnames = {f for f in input_filename_list if f is not None} + + if not fnames: + return "resampled_data_{resample_suffix}{resample_file_ext}" + + # TODO: maybe remove ending suffix for single file names? + prefix = os.path.commonprefix( + [PurePath(f).stem.strip('_- ') for f in fnames] + ) + + return prefix + "{resample_suffix}{resample_file_ext}" + + def build_mask(dqarr, bitvalue, flag_name_map=None): """Build a bit mask from an input DQ array and a bitvalue flag @@ -26,10 +43,47 @@ def get_tmeasure(model): Returns a tuple of (exptime, is_measurement_time) """ try: - tmeasure = model.meta.exposure.measurement_time - except AttributeError: - return model.meta.exposure.exposure_time, False + tmeasure = model["measurement_time"] + except KeyError: + return model["exposure_time"], False if tmeasure is None: - return model.meta.exposure.exposure_time, False + return model["exposure_time"], False else: return tmeasure, True + + +# FIXME: temporarily copied here to avoid this import: +# from stdatamodels.jwst.library.basic_utils import bytes2human +def bytes2human(n): + """Convert bytes to human-readable format + + Taken from the `psutil` library which references + http://code.activestate.com/recipes/578019 + + Parameters + ---------- + n : int + Number to convert + + Returns + ------- + readable : str + A string with units attached. + + Examples + -------- + >>> bytes2human(10000) + '9.8K' + + >>> bytes2human(100001221) + '95.4M' + """ + symbols = ('K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y') + prefix = {} + for i, s in enumerate(symbols): + prefix[s] = 1 << (i + 1) * 10 + for s in reversed(symbols): + if n >= prefix[s]: + value = float(n) / prefix[s] + return '%.1f%s' % (value, s) + return "%sB" % n \ No newline at end of file From 7a7bd80f540db460c29fc15c6f9627f9f930b5d3 Mon Sep 17 00:00:00 2001 From: Mihai Cara Date: Wed, 2 Oct 2024 00:47:41 -0400 Subject: [PATCH 08/10] refactor --- src/stcal/resample/__init__.py | 3 +- src/stcal/resample/resample.py | 31 +++++++++++------- src/stcal/resample/utils.py | 60 +++++++++++++++++++++++++++------- 3 files changed, 69 insertions(+), 25 deletions(-) diff --git a/src/stcal/resample/__init__.py b/src/stcal/resample/__init__.py index 2baa9834a..991b8abbc 100644 --- a/src/stcal/resample/__init__.py +++ b/src/stcal/resample/__init__.py @@ -1,8 +1,9 @@ from .resample import * __all__ = [ - "LibModelAccess", + "LibModelAccessBase", "OutputTooLargeError", "Resample", "resampled_wcs_from_models", + "UnsupportedWCSError", ] diff --git a/src/stcal/resample/resample.py b/src/stcal/resample/resample.py index bffb13020..88644dbcd 100644 --- a/src/stcal/resample/resample.py +++ b/src/stcal/resample/resample.py @@ -18,8 +18,8 @@ interpret_bit_flags, ) -from .utils import bytes2human, get_tmeasure -from ..alignment.util import ( +from stcal.resample.utils import bytes2human, get_tmeasure +from stcal.alignment.util import ( compute_scale, wcs_bbox_from_shape, wcs_from_footprints, @@ -30,10 +30,11 @@ log.setLevel(logging.DEBUG) __all__ = [ - "LibModelAccess", + "LibModelAccessBase", "OutputTooLargeError", "Resample", "resampled_wcs_from_models", + "UnsupportedWCSError", ] _SUPPORTED_CUSTOM_WCS_PARS = [ @@ -62,7 +63,7 @@ def _resample_range(data_shape, bbox=None): return xmin, xmax, ymin, ymax -class LibModelAccess(abc.ABC): +class LibModelAccessBase(abc.ABC): # list of model attributes needed by this module. While this is not # required, it is helpful for subclasses to check they know how to # access these attributes. @@ -131,8 +132,8 @@ def resampled_wcs_from_models( Parameters ---------- - input_models : LibModelAccess - An object of `LibModelAccess`-derived type. + input_models : LibModelAccessBase + An object of `LibModelAccessBase`-derived type. pixel_scale_ratio : float, optional Desired pixel scale ratio defined as the ratio of the first model's @@ -241,6 +242,12 @@ class OutputTooLargeError(RuntimeError): """Raised when the output is too large for in-memory instantiation""" +class UnsupportedWCSError(RuntimeError): + """ Raised when provided output WCS has an unexpected number of axes + or has an unsupported structure. + """ + + class Resample: """ This is the controlling routine for the resampling process. @@ -282,9 +289,9 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", """ Parameters ---------- - input_models : LibModelAccess - A `LibModelAccess` object allowing iterating over all contained - models of interest. + input_models : LibModelAccessBase + A `LibModelAccessBase`-based object allowing iterating over + all contained models of interest. kwargs : dict Other parameters. @@ -414,8 +421,8 @@ def check_output_wcs(self, output_wcs, wcs_pars, """ naxes = output_wcs.output_frame.naxes if naxes != 2: - raise RuntimeError( - "Output WCS needs 2 spatial axes but the " + raise UnsupportedWCSError( + "Output WCS needs 2 coordinate axes but the " f"supplied WCS has {naxes} axes." ) @@ -841,7 +848,7 @@ def init_resample_variance(self): 'fillval': np.nan, 'out_shape': self._output_array_shape, # 'exptime': 1.0, - 'no_ctx': True, + 'disable_ctx': True, } self.driz_rnoise = Drizzle(**driz_init_kwargs) self.driz_poisson = Drizzle(**driz_init_kwargs) diff --git a/src/stcal/resample/utils.py b/src/stcal/resample/utils.py index 71d32b123..976010b06 100644 --- a/src/stcal/resample/utils.py +++ b/src/stcal/resample/utils.py @@ -1,26 +1,62 @@ -import os -from pathlib import Path, PurePath +from copy import deepcopy +import asdf import numpy as np from astropy.nddata.bitmask import interpret_bit_flags __all__ = [ - "build_mask", "build_output_model_name", "get_tmeasure", "bytes2human" + "build_mask", "get_tmeasure", "bytes2human", "load_custom_wcs" ] -def build_output_model_name(input_filename_list): - fnames = {f for f in input_filename_list if f is not None} +def load_custom_wcs(asdf_wcs_file, output_shape=None): + """ Load a custom output WCS from an ASDF file. - if not fnames: - return "resampled_data_{resample_suffix}{resample_file_ext}" + Parameters + ---------- + asdf_wcs_file : str + Path to an ASDF file containing a GWCS structure. + + output_shape : tuple of int, optional + Array shape (in ``[x, y]`` order) for the output data. If not provided, + the custom WCS must specify one of: pixel_shape, + array_shape, or bounding_box. + + Returns + ------- + wcs : WCS + The output WCS to resample into. - # TODO: maybe remove ending suffix for single file names? - prefix = os.path.commonprefix( - [PurePath(f).stem.strip('_- ') for f in fnames] - ) + """ + if not asdf_wcs_file: + return None + + with asdf.open(asdf_wcs_file) as af: + wcs = deepcopy(af.tree["wcs"]) + wcs.pixel_area = af.tree.get("pixel_area", None) + wcs.pixel_shape = af.tree.get("pixel_shape", None) + wcs.array_shape = af.tree.get("array_shape", None) + + if output_shape is not None: + wcs.array_shape = output_shape[::-1] + wcs.pixel_shape = output_shape + elif wcs.pixel_shape is not None: + wcs.array_shape = wcs.pixel_shape[::-1] + elif wcs.array_shape is not None: + wcs.pixel_shape = wcs.array_shape[::-1] + elif wcs.bounding_box is not None: + wcs.array_shape = tuple( + int(axs[1] + 0.5) + for axs in wcs.bounding_box.bounding_box(order="C") + ) + else: + raise ValueError( + "Step argument 'output_shape' is required when custom WCS " + "does not have neither of 'array_shape', 'pixel_shape', or " + "'bounding_box' attributes set." + ) - return prefix + "{resample_suffix}{resample_file_ext}" + return wcs def build_mask(dqarr, bitvalue, flag_name_map=None): From 2da57ba14e977f7d854f95d974a37d2496b8fa55 Mon Sep 17 00:00:00 2001 From: Mihai Cara Date: Fri, 11 Oct 2024 00:05:16 -0400 Subject: [PATCH 09/10] flake8 --- src/stcal/resample/__init__.py | 10 ++++- src/stcal/resample/resample.py | 70 ++++++++++++++++++---------------- src/stcal/resample/utils.py | 2 +- 3 files changed, 47 insertions(+), 35 deletions(-) diff --git a/src/stcal/resample/__init__.py b/src/stcal/resample/__init__.py index 991b8abbc..f2b24de45 100644 --- a/src/stcal/resample/__init__.py +++ b/src/stcal/resample/__init__.py @@ -1,4 +1,10 @@ -from .resample import * +from .resample import ( + LibModelAccessBase, + OutputTooLargeError, + Resample, + resampled_wcs_from_models, + UnsupportedWCSError, +) __all__ = [ "LibModelAccessBase", @@ -6,4 +12,4 @@ "Resample", "resampled_wcs_from_models", "UnsupportedWCSError", -] +] \ No newline at end of file diff --git a/src/stcal/resample/resample.py b/src/stcal/resample/resample.py index 88644dbcd..56e167921 100644 --- a/src/stcal/resample/resample.py +++ b/src/stcal/resample/resample.py @@ -91,7 +91,7 @@ class LibModelAccessBase(abc.ABC): "elapsed_exposure_time", "pixelarea_steradians", -# "pixelarea_arcsecsq", + # "pixelarea_arcsecsq", "level", # sky level "subtracted", @@ -266,7 +266,6 @@ class Resample: """ resample_suffix = 'i2d' resample_file_ext = '.fits' - n_arrays_per_output = 2 # #flt-point arrays in the output (data, weight, var, err, etc.) # supported output arrays (subclasses can add more): output_array_types = { @@ -310,8 +309,8 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", self._accumulate = accumulate # these are attributes that are used only for information purpose - # and are added to created the output_model only if they are not already - # present there: + # and are added to created the output_model only if they are + # not already present there: self._pixel_scale_ratio = None self._output_pixel_scale = None @@ -394,7 +393,8 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", npix = np.prod(self._output_array_shape) if not npix: raise ValueError( - f"Invalid output frame shape: {tuple(self._output_array_shape)}" + "Invalid output frame shape: " + f"{tuple(self._output_array_shape)}" ) # set up output model (arrays, etc.) @@ -617,7 +617,8 @@ def check_memory_requirements(self, arrays, allowed_memory): that there is enough memory to hold the output. """ - if allowed_memory is None and "DMODEL_ALLOWED_MEMORY" not in os.environ: + if (allowed_memory is None and + "DMODEL_ALLOWED_MEMORY" not in os.environ): return allowed_memory = float(allowed_memory) @@ -695,7 +696,7 @@ def build_driz_weight(self, model_info, weight_type=None, good_bits=None): # apply a median filter to smooth the weight at saturated # (or high read-out noise) single pixels. keep kernel size # small to still give lower weight to extended CRs, etc. - ksz = weight_type[8 if selective_median else 7 :] + ksz = weight_type[8 if selective_median else 7:] if ksz: kernel_size = int(ksz) if not (kernel_size % 2): @@ -805,7 +806,7 @@ def finalize_time_info(self): "start_time": min(self._exptime_start), "end_time": max(self._exptime_end), # Update other exposure time keywords: - # XPOSURE (identical to the total effective exposure time, EFFEXPTM) + # XPOSURE (identical to the total effective exposure time,EFFEXPTM) "effective_exposure_time": self._total_exposure_time, # DURATION (identical to TELAPSE, elapsed time) "duration": self._duration, @@ -1034,7 +1035,8 @@ def resample_variance_data(self, data_model, add_image_kwargs): weight[weight_mask] = np.reciprocal(var[weight_mask]) weight_mask &= (weight > 0.0) # Add the inverse of the resampled variance to a running sum. - # Update only pixels (in the running sum) with valid new values: + # Update only pixels (in the running sum) with + # valid new values: self._var_rnoise_sum[weight_mask] = np.nansum( [ self._var_rnoise_sum[weight_mask], @@ -1100,18 +1102,21 @@ def finalize_variance_processing(self): # Zero weight and missing values are NaN in the output. with warnings.catch_warnings(): warnings.filterwarnings("ignore", "invalid value*", RuntimeWarning) - warnings.filterwarnings("ignore", "divide by zero*", RuntimeWarning) + warnings.filterwarnings( + "ignore", + "divide by zero*", + RuntimeWarning, + ) # readout noise np.reciprocal(self._var_rnoise_sum, out=self._var_rnoise_sum) + v = self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_rnoise"] + ) if self._accumulate and self._output_model["var_rnoise"]: - self._output_model["var_rnoise"] += self._var_rnoise_sum.astype( - dtype=self.output_array_types["var_rnoise"] - ) + self._output_model["var_rnoise"] += v else: - self._output_model["var_rnoise"] = self._var_rnoise_sum.astype( - dtype=self.output_array_types["var_rnoise"] - ) + self._output_model["var_rnoise"] = v # Poisson noise for _ in range(2): @@ -1121,14 +1126,13 @@ def finalize_variance_processing(self): out=self._var_poisson_sum ) + v = self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_poisson"] + ) if self._accumulate and self._output_model["var_poisson"]: - self._output_model["var_poisson"] += self._var_rnoise_sum.astype( - dtype=self.output_array_types["var_poisson"] - ) + self._output_model["var_poisson"] += v else: - self._output_model["var_poisson"] = self._var_rnoise_sum.astype( - dtype=self.output_array_types["var_poisson"] - ) + self._output_model["var_poisson"] = v # flat's noise for _ in range(2): @@ -1138,14 +1142,13 @@ def finalize_variance_processing(self): out=self._var_flat_sum ) + v = self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_flat"] + ) if self._accumulate and self._output_model["var_flat"]: - self._output_model["var_flat"] += self._var_rnoise_sum.astype( - dtype=self.output_array_types["var_flat"] - ) + self._output_model["var_flat"] += v else: - self._output_model["var_flat"] = self._var_rnoise_sum.astype( - dtype=self.output_array_types["var_flat"] - ) + self._output_model["var_flat"] = v # free arrays: del self._var_rnoise_sum @@ -1175,7 +1178,8 @@ def compute_errors(self): self._output_model["err"] = err -def _get_boundary_points(xmin, xmax, ymin, ymax, dx=None, dy=None, shrink=0): +def _get_boundary_points(xmin, xmax, ymin, ymax, dx=None, dy=None, + shrink=0): # noqa: E741 """ xmin, xmax, ymin, ymax - integer coordinates of pixel boundaries step - distance between points along an edge @@ -1209,7 +1213,7 @@ def _get_boundary_points(xmin, xmax, ymin, ymax, dx=None, dy=None, shrink=0): b = np.s_[0:sx] # bottom edge r = np.s_[sx:sx + sy] # right edge t = np.s_[sx + sy:2 * sx + sy] # top edge - l = np.s_[2 * sx + sy:2 * sx + 2 * sy] # left + l = np.s_[2 * sx + sy:2 * sx + 2 * sy] # noqa: E741 left edge x[b] = np.linspace(xmin, xmax, sx, False) y[b] = ymin @@ -1233,7 +1237,9 @@ def _compute_image_pixel_area(wcs): raise ValueError("WCS must have array_shape attribute set.") valid_polygon = False - spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == 'SPATIAL')[0] + spatial_idx = np.where( + np.array(wcs.output_frame.axes_type) == 'SPATIAL' + )[0] ny, nx = wcs.array_shape @@ -1253,7 +1259,7 @@ def _compute_image_pixel_area(wcs): while xmin < xmax and ymin < ymax: try: - x, y, image_area, center, b, r, t, l = _get_boundary_points( + (x, y, image_area, center, b, r, t, l) = _get_boundary_points( xmin=xmin, xmax=xmax, ymin=ymin, diff --git a/src/stcal/resample/utils.py b/src/stcal/resample/utils.py index 976010b06..c6310f25f 100644 --- a/src/stcal/resample/utils.py +++ b/src/stcal/resample/utils.py @@ -122,4 +122,4 @@ def bytes2human(n): if n >= prefix[s]: value = float(n) / prefix[s] return '%.1f%s' % (value, s) - return "%sB" % n \ No newline at end of file + return "%sB" % n From 55294b0b0f433af379fc64984ea2a700038c39f0 Mon Sep 17 00:00:00 2001 From: Mihai Cara Date: Mon, 28 Oct 2024 21:40:42 -0400 Subject: [PATCH 10/10] fix incorrect definition of pix ratio --- src/stcal/resample/resample.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/stcal/resample/resample.py b/src/stcal/resample/resample.py index 56e167921..531eb516d 100644 --- a/src/stcal/resample/resample.py +++ b/src/stcal/resample/resample.py @@ -136,11 +136,11 @@ def resampled_wcs_from_models( An object of `LibModelAccessBase`-derived type. pixel_scale_ratio : float, optional - Desired pixel scale ratio defined as the ratio of the first model's - pixel scale computed from this model's WCS at the fiducial point - (taken as the ``ref_ra`` and ``ref_dec`` from the ``wcsinfo`` meta - attribute of the first input image) to the desired output pixel - scale. Ignored when ``pixel_scale`` is specified. + Desired pixel scale ratio defined as the ratio of the desired output + pixel scale to the first input model's pixel scale computed from this + model's WCS at the fiducial point (taken as the ``ref_ra`` and + ``ref_dec`` from the ``wcsinfo`` meta attribute of the first input + image). Ignored when ``pixel_scale`` is specified. pixel_scale : float, None, optional Desired pixel scale (in degrees) of the output WCS. When provided, @@ -209,16 +209,16 @@ def resampled_wcs_from_models( ) if pixel_scale is None: - pixel_scale = pscale_in0 / pixel_scale_ratio + pixel_scale = pscale_in0 * pixel_scale_ratio log.info( - f"Pixel scale ratio (pscale_in / pscale_out): {pixel_scale_ratio}" + f"Pixel scale ratio (pscale_out/pscale_in): {pixel_scale_ratio}" ) log.info(f"Computed output pixel scale: {3600 * pixel_scale} arcsec.") else: - pixel_scale_ratio = pscale_in0 / pixel_scale + pixel_scale_ratio = pixel_scale / pscale_in0 log.info(f"Output pixel scale: {3600 * pixel_scale} arcsec.") log.info( - "Computed pixel scale ratio (pscale_in / pscale_out): " + "Computed pixel scale ratio (pscale_out/pscale_in): " f"{pixel_scale_ratio}." ) @@ -381,7 +381,6 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", self._output_pixel_scale = np.rad2deg( np.sqrt(_compute_image_pixel_area(output_wcs)) ) - self._pixel_scale_ratio = output_model.get("wcs", None) log.info( "Computed output pixel scale: " f"{3600 * self._output_pixel_scale} arcsec."