diff --git a/jwst/resample/resample.py b/jwst/resample/resample.py index 11d1a12011d..c3c1b3f2909 100644 --- a/jwst/resample/resample.py +++ b/jwst/resample/resample.py @@ -8,6 +8,7 @@ from stcal.resample import ResampleModelIO, ResampleCoAdd, ResampleSingle from stcal.resample.utils import get_tmeasure from drizzle.resample import Drizzle +from stdatamodels.jwst.datamodels.dqflags import pixel from ..datamodels import ModelContainer from ..model_blender import blendmeta @@ -58,6 +59,7 @@ def new_model(self, image_shape=None, file_name=None): class ResampleJWSTCoAdd(ResampleJWSTModelIO, ResampleCoAdd): # resample_array_names = [ # {'attr': 'data', 'variance', 'exptime'] + dq_flag_name_map = pixel def __init__(self, *args, blendheaders=True, **kwargs): super().__init__(*args, **kwargs) self._blendheaders = blendheaders @@ -107,177 +109,10 @@ def _check_var_array(self, data_model, array_name): return False return True - def extra_pre_resample_setup(self): - self._var_rnoise_sum = np.full(self._output_array_shape, np.nan) - self._var_poisson_sum = np.full(self._output_array_shape, np.nan) - self._var_flat_sum = np.full(self._output_array_shape, np.nan) - # self._total_weight_var_rnoise = np.zeros(self._output_array_shape) - self._total_weight_var_poisson = np.zeros(self._output_array_shape) - self._total_weight_var_flat = np.zeros(self._output_array_shape) - - def post_process_resample_model(self, data_model, driz_init_kwargs, add_image_kwargs): - log.info("Resampling variance components") - - # create resample objects for the three variance arrays: - driz_init_kwargs = { - 'kernel': self.kernel, - 'fillval': np.nan, - 'out_shape': self._output_array_shape, - # 'exptime': 1.0, - 'no_ctx': True, - } - driz_rnoise = Drizzle(**driz_init_kwargs) - driz_poisson = Drizzle(**driz_init_kwargs) - driz_flat = Drizzle(**driz_init_kwargs) - - # Resample read-out noise and compute weight map for variance arrays - if self._check_var_array(data_model, 'var_rnoise'): - data = np.sqrt(data_model.var_rnoise) - driz_rnoise.add_image(data, **add_image_kwargs) - var = driz_rnoise.out_img - np.square(var, out=var) - - weight_mask = var > 0 - - # Set the weight for the image from the weight type - if self.weight_type == "ivm": - weight_mask = var > 0 - weight = np.ones(self._output_array_shape) - weight[weight_mask] = np.reciprocal(var[weight_mask]) - weight_mask &= (weight > 0.0) - # Add the inverse of the resampled variance to a running sum. - # Update only pixels (in the running sum) with valid new values: - self._var_rnoise_sum[weight_mask] = np.nansum( - [ - self._var_rnoise_sum[weight_mask], - weight[weight_mask] - ], - axis=0 - ) - elif self.weight_type == "exptime": - weight = np.full( - self._output_array_shape, - get_tmeasure(data_model)[0], - ) - weight_mask = np.ones(self._output_array_shape, dtype=bool) - self._var_rnoise_sum = np.nansum( - [self._var_rnoise_sum, weight], - axis=0 - ) - else: - weight = np.ones(self._output_array_shape) - weight_mask = np.ones(self._output_array_shape, dtype=bool) - self._var_rnoise_sum = np.nansum( - [self._var_rnoise_sum, weight], - axis=0 - ) - else: - weight = np.ones(self._output_array_shape) - weight_mask = np.ones(self._output_array_shape, dtype=bool) - - if self._check_var_array(data_model, 'var_poisson'): - data = np.sqrt(data_model.var_poisson) - driz_poisson.add_image(data, **add_image_kwargs) - var = driz_poisson.out_img - np.square(var, out=var) - - mask = (var > 0) & weight_mask - - # Add the inverse of the resampled variance to a running sum. - # Update only pixels (in the running sum) with valid new values: - self._var_poisson_sum[mask] = np.nansum( - [ - self._var_poisson_sum[mask], - var[mask] * weight[mask] * weight[mask] - ], - axis=0 - ) - self._total_weight_var_poisson[mask] += weight[mask] - - if self._check_var_array(data_model, 'var_flat'): - data = np.sqrt(data_model.var_flat) - driz_flat.add_image(data, **add_image_kwargs) - var = driz_flat.out_img - np.square(var, out=var) - - mask = (var > 0) & weight_mask - - # Add the inverse of the resampled variance to a running sum. - # Update only pixels (in the running sum) with valid new values: - self._var_flat_sum[mask] = np.nansum( - [ - self._var_flat_sum[mask], - var[mask] * weight[mask] * weight[mask] - ], - axis=0 - ) - self._total_weight_var_flat[mask] += weight[mask] - - def finalize_resample(self): - # We now have a sum of the weighted resampled variances. - # Divide by the total weights, squared, and set in the output model. - # Zero weight and missing values are NaN in the output. - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "invalid value*", RuntimeWarning) - warnings.filterwarnings("ignore", "divide by zero*", RuntimeWarning) - - odt = self._output_model.data.dtype - - # readout noise - np.reciprocal(self._var_rnoise_sum, out=self._var_rnoise_sum) - self._output_model.var_rnoise = self._var_rnoise_sum.astype(dtype=odt) - - # Poisson noise - for _ in range(2): - np.divide( - self._var_poisson_sum, - self._total_weight_var_poisson, - out=self._var_poisson_sum - ) - self._output_model.var_poisson = self._var_poisson_sum.astype(dtype=odt) - - # flat's noise - for _ in range(2): - np.divide( - self._var_flat_sum, - self._total_weight_var_flat, - out=self._var_flat_sum - ) - self._output_model.var_flat = self._var_flat_sum.astype(dtype=odt) - - # compute total error: - vars = np.array( - [ - self._var_rnoise_sum, - self._var_poisson_sum, - self._var_flat_sum, - ] - ) - all_nan_mask = np.any(np.isnan(vars), axis=0) - self._output_model.err = np.sqrt(np.nansum(vars, axis=0)).astype(dtype=odt) - self._output_model.err[all_nan_mask] = np.nan - - del vars - del self._var_rnoise_sum - del self._var_poisson_sum - del self._var_flat_sum - # del self._total_weight_var_rnoise - del self._total_weight_var_poisson - del self._total_weight_var_flat - - # update meta for the output model: - self._output_model.meta.cal_step.resample = 'COMPLETE' - _update_fits_wcsinfo(self._output_model) - util.update_s_region_imaging(self._output_model) - self._output_model.meta.asn.pool_name = self._input_models.asn_pool_name - self._output_model.meta.asn.table_name = self._input_models.asn_table_name - self._output_model.meta.resample.pixel_scale_ratio = self._pixel_scale_ratio - self._output_model.meta.resample.pixfrac = self.pixfrac - # TODO: Not sure about funct. signature and also I don't like it needs # to open input files again. Should we store meta of all inputs? # Should blendmeta.blendmodels be redesigned to blend one meta at a time? - def blend_output_metadata(self, output_model): + def blend_output_metadata(self): """ Create new output metadata based on blending all input metadata. """ if not self._blendheaders: @@ -290,15 +125,26 @@ def blend_output_metadata(self, output_model): log.info(f'Blending metadata for {self._output_filename}') blendmeta.blendmodels( - output_model, + self._output_model, inputs=self._input_models, output=self._output_filename, ignore=ignore_list ) + def final_post_processing(self): + # update meta for the output model: + self._output_model.meta.cal_step.resample = 'COMPLETE' + _update_fits_wcsinfo(self._output_model) + util.update_s_region_imaging(self._output_model) + self._output_model.meta.asn.pool_name = self._input_models.asn_pool_name + self._output_model.meta.asn.table_name = self._input_models.asn_table_name + self._output_model.meta.resample.pixel_scale_ratio = self._pixel_scale_ratio + self._output_model.meta.resample.pixfrac = self.pixfrac + self.blend_output_metadata() + class ResampleJWSTSingle(ResampleJWSTModelIO, ResampleSingle): - pass + dq_flag_name_map = pixel def _update_fits_wcsinfo(model):