From 12b43afdd02a64ce386e71439e3b04e1cedcc4b6 Mon Sep 17 00:00:00 2001 From: Melanie Clarke Date: Thu, 14 Nov 2024 10:43:51 -0500 Subject: [PATCH] First version of PSF extraction, contains hacks Co-authored-by: jemorrison --- jwst/extract_1d/extract.py | 101 ++++++++++-- jwst/extract_1d/extract_1d_step.py | 27 +++- jwst/extract_1d/psf_profile.py | 245 +++++++++++++++++++++++++++++ 3 files changed, 352 insertions(+), 21 deletions(-) create mode 100644 jwst/extract_1d/psf_profile.py diff --git a/jwst/extract_1d/extract.py b/jwst/extract_1d/extract.py index a09475b3af..5a374f191a 100644 --- a/jwst/extract_1d/extract.py +++ b/jwst/extract_1d/extract.py @@ -16,6 +16,7 @@ from jwst.lib.wcs_utils import get_wavelengths from jwst.extract_1d import extract1d, spec_wcs from jwst.extract_1d.apply_apcorr import select_apcorr +from jwst.extract_1d.psf_profile import psf_profile log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) @@ -23,6 +24,9 @@ WFSS_EXPTYPES = ['NIS_WFSS', 'NRC_WFSS', 'NRC_GRISM'] """Exposure types to be regarded as wide-field slitless spectroscopy.""" +OPTIMAL_EXPTYPES = ['MIR_LRS-FIXEDSLIT'] +"""Exposure types for which optimal extraction is available.""" + ANY = "ANY" """Wildcard for slit name. @@ -139,7 +143,8 @@ def read_apcorr_ref(refname, exptype): def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, smoothing_length, bkg_fit, bkg_order, use_source_posn, - subtract_background): + subtract_background, extraction_type, specwcs_ref_name, + psf_ref_name): """Get extraction parameter values. Parameters @@ -202,6 +207,17 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, subtract_background : bool or None If False, all background parameters will be ignored. + extraction_type : str + Extraction type ('box' or 'optimal'). Optimal extraction is + only available if `specwcs_ref_name` and `psf_ref_name` are + not 'N/A'. + + specwcs_ref_name : str + The name of the specwcs reference file, or "N/A". + + psf_ref_name : str + The name of the PSF reference file, or "N/A". + Returns ------- extract_params : dict @@ -228,6 +244,8 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, extract_params['bkg_order'] = 0 # because no background sub. extract_params['subtract_background'] = False extract_params['extraction_type'] = 'box' + extract_params['specwcs_ref_name'] = 'N/A' + extract_params['psf_ref_name'] = 'N/A' if use_source_posn is None: extract_params['use_source_posn'] = False @@ -326,9 +344,10 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, # If the user supplied a value, use that value. extract_params['smoothing_length'] = smoothing_length - # Default the extraction type to 'box': 'optimal' - # is not yet supported. - extract_params['extraction_type'] = 'box' + # Set the extraction type to 'box' or 'optimal' + extract_params['extraction_type'] = extraction_type + extract_params['specwcs'] = specwcs_ref_name + extract_params['psf'] = psf_ref_name break @@ -1206,10 +1225,27 @@ def define_aperture(input_model, slit, extract_params, exp_type): # Offet extract parameters by location - nominal shift_by_source_location(location, nominal_location, extract_params) + else: + middle_pix, middle_wl, location = None, None, None # Make a spatial profile, including source shifts if necessary - profile, lower_limit, upper_limit = box_profile( - data_shape, extract_params, wl_array, return_limits=True) + if extract_params['extraction_type'] == 'optimal': + if location is None: + nominal_profile = box_profile(data_shape, extract_params, wl_array, + label='nominal aperture') + middle_pix, location = aperture_center(nominal_profile) + if extract_params['dispaxis'] == HORIZONTAL: + middle_wl = wl_array[middle_pix, location] + else: + middle_wl = wl_array[location, middle_pix] + + profile, lower_limit, upper_limit = psf_profile( + data_model, extract_params['psf'], extract_params['specwcs'], + middle_wl, location) + + else: + profile, lower_limit, upper_limit = box_profile( + data_shape, extract_params, wl_array, return_limits=True) # Make sure profile weights are zero where wavelengths are invalid profile[~np.isfinite(wl_array)] = 0.0 @@ -1383,8 +1419,9 @@ def extract_one_slit(data_model, integ, profile, bg_profile, extract_params): def create_extraction(input_model, slit, output_model, extract_ref_dict, slitname, sp_order, smoothing_length, bkg_fit, bkg_order, use_source_posn, exp_type, - subtract_background, apcorr_ref_model, log_increment, - save_profile): + subtract_background, apcorr_ref_model, + extraction_type, specwcs_ref_name, psf_ref_name, + log_increment, save_profile): """Extract spectra from an input model and append to an output model. Input data, specified in the `slit` or `input_model`, should contain data @@ -1463,6 +1500,14 @@ def create_extraction(input_model, slit, output_model, apcorr_ref_model : DataModel or None The aperture correction reference datamodel, containing the APCORR reference file data. + extraction_type : str + Extraction type ('box' or 'optimal'). Optimal extraction is + only available if `specwcs_ref_name` and `psf_ref_name` are + not 'N/A'. + specwcs_ref_name : str + The name of the specwcs reference file, or "N/A". + psf_ref_name : str + The name of the PSF reference file, or "N/A". log_increment : int If greater than 0 and the input data are multi-integration, a message will be written to the log every `log_increment` integrations. @@ -1543,7 +1588,7 @@ def create_extraction(input_model, slit, output_model, extract_params = get_extract_parameters( extract_ref_dict, data_model, slitname, sp_order, input_model.meta, smoothing_length, bkg_fit,bkg_order, use_source_posn, - subtract_background + subtract_background, extraction_type, specwcs_ref_name, psf_ref_name ) if extract_params['match'] == NO_MATCH: @@ -1784,9 +1829,10 @@ def create_extraction(input_model, slit, output_model, return profile_model -def run_extract1d(input_model, extract_ref_name, apcorr_ref_name, smoothing_length, - bkg_fit, bkg_order, log_increment, subtract_background, - use_source_posn, save_profile): +def run_extract1d(input_model, extract_ref_name, apcorr_ref_name, + specwcs_ref_name, psf_ref_name, extraction_type, + smoothing_length, bkg_fit, bkg_order, log_increment, + subtract_background, use_source_posn, save_profile): """Extract all 1-D spectra from an input model. Parameters @@ -1798,7 +1844,18 @@ def run_extract1d(input_model, extract_ref_name, apcorr_ref_name, smoothing_leng The name of the extract1d reference file, or "N/A". apcorr_ref_name : str - Name of the APCORR reference file. Default is None + Name of the APCORR reference file. Default is None. + + specwcs_ref_name : str + The name of the specwcs reference file, or "N/A". + + psf_ref_name : str + The name of the PSF reference file, or "N/A". + + extraction_type : str + Extraction type ('box' or 'optimal'). Optimal extraction is + only available if `specwcs_ref_name` and `psf_ref_name` are + not 'N/A'. smoothing_length : int or None Width of a boxcar function for smoothing the background regions. @@ -1863,6 +1920,14 @@ def run_extract1d(input_model, extract_ref_name, apcorr_ref_name, smoothing_leng if apcorr_ref_name is not None and apcorr_ref_name != 'N/A': apcorr_ref_model = read_apcorr_ref(apcorr_ref_name, exp_type) + # Check for non-null specwcs and PSF reference files + if (exp_type not in OPTIMAL_EXPTYPES + or specwcs_ref_name == 'N/A' or psf_ref_name == 'N/A'): + if extraction_type != 'box': + log.warning(f'Optimal extraction is not available for EXP_TYPE {exp_type}') + log.warning('Defaulting to box extraction.') + extraction_type = 'box' + # Set up the output model output_model = datamodels.MultiSpecModel() if hasattr(meta_source, "int_times"): @@ -1944,8 +2009,9 @@ def run_extract1d(input_model, extract_ref_name, apcorr_ref_name, smoothing_leng input_model, slit, output_model, extract_ref_dict, slitname, sp_order, smoothing_length, bkg_fit, bkg_order, use_source_posn, exp_type, - subtract_background, apcorr_ref_model, log_increment, - save_profile) + subtract_background, apcorr_ref_model, + extraction_type, specwcs_ref_name, psf_ref_name, + log_increment, save_profile) except ContinueError: pass @@ -1979,8 +2045,9 @@ def run_extract1d(input_model, extract_ref_name, apcorr_ref_name, smoothing_leng input_model, slit, output_model, extract_ref_dict, slitname, sp_order, smoothing_length, bkg_fit, bkg_order, use_source_posn, exp_type, - subtract_background, apcorr_ref_model, log_increment, - save_profile) + subtract_background, apcorr_ref_model, + extraction_type, specwcs_ref_name, psf_ref_name, + log_increment, save_profile) except ContinueError: pass diff --git a/jwst/extract_1d/extract_1d_step.py b/jwst/extract_1d/extract_1d_step.py index 3d030c8bfe..b48957e078 100644 --- a/jwst/extract_1d/extract_1d_step.py +++ b/jwst/extract_1d/extract_1d_step.py @@ -15,6 +15,12 @@ class Extract1dStep(Step): Attributes ---------- + extraction_type : str + If 'box', a standard extraction is performed, summing over an + aperture box. If 'optimal', a PSF-based extraction is performed. + Currently, optimal extraction is only available for MIRI LRS Fixed + Slit data. + use_source_posn : bool or None If True, the source and background extraction positions specified in the extract1d reference file (or the default position, if there is no @@ -154,6 +160,7 @@ class Extract1dStep(Step): class_alias = "extract_1d" spec = """ + extraction_type = option("box", "optimal", default="box") # Perform box or optimal extraction use_source_posn = boolean(default=None) # use source coords to center extractions? apply_apcorr = boolean(default=True) # apply aperture corrections? log_increment = integer(default=50) # increment for multi-integration log messages @@ -186,7 +193,8 @@ class Extract1dStep(Step): soss_modelname = output_file(default = None) # Filename for optional model output of traces and pixel weights """ - reference_file_types = ['extract1d', 'apcorr', 'pastasoss', 'specprofile', 'speckernel'] + reference_file_types = ['extract1d', 'apcorr', 'pastasoss', 'specprofile', + 'speckernel', 'specwcs'] #, 'psf'] def _get_extract_reference_files_by_mode(self, model, exp_type): """Get extraction reference files with special handling by exposure type.""" @@ -207,7 +215,15 @@ def _get_extract_reference_files_by_mode(self, model, exp_type): if apcorr_ref != 'N/A': self.log.info(f'Using APCORR file {apcorr_ref}') - return extract_ref, apcorr_ref + if exp_type in extract.OPTIMAL_EXPTYPES: + specwcs_ref = self.get_reference_file(model, 'specwcs') + #psf_ref = self.get_reference_file(model, 'psf') + psf_ref = "MIRI_LRS_SLIT_EPSF_20240602_WAVE_1D.fits" + else: + specwcs_ref = 'N/A' + psf_ref = 'N/A' + + return extract_ref, apcorr_ref, specwcs_ref, psf_ref def _extract_soss(self, model): """Extract NIRISS SOSS spectra.""" @@ -381,8 +397,8 @@ def process(self, input): result = ModelContainer() for model in input_model: # Get the reference file names - extract_ref, apcorr_ref = self._get_extract_reference_files_by_mode( - model, exp_type) + extract_ref, apcorr_ref, specwcs_ref, psf_ref = \ + self._get_extract_reference_files_by_mode(model, exp_type) profile = None if isinstance(model, datamodels.IFUCubeModel): @@ -394,6 +410,9 @@ def process(self, input): model, extract_ref, apcorr_ref, + specwcs_ref, + psf_ref, + self.extraction_type, self.smoothing_length, self.bkg_fit, self.bkg_order, diff --git a/jwst/extract_1d/psf_profile.py b/jwst/extract_1d/psf_profile.py new file mode 100644 index 0000000000..8d93ac0727 --- /dev/null +++ b/jwst/extract_1d/psf_profile.py @@ -0,0 +1,245 @@ +import logging +import numpy as np + +from scipy.interpolate import CubicSpline +from scipy import interpolate +from scipy import ndimage +from astropy.io import fits + +from stdatamodels.jwst.datamodels import MiriLrsPsfModel + +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +HORIZONTAL = 1 +VERTICAL = 2 +"""Dispersion direction, predominantly horizontal or vertical.""" + + +def open_specwcs(specwcs_ref_name: str, exp_type: str): + """Open the specwcs reference file. + + Currently only works on MIRI LRS-FIXEDSLIT exposures. + + Parameters + ---------- + specwcs_ref_name : str + The name of the specwcs reference file. This file contains + information of the trace location. For MIRI LRS-FIXEDSlIT it + is a FITS file containing the x,y center of the trace. + ext_type : str + The exposure type of the data. + + Returns + ------- + trace, wave_trace, wavetab + Center of the trace in x and y for a given wavelength. + + """ + if exp_type == 'MIR_LRS-FIXEDSLIT': + # use fits to read file (the datamodel does not have all that is needed) + ref = fits.open(specwcs_ref_name) + + with ref: + lrsdata = np.array([d for d in ref[1].data]) + # Get the zero point from the reference data. + # The zero_point is X, Y (which should be COLUMN, ROW) + # These are 1-indexed in CDP-7 (i.e., SIAF convention) so must be converted to 0-indexed + # for lrs_fixedslit + zero_point = ref[0].header['imx'] - 1, ref[0].header['imy'] - 1 + + # In the lrsdata reference table, X_center,Y_center, wavelength relative to zero_point + + xcen = lrsdata[:, 0] + ycen = lrsdata[:, 1] + wavetab = lrsdata[:, 2] + trace = xcen + zero_point[0] + wave_trace = ycen + zero_point[1] + + else: + raise NotImplementedError(f'Specwcs files for EXP_TYPE {exp_type} ' + f'are not supported.') + + return trace, wave_trace, wavetab + + +def open_psf(psf_refname: str, exp_type: str): + """Open the PSF reference file. + + Parameters + ---------- + psf_ref_name : str + The name of the psf reference file. + ext_type : str + The exposure type of the data. + + Returns + ------- + psf_model : MiriLrsPsfModel + Currently only works on MIRI LRS-FIXEDSLIT exposures. + Returns the EPSF model. + + """ + if exp_type == 'MIR_LRS-FIXEDSLIT': + # The information we read in from PSF file is: + # center_col: psf_model.meta.psf.center_col + # super sample factor: psf_model.meta.psf.subpix) + # psf : psf_model.data (2d) + # wavelength of PSF planes: psf_model.wave + psf_model = MiriLrsPsfModel(psf_refname) + + else: + raise NotImplementedError(f'PSF files for EXP_TYPE {exp_type} ' + f'are not supported.') + return psf_model + + +def psf_profile(input_model, psf_ref_name, specwcs_ref_name, middle_wl, location): + """Create a spatial profile from a PSF reference. + + Currently only works on MIRI LRS-FIXEDSLIT exposures. + Input data must be point source. + + The extraction routine can support multiple sources for + simultaneous extraction, but for this first version, we will assume + one source only, located at the planned position (dither RA/Dec), and + return a single profile. + + Parameters + ---------- + input_model : data model + This can be either the input science file or one SlitModel out of + a list of slits. + psf_ref_name : str + PSF reference filename. + specwcs_ref_name : str + Reference file containing information on the spectral trace. + middle_wl : float or None + Wavelength value to use as the center of the trace. If not provided, + the wavelength at the center of the bounding box will be used. + location : float or None + Spatial index to use as the center of the trace. If not provided, + the location at the center of the bounding box will be used. + + Returns + ------- + profile : ndarray + Spatial profile matching the input data. + lower_limit : int + Lower limit of the aperture in the cross-dispersion direction. + For PSF profiles, this is always set to the lower edge of the bounding box, + since the full array may have non-zero weight. + upper_limit : int + Upper limit of the aperture in the cross-dispersion direction. + For PSF profiles, this is always set to the upper edge of the bounding box, + since the full array may have non-zero weight. + """ + # Check input exposure type + exp_type = input_model.meta.exposure.type + if exp_type != 'MIR_LRS-FIXEDSLIT': + raise NotImplementedError(f'PSF extraction is not supported for ' + f'EXP_TYPE {exp_type}') + + # Read in reference files + trace, wave_trace, wavetab = open_specwcs(specwcs_ref_name, exp_type) + psf_model = open_psf(psf_ref_name, exp_type) + + dispaxis = input_model.meta.wcsinfo.dispersion_direction + wcs = input_model.meta.wcs + bbox = wcs.bounding_box + center_x = np.mean(bbox[0]) + center_y = np.mean(bbox[1]) + + # Determine the location using the WCS + if middle_wl is None: + _, _, middle_wl = wcs(center_x, center_y) + if location is None: + if dispaxis == HORIZONTAL: + location = center_y + else: + location = center_x + + y0 = int(np.ceil(bbox[1][0])) + y1 = int(np.ceil(bbox[1][1])) + x0 = int(np.round(bbox[0][0])) + x1 = int(np.round(bbox[0][1])) + cutout = input_model.data[y0:y1, x0:x1] + + # Perform fit of reference trace and corresponding wavelength + # The wavelength for the reference trace does not exactly line up exactly with the data + cs = CubicSpline(wavetab, trace) + cen_shift = cs(middle_wl) + shift = location - cen_shift + log.info(f'Centering profile on spectrum at {location}, wavelength {middle_wl}') + log.info(f'For this wavelength, the reference trace location is at {cen_shift}') + log.info(f'Shift to apply to ref trace: {shift}') + + # todo - if possible, fix this for s2d - + # cen_shift is wrong, wavelengths don't match PSF + + # adjust the trace to the slit region + trace_cutout = trace - bbox[0][0] + trace_shift = trace_cutout + shift + psf_wave = psf_model.wave + + # trace_shift: for each wavelength in the PSF, this is the shift in x to apply + # to the PSF image to shift it to fall on the source. + # wavetab : this is the wavelength corresponding to the trace. + # This wavelength may not match exactly to the PSF. + + # Determine what the shifts per row are for the wavelengths + # given by the model PSF + psf_subpix = psf_model.meta.psf.subpix + + psf_interp = interpolate.interp1d(wavetab, trace_shift, fill_value="extrapolate") + psf_shift = psf_interp(psf_wave) + psf_shift = psf_model.meta.psf.center_col - (psf_shift * psf_subpix) + + # Note: this assumes that data wavelengths are identical to PSF wavelengths + data_shape = cutout.shape + _y, _x = np.mgrid[:data_shape[0], :data_shape[1]] + + # Scale cross-dispersion coordinates by subpixel value and shift by trace + if dispaxis == HORIZONTAL: + if data_shape[1] != psf_shift.size: + log.error('Data shape does not match PSF reference.') + log.error('Optimal extraction must be performed on cal files.') + raise NotImplementedError('Optimal extraction not implemented for resampled data.') + + _y_sh = _y * psf_subpix + psf_shift + sprofile = ndimage.map_coordinates(psf_model.data, [_y_sh, _x], order=1) + else: + if data_shape[0] != psf_shift.size: + log.error('Data shape does not match PSF reference.') + log.error('Optimal extraction must be performed on cal files.') + raise NotImplementedError('Optimal extraction not implemented for resampled data.') + + _x_sh = _x * psf_subpix + psf_shift[:, np.newaxis] + sprofile = ndimage.map_coordinates(psf_model.data, [_y, _x_sh], order=1) + + # Normalize the spatial profile at each dispersion element + if dispaxis == HORIZONTAL: + psum = np.sum(sprofile, axis=0) + sprofile[:, psum > 0] = sprofile[:, psum > 0] / psum[psum > 0] + sprofile[:, psum <= 0] = 0.0 + else: + psum = np.sum(sprofile, axis=1) + sprofile[psum > 0, :] = sprofile[psum > 0, :] / psum[psum > 0, None] + sprofile[psum <= 0, :] = 0.0 + sprofile[~np.isfinite(sprofile)] = 0.0 + sprofile[sprofile < 0] = 0.0 + + # Make the output profile, matching the input data + data_shape = input_model.data.shape + profile = np.full(data_shape, 0.0) + output_y = _y + y0 + output_x = _x + x0 + valid = (output_y >= 0) & (output_y < y1) & (output_x >= 0) & (output_x < x1) + profile[output_y[valid], output_x[valid]] = sprofile[valid] + + if dispaxis == HORIZONTAL: + limits = (y0, y1) + else: + limits = (x0, x1) + return profile, *limits