diff --git a/changes/9022.extract_1d.rst b/changes/9022.extract_1d.rst index 874ee7606d..2446dabeba 100644 --- a/changes/9022.extract_1d.rst +++ b/changes/9022.extract_1d.rst @@ -1 +1,2 @@ -Added on option to calculate a source trace from WCS and expected source positions for NIRSpec BOTS/MOS/FS data and perform a box extraction around that trace. This option is turned on by default for the BOTS mode. +Expanded the ``use_source_posn`` option to calculate a source trace from WCS and expected source positions for unresampled NIRSpec and MIRI LRS fixed slit data. +Added the step parameter ``position_offset`` to allow an additional aperture offset in pixels. diff --git a/docs/jwst/extract_1d/arguments.rst b/docs/jwst/extract_1d/arguments.rst index 8637bb4694..5d3b47e7ae 100644 --- a/docs/jwst/extract_1d/arguments.rst +++ b/docs/jwst/extract_1d/arguments.rst @@ -26,21 +26,13 @@ Step Arguments for Slit and Slitless Spectroscopic Data file should be shifted to account for the expected position of the source. If None (the default), the step will decide whether to use the source position based on the observing mode and the source type. By default, source position corrections - are attempted only for NIRSpec MOS and NIRSpec and MIRI LRS fixed-slit point sources. + are attempted only for point sources in NIRSpec MOS/FS/BOTS and MIRI LRS fixed-slit exposures. Set to False to ignore position estimates for all modes; set to True to additionally attempt - source position correction for NIRSpec BOTS data or extended sources. + source position correction for extended sources. -``--use_trace`` - Specify whether to calculate a 2D trace and extract the pixels around that trace - within the ``extract_width`` defined in the :ref:`EXTRACT1D `. - If None (the default), the step will decide whether to use the source position based - on the observing mode. Currently this option is only available for NIRSpec BOTS, MOS, - and fixed-slit modes, where the trace will be calculated from the expected positions - of the sources. This option is used by default for the NIRSpec BOTS mode. - -``--trace_offset`` +``--position_offset`` Specify a number of pixels (fractional pixels are allowed) to offset the - calculated trace if ``use_trace`` is set to True. The default is 0. + extraction aperture from the nominal position. The default is 0. ``--smoothing_length`` If ``smoothing_length`` is greater than 1 (and is an odd integer), the diff --git a/jwst/extract_1d/extract.py b/jwst/extract_1d/extract.py index 5d7ae1e642..f78b826e0f 100644 --- a/jwst/extract_1d/extract.py +++ b/jwst/extract_1d/extract.py @@ -4,6 +4,7 @@ from json.decoder import JSONDecodeError from astropy.modeling import polynomial +from gwcs.wcstools import grid_from_bounding_box from scipy.interpolate import interp1d from stdatamodels.jwst import datamodels from stdatamodels.jwst.datamodels.apcorr import ( @@ -20,7 +21,7 @@ __all__ = ['run_extract1d', 'read_extract1d_ref', 'read_apcorr_ref', 'get_extract_parameters', 'box_profile', 'aperture_center', - 'location_from_wcs', 'shift_by_source_location', 'nirspec_trace_from_wcs', + 'location_from_wcs', 'shift_by_offset', 'define_aperture', 'extract_one_slit', 'create_extraction'] @@ -30,6 +31,9 @@ WFSS_EXPTYPES = ['NIS_WFSS', 'NRC_WFSS', 'NRC_GRISM'] """Exposure types to be regarded as wide-field slitless spectroscopy.""" +SRCPOS_EXPTYPES = ['MIR_LRS-FIXEDSLIT', 'NRS_FIXEDSLIT', 'NRS_MSASPEC', 'NRS_BRIGHTOBJ'] +"""Exposure types for which source position can be estimated.""" + ANY = "ANY" """Wildcard for slit name. @@ -142,7 +146,7 @@ def read_apcorr_ref(refname, exptype): def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, smoothing_length=None, bkg_fit=None, bkg_order=None, use_source_posn=None, subtract_background=None, - use_trace=None, trace_offset=0): + position_offset=0.0): """Get extraction parameter values. Parameters @@ -205,15 +209,8 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, subtract_background : bool or None, optional If False, all background parameters will be ignored. - use_trace : bool or None, optional - If True the trace of the source will be calculated and used for - 1D extraction. - If None, the value specified in `ref_dict` will be used. - Otherwise it it will be set to False for all exposure types - except NIRSpec BOTS data. - - trace_offset : float or None, optional - Pixel offset to apply to the automatically calucated trace. + position_offset : float or None, optional + Pixel offset to apply to the nominal source location. If None, the value specified in `ref_dict` will be used or it will default to 0. @@ -246,8 +243,7 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, extract_params['use_source_posn'] = False # no source position correction extract_params['position_correction'] = 0 extract_params['independent_var'] = 'pixel' - extract_params['use_trace'] = False - extract_params['trace_offset'] = 0. + extract_params['position_offset'] = 0. extract_params['trace'] = None # Note that extract_params['dispaxis'] is not assigned. # This will be done later, possibly slit by slit. @@ -322,7 +318,7 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, if use_source_posn is None: # no value set on command line if use_source_posn_aper is None: # no value set in ref file # Use a suitable default - if meta.exposure.type in ['MIR_LRS-FIXEDSLIT', 'NRS_FIXEDSLIT', 'NRS_MSASPEC']: + if meta.exposure.type in SRCPOS_EXPTYPES: use_source_posn = True log.info(f"Turning on source position correction " f"for exp_type = {meta.exposure.type}") @@ -331,21 +327,7 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, else: # use the value from the ref file use_source_posn = use_source_posn_aper extract_params['use_source_posn'] = use_source_posn - - - use_trace_aper = aper.get('use_trace', None) # value from the extract1d ref file - if use_trace is None: # no value set on command line - if use_trace_aper is None: # no value set in ref file - # Use a suitable default - if meta.exposure.type in ['NRS_BRIGHTOBJ',]: - use_trace = True - log.info(f"Turning on trace extraction exp_type = {meta.exposure.type}") - else: - use_trace = False - else: # use the value from the ref file - use_trace = use_trace_aper - extract_params['use_trace'] = use_trace - extract_params['trace_offset'] = trace_offset + extract_params['position_offset'] = position_offset extract_params['trace'] = None extract_params['extract_width'] = aper.get('extract_width') @@ -396,21 +378,10 @@ def log_initial_parameters(extract_params): return log.debug("Extraction parameters:") - log.debug(f"dispaxis = {extract_params['dispaxis']}") - log.debug(f"spectral order = {extract_params['spectral_order']}") - log.debug(f"initial xstart = {extract_params['xstart']}") - log.debug(f"initial xstop = {extract_params['xstop']}") - log.debug(f"initial ystart = {extract_params['ystart']}") - log.debug(f"initial ystop = {extract_params['ystop']}") - log.debug(f"extract_width = {extract_params['extract_width']}") - log.debug(f"initial src_coeff = {extract_params['src_coeff']}") - log.debug(f"initial bkg_coeff = {extract_params['bkg_coeff']}") - log.debug(f"bkg_fit = {extract_params['bkg_fit']}") - log.debug(f"bkg_order = {extract_params['bkg_order']}") - log.debug(f"smoothing_length = {extract_params['smoothing_length']}") - log.debug(f"independent_var = {extract_params['independent_var']}") - log.debug(f"use_source_posn = {extract_params['use_source_posn']}") - log.debug(f"extraction_type = {extract_params['extraction_type']}") + skip_keys = {'match', 'trace'} + for key, value in extract_params.items(): + if key not in skip_keys: + log.debug(f" {key} = {value}") def create_poly(coeff): @@ -826,8 +797,8 @@ def box_profile(shape, extract_params, wl_array, coefficients='src_coeff', # Set aperture region, in this priority order: # 1. src_coeff upper and lower limits (or bkg_coeff, for background profile) - # 2. Using a trace +/- the extraction width - # 3. center of start/stop values +/- extraction width + # 2. trace +/- extraction width / 2 + # 3. center of start/stop values +/- extraction width / 2 # 4. start/stop values profile = np.full(shape, 0.0) if extract_params[coefficients] is not None: @@ -887,17 +858,19 @@ def box_profile(shape, extract_params, wl_array, coefficients='src_coeff', width = extract_params['extract_width'] trace = extract_params['trace'] + if extract_params['dispaxis'] != HORIZONTAL: + trace = np.tile(trace, (shape[1], 1)).T + lower_limit_region = trace - (width - 1.0) / 2.0 upper_limit_region = lower_limit_region + width - 1 _set_weight_from_limits(profile, dval, lower_limit_region, upper_limit_region) - lower_limit = np.mean(lower_limit_region) - upper_limit = np.mean(upper_limit_region) + lower_limit = np.nanmean(lower_limit_region) + upper_limit = np.nanmean(upper_limit_region) log.info(f'Mean {label} start/stop from trace: ' - f'{lower_limit:.2f} -> {upper_limit:.2f} (inclusive)') - + f'{lower_limit:.2f} -> {upper_limit:.2f} (inclusive)') elif extract_params['extract_width'] is not None: # Limits from extraction width at center of ystart/stop if present, @@ -1020,7 +993,6 @@ def location_from_wcs(input_model, slit): ---------- input_model : DataModel The input science model containing metadata information. - slit : DataModel or None One slit from a MultiSlitModel (or similar), or None. The WCS and target coordinates will be retrieved from `slit` @@ -1036,29 +1008,29 @@ def location_from_wcs(input_model, slit): nominal extraction location, in case it varies along the spectrum. The offset will then be the difference between `location` (below) and the nominal location. - middle_wl : float or None The wavelength at pixel `middle`. - location : float or None Pixel coordinate in the cross-dispersion direction within the spectral image that is at the planned target location. The spectral extraction region should be centered here. + trace : ndarray or None + An array of source positions, one per dispersion element, corresponding + to the location at each point in the wavelength array. If the + input data is resampled, the trace corresponds directly to the + location. """ if slit is not None: - wcs_source = slit + shape = slit.data.shape[-2:] + wcs = slit.meta.wcs + dispaxis = slit.meta.wcsinfo.dispersion_direction else: - wcs_source = input_model - wcs = wcs_source.meta.wcs - dispaxis = wcs_source.meta.wcsinfo.dispersion_direction + shape = input_model.data.shape[-2:] + wcs = input_model.meta.wcs + dispaxis = input_model.meta.wcsinfo.dispersion_direction bb = wcs.bounding_box # ((x0, x1), (y0, y1)) if bb is None: - if slit is None: - shape = input_model.data.shape - else: - shape = slit.data.shape - bb = wcs_bbox_from_shape(shape) if dispaxis == HORIZONTAL: @@ -1084,13 +1056,15 @@ def location_from_wcs(input_model, slit): lower = bb[0][0] upper = bb[0][1] - # We need transform[2], a 1-D array of wavelengths crossing the spectrum - # near its middle. + # Get the wavelengths for the valid data in the sky transform, + # average to get the middle wavelength fwd_transform = wcs(x, y) middle_wl = np.nanmean(fwd_transform[2]) exp_type = input_model.meta.exposure.type + trace = None if exp_type in ['NRS_FIXEDSLIT', 'NRS_MSASPEC', 'NRS_BRIGHTOBJ']: + log.info("Using source_xpos and source_ypos to center extraction.") if slit is None: xpos = input_model.source_xpos ypos = input_model.source_ypos @@ -1101,12 +1075,15 @@ def location_from_wcs(input_model, slit): slit2det = wcs.get_transform('slit_frame', 'detector') if 'gwa' in wcs.available_frames: # Input is not resampled, wavelengths need to be meters - x_y = slit2det(xpos, ypos, middle_wl * 1e-6) + _, location = slit2det(xpos, ypos, middle_wl * 1e-6) else: - x_y = slit2det(xpos, ypos, middle_wl) - log.info("Using source_xpos and source_ypos to center extraction.") + _, location = slit2det(xpos, ypos, middle_wl) + + if ~np.isnan(location): + trace = _nirspec_trace_from_wcs(shape, bb, wcs, xpos, ypos) elif exp_type == 'MIR_LRS-FIXEDSLIT': + log.info("Using dithered_ra and dithered_dec to center extraction.") try: if slit is None: dithra = input_model.meta.dither.dithered_ra @@ -1114,23 +1091,21 @@ def location_from_wcs(input_model, slit): else: dithra = slit.meta.dither.dithered_ra dithdec = slit.meta.dither.dithered_dec - x_y = wcs.backward_transform(dithra, dithdec, middle_wl) + location, _ = wcs.backward_transform(dithra, dithdec, middle_wl) + except (AttributeError, TypeError): log.warning("Dithered pointing location not found in wcsinfo.") - return None, None, None - else: - log.warning(f"Source position cannot be found for EXP_TYPE {exp_type}") - return None, None, None + return None, None, None, None - # location is the XD location of the spectrum: - if dispaxis == HORIZONTAL: - location = x_y[1] + if ~np.isnan(location): + trace = _miri_trace_from_wcs(shape, bb, wcs, dithra, dithdec) else: - location = x_y[0] + log.warning(f"Source position cannot be found for EXP_TYPE {exp_type}") + return None, None, None, None if np.isnan(location): log.warning('Source position could not be determined from WCS.') - return None, None, None + return None, None, None, None # If the target is at the edge of the image or at the edge of the # non-NaN area, we can't use the WCS to find the @@ -1138,48 +1113,38 @@ def location_from_wcs(input_model, slit): if location < lower or location > upper: log.warning(f"WCS implies the target is at {location:.2f}, which is outside the bounding box,") log.warning("so we can't get spectrum location using the WCS") - return None, None, None - - return middle, middle_wl, location + return None, None, None, None + return middle, middle_wl, location, trace -def shift_by_source_location(location, nominal_location, extract_params): - """Shift the nominal extraction parameters by the source location. - The offset applied is `location` - `nominal_location`, along - the cross-dispersion direction. +def shift_by_offset(offset, extract_params, update_trace=True): + """Shift the nominal extraction parameters by a pixel offset. Start, stop, and polynomial coefficient values for source and background are updated in place in the `extract_params` dictionary. + The source trace value, if present, is also updated if desired. Parameters ---------- - location : float - The source location in the cross-dispersion direction - at which to center the extraction aperture. - nominal_location : float - The center of the nominal extraction aperture, in the - cross-dispersion direction, according to the extraction - parameters. + offset : float + Cross-dispersion offset to apply, in pixels. extract_params : dict Extraction parameters to update, as created by - `get_extraction_parameters`, and corresponding to the - specified nominal location. + `get_extraction_parameters`. + update_trace : bool + If True, the trace in `extract_params['trace']` is also updated + if present. """ - - # Get the center of the nominal aperture - offset = location - nominal_location - log.info(f"Nominal location is {nominal_location:.2f}, " - f"so offset is {offset:.2f} pixels") - - # Shift aperture limits by the difference between the - # source location and the nominal center + # Shift polynomial coefficients coeff_params = ['src_coeff', 'bkg_coeff'] for params in coeff_params: if extract_params[params] is not None: for coeff_list in extract_params[params]: coeff_list[0] += offset + + # Shift start/stop values if extract_params['dispaxis'] == HORIZONTAL: start_stop_params = ['ystart', 'ystop'] else: @@ -1188,7 +1153,12 @@ def shift_by_source_location(location, nominal_location, extract_params): if extract_params[params] is not None: extract_params[params] += offset -def nirspec_trace_from_wcs(input_model, slit, trace_offset=0): + # Shift the full trace + if update_trace and extract_params['trace'] is not None: + extract_params['trace'] += offset + + +def _nirspec_trace_from_wcs(shape, bounding_box, wcs_ref, source_xpos, source_ypos): """Calculate NIRSpec source trace from WCS. The source trace is calculated by projecting the recorded source @@ -1197,18 +1167,19 @@ def nirspec_trace_from_wcs(input_model, slit, trace_offset=0): Parameters ---------- - input_model : DataModel - The input science model containing metadata information. - - slit : DataModel or None - One slit from a MultiSlitModel (or similar), or None. - The WCS and target coordinates will be retrieved from `slit` - unless `slit` is None. In that case, they will be retrieved - from `input_model`. - - trace_offset : int, optional - Signed number of pixels to offset the trace in the cross- - dispersion direction, by default 0. + shape : tuple of int + 2D shape for the full input data array, (ny, nx). + bounding_box : tuple + A pair of tuples, each consisting of two numbers. + Represents the range of useful pixel values in both dimensions, + ((xmin, xmax), (ymin, ymax)). + wcs_ref : `~gwcs.WCS` + WCS for the input data model, containing slit and detector + transforms. + source_xpos : float + Slit position, in the x direction, for the target. + source_ypos : float + Slit position, in the y direction, for the target. Returns ------- @@ -1216,27 +1187,12 @@ def nirspec_trace_from_wcs(input_model, slit, trace_offset=0): Fractional pixel positions in the y (cross-dispersion direction) of the trace for each x (dispersion direction) pixel. """ - if slit is None: - shape = input_model.data.shape - wcs_ref = input_model.meta.wcs - source_xpos = getattr(input_model, "source_xpos", 0.0) - source_ypos = getattr(input_model, "source_ypos", 0.0) - else: - shape = slit.data.shape - wcs_ref = slit.meta.wcs - source_xpos = getattr(slit, "source_xpos", 0.0) - source_ypos = getattr(slit, "source_ypos", 0.0) - - log.info(f"Determining the spectral trace for source_xpos = {source_xpos:.3f} and source_xpos = {source_ypos:.3f}") - - nx = shape[-1] - ny = shape[-2] - y, x = np.meshgrid(np.arange(ny), np.arange(nx), indexing="ij") - - d2s = wcs_ref.get_transform("detector", "slit_frame") + x, y = grid_from_bounding_box(bounding_box) + nx = int(bounding_box[0][1] - bounding_box[0][0]) # Calculate the wavelengths in the slit frame because they are in # meters for cal files and um for s2d files + d2s = wcs_ref.get_transform("detector", "slit_frame") _, _, slit_wavelength = d2s(x,y) # Make an initial array of wavelengths that will cover the wavelength range of the data @@ -1256,11 +1212,74 @@ def nirspec_trace_from_wcs(input_model, slit, trace_offset=0): # direction interp_trace = interp1d(trace_x, trace_y, fill_value='extrapolate') - # Shift the trace if a shift has been supplied + # Get the trace position for each dispersion element trace = interp_trace(np.arange(nx)) - trace += trace_offset - return trace + # Place the trace in the full array + full_trace = np.full(shape[1], np.nan) + x0 = int(np.ceil(bounding_box[0][0])) + full_trace[x0:x0 + nx] = trace + + return full_trace + + +def _miri_trace_from_wcs(shape, bounding_box, wcs_ref, source_ra, source_dec): + """Calculate MIRI LRS fixed slit source trace from WCS. + + The source trace is calculated by projecting the recorded source + positions dithered_ra/dec from the world frame onto detector pixels. + + Parameters + ---------- + shape : tuple of int + 2D shape for the full input data array, (ny, nx). + bounding_box : tuple + A pair of tuples, each consisting of two numbers. + Represents the range of useful pixel values in both dimensions, + ((xmin, xmax), (ymin, ymax)). + wcs_ref : `~gwcs.WCS` + WCS for the input data model, containing sky and detector + transforms, forward and backward. + source_ra : float + RA coordinate for the target. + source_dec : float + Dec coordinate for the target. + + Returns + ------- + trace : ndarray of float + Fractional pixel positions in the x (cross-dispersion direction) + of the trace for each y (dispersion direction) pixel. + """ + x, y = grid_from_bounding_box(bounding_box) + ny = int(bounding_box[1][1] - bounding_box[1][0]) + + # Calculate the wavelengths for the full array + _, _, slit_wavelength = wcs_ref(x, y) + + # Make an initial array of wavelengths that will cover the wavelength range of the data + wave_vals = np.linspace(np.nanmin(slit_wavelength), np.nanmax(slit_wavelength), ny) + + # Get arrays of the source position + pos_ra = np.full(ny, source_ra) + pos_dec = np.full(ny, source_dec) + + # Calculate the expected center of the source trace + trace_x, trace_y = wcs_ref.backward_transform(pos_ra, pos_dec, wave_vals) + + # Interpolate the trace to a regular pixel grid in the dispersion + # direction + interp_trace = interp1d(trace_y, trace_x, fill_value='extrapolate') + + # Get the trace position for each dispersion element within the bounding box + trace = interp_trace(np.arange(ny)) + + # Place the trace in the full array + full_trace = np.full(shape[0], np.nan) + y0 = int(np.ceil(bounding_box[1][0])) + full_trace[y0:y0 + ny] = trace + + return full_trace def define_aperture(input_model, slit, extract_params, exp_type): @@ -1319,20 +1338,11 @@ def define_aperture(input_model, slit, extract_params, exp_type): # Get a wavelength array for the data wl_array = get_wavelengths(data_model, exp_type, extract_params['spectral_order']) - # Calculate a trace from WCS and update extract parameters - if extract_params['use_trace']: - if exp_type in ['NRS_BRIGHTOBJ', 'NRS_FIXEDSLIT', 'NRS_MSASPEC']: - trace = nirspec_trace_from_wcs( - input_model, slit, trace_offset=extract_params['trace_offset']) - extract_params['trace'] = trace - else: - log.warning('Calculating a trace is not supported for this mode.') - # Shift aperture definitions by source position if needed # Extract parameters are updated in place - if extract_params['use_source_posn'] and extract_params['trace'] is None: + if extract_params['use_source_posn']: # Source location from WCS - middle_pix, middle_wl, location = location_from_wcs(input_model, slit) + middle_pix, middle_wl, location, trace = location_from_wcs(input_model, slit) if location is not None: log.info(f"Computed source location is {location:.2f}, " @@ -1344,9 +1354,22 @@ def define_aperture(input_model, slit, extract_params, exp_type): nominal_location, _ = aperture_center( nominal_profile, extract_params['dispaxis'], middle_pix=middle_pix) - # Offet extract parameters by location - nominal - shift_by_source_location(location, nominal_location, extract_params) + # Offset extract parameters by location - nominal + offset = location - nominal_location + log.info(f"Nominal location is {nominal_location:.2f}, " + f"so offset is {offset:.2f} pixels") + shift_by_offset(offset, extract_params, update_trace=False) + else: + middle_pix, middle_wl, location, trace = None, None, None, None + + # Store the trace, if computed + extract_params['trace'] = trace + # Add an extra position offset if desired, from extract_params['position_offset'] + offset = extract_params.get('position_offset', 0.0) + if offset != 0.0: + log.info(f"Applying additional cross-dispersion offset {offset:.2f} pixels") + shift_by_offset(offset, extract_params, update_trace=True) # Make a spatial profile, including source shifts if necessary profile, lower_limit, upper_limit = box_profile( @@ -1947,7 +1970,7 @@ def create_extraction(input_model, slit, output_model, def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, smoothing_length=None, bkg_fit=None, bkg_order=None, log_increment=50, subtract_background=None, - use_source_posn=None, use_trace=None, trace_offset=0, + use_source_posn=None, position_offset=0.0, save_profile=False, save_scene_model=False): """Extract all 1-D spectra from an input model. @@ -1983,12 +2006,9 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, If True, the target and background positions specified in the reference file (or the default position, if there is no reference file) will be shifted to account for source position offset. - use_trace : bool or None - If True, a source trace will be calculated and used for 1D extraction. - If None, the value in the extract_1d reference file will be used. - trace_offset : float - Number of pixels to shift the calculated trace in the cross-dispersion - direction if use_trace is True. + position_offset : float + Number of pixels to shift the nominal source position in the + cross-dispersion direction. save_profile : bool If True, the spatial profiles created for the input model will be returned as ImageModels. If False, the return value is None. @@ -2084,9 +2104,9 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, save_profile=save_profile, save_scene_model=save_scene_model, smoothing_length=smoothing_length, bkg_fit=bkg_fit, bkg_order=bkg_order, - use_source_posn=use_source_posn, subtract_background=subtract_background, - use_trace=use_trace, trace_offset=trace_offset) + use_source_posn=use_source_posn, + position_offset=position_offset) except ContinueError: continue @@ -2121,9 +2141,9 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, save_profile=save_profile, save_scene_model=save_scene_model, smoothing_length=smoothing_length, bkg_fit=bkg_fit, bkg_order=bkg_order, - use_source_posn=use_source_posn, subtract_background=subtract_background, - use_trace=use_trace, trace_offset=trace_offset) + use_source_posn=use_source_posn, + position_offset=position_offset) except ContinueError: pass @@ -2160,9 +2180,9 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, save_profile=save_profile, save_scene_model=save_scene_model, smoothing_length=smoothing_length, bkg_fit=bkg_fit, bkg_order=bkg_order, - use_source_posn=use_source_posn, subtract_background=subtract_background, - use_trace=use_trace, trace_offset=trace_offset) + use_source_posn=use_source_posn, + position_offset=position_offset) except ContinueError: pass diff --git a/jwst/extract_1d/extract_1d_step.py b/jwst/extract_1d/extract_1d_step.py index afde247cf6..3738ee51fc 100644 --- a/jwst/extract_1d/extract_1d_step.py +++ b/jwst/extract_1d/extract_1d_step.py @@ -162,8 +162,7 @@ class Extract1dStep(Step): apply_apcorr = boolean(default=True) # apply aperture corrections? use_source_posn = boolean(default=None) # use source coords to center extractions? - use_trace = boolean(default=None) # use source trace for extraction - trace_offset = float(default=0) # number of pixels to shift source trace in the cross-dispersion direction + position_offset = float(default=0) # number of pixels to shift source trace in the cross-dispersion direction smoothing_length = integer(default=None) # background smoothing size bkg_fit = option("poly", "mean", "median", None, default=None) # background fitting type bkg_order = integer(default=None, min=0) # order of background polynomial fit @@ -424,8 +423,7 @@ def process(self, input): self.log_increment, self.subtract_background, self.use_source_posn, - self.use_trace, - self.trace_offset, + self.position_offset, self.save_profile, self.save_scene_model, ) diff --git a/jwst/extract_1d/tests/conftest.py b/jwst/extract_1d/tests/conftest.py index 467f03f385..00456e991f 100644 --- a/jwst/extract_1d/tests/conftest.py +++ b/jwst/extract_1d/tests/conftest.py @@ -34,9 +34,24 @@ def simple_wcs_function(x, y): # Add a bounding box simple_wcs_function.bounding_box = wcs_bbox_from_shape(shape) - # Add a few expected attributes, so they can be monkeypatched as needed - simple_wcs_function.get_transform = None - simple_wcs_function.backward_transform = None + # Define a simple transform + def get_transform(*args, **kwargs): + def return_results(*args, **kwargs): + if len(args) == 2: + zeros = np.zeros(args[0].shape) + wave, _ = np.meshgrid(args[0], args[1]) + return zeros, zeros, wave + if len(args) == 3: + try: + nx = len(args[0]) + except TypeError: + nx = 1 + pix = np.arange(nx) + trace = np.ones(nx) + return pix, trace + return return_results + + simple_wcs_function.get_transform = get_transform simple_wcs_function.available_frames = [] return simple_wcs_function @@ -68,10 +83,17 @@ def simple_wcs_function(x, y): # Add a bounding box simple_wcs_function.bounding_box = wcs_bbox_from_shape(shape) - # Add a few expected attributes, so they can be monkeypatched as needed - simple_wcs_function.get_transform = None - simple_wcs_function.backward_transform = None - simple_wcs_function.available_frames = [] + # Mock a simple backward transform + def backward_transform(*args, **kwargs): + try: + nx = len(args[0]) + except TypeError: + nx = 1 + pix = np.arange(nx) + trace = np.ones(nx) + return trace, pix + + simple_wcs_function.backward_transform = backward_transform return simple_wcs_function diff --git a/jwst/extract_1d/tests/test_extract.py b/jwst/extract_1d/tests/test_extract.py index 618f6c1dc7..9e4dbd5baa 100644 --- a/jwst/extract_1d/tests/test_extract.py +++ b/jwst/extract_1d/tests/test_extract.py @@ -31,7 +31,7 @@ def extract1d_ref_dict(): {'id': 'slit6', 'use_source_posn': True}, {'id': 'slit7', 'spectral_order': 20}, {'id': 'S200A1'}, - {'id': 'S1600A1'} + {'id': 'S1600A1', 'use_source_posn': False} ] ref_dict = {'apertures': apertures} return ref_dict @@ -59,10 +59,9 @@ def extract_defaults(): 'spectral_order': 1, 'src_coeff': None, 'subtract_background': False, - 'trace_offset': 0, + 'position_offset': 0.0, 'trace': None, 'use_source_posn': False, - 'use_trace': False, 'xstart': 0, 'xstop': 49, 'ystart': 0, @@ -141,7 +140,7 @@ def test_get_extract_parameters_default( extract1d_ref_dict, input_model, 'slit1', 1, input_model.meta) # returned value has defaults except that use_source_posn - # is switched on for NRS_FIXEDSLIT and use_trace is False + # is switched on for NRS_FIXEDSLIT expected = extract_defaults expected['use_source_posn'] = True @@ -201,13 +200,14 @@ def test_get_extract_parameters_no_match( def test_get_extract_parameters_source_posn_exptype( mock_nirspec_bots, extract1d_ref_dict, extract_defaults): input_model = mock_nirspec_bots + input_model.meta.exposure.type = 'NRS_LAMP' # match a bare entry params = ex.get_extract_parameters( extract1d_ref_dict, input_model, 'slit1', 1, input_model.meta, use_source_posn=None) - # use_source_posn is switched off for NRS_BRIGHTOBJ + # use_source_posn is switched off for unknown exptypes assert params['use_source_posn'] is False @@ -753,6 +753,45 @@ def test_box_profile_from_width(extract_defaults, dispaxis): assert np.all(profile[8:] == 0.0) +@pytest.mark.parametrize('dispaxis', [1, 2]) +def test_box_profile_from_trace(extract_defaults, dispaxis): + shape = (10, 10) + wl_array = np.empty(shape) + wl_array[:] = np.linspace(3, 5, 10) + + params = extract_defaults + params['dispaxis'] = dispaxis + + # Set a linear trace + params['trace'] = np.arange(10) + 1.5 + + # Set the width to 4 pixels + params['extract_width'] = 4.0 + + # Make the profile + profile, lower, upper = ex.box_profile( + shape, extract_defaults, wl_array, return_limits=True) + if dispaxis == 2: + profile = profile.T + + expected = [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]] + + assert np.allclose(profile, expected) + + # upper and lower limits are averages + assert np.isclose(lower, 4.5) + assert np.isclose(upper, 7.5) + + @pytest.mark.parametrize('middle', [None, 7]) @pytest.mark.parametrize('dispaxis', [1, 2]) def test_aperture_center(middle, dispaxis): @@ -822,27 +861,19 @@ def test_location_from_wcs_nirspec( monkeypatch, mock_nirspec_fs_one_slit, resampled, is_slit, missing_bbox): model = mock_nirspec_fs_one_slit - # monkey patch in a transform for the wcs - def slit2det(*args, **kwargs): - def return_one(*args, **kwargs): - return 0.0, 1.0 - return return_one - - monkeypatch.setattr(model.meta.wcs, 'get_transform', slit2det) - if not resampled: - # also mock available frames, so it looks like unresampled cal data + # mock available frames, so it looks like unresampled cal data monkeypatch.setattr(model.meta.wcs, 'available_frames', ['gwa']) if missing_bbox: - # also mock a missing bounding box - should have same results + # mock a missing bounding box - should have same results # for the test data monkeypatch.setattr(model.meta.wcs, 'bounding_box', None) if is_slit: - middle, middle_wl, location = ex.location_from_wcs(model, model) + middle, middle_wl, location, trace = ex.location_from_wcs(model, model) else: - middle, middle_wl, location = ex.location_from_wcs(model, None) + middle, middle_wl, location, trace = ex.location_from_wcs(model, None) # middle pixel is center of dispersion axis assert middle == int((model.data.shape[1] - 1) / 2) @@ -853,6 +884,9 @@ def return_one(*args, **kwargs): # location is 1.0 - from the mocked transform function assert location == 1.0 + # trace is the same, in an array + assert np.all(trace == 1.0) + @pytest.mark.parametrize('is_slit', [True, False]) def test_location_from_wcs_miri(monkeypatch, mock_miri_lrs_fs, is_slit): @@ -866,11 +900,17 @@ def return_one(*args, **kwargs): monkeypatch.setattr(model.meta.wcs, 'backward_transform', radec2det()) + # mock the trace function + def mock_trace(*args, **kwargs): + return np.full(model.data.shape[-2], 1.0) + + monkeypatch.setattr(ex, '_miri_trace_from_wcs', mock_trace) + # Get the slit center from the WCS if is_slit: - middle, middle_wl, location = ex.location_from_wcs(model, model) + middle, middle_wl, location, trace = ex.location_from_wcs(model, model) else: - middle, middle_wl, location = ex.location_from_wcs(model, None) + middle, middle_wl, location, trace = ex.location_from_wcs(model, None) # middle pixel is center of dispersion axis assert middle == int((model.data.shape[0] - 1) / 2) @@ -881,12 +921,18 @@ def return_one(*args, **kwargs): # location is 1.0 - from the mocked transform function assert location == 1.0 + # trace is the same, in an array + assert np.all(trace == 1.0) + def test_location_from_wcs_missing_data(mock_miri_lrs_fs, log_watcher): + model = mock_miri_lrs_fs + model.meta.wcs.backward_transform = None + # model is missing WCS information - None values are returned log_watcher.message = "Dithered pointing location not found" - result = ex.location_from_wcs(mock_miri_lrs_fs, None) - assert result == (None, None, None) + result = ex.location_from_wcs(model, None) + assert result == (None, None, None, None) log_watcher.assert_seen() @@ -894,7 +940,7 @@ def test_location_from_wcs_wrong_exptype(mock_niriss_soss, log_watcher): # model is not a handled exposure type log_watcher.message = "Source position cannot be found for EXP_TYPE" result = ex.location_from_wcs(mock_niriss_soss, None) - assert result == (None, None, None) + assert result == (None, None, None, None) log_watcher.assert_seen() @@ -913,7 +959,7 @@ def return_one(*args, **kwargs): # WCS transform returns NaN for the location log_watcher.message = "Source position could not be determined" result = ex.location_from_wcs(model, None) - assert result == (None, None, None) + assert result == (None, None, None, None) log_watcher.assert_seen() @@ -929,84 +975,100 @@ def return_one(*args, **kwargs): monkeypatch.setattr(model.meta.wcs, 'get_transform', slit2det) + # mock the trace function + def mock_trace(*args, **kwargs): + return np.full(model.data.shape[-1], 1.0) + + monkeypatch.setattr(ex, '_nirspec_trace_from_wcs', mock_trace) + # WCS transform a value outside the bounding box log_watcher.message = "outside the bounding box" result = ex.location_from_wcs(model, None) - assert result == (None, None, None) + assert result == (None, None, None, None) log_watcher.assert_seen() -def test_shift_by_source_location_horizontal(extract_defaults): - location = 12.5 - nominal_location = 15.0 - offset = location - nominal_location +def test_shift_by_offset_horizontal(extract_defaults): + offset = 2.5 extract_params = extract_defaults.copy() extract_params['dispaxis'] = 1 + extract_params['position_offset'] = offset - ex.shift_by_source_location(location, nominal_location, extract_params) + ex.shift_by_offset(offset, extract_params) assert extract_params['xstart'] == extract_defaults['xstart'] assert extract_params['xstop'] == extract_defaults['xstop'] assert extract_params['ystart'] == extract_defaults['ystart'] + offset assert extract_params['ystop'] == extract_defaults['ystop'] + offset -def test_shift_by_source_location_vertical(extract_defaults): - location = 12.5 - nominal_location = 15.0 - offset = location - nominal_location +def test_shift_by_offset_vertical(extract_defaults): + offset = 2.5 extract_params = extract_defaults.copy() extract_params['dispaxis'] = 2 + extract_params['position_offset'] = offset - ex.shift_by_source_location(location, nominal_location, extract_params) + ex.shift_by_offset(offset, extract_params) assert extract_params['xstart'] == extract_defaults['xstart'] + offset assert extract_params['xstop'] == extract_defaults['xstop'] + offset assert extract_params['ystart'] == extract_defaults['ystart'] assert extract_params['ystop'] == extract_defaults['ystop'] -def test_shift_by_source_location_coeff(extract_defaults): - location = 6.5 - nominal_location = 4.0 - offset = location - nominal_location +def test_shift_by_offset_coeff(extract_defaults): + offset = 2.5 extract_params = extract_defaults.copy() extract_params['dispaxis'] = 1 + extract_params['position_offset'] = offset extract_params['src_coeff'] = [[2.5, 1.0], [6.5, 1.0]] extract_params['bkg_coeff'] = [[-0.5], [3.0], [6.0], [9.5]] - ex.shift_by_source_location(location, nominal_location, extract_params) + ex.shift_by_offset(offset, extract_params) assert extract_params['src_coeff'] == [[2.5 + offset, 1.0], [6.5 + offset, 1.0]] assert extract_params['bkg_coeff'] == [[-0.5 + offset], [3.0 + offset], [6.0 + offset], [9.5 + offset]] -@pytest.mark.parametrize('is_slit', [True, False]) -def test_nirspec_trace_from_wcs( - monkeypatch, mock_nirspec_fs_one_slit, is_slit): - model = mock_nirspec_fs_one_slit +def test_shift_by_offset_trace(extract_defaults): + offset = 2.5 - # monkey patch in a transform for the wcs - def slit2det(*args, **kwargs): - def return_results(*args, **kwargs): - if len(args) == 2: - zeros = np.zeros(args[0].shape) - wave, _ = np.meshgrid(args[0], args[1]) - return zeros, zeros, wave - if len(args) == 3: - pix = np.arange(len(args[0])) - trace = np.ones(len(args[0])) - return pix, trace - return return_results + extract_params = extract_defaults.copy() + extract_params['dispaxis'] = 1 + extract_params['position_offset'] = offset + extract_params['trace'] = np.arange(10, dtype=float) - monkeypatch.setattr(model.meta.wcs, 'get_transform', slit2det) + ex.shift_by_offset(offset, extract_params, update_trace=True) + assert np.all(extract_params['trace'] == np.arange(10) + offset) - if is_slit: - trace = ex.nirspec_trace_from_wcs(model, model) - else: - trace = ex.nirspec_trace_from_wcs(model, None) +def test_shift_by_offset_trace_no_update(extract_defaults): + offset = 2.5 + + extract_params = extract_defaults.copy() + extract_params['dispaxis'] = 1 + extract_params['position_offset'] = offset + extract_params['trace'] = np.arange(10, dtype=float) + + ex.shift_by_offset(offset, extract_params, update_trace=False) + assert np.all(extract_params['trace'] == np.arange(10)) + + +def test_nirspec_trace_from_wcs(mock_nirspec_fs_one_slit): + model = mock_nirspec_fs_one_slit + trace = ex._nirspec_trace_from_wcs(model.data.shape, model.meta.wcs.bounding_box, + model.meta.wcs, 1.0, 1.0) + # mocked model contains some mock transforms as well - all ones are expected + assert np.all(trace == np.ones(model.data.shape[-1])) + + +def test_miri_trace_from_wcs(mock_miri_lrs_fs): + model = mock_miri_lrs_fs + trace = ex._miri_trace_from_wcs(model.data.shape, model.meta.wcs.bounding_box, + model.meta.wcs, 1.0, 1.0) + + # mocked model contains some mock transforms as well - all ones are expected assert np.all(trace == np.ones(model.data.shape[-1])) @@ -1125,7 +1187,7 @@ def test_define_aperture_use_source(monkeypatch, mock_nirspec_fs_one_slit, extra # mock the source location function def mock_source_location(*args): - return 24, 7.74, 9.5 + return 24, 7.74, 9.5, np.full(model.data.shape[-1], 9.5) monkeypatch.setattr(ex, 'location_from_wcs', mock_source_location) @@ -1141,6 +1203,24 @@ def mock_source_location(*args): assert np.all(profile[13:] == 0.0) +def test_define_aperture_extra_offset(mock_nirspec_fs_one_slit, extract_defaults): + model = mock_nirspec_fs_one_slit + extract_defaults['dispaxis'] = 1 + slit = None + exptype = 'NRS_FIXEDSLIT' + + extract_defaults['position_offset'] = 2.0 + + result = ex.define_aperture(model, slit, extract_defaults, exptype) + _, _, _, profile, _, limits = result + assert profile.shape == model.data.shape + + # Default profile is shifted 2 pixels up + assert np.all(profile[:2] == 0.0) + assert np.all(profile[2:] == 1.0) + assert limits == (2, model.data.shape[0] + 1, 0, model.data.shape[1] - 1) + + def test_extract_one_slit_horizontal(mock_nirspec_fs_one_slit, extract_defaults, simple_profile, background_profile): # update parameters to subtract background @@ -1346,7 +1426,8 @@ def test_create_extraction_missing_wavelengths(create_extraction_inputs, log_wat model.wavelength = np.full_like(model.data, np.nan) log_watcher.message = 'Spectrum is empty; no valid data' with pytest.raises(ex.ContinueError): - ex.create_extraction(*create_extraction_inputs) + with pytest.warns(RuntimeWarning, match='All-NaN'): + ex.create_extraction(*create_extraction_inputs) log_watcher.assert_seen() @@ -1357,10 +1438,9 @@ def test_create_extraction_nrs_apcorr(create_extraction_inputs, nirspec_fs_apcor model.meta.cal_step.photom = 'COMPLETE' create_extraction_inputs[0] = model - # Set use_trace to false because the mock does not have a WCS log_watcher.message = 'Tabulating aperture correction' ex.create_extraction(*create_extraction_inputs, apcorr_ref_model=nirspec_fs_apcorr, - use_source_posn=False, use_trace=False) + use_source_posn=False) log_watcher.assert_seen() @@ -1369,10 +1449,11 @@ def test_create_extraction_one_int(create_extraction_inputs, mock_nirspec_bots, model = mock_nirspec_bots model.data = model.data[0].reshape(1, *model.data.shape[-2:]) create_extraction_inputs[0] = model + create_extraction_inputs[4] = 'S1600A1' log_watcher.message = '1 integration done' ex.create_extraction( - *create_extraction_inputs, log_increment=1, use_trace=False) + *create_extraction_inputs, log_increment=1) output_model = create_extraction_inputs[2] assert len(output_model.spec) == 1 log_watcher.assert_seen() @@ -1381,10 +1462,11 @@ def test_create_extraction_one_int(create_extraction_inputs, mock_nirspec_bots, def test_create_extraction_log_increment( create_extraction_inputs, mock_nirspec_bots, log_watcher): create_extraction_inputs[0] = mock_nirspec_bots + create_extraction_inputs[4] = 'S1600A1' # all integrations are logged log_watcher.message = '... 9 integrations done' - ex.create_extraction(*create_extraction_inputs, log_increment=1, use_trace=False) + ex.create_extraction(*create_extraction_inputs, log_increment=1) log_watcher.assert_seen() @@ -1399,7 +1481,7 @@ def test_create_extraction_use_source( # mock the source location function def mock_source_location(*args): - return 24, 7.74, 9.5 + return 24, 7.74, 9.5, np.full(model.data.shape[-1], 9.5) monkeypatch.setattr(ex, 'location_from_wcs', mock_source_location) @@ -1417,36 +1499,33 @@ def mock_source_location(*args): log_watcher.assert_seen() -@pytest.mark.parametrize('use_trace', [True, False, None]) @pytest.mark.parametrize('extract_width', [None, 7]) def test_create_extraction_use_trace( monkeypatch, create_extraction_inputs, mock_nirspec_bots, - use_trace, extract_width, log_watcher): + extract_width, log_watcher): model = mock_nirspec_bots create_extraction_inputs[0] = model aper = create_extraction_inputs[3]['apertures'] create_extraction_inputs[4] = 'S1600A1' for i in range(len(aper)): if aper[i]['id'] == 'S1600A1': + aper[i]['use_source_posn'] = True aper[i]['extract_width'] = extract_width - aper[i]['trace_offset'] = 0 + aper[i]['position_offset'] = 0 # mock the source location function - def mock_trace(*args, **kwargs): - return np.full(model.data.shape[-1], 25) - - monkeypatch.setattr(ex, 'nirspec_trace_from_wcs', mock_trace) + def mock_source_location(*args): + return 24, 7.74, 25, np.full(model.data.shape[-1], 25) - if use_trace is not False and extract_width is not None: + monkeypatch.setattr(ex, 'location_from_wcs', mock_source_location) + if extract_width is not None: # If explicitly set to True, or unspecified + source type is POINT, # source position is used log_watcher.message = 'aperture start/stop from trace: 22' - elif extract_width is not None: - log_watcher.message = 'Aperture start/stop: 21.5' else: # If False, source trace is not used log_watcher.message = 'Aperture start/stop: 0' - ex.create_extraction(*create_extraction_inputs, use_trace=use_trace) + ex.create_extraction(*create_extraction_inputs) log_watcher.assert_seen() @@ -1496,7 +1575,6 @@ def test_run_extract1d_save_cube_scene(mock_nirspec_bots): scene_model.close() - def test_run_extract1d_tso(mock_nirspec_bots): model = mock_nirspec_bots output_model, _, _ = ex.run_extract1d(model) diff --git a/jwst/extract_1d/tests/test_extract_1d_step.py b/jwst/extract_1d/tests/test_extract_1d_step.py index 43739015ee..0493679097 100644 --- a/jwst/extract_1d/tests/test_extract_1d_step.py +++ b/jwst/extract_1d/tests/test_extract_1d_step.py @@ -53,7 +53,7 @@ def test_extract_nirspec_mos_multi_slit(mock_nirspec_mos, simple_wcs): def test_extract_nirspec_bots(mock_nirspec_bots, simple_wcs): result = Extract1dStep.call( - mock_nirspec_bots, apply_apcorr=False, use_trace=False) + mock_nirspec_bots, apply_apcorr=False, use_source_posn=False) assert result.meta.cal_step.extract_1d == 'COMPLETE' assert (result.spec[0].name == 'S1600A1') diff --git a/jwst/regtest/test_nirspec_bots_extract1d.py b/jwst/regtest/test_nirspec_bots_extract1d.py index 148b60e41f..d04c7cfbfc 100644 --- a/jwst/regtest/test_nirspec_bots_extract1d.py +++ b/jwst/regtest/test_nirspec_bots_extract1d.py @@ -20,6 +20,7 @@ def run_extract(rtdata_module, request): # Run the calwebb_spec2 pipeline; args = ["extract_1d", rtdata.input, f"--override_extract1d={ref_file}", + "--use_source_posn=False", "--suffix=x1dints"] Step.from_cmdline(args)