From fc41e18c323596588ac9cb914ef41a4fbf7c434c Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Mon, 16 Dec 2024 18:12:41 -0800 Subject: [PATCH] JP-1136: Compute scaling for WFSS background subtraction using error-weighted mean (#8990) --- changes/8990.background.rst | 1 + docs/jwst/background_step/arguments.rst | 18 + docs/jwst/background_step/description.rst | 19 +- jwst/background/background_step.py | 26 +- jwst/background/background_sub.py | 194 ---------- jwst/background/background_sub_wfss.py | 284 ++++++++++++++ jwst/background/tests/conftest.py | 7 + jwst/background/tests/test_background.py | 205 +--------- jwst/background/tests/test_background_wfss.py | 361 ++++++++++++++++++ 9 files changed, 700 insertions(+), 415 deletions(-) create mode 100644 changes/8990.background.rst create mode 100644 jwst/background/background_sub_wfss.py create mode 100644 jwst/background/tests/conftest.py create mode 100644 jwst/background/tests/test_background_wfss.py diff --git a/changes/8990.background.rst b/changes/8990.background.rst new file mode 100644 index 0000000000..ed656009a4 --- /dev/null +++ b/changes/8990.background.rst @@ -0,0 +1 @@ +Compute scaling for WFSS background subtraction using error-weighted mean diff --git a/docs/jwst/background_step/arguments.rst b/docs/jwst/background_step/arguments.rst index 576260ef66..fbee951e6f 100644 --- a/docs/jwst/background_step/arguments.rst +++ b/docs/jwst/background_step/arguments.rst @@ -27,3 +27,21 @@ control the sigma clipping, and are passed as arguments to the astropy Sets the minimum (faintest) magnitude limit to use when selecting sources from the WFSS source catalog, based on the value of `isophotal_abmag` in the source catalog. Defaults to ``None``. + +``--wfss_maxiter`` + Only applies to Wide Field Slitless Spectroscopy (WFSS) exposures. + Sets the maximum number of iterations allowed for iterative outlier rejection + during determination of the reference background scaling factor. Defaults to 5. + +``--wfss_rms_stop`` + Only applies to Wide Field Slitless Spectroscopy (WFSS) exposures. + If the percentage difference in the RMS of the background-subtracted image + between iterations is smaller than this value, stop the iterative outlier + rejection process. + Defaults to 0, i.e., do all iterations up to ``wfss_maxiter``. + +``--wfss_outlier_percent`` + Only applies to Wide Field Slitless Spectroscopy (WFSS) exposures. + Sets the percentile of outliers in the data to reject on both the low and high end + per iteration during determination of the reference background scaling factor + Defaults to 1, i.e., keep the middle 98 percent of the data each iteration. diff --git a/docs/jwst/background_step/description.rst b/docs/jwst/background_step/description.rst index 183c8e1951..3934b1db4c 100644 --- a/docs/jwst/background_step/description.rst +++ b/docs/jwst/background_step/description.rst @@ -87,7 +87,15 @@ For Wide-Field Slitless Spectroscopy expsoures (NIS_WFSS and NRC_WFSS), a background reference image is subtracted from the target exposure. Before being subtracted, the background reference image is scaled to match the signal level of the WFSS image within background (source-free) regions of the -image. +image. The scaling factor is determined based on the variance-weighted mean +of the science data, i.e., ``factor = sum(sci*bkg/var) / sum(bkg*bkg/var)``. +This factor is equivalent to solving for the scaling constant applied to the +reference background that gives the maximum likelihood of matching +the science data. +Outliers are rejected iteratively during determination of the scaling factor +in order to avoid biasing the scaling factor based on outliers. The iterative +rejection process is controlled by the +``wfss_outlier_percent``, ``wfss_rms_stop``, and ``wfss_maxiter`` step arguments. The locations of source spectra are determined from a source catalog (specified by the primary header keyword SCATFILE), in conjunction with a reference file @@ -99,15 +107,6 @@ abmag of the source catalog objects used to define the background regions. The default is to use all source catalog entries that result in a spectrum falling within the WFSS image. -Robust mean values are obtained for the background regions in the WFSS image and -for the same regions in the background reference image, and the ratio of those two -mean values is used to scale the background reference image. The robust mean is -computed by excluding the lowest 25% and highest 25% of the data (using the -numpy.percentile function), and taking a simple arithmetic mean of the -remaining values. Note that NaN values (if any) in the background -reference image are currently set to zero. If there are a lot of NaNs, -it may be that more than 25% of the lowest values will need to be excluded. - For both background methods the output results are always returned in a new data model, leaving the original input model unchanged. diff --git a/jwst/background/background_step.py b/jwst/background/background_step.py index 4f5946195d..a63d55ac85 100755 --- a/jwst/background/background_step.py +++ b/jwst/background/background_step.py @@ -2,7 +2,8 @@ from stdatamodels.jwst import datamodels from ..stpipe import Step -from . import background_sub +from .background_sub import background_sub +from .background_sub_wfss import subtract_wfss_bkg import numpy as np __all__ = ["BackgroundStep"] @@ -19,6 +20,9 @@ class BackgroundStep(Step): sigma = float(default=3.0) # Clipping threshold maxiters = integer(default=None) # Number of clipping iterations wfss_mmag_extract = float(default=None) # WFSS minimum abmag to extract + wfss_maxiter = integer(default=5) # WFSS iterative outlier rejection max iterations + wfss_rms_stop = float(default=0) # WFSS iterative outlier rejection RMS improvement threshold (percent) + wfss_outlier_percent = float(default=1) # WFSS outlier percentile to reject per iteration """ # These reference files are only used for WFSS/GRISM data. @@ -60,8 +64,16 @@ def process(self, input, bkg_list): wlrange_name) # Do the background subtraction for WFSS/GRISM data - result = background_sub.subtract_wfss_bkg( - input_model, bkg_name, wlrange_name, self.wfss_mmag_extract) + rescaler_kwargs = {"p": self.wfss_outlier_percent, + "maxiter": self.wfss_maxiter, + "delta_rms_thresh": self.wfss_rms_stop/100, + } + result = subtract_wfss_bkg( + input_model, + bkg_name, + wlrange_name, + self.wfss_mmag_extract, + rescaler_kwargs=rescaler_kwargs) if result is None: result = input_model result.meta.cal_step.back_sub = 'SKIPPED' @@ -88,10 +100,10 @@ def process(self, input, bkg_list): break # Do the background subtraction if do_sub: - bkg_model, result = background_sub.background_sub(input_model, - bkg_list, - self.sigma, - self.maxiters) + bkg_model, result = background_sub(input_model, + bkg_list, + self.sigma, + self.maxiters) result.meta.cal_step.back_sub = 'COMPLETE' if self.save_combined_background: comb_bkg_path = self.save_model(bkg_model, suffix=self.bkg_suffix, force=True) diff --git a/jwst/background/background_sub.py b/jwst/background/background_sub.py index 9edd641c23..a04a03ea9c 100755 --- a/jwst/background/background_sub.py +++ b/jwst/background/background_sub.py @@ -1,13 +1,10 @@ import copy -import math import numpy as np import warnings from stdatamodels.jwst import datamodels -from stdatamodels.jwst.datamodels.dqflags import pixel from . import subtract_images -from ..assign_wcs.util import create_grism_bbox from astropy.stats import sigma_clip from astropy.utils.exceptions import AstropyUserWarning @@ -270,194 +267,3 @@ def average_background(input_model, bkg_list, sigma, maxiters): avg_bkg.err = (np.sqrt(merr.sum(axis=0)) / (num_bkg - merr.mask.sum(axis=0))).data return avg_bkg - - -def sufficient_background_pixels(dq_array, bkg_mask, min_pixels=100): - """Count number of good pixels for background use. - - Check DQ flags of pixels selected for bkg use - XOR the DQ values with - the DO_NOT_USE flag to flip the DO_NOT_USE bit. Then count the number - of pixels that AND with the DO_NOT_USE flag, i.e. initially did not have - the DO_NOT_USE bit set. - """ - return np.count_nonzero((dq_array[bkg_mask] - ^ pixel['DO_NOT_USE']) - & pixel['DO_NOT_USE'] - ) > min_pixels - - -def subtract_wfss_bkg(input_model, bkg_filename, wl_range_name, mmag_extract=None): - """Scale and subtract a background reference image from WFSS/GRISM data. - - Parameters - ---------- - input_model : JWST data model - input target exposure data model - - bkg_filename : str - name of master background file for WFSS/GRISM - - wl_range_name : str - name of wavelengthrange reference file - - mmag_extract : float - minimum abmag of grism objects to extract - - Returns - ------- - result : JWST data model - background-subtracted target data model - """ - - bkg_ref = datamodels.open(bkg_filename) - - if hasattr(input_model.meta, "source_catalog"): - got_catalog = True - else: - log.warning("No source_catalog found in input.meta.") - got_catalog = False - - # If there are NaNs, we have to replace them with something harmless. - bkg_ref = no_NaN(bkg_ref) - - # Create a mask from the source catalog, True where there are no sources, - # i.e. in regions we can use as background. - if got_catalog: - bkg_mask = mask_from_source_cat(input_model, wl_range_name, mmag_extract) - if not sufficient_background_pixels(input_model.dq, bkg_mask): - log.warning("Not enough background pixels to work with.") - log.warning("Step will be SKIPPED.") - return None - else: - bkg_mask = np.ones(input_model.data.shape, dtype=bool) - # Compute the mean values of science image and background reference - # image, including only regions where there are no identified sources. - # Exclude pixel values in the lower and upper 25% of the histogram. - lowlim = 25. - highlim = 75. - sci_mean = robust_mean(input_model.data[bkg_mask], - lowlim=lowlim, highlim=highlim) - bkg_mean = robust_mean(bkg_ref.data[bkg_mask], - lowlim=lowlim, highlim=highlim) - - log.debug("mean of [{}, {}] percentile grism image = {}" - .format(lowlim, highlim, sci_mean)) - log.debug("mean of [{}, {}] percentile background image = {}" - .format(lowlim, highlim, bkg_mean)) - - result = input_model.copy() - if bkg_mean != 0.: - subtract_this = (sci_mean / bkg_mean) * bkg_ref.data - result.data = input_model.data - subtract_this - log.info(f"Average of background image subtracted = {subtract_this.mean(dtype=float)}") - else: - log.warning("Background image has zero mean; nothing will be subtracted.") - result.dq = np.bitwise_or(input_model.dq, bkg_ref.dq) - - bkg_ref.close() - - return result - - -def no_NaN(model, fill_value=0.): - """Replace NaNs with a harmless value. - - Parameters - ---------- - model : JWST data model - Reference file model. - - fill_value : float - NaNs will be replaced with this value. - - Returns - ------- - result : JWST data model - Reference file model without NaNs in data array. - """ - - mask = np.isnan(model.data) - if mask.sum(dtype=np.intp) == 0: - return model - else: - temp = model.copy() - temp.data[mask] = fill_value - return temp - - -def mask_from_source_cat(input_model, wl_range_name, mmag_extract=None): - """Create a mask that is False within bounding boxes of sources. - - Parameters - ---------- - input_model : JWST data model - input target exposure data model - - wl_range_name : str - Name of the wavelengthrange reference file - - mmag_extract : float - Minimum abmag of grism objects to extract - - Returns - ------- - bkg_mask : ndarray - Boolean mask: True for background, False for pixels that are - inside at least one of the source regions defined in the source - catalog. - """ - - shape = input_model.data.shape - bkg_mask = np.ones(shape, dtype=bool) - - reference_files = {"wavelengthrange": wl_range_name} - grism_obj_list = create_grism_bbox(input_model, reference_files, mmag_extract) - - for obj in grism_obj_list: - order_bounding = obj.order_bounding - for order in order_bounding.keys(): - ((ymin, ymax), (xmin, xmax)) = order_bounding[order] - xmin = int(math.floor(xmin)) - xmax = int(math.ceil(xmax)) + 1 # convert to slice limit - ymin = int(math.floor(ymin)) - ymax = int(math.ceil(ymax)) + 1 - xmin = max(xmin, 0) - xmax = min(xmax, shape[-1]) - ymin = max(ymin, 0) - ymax = min(ymax, shape[-2]) - bkg_mask[..., ymin:ymax, xmin:xmax] = False - - return bkg_mask - - -def robust_mean(x, lowlim=25., highlim=75.): - """Compute a mean value, excluding outliers. - - Parameters - ---------- - x : ndarray - The array for which we want a mean value. - - lowlim : float - The lower `lowlim` percent of the data will not be used when - computing the mean. - - highlim : float - The upper `highlim` percent of the data will not be used when - computing the mean. - - Returns - ------- - mean_value : float - The mean of `x`, excluding data outside `lowlim` to `highlim` - percentile limits. - """ - - nan_mask = np.isnan(x) - cleaned_x = x[~nan_mask] - limits = np.percentile(cleaned_x, (lowlim, highlim)) - mask = np.logical_and(cleaned_x >= limits[0], cleaned_x <= limits[1]) - - mean_value = np.mean(cleaned_x[mask], dtype=float) - - return mean_value diff --git a/jwst/background/background_sub_wfss.py b/jwst/background/background_sub_wfss.py new file mode 100644 index 0000000000..dcfc7baee5 --- /dev/null +++ b/jwst/background/background_sub_wfss.py @@ -0,0 +1,284 @@ +import math +import numpy as np +import warnings + +from stdatamodels.jwst import datamodels +from stdatamodels.jwst.datamodels.dqflags import pixel + +from jwst.assign_wcs.util import create_grism_bbox + +import logging +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +def subtract_wfss_bkg( + input_model, + bkg_filename, + wl_range_name, + mmag_extract=None, + rescaler_kwargs={}, +): + """Scale and subtract a background reference image from WFSS/GRISM data. + + Parameters + ---------- + input_model : JWST data model + input target exposure data model + + bkg_filename : str + name of master background file for WFSS/GRISM + + wl_range_name : str + name of wavelengthrange reference file + + mmag_extract : float, optional, default None + minimum abmag of grism objects to extract + + rescaler_kwargs : dict, optional, default {} + Keyword arguments to pass to ScalingFactorComputer + + Returns + ------- + result : JWST data model + background-subtracted target data model + """ + bkg_ref = datamodels.open(bkg_filename) + + # get the dispersion axis + try: + dispaxis = input_model.meta.wcsinfo.dispersion_direction + except AttributeError: + log.warning("Dispersion axis not found in input science image metadata. " + "Variance stopping criterion will have no effect for iterative " + "outlier rejection (will run until maxiter).") + dispaxis = None + rescaler_kwargs["dispersion_axis"] = dispaxis + + # get the source catalog for masking + if hasattr(input_model.meta, "source_catalog"): + got_catalog = True + else: + log.warning("No source_catalog found in input.meta.") + got_catalog = False + + # Create a mask from the source catalog, True where there are no sources, + # i.e. in regions we can use as background. + if got_catalog: + bkg_mask = _mask_from_source_cat(input_model, wl_range_name, mmag_extract) + if not _sufficient_background_pixels(input_model.dq, bkg_mask): + log.warning("Not enough background pixels to work with.") + log.warning("Step will be SKIPPED.") + return None + else: + bkg_mask = np.ones(input_model.data.shape, dtype=bool) + + # compute scaling factor for the reference background image + log.info("Starting iterative outlier rejection for background subtraction.") + rescaler = _ScalingFactorComputer(**rescaler_kwargs) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + category=RuntimeWarning, + message="All-NaN slice encountered") + # copy to avoid propagating NaNs from iterative clipping into final product + sci = input_model.data.copy() + var = input_model.err.copy()**2 + bkg = bkg_ref.data.copy() + factor, _ = rescaler(sci, bkg, var, mask=~bkg_mask) + + # extract the derived factor and apply it to the unmasked, non-outlier-rejected data + subtract_this = factor * bkg_ref.data + result = input_model.copy() + result.data = input_model.data - subtract_this + result.dq = np.bitwise_or(input_model.dq, bkg_ref.dq) + + log.info(f"Average of scaled background image = {np.nanmean(subtract_this):.3e}") + log.info(f"Scaling factor = {factor:.5e}") + + bkg_ref.close() + + return result + + +class _ScalingFactorComputer: + + def __init__(self, p=1.0, maxiter=5, delta_rms_thresh=0, dispersion_axis=None): + """ + Parameters + ---------- + p: float, optional + Percentile for sigma clipping on both low and high ends per iteration, default 1.0. + For example, with p=2.0, the middle 96% of the data is kept. + maxiter: int, optional + Maximum number of iterations for outlier rejection. Default 5. + delta_rms_thresh: float, optional + Stopping criterion for outlier rejection; stops when the rms residuals + change by less than this fractional threshold in a single iteration. + For example, assuming delta_rms_thresh=0.1 and a residual RMS of 100 + in iteration 1, the iteration will stop if the RMS residual in iteration + 2 is greater than 90. + Default 0.0, i.e., ignore this and only stop at maxiter. + dispersion_axis: int, optional + The index to select the along-dispersion axis. Used to compute the RMS + residual, so must be set if rms_thresh > 0. Default None. + """ + if (delta_rms_thresh > 0) and (dispersion_axis not in [1,2]): + msg = (f"Unrecognized dispersion axis {dispersion_axis}. " + "Dispersion axis must be specified if delta_rms_thresh " + "is used as a stopping criterion.") + raise ValueError(msg) + + self.p = p + self.maxiter = maxiter + self.delta_rms_thresh = delta_rms_thresh + self.dispersion_axis = dispersion_axis + + + def __call__(self, sci, bkg, var, mask=None): + """ + Parameters + ---------- + sci: ndarray + The science data. + bkg: ndarray + The reference background model. + var: ndarray + Total variance (error squared) of the science data. + mask: ndarray[bool], optional + Initial mask to be applied to the data, True where bad. + Typically this would mask out the real sources in the data. + + Returns + ------- + float + Scaling factor that minimizes sci - factor*bkg, + taking into account residuals and outliers. + ndarray[bool] + Outlier mask generated by the iterative clipping procedure. + """ + if mask is None: + mask = np.zeros(sci.shape, dtype="bool") + self._update_nans(sci, bkg, var, mask) + + # iteratively reject more and more outliers + i = 0 + last_rms_resid = np.inf + while (i < self.maxiter): + + # compute the factor that minimizes the residuals + factor = self.err_weighted_mean(sci, bkg, var) + sci_sub = sci-factor*bkg + + # Check fractional improvement stopping criterion before incrementing. + # Note this never passes in iteration 0 because last_rms_resid is inf. + if self.delta_rms_thresh > 0: + rms_resid = self._compute_rms_residual(sci_sub) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + category=RuntimeWarning, + message="invalid value encountered in scalar divide") + fractional_diff = (last_rms_resid - rms_resid)/last_rms_resid + if fractional_diff < self.delta_rms_thresh: + msg = (f"Stopping at iteration {i}; too little improvement " + "since last iteration (hit delta_rms_thresh).") + log.info(msg) + break + last_rms_resid = rms_resid + + i += 1 + + # Reject outliers based on residual between sci and bkg. + # Updating the sci, var, and bkg nan values means that + # they are ignored by nanpercentile in the next iteration + limits = np.nanpercentile(sci_sub, (self.p, 100-self.p)) + mask += np.logical_or(sci_sub < limits[0], sci_sub > limits[1]) + self._update_nans(sci, bkg, var, mask) + + if i >= self.maxiter: + log.info(f"Stopped at maxiter ({i}).") + + self._iters_run_last_call = i + return self.err_weighted_mean(sci, bkg, var), mask + + + def err_weighted_mean(self, sci, bkg, var): + """Remove any var=0 values, which can happen for real data""" + mask = (var == 0) + self._update_nans(sci, bkg, var, mask) + return np.nansum(sci*bkg/var) / np.nansum(bkg*bkg/var) + + + def _update_nans(self, sci, bkg, var, mask): + sci[mask] = np.nan + bkg[mask] = np.nan + var[mask] = np.nan + + + def _compute_rms_residual(self, sci_sub): + """ + Calculate the background-subtracted RMS along the dispersion axis, which is found + by taking the median profile of the image collapsed along the cross-dispersion axis. + Note meta.wcsinfo.dispersion_axis is 1-indexed coming out of assign_wcs, i.e., in [1,2]. + So we need to """ + collapsing_axis = int(self.dispersion_axis - 1) + sci_sub_profile = np.nanmedian(sci_sub, axis=collapsing_axis) + return np.sqrt(np.nanmean(sci_sub_profile**2)) + + +def _sufficient_background_pixels(dq_array, bkg_mask, min_pixels=100): + """Count number of good pixels for background use. + + Check DQ flags of pixels selected for bkg use - XOR the DQ values with + the DO_NOT_USE flag to flip the DO_NOT_USE bit. Then count the number + of pixels that AND with the DO_NOT_USE flag, i.e. initially did not have + the DO_NOT_USE bit set. + """ + return np.count_nonzero((dq_array[bkg_mask] + ^ pixel['DO_NOT_USE']) + & pixel['DO_NOT_USE'] + ) > min_pixels + + +def _mask_from_source_cat(input_model, wl_range_name, mmag_extract=None): + """Create a mask that is False within bounding boxes of sources. + + Parameters + ---------- + input_model : JWST data model + input target exposure data model + + wl_range_name : str + Name of the wavelengthrange reference file + + mmag_extract : float + Minimum abmag of grism objects to extract + + Returns + ------- + bkg_mask : ndarray + Boolean mask: True for background, False for pixels that are + inside at least one of the source regions defined in the source + catalog. + """ + + shape = input_model.data.shape + bkg_mask = np.ones(shape, dtype=bool) + + reference_files = {"wavelengthrange": wl_range_name} + grism_obj_list = create_grism_bbox(input_model, reference_files, mmag_extract) + + for obj in grism_obj_list: + order_bounding = obj.order_bounding + for order in order_bounding.keys(): + ((ymin, ymax), (xmin, xmax)) = order_bounding[order] + xmin = int(math.floor(xmin)) + xmax = int(math.ceil(xmax)) + 1 # convert to slice limit + ymin = int(math.floor(ymin)) + ymax = int(math.ceil(ymax)) + 1 + xmin = max(xmin, 0) + xmax = min(xmax, shape[-1]) + ymin = max(ymin, 0) + ymax = min(ymax, shape[-2]) + bkg_mask[..., ymin:ymax, xmin:xmax] = False + + return bkg_mask diff --git a/jwst/background/tests/conftest.py b/jwst/background/tests/conftest.py new file mode 100644 index 0000000000..cfe5d98314 --- /dev/null +++ b/jwst/background/tests/conftest.py @@ -0,0 +1,7 @@ +import pytest +import pathlib + + +@pytest.fixture(scope="module") +def data_path(): + return pathlib.Path(__file__).parent / "data" \ No newline at end of file diff --git a/jwst/background/tests/test_background.py b/jwst/background/tests/test_background.py index 00b04c93a8..841aa3a362 100644 --- a/jwst/background/tests/test_background.py +++ b/jwst/background/tests/test_background.py @@ -1,31 +1,16 @@ """ Unit tests for background subtraction """ -import pathlib - -from astropy.stats import sigma_clipped_stats import pytest -import numpy as np from numpy.testing import assert_allclose from stdatamodels.jwst import datamodels -from stdatamodels.jwst.datamodels.dqflags import pixel - -from jwst.assign_wcs import AssignWcsStep from jwst.background import BackgroundStep -from jwst.stpipe import Step -from jwst.background.background_sub import (robust_mean, mask_from_source_cat, - no_NaN, sufficient_background_pixels) - - -@pytest.fixture(scope="module") -def data_path(): - return pathlib.Path(__file__).parent / "data" @pytest.fixture(scope='module') def background(tmp_path_factory): - """Generate a background image to feed to background step""" + """Generate a background image to feed to background step""" filename = tmp_path_factory.mktemp('background_input') filename = filename / 'background.fits' @@ -190,140 +175,6 @@ def test_nirspec_gwa_ytilt(tmp_cwd, background, science_image): back_image.close() -@pytest.fixture(scope='module') -def make_wfss_datamodel(data_path): - """Generate WFSS Observation""" - wcsinfo = { - 'dec_ref': -27.79156387419731, - 'ra_ref': 53.16247756038121, - 'roll_ref': 0.04254766236781744, - 'v2_ref': -290.1, - 'v3_ref': -697.5, - 'v3yangle': 0.56987, - 'vparity': -1} - - observation = { - 'date': '2023-01-05', - 'time': '8:59:37'} - - exposure = { - 'duration': 11.805952, - 'end_time': 58119.85416, - 'exposure_time': 11.776, - 'frame_time': 0.11776, - 'group_time': 0.11776, - 'groupgap': 0, - 'integration_time': 11.776, - 'nframes': 1, - 'ngroups': 8, - 'nints': 1, - 'nresets_between_ints': 0, - 'nsamples': 1, - 'sample_time': 10.0, - 'start_time': 58668.72509857639, - 'zero_frame': False} - - subarray = {'xsize': 2048, - 'ysize': 2048, - 'xstart': 1, - 'ystart': 1} - - instrument = { - 'filter_position': 1, - 'pupil_position': 1} - - image = datamodels.ImageModel((2048, 2048)) - - image.meta.wcsinfo._instance.update(wcsinfo) - image.meta.instrument._instance.update(instrument) - image.meta.observation._instance.update(observation) - image.meta.subarray._instance.update(subarray) - image.meta.exposure._instance.update(exposure) - image.data = np.random.rand(2048, 2048) - image.meta.source_catalog = str(data_path / "test_cat.ecsv") - - return image - - -filter_list = ['F250M', 'F277W', 'F335M', 'F356W', 'F460M', - 'F356W', 'F410M', 'F430M', 'F444W'] # + ['F480M', 'F322W2', 'F300M'] - - -@pytest.mark.parametrize("pupils", ['GRISMC', 'GRISMR']) -@pytest.mark.parametrize("filters", filter_list) -@pytest.mark.parametrize("detectors", ['NRCALONG', 'NRCBLONG']) -def test_nrc_wfss_background(tmp_cwd, filters, pupils, detectors, make_wfss_datamodel): - """Test background subtraction for NIRCAM WFSS modes.""" - data = make_wfss_datamodel - - data.meta.instrument.filter = filters - data.meta.instrument.pupil = pupils - data.meta.instrument.detector = detectors - data.meta.instrument.channel = 'LONG' - data.meta.instrument.name = 'NIRCAM' - data.meta.exposure.type = 'NRC_WFSS' - - if data.meta.instrument.detector == 'NRCALONG': - data.meta.instrument.module = 'A' - elif data.meta.instrument.detector == 'NRCBLONG': - data.meta.instrument.module = 'B' - - wcs_corrected = AssignWcsStep.call(data) - - # Get References - wavelenrange = Step().get_reference_file(wcs_corrected, "wavelengthrange") - bkg_file = Step().get_reference_file(wcs_corrected, 'wfssbkg') - - mask = mask_from_source_cat(wcs_corrected, wavelenrange) - - with datamodels.open(bkg_file) as bkg_ref: - bkg_ref = no_NaN(bkg_ref) - - # calculate backgrounds - pipeline_data_mean = robust_mean(wcs_corrected.data[mask]) - test_data_mean, _, _ = sigma_clipped_stats(wcs_corrected.data, sigma=2) - - pipeline_reference_mean = robust_mean(bkg_ref.data[mask]) - test_reference_mean, _, _ = sigma_clipped_stats(bkg_ref.data, sigma=2) - - assert np.isclose([pipeline_data_mean], [test_data_mean], rtol=1e-3) - assert np.isclose([pipeline_reference_mean], [test_reference_mean], rtol=1e-1) - - -@pytest.mark.parametrize("filters", ['GR150C', 'GR150R']) -@pytest.mark.parametrize("pupils", ['F090W', 'F115W', 'F140M', 'F150W', 'F158M', 'F200W']) -def test_nis_wfss_background(filters, pupils, make_wfss_datamodel): - """Test background subtraction for NIRISS WFSS modes.""" - data = make_wfss_datamodel - - data.meta.instrument.filter = filters - data.meta.instrument.pupil = pupils - data.meta.instrument.detector = 'NIS' - data.meta.instrument.name = 'NIRISS' - data.meta.exposure.type = 'NIS_WFSS' - - wcs_corrected = AssignWcsStep.call(data) - - # Get References - wavelenrange = Step().get_reference_file(wcs_corrected, "wavelengthrange") - bkg_file = Step().get_reference_file(wcs_corrected, 'wfssbkg') - - mask = mask_from_source_cat(wcs_corrected, wavelenrange) - - with datamodels.open(bkg_file) as bkg_ref: - bkg_ref = no_NaN(bkg_ref) - - # calculate backgrounds - pipeline_data_mean = robust_mean(wcs_corrected.data[mask]) - test_data_mean, _, _ = sigma_clipped_stats(wcs_corrected.data, sigma=2) - - pipeline_reference_mean = robust_mean(bkg_ref.data[mask]) - test_reference_mean, _, _ = sigma_clipped_stats(bkg_ref.data, sigma=2) - - assert np.isclose([pipeline_data_mean], [test_data_mean], rtol=1e-3) - assert np.isclose([pipeline_reference_mean], [test_reference_mean], rtol=1e-1) - - @pytest.mark.parametrize('data_shape,background_shape', [((10, 10), (10, 10)), ((10, 10), (20, 20)), @@ -371,57 +222,3 @@ def test_miri_subarray_partial_overlap(data_shape, background_shape): image.close() background.close() - - -def test_robust_mean(): - """Test robust mean calculation""" - data = np.random.rand(2048, 2048) - result = robust_mean(data) - test = np.mean(data) - - assert np.isclose([test], [result], rtol=1e-3) - - -def test_no_nan(): - """Make sure that nan values are filled with fill value""" - # Make data model - model = datamodels.ImageModel() - data = np.random.rand(10, 10) - - # Randomly insert NaNs - data.ravel()[np.random.choice(data.size, 10, replace=False)] = np.nan - model.data = data - - # Randomly select fill value - fill_val = np.random.randint(0, 20) - - # Call no_NaN - result = no_NaN(model, fill_value=fill_val) - - # Use np.nan to find NaNs. - test_result = np.isnan(model.data) - # Assign fill values to NaN indices - model.data[test_result] = fill_val - - # Make sure arrays are equal. - assert np.array_equal(model.data, result.data) - - -def test_sufficient_background_pixels(): - model = datamodels.ImageModel(data=np.zeros((2048, 2048)), - dq=np.zeros((2048, 2048))) - refpix_flags = pixel['DO_NOT_USE'] | pixel['REFERENCE_PIXEL'] - model.dq[:4, :] = refpix_flags - model.dq[-4:, :] = refpix_flags - model.dq[:, :4] = refpix_flags - model.dq[:, -4:] = refpix_flags - - bkg_mask = np.ones((2048, 2048), dtype=bool) - # With full array minux refpix available for bkg, should be sufficient - assert sufficient_background_pixels(model.dq, bkg_mask) - - bkg_mask[4: -4, :] = 0 - bkg_mask[:, 4: -4] = 0 - # Now mask out entire array, mocking full source coverage of detector - - # no pixels should be available for bkg - assert not sufficient_background_pixels(model.dq, bkg_mask) diff --git a/jwst/background/tests/test_background_wfss.py b/jwst/background/tests/test_background_wfss.py new file mode 100644 index 0000000000..d1219fbaa1 --- /dev/null +++ b/jwst/background/tests/test_background_wfss.py @@ -0,0 +1,361 @@ +import pytest +import numpy as np +from pathlib import Path + +from stdatamodels.jwst.datamodels.dqflags import pixel +from stdatamodels.jwst import datamodels +from jwst.stpipe import Step +from jwst.assign_wcs import AssignWcsStep +from jwst.background import BackgroundStep +from jwst.background.background_sub_wfss import (subtract_wfss_bkg, + _mask_from_source_cat, + _sufficient_background_pixels, + _ScalingFactorComputer) + +BKG_SCALING = 0.123 +DETECTOR_SHAPE = (2048, 2048) +INITIAL_NAN_FRACTION = 1e-4 +INITIAL_OUTLIER_FRACTION = 1e-3 + +@pytest.fixture(scope="module") +def known_bkg(): + """Make a simplified version of the reference background model data.""" + + ny, nx = DETECTOR_SHAPE + y, x = np.mgrid[:ny, :nx] + gradient = x * y / (nx*ny) + gradient = gradient - np.mean(gradient) + return gradient + 1 + + +@pytest.fixture(scope="module") +def mock_data(known_bkg): + """Synthetic data with NaNs, noise, and the known background structure + but rescaled. Later tests will ensure we can retrieve the proper scaling.""" + + err_scaling = 0.05 + nan_fraction = INITIAL_NAN_FRACTION + + # make random data and error arrays + rng = np.random.default_rng(seed=42) + data = rng.normal(0, 1, DETECTOR_SHAPE) + # ensure all errors are positive and not too close to zero + err = err_scaling*(1 + rng.normal(0, 1, DETECTOR_SHAPE)**2) + + # add NaNs + num_nans = int(data.size * nan_fraction) + nan_indices = np.unravel_index(rng.choice(data.size, num_nans), data.shape) + data[nan_indices] = np.nan + err[nan_indices] = np.nan + original_data_mean = np.nanmean(data) + + # add some outliers + num_outliers = int(data.size * INITIAL_OUTLIER_FRACTION) + outlier_indices = np.unravel_index(rng.choice(data.size, num_outliers), data.shape) + data[outlier_indices] = rng.normal(100, 1, num_outliers) + + data[nan_indices] = np.nan + err[nan_indices] = np.nan + + # also add a small background to the data with same structure + # as the known reference background to see if it will get removed + data += known_bkg*BKG_SCALING + + return data, err, original_data_mean + + +@pytest.fixture(scope='module') +def make_wfss_datamodel(data_path, mock_data): + + """Generate WFSS Observation""" + wcsinfo = { + 'dec_ref': -27.79156387419731, + 'ra_ref': 53.16247756038121, + 'roll_ref': 0.04254766236781744, + 'v2_ref': -290.1, + 'v3_ref': -697.5, + 'v3yangle': 0.56987, + 'vparity': -1} + + observation = { + 'date': '2023-01-05', + 'time': '8:59:37'} + + exposure = { + 'duration': 11.805952, + 'end_time': 58119.85416, + 'exposure_time': 11.776, + 'frame_time': 0.11776, + 'group_time': 0.11776, + 'groupgap': 0, + 'integration_time': 11.776, + 'nframes': 1, + 'ngroups': 8, + 'nints': 1, + 'nresets_between_ints': 0, + 'nsamples': 1, + 'sample_time': 10.0, + 'start_time': 58668.72509857639, + 'zero_frame': False} + + subarray = {'xsize': DETECTOR_SHAPE[0], + 'ysize': DETECTOR_SHAPE[1], + 'xstart': 1, + 'ystart': 1} + + instrument = { + 'filter_position': 1, + 'pupil_position': 1} + + image = datamodels.ImageModel(DETECTOR_SHAPE) + + image.meta.wcsinfo._instance.update(wcsinfo) + image.meta.instrument._instance.update(instrument) + image.meta.observation._instance.update(observation) + image.meta.subarray._instance.update(subarray) + image.meta.exposure._instance.update(exposure) + + image.data = mock_data[0] + image.err = mock_data[1] + image.original_data_mean = mock_data[2] #just add this here for convenience + image.dq = np.isnan(image.data) + + image.meta.source_catalog = str(data_path / "test_cat.ecsv") + + return image + + +@pytest.fixture +def make_nrc_wfss_datamodel(make_wfss_datamodel): + """Make a NIRCAM WFSS datamodel and call AssignWCS to populate its WCS""" + data = make_wfss_datamodel.copy() + data.meta.instrument.filter = 'F250M' + data.meta.instrument.pupil = 'GRISMC' + data.meta.instrument.detector = 'NRCALONG' + data.meta.instrument.channel = 'LONG' + data.meta.instrument.name = 'NIRCAM' + data.meta.exposure.type = 'NRC_WFSS' + data.meta.instrument.module = 'A' + result = AssignWcsStep.call(data) + + return result + + +@pytest.fixture +def make_nis_wfss_datamodel(make_wfss_datamodel): + """Make a NIRISS WFSS datamodel and call AssignWCS to populate its WCS""" + data = make_wfss_datamodel.copy() + data.meta.instrument.filter = 'GR150C' + data.meta.instrument.pupil = 'F090W' + data.meta.instrument.detector = 'NIS' + data.meta.instrument.name = 'NIRISS' + data.meta.exposure.type = 'NIS_WFSS' + result = AssignWcsStep.call(data) + + return result + + +@pytest.fixture() +def bkg_file(tmp_cwd, make_wfss_datamodel, known_bkg): + """Mock background reference file""" + + bkg_fname = "ref_bkg.fits" + bkg_image = make_wfss_datamodel.copy() + bkg_image.data = known_bkg + bkg_image.save(tmp_cwd / Path(bkg_fname)) + + return bkg_fname + + +def shared_tests(sci, mask, original_data_mean): + """Tests that are common to all WFSS modes + Note that NaN fraction test in test_nrc_wfss_background and test_nis_wfss_background + cannot be applied to the full run tests because the background reference files contain + NaNs in some cases (specifically for NIRISS)""" + + # re-mask data so "real" sources are ignored here + sci[~mask] = np.nan + + # test that the background has been subtracted from the data to within some fraction of + # the noise in the data. There's probably a first-principles way to determine the tolerance, + # but this is ok for the purposes of this test. + # ignore the outliers here too + sci[sci>50] = np.nan + tol = 0.01*np.nanstd(sci) + assert np.isclose(np.nanmean(sci), original_data_mean, atol=tol) + + +def test_nrc_wfss_background(make_nrc_wfss_datamodel, bkg_file): + """Test background subtraction for NIRCAM WFSS modes.""" + data = make_nrc_wfss_datamodel.copy() + + # Get References + wavelenrange = Step().get_reference_file(data, "wavelengthrange") + + # do the subtraction + result = subtract_wfss_bkg(data, bkg_file, wavelenrange) + sci = result.data.copy() + + # ensure NaN fraction did not increase. Rejecting outliers during determination + # of factor should not have carried over into result. + nan_frac = np.sum(np.isnan(sci))/sci.size + assert np.isclose(nan_frac, INITIAL_NAN_FRACTION, rtol=1e-2) + + # re-compute mask to ignore "real" sources for tests + mask = _mask_from_source_cat(result, wavelenrange) + + shared_tests(sci, mask, data.original_data_mean) + + +def test_nis_wfss_background(make_nis_wfss_datamodel, bkg_file): + """Test background subtraction for NIRISS WFSS modes.""" + data = make_nis_wfss_datamodel.copy() + + # Get References + wavelenrange = Step().get_reference_file(data, "wavelengthrange") + + # do the subtraction + result = subtract_wfss_bkg(data, bkg_file, wavelenrange) + sci = result.data.copy() + + # ensure NaN fraction did not increase. Rejecting outliers during determination + # of factor should not have carried over into result. + nan_frac = np.sum(np.isnan(sci))/sci.size + assert np.isclose(nan_frac, INITIAL_NAN_FRACTION, rtol=1e-2) + + mask = _mask_from_source_cat(result, wavelenrange) + shared_tests(sci, mask, data.original_data_mean) + + +# test both filters because they have opposite dispersion directions +@pytest.mark.parametrize("pupil", ["GRISMC", "GRISMR"]) +def test_nrc_wfss_full_run(pupil, make_nrc_wfss_datamodel): + """Test full run of NIRCAM WFSS background subtraction. + The residual structure in the background will not look as nice as in + test_nis_wfss_background because here it's taken from a reference file, + so the bkg has real detector imperfections + while the data is synthetic and just has a mock gradient""" + data = make_nrc_wfss_datamodel.copy() + data.meta.instrument.pupil = pupil + + # do the subtraction. set all options to ensure they are at least recognized + result = BackgroundStep.call(data, None, + wfss_maxiter=3, + wfss_outlier_percent=0.5, + wfss_rms_stop=0,) + + sci = result.data.copy() + # re-derive mask to ignore "real" sources for tests + wavelenrange = Step().get_reference_file(data, "wavelengthrange") + mask = _mask_from_source_cat(result, wavelenrange) + shared_tests(sci, mask, data.original_data_mean) + + +@pytest.mark.parametrize("filt", ["GR150C", "GR150R"]) +def test_nis_wfss_full_run(filt, make_nis_wfss_datamodel): + """Test full run of NIRISS WFSS background subtraction. + The residual structure in the background will not look as nice as in + test_nis_wfss_background because here it's taken from a reference file, + so the bkg has real detector imperfections + while the data is synthetic and just has a mock gradient""" + data = make_nis_wfss_datamodel.copy() + data.meta.instrument.filter = filt + + # do the subtraction. set all options to ensure they are at least recognized + result = BackgroundStep.call(data, None, + wfss_maxiter=3, + wfss_outlier_percent=0.5, + wfss_rms_stop=0,) + + sci = result.data.copy() + # re-derive mask to ignore "real" sources for tests + wavelenrange = Step().get_reference_file(data, "wavelengthrange") + mask = _mask_from_source_cat(result, wavelenrange) + shared_tests(sci, mask, data.original_data_mean) + + +def test_sufficient_background_pixels(): + model = datamodels.ImageModel(data=np.zeros((2048, 2048)), + dq=np.zeros((2048, 2048))) + refpix_flags = pixel['DO_NOT_USE'] | pixel['REFERENCE_PIXEL'] + model.dq[:4, :] = refpix_flags + model.dq[-4:, :] = refpix_flags + model.dq[:, :4] = refpix_flags + model.dq[:, -4:] = refpix_flags + + bkg_mask = np.ones((2048, 2048), dtype=bool) + # With full array minux refpix available for bkg, should be sufficient + assert _sufficient_background_pixels(model.dq, bkg_mask) + + bkg_mask[4: -4, :] = 0 + bkg_mask[:, 4: -4] = 0 + # Now mask out entire array, mocking full source coverage of detector - + # no pixels should be available for bkg + assert not _sufficient_background_pixels(model.dq, bkg_mask) + + +def test_weighted_mean(make_wfss_datamodel, bkg_file): + + sci = make_wfss_datamodel.data + var = make_wfss_datamodel.err**2 + with datamodels.open(bkg_file) as bkg_model: + bkg = bkg_model.data + + # put 0.1% zero values in variance to ensure coverage of previous bug where zero-valued + # variances in real data caused factor = 1/np.inf = 0 + rng = np.random.default_rng(seed=42) + n_bad = int(var.size / 1000) + bad_i = rng.choice(var.size-1, n_bad) + var[np.unravel_index(bad_i, var.shape)] = 0.0 + + # instantiate scaling factor computer + rescaler = _ScalingFactorComputer() + + # just get the weighted mean without iteration + # to check it's as expected, mask outliers + sci[sci>50] = np.nan + factor = rescaler.err_weighted_mean(sci, bkg, var) + original_data_mean = make_wfss_datamodel.original_data_mean + expected_factor = BKG_SCALING+original_data_mean + assert np.isclose(factor, expected_factor, atol=1e-3) + + # ensure it still works after iteration + for niter in [1,2,5]: + for p in [2, 0.5, 0.1]: + rescaler = _ScalingFactorComputer(p=p, maxiter=niter) + assert rescaler.delta_rms_thresh == 0 #check rms_thresh=None input sets thresh properly + + factor, mask_out = rescaler(sci, bkg, var) + mask_fraction = np.sum(mask_out)/mask_out.size + max_mask_fraction = p*niter*2 + INITIAL_NAN_FRACTION + + assert np.isclose(factor, expected_factor, atol=1e-3) + assert mask_fraction <= max_mask_fraction + assert mask_fraction > INITIAL_NAN_FRACTION + + # test that variance stopping criterion works + # tune the RMS thresh to take roughly half the iterations + # need lots of significant digits here because iterating makes little difference + # for this test case + maxiter = 10 + delta_rms_thresh = 1e-4 + p = 100*INITIAL_OUTLIER_FRACTION/2 + rescaler = _ScalingFactorComputer(p=p, + dispersion_axis=1, + delta_rms_thresh=delta_rms_thresh, + maxiter=maxiter) + factor, mask_out = rescaler(sci, bkg, var) + assert rescaler._iters_run_last_call < maxiter + + # test putting mask=None works ok, and that maxiter=0 just gives you err weighted mean + rescaler = _ScalingFactorComputer(maxiter=0) + factor, mask_out = rescaler(sci, bkg, var) + assert np.all(mask_out == 0) + assert factor == rescaler.err_weighted_mean(sci, bkg, var) + + # test invalid inputs + with pytest.raises(ValueError): + rescaler = _ScalingFactorComputer(dispersion_axis=5, delta_rms_thresh=1) + + with pytest.raises(ValueError): + rescaler = _ScalingFactorComputer(dispersion_axis=None, delta_rms_thresh=1)