diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000000..7e5a9ce091 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,250 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000000..105ce2da2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/jwst.iml b/.idea/jwst.iml new file mode 100644 index 0000000000..524fd84d4c --- /dev/null +++ b/.idea/jwst.iml @@ -0,0 +1,15 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000000..558e5ddba0 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000000..79fc4059f9 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000000..94a25f7f4c --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml new file mode 100644 index 0000000000..880d93f233 --- /dev/null +++ b/.idea/workspace.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + 1681931910213 + + + + + + \ No newline at end of file diff --git a/jwst/extract_1d/extract_1d_step.py b/jwst/extract_1d/extract_1d_step.py index 44c8f71ca2..ac6b99a7dc 100644 --- a/jwst/extract_1d/extract_1d_step.py +++ b/jwst/extract_1d/extract_1d_step.py @@ -426,8 +426,7 @@ def process(self, input): return input_model # Load reference files. - spectrace_ref_name = self.get_reference_file(input_model, 'spectrace') - wavemap_ref_name = self.get_reference_file(input_model, 'wavemap') + pastasoss_ref_name = self.get_reference_file(input_model, 'pastasoss') # Add support for "pastasoss" reffile specprofile_ref_name = self.get_reference_file(input_model, 'specprofile') speckernel_ref_name = self.get_reference_file(input_model, 'speckernel') @@ -452,8 +451,7 @@ def process(self, input): # Run the extraction. result, ref_outputs, atoca_outputs = soss_extract.run_extract1d( input_model, - spectrace_ref_name, - wavemap_ref_name, + pastasoss_ref_name, specprofile_ref_name, speckernel_ref_name, subarray, diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py index 1a23994197..127b6a53e7 100644 --- a/jwst/extract_1d/soss_extract/soss_extract.py +++ b/jwst/extract_1d/soss_extract/soss_extract.py @@ -13,6 +13,7 @@ from .soss_syscor import make_background_mask, soss_background from .soss_solver import solve_transform, transform_wavemap, transform_profile, transform_coords from .atoca import ExtractionEngine, MaskOverlapError +from .soss_wavemaps import get_soss_wavemaps from .atoca_utils import (ThroughputSOSS, WebbKernel, grid_from_map, mask_bad_dispersion_direction, make_combined_adaptive_grid, get_wave_p_or_m, oversample_grid) from .soss_boxextract import get_box_weights, box_extract, estim_error_nearest_data @@ -21,15 +22,12 @@ log.setLevel(logging.DEBUG) -def get_ref_file_args(ref_files, transform): +def get_ref_file_args(ref_files): """Prepare the reference files for the extraction engine. Parameters ---------- ref_files : dict A dictionary of the reference file DataModels. - transform : array or list - A 3-element array describing the rotation and translation to apply - to the reference files in order to match the observation. Returns ------- @@ -39,32 +37,18 @@ def get_ref_file_args(ref_files, transform): """ # The wavelength maps for order 1 and 2. - wavemap_ref = ref_files['wavemap'] + pastasoss_ref = ref_files['pastasoss'] + pwcpos = pastasoss_ref.map[0].pwcpos # Make sure the pastasoss_ref object has these attrs + subarray = pastasoss_ref.map[0].subarray # Make sure the pastasoss_ref object has these attrs + pad = pastasoss_ref.map[0].padding # Make sure the pastasoss_ref object has these attrs - ovs = wavemap_ref.map[0].oversampling - pad = wavemap_ref.map[0].padding + # Use pastasoss to generate the appropriate wavemap for the given PWCPOS + (wavemap_o1, wavemap_o2), (spectrace_o1, spectrace_o2) = get_soss_wavemaps(pwcpos=pwcpos, subarray=subarray, padding=True, padsize=pad) - wavemap_o1 = transform_wavemap(transform, wavemap_ref.map[0].data, ovs, pad) - wavemap_o2 = transform_wavemap(transform, wavemap_ref.map[1].data, ovs, pad) - - # Make sure all pixels follow the expected direction of the dispersion - wavemap_o1, flag_o1 = mask_bad_dispersion_direction(wavemap_o1) - wavemap_o2, flag_o2 = mask_bad_dispersion_direction(wavemap_o2) - - # Warn if not all pixels were corrected - msg_warning = 'Some pixels in order {} do not follow the expected dispersion axis' - if not flag_o1: - log.warning(msg_warning.format(1)) - if not flag_o2: - log.warning(msg_warning.format(2)) - - # The spectral profiles for order 1 and 2. + # The spectral profiles for order 1 and 2 (no transform necessary using pastasoss) specprofile_ref = ref_files['specprofile'] - ovs = specprofile_ref.profile[0].oversampling - pad = specprofile_ref.profile[0].padding - - specprofile_o1 = transform_profile(transform, specprofile_ref.profile[0].data, ovs, pad, norm=False) - specprofile_o2 = transform_profile(transform, specprofile_ref.profile[1].data, ovs, pad, norm=False) + specprofile_o1 = specprofile_ref.profile[0].data + specprofile_o2 = specprofile_ref.profile[1].data # The throughput curves for order 1 and 2. spectrace_ref = ref_files['spectrace'] @@ -106,18 +90,18 @@ def get_ref_file_args(ref_files, transform): return [wavemap_o1, wavemap_o2], [specprofile_o1, specprofile_o2], [throughput_o1, throughput_o2], [kernels_o1, kernels_o2] -def get_trace_1d(ref_files, transform, order, cols=None): +def get_trace_1d(ref_files, order, transform=None, cols=None): """Get the x, y, wavelength of the trace after applying the transform. Parameters ---------- ref_files : dict A dictionary of the reference file DataModels. - transform : array or list + order : int + The spectral order for which to return the trace parameters. + transform : array or list, optional A 3-element list or array describing the rotation and translation to apply to the reference files in order to match the observation. - order : int - The spectral order for which to return the trace parameters. cols : array[int], optional The columns on the detector for which to compute the trace parameters. If not given, all columns will be computed. @@ -133,21 +117,30 @@ def get_trace_1d(ref_files, transform, order, cols=None): else: xtrace = cols - spectrace_ref = ref_files['spectrace'] + pastasoss_ref = ref_files['pastasoss'] # Read x, y, wavelength for the relevant order. - xref = spectrace_ref.trace[order - 1].data['X'] - yref = spectrace_ref.trace[order - 1].data['Y'] - waveref = spectrace_ref.trace[order - 1].data['WAVELENGTH'] + xref = pastasoss_ref.trace[order - 1].data['X'] + yref = pastasoss_ref.trace[order - 1].data['Y'] + waveref = pastasoss_ref.trace[order - 1].data['WAVELENGTH'] + + # No transform necessary is using pastasoss + if transform is None: + xtrace = xref + ytrace = yref + wavetrace = waveref + + # Transform the trace if using a static PWCPOS + else: - # Rotate and shift the positions based on transform. - angle, xshift, yshift = transform - xrot, yrot = transform_coords(angle, xshift, yshift, xref, yref) + # Rotate and shift the positions based on transform. + angle, xshift, yshift = transform + xrot, yrot = transform_coords(angle, xshift, yshift, xref, yref) - # Interpolate y and wavelength to the requested columns. - sort = np.argsort(xrot) - ytrace = np.interp(xtrace, xrot[sort], yrot[sort]) - wavetrace = np.interp(xtrace, xrot[sort], waveref[sort]) + # Interpolate y and wavelength to the requested columns. + sort = np.argsort(xrot) + ytrace = np.interp(xtrace, xrot[sort], yrot[sort]) + wavetrace = np.interp(xtrace, xrot[sort], waveref[sort]) return xtrace, ytrace, wavetrace @@ -1015,7 +1008,7 @@ def extract_image(decontaminated_data, scierr, scimask, box_weights, bad_pix='mo return fluxes, fluxerrs, npixels -def run_extract1d(input_model, spectrace_ref_name, wavemap_ref_name, +def run_extract1d(input_model, pastasoss_ref_name, specprofile_ref_name, speckernel_ref_name, subarray, soss_filter, soss_kwargs): """Run the spectral extraction on NIRISS SOSS data. @@ -1023,10 +1016,8 @@ def run_extract1d(input_model, spectrace_ref_name, wavemap_ref_name, ---------- input_model : DataModel The input DataModel. - spectrace_ref_name : str - Name of the spectrace reference file. - wavemap_ref_name : str - Name of the wavemap reference file. + pastasoss_ref_name : str + Name of the pastasoss reference file. specprofile_ref_name : str Name of the specprofile reference file. speckernel_ref_name : str @@ -1052,14 +1043,12 @@ def run_extract1d(input_model, spectrace_ref_name, wavemap_ref_name, order_str_2_int = {f'Order {order}': order for order in [1, 2, 3]} # Read the reference files. - spectrace_ref = datamodels.SpecTraceModel(spectrace_ref_name) - wavemap_ref = datamodels.WaveMapModel(wavemap_ref_name) + pastasoss_ref = datamodels.PastasossModel(pastasoss_ref_name) specprofile_ref = datamodels.SpecProfileModel(specprofile_ref_name) speckernel_ref = datamodels.SpecKernelModel(speckernel_ref_name) ref_files = dict() - ref_files['spectrace'] = spectrace_ref - ref_files['wavemap'] = wavemap_ref + ref_files['pastasoss'] = pastasoss_ref ref_files['specprofile'] = specprofile_ref ref_files['speckernel'] = speckernel_ref diff --git a/jwst/extract_1d/soss_extract/soss_traces.py b/jwst/extract_1d/soss_extract/soss_traces.py new file mode 100644 index 0000000000..d1f17f10f4 --- /dev/null +++ b/jwst/extract_1d/soss_extract/soss_traces.py @@ -0,0 +1,219 @@ +# Module to predict SOSS trace positions for a given spectral order(s) +# given the a pupil wheel position angle taken from the "PWCPOS" fits +# header keyword. + +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +from scipy.interpolate import interp1d +from pkg_resources import resource_filename + +from pastasoss.wavecal import get_wavecal_meta_for_spectral_order +from pastasoss.wavecal import get_wavelengths + +PWCPOS_CMD = 245.7600 # Commanded PWCPOS for the GR700XD + + +# TODO: order 3 currently unsupport ATM. Will be support in the future: TBD +REFERENCE_TRACE_FILES = { + "order1": "jwst_niriss_gr700xd_order1_trace_refmodel.txt", + "order2": "jwst_niriss_gr700xd_order2_trace_refmodel_002.txt", + # order 3 currently unsupport ATM. Will be support in the future: TBD +} + +REFERENCE_WAVECAL_MODELS = { + "order1": resource_filename( + __name__, "data/jwst_niriss_gr700xd_wavelength_model_order1.json" + ), + "order2": resource_filename( + __name__, "data/jwst_niriss_gr700xd_wavelength_model_order2_002.json" + ), +} + + +@dataclass +class TraceModel: + order: str + x: np.ndarray + y: np.ndarray + wavelength: np.ndarray + + +def rotate( + x: np.ndarray, + y: np.ndarray, + angle: float, + origin: Tuple[float, float] = (0, 0), + interp: bool = True, +) -> Tuple[np.ndarray, np.ndarray]: + """ + Applies a rotation transformation to a set of 2D points. + + Parameters + ---------- + x : np.ndarray + The x-coordinates of the points to be transformed. + y : np.ndarray + The y-coordinates of the points to be transformed. + angle : float + The angle (in degrees) by which to rotate the points. + origin : Tuple[float, float], optional + The point about which to rotate the points. Default is (0, 0). + interp : bool, optional + Whether to interpolate the rotated positions onto the original x-pixel + column values. Default is True. + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + The x and y coordinates of the rotated points. + + Examples + -------- + >>> x = np.array([0, 1, 2, 3]) + >>> y = np.array([0, 1, 2, 3]) + >>> x_rot, y_rot = rotate(x, y, 90) + """ + + # shift to rotate about center + xy_center = np.atleast_2d(origin).T + xy = np.vstack([x, y]) + + # Rotation transform matrix + radians = np.radians(angle) + c, s = np.cos(radians), np.sin(radians) + R = np.array([[c, -s], [s, c]]) + + # apply transformation + x_new, y_new = R @ (xy - xy_center) + xy_center + + # interpolate rotated positions onto x-pixel column values (default) + if interp: + # interpolate new coordinates onto original x values and mask values + # outside of the domain of the image 0<=x<=2047 and 0<=y<=255. + y_new = interp1d(x_new, y_new, fill_value="extrapolate")(x) + mask = np.where(y_new <= 255.0) + x = x[mask] + y_new = y_new[mask] + return x, y_new + + return x_new, y_new + + +def get_reference_trace( + file: str, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Load in the reference trace positions given a file associated with a given + spectral order. + + Parameters + ---------- + file : str + The path to the file containing the reference trace positions. + + Returns + ------- + Tuple[np.ndarray, np.ndarray, np.ndarray] + A tuple containing the x, y positions and origin of the reference + traces. + + Examples + -------- + >>> x, y, origin = get_reference_traces_positions('ref_filename.txt') + """ + filepath = resource_filename(__name__, f"data/{file}") + traces = np.loadtxt(filepath) + origin = traces[0] + x = traces[1:, 0] + y = traces[1:, 1] + return x, y, origin + + +def get_soss_traces( + pwcpos: float, order: str = "123", subarray: str = 'SUBSTRIP256', interp: bool = True +) -> Tuple[np.ndarray, np.ndarray]: + """ + This is the primary method for generate the gr700xd trace position given a + pupil wheel positions angle provided in the FITS header under keyword + PWCPOS. The traces for a given spectral order are found by perform a + rotation transformation using the refence trace positions at the commanded + PWCPOS=245.76 degrees. This methods yield sub-pixel performance and will be + improved upon in later interations as more NIRISS/SOSS observations become + available. + + Parameters + ---------- + pwcpos : float + The pupil wheel positions angle provided in the FITS header under + keyword PWCPOS. + order : str, optional + The spectral order to compute the new traces for. Default is '123'. + Support for order 3 will be added at a later date. + subarray : str + The subarray being used, ['SUBSTRIP96', 'SUBSTRIP256'] + interp : bool, optional + Whether to interpolate the rotated positions onto the original x-pixel + column values. Default is True. + + Returns + ------- + Tuple[np.ndarray, np.ndarray]] + If `order` is '1', a tuple of the x and y coordinates of the rotated + points for the first spectral order. + If `order` is '2', a tuple of the x and y coordinates of the rotated + points for the second spectral order. + If `order` is '3' or a combination of '1', '2', and '3', a list of + tuples of the x and y coordinates of the rotated points for each + spectral order. + + Raises + ------ + ValueError + If `order` is not '1', '2', '3', or a combination of '1', '2', and '3'. + + Examples + -------- + >>> x_new, y_new = get_trace_from_reference_transform(2.3) + """ + + norders = len(order) + + if norders > 3: + raise ValueError("order must be: 1,2, 3.") + if norders > 1: + # recursively compute the new traces for each order + return [get_soss_traces(pwcpos, m, subarray) for m in order] + + # This might be an alternative way of writing this + # if 'order'+order in REFERENCE_TRACE_FILES.keys(): + # ref_file = REFERENCE_TRACE_FILES["order"+order] + # This section can definite be refactored in a later version + elif order == "1": + ref_trace_file = REFERENCE_TRACE_FILES["order1"] + wave_cal_model_meta = get_wavecal_meta_for_spectral_order("order1") + + elif order == "2": + ref_trace_file = REFERENCE_TRACE_FILES["order2"] + wave_cal_model_meta = get_wavecal_meta_for_spectral_order("order2") + + elif order == "3": + print("The software currently does not support order 3 at this time.") + return None + + # reference trace data + x, y, origin = get_reference_trace(ref_trace_file) + + # Offset for SUBSTRIP96 + if subarray == 'SUBSTRIP96': + y -= 10 + + # rotated reference trace + x_new, y_new = rotate(x, y, pwcpos - PWCPOS_CMD, origin, interp=interp) + + # wavelength associated to trace at given pwcpos value + wavelengths = get_wavelengths(x_new, pwcpos, wave_cal_model_meta) + + # return x_new, y_new, wavelengths + return TraceModel(order, x_new, y_new, wavelengths) diff --git a/jwst/extract_1d/soss_extract/soss_wavemaps.py b/jwst/extract_1d/soss_extract/soss_wavemaps.py new file mode 100644 index 0000000000..0429b97c34 --- /dev/null +++ b/jwst/extract_1d/soss_extract/soss_wavemaps.py @@ -0,0 +1,220 @@ +""" +Module to generate 2D SOSS wavemap arrays from 1D wavelength solutions + +Author: Joe Filippazzo +Date: 02/28/2024 +Usage: +import pastasoss +pwcpos = 245.8 +wavemaps = pastasoss.get_soss_wavemaps(pwcpos) +""" + +import numpy as np + +from pastasoss.soss_traces import get_soss_traces, PWCPOS_CMD + + +def extrapolate_to_wavegrid(w_grid, wavelength, quantity): + """ + Extrapolates quantities on the right and the left of a given array of quantity + + Parameters + ---------- + w_grid : sequence + The wavelength grid to interpolate onto + wavelength : sequence + The native wavelength values of the data + quantity : sequence + The data to interpolate + + Returns + ------- + Array + The interpolated quantities + """ + sorted = np.argsort(wavelength) + q = quantity[sorted] + w = wavelength[sorted] + + # Determine the slope on the right of the array + slope_right = (q[-1] - q[-2]) / (w[-1] - w[-2]) + # extrapolate at wavelengths larger than the max on the right + indright = np.where(w_grid > w[-1])[0] + q_right = q[-1] + (w_grid[indright] - w[-1]) * slope_right + # Determine the slope on the left of the array + slope_left = (q[1] - q[0]) / (w[1] - w[0]) + # extrapolate at wavelengths smaller than the min on the left + indleft = np.where(w_grid < w[0])[0] + q_left = q[0] + (w_grid[indleft] - w[0]) * slope_left + # Construct and extrapolated array of the quantity + w = np.concatenate((w_grid[indleft], w, w_grid[indright])) + q = np.concatenate((q_left, q, q_right)) + + # resample at the w_grid everywhere + q_grid = np.interp(w_grid, w, q) + + return q_grid + + +def calc_2d_wave_map(wave_grid, x_dms, y_dms, tilt, oversample=2, padding=10, maxiter=5, dtol=1e-2): + """Compute the 2D wavelength map on the detector. + + Parameters + ---------- + wave_grid : sequence + The wavelength corresponding to the x_dms, y_dms, and tilt values. + x_dms : sequence + The trace x position on the detector in DMS coordinates. + y_dms : sequence + The trace y position on the detector in DMS coordinates. + tilt : sequence + The trace tilt angle in degrees. + oversample : int + The oversampling factor of the input coordinates. + padding : int + The native pixel padding around the edge of the detector. + maxiter : int + The maximum number of iterations used when solving for the wavelength at each pixel. + dtol : float + The tolerance of the iterative solution in pixels. + + Returns + ------- + Array + An array containing the wavelength at each pixel on the detector. + """ + os = np.copy(oversample) + xpad = np.copy(padding) + ypad = np.copy(padding) + + # No need to compute wavelengths across the entire detector, slightly larger than SUBSTRIP256 will do. + dimx, dimy = 2048, 300 + y_dms = y_dms + (dimy - 2048) # Adjust y-coordinate to area of interest. + + # Generate the oversampled grid of pixel coordinates. + x_vec = np.arange((dimx + 2 * xpad) * os) / os - (os - 1) / (2 * os) - xpad + y_vec = np.arange((dimy + 2 * ypad) * os) / os - (os - 1) / (2 * os) - ypad + x_grid, y_grid = np.meshgrid(x_vec, y_vec) + + # Iteratively compute the wavelength at each pixel. + delta_x = 0.0 # A shift in x represents a shift in wavelength. + for niter in range(maxiter): + + # Assume all y have same wavelength. + wave_iterated = np.interp(x_grid - delta_x, x_dms[::-1], + wave_grid[::-1]) # Invert arrays to get increasing x. + + # Compute the tilt angle at the wavelengths. + tilt_tmp = np.interp(wave_iterated, wave_grid, tilt) + + # Compute the trace position at the wavelengths. + x_estimate = np.interp(wave_iterated, wave_grid, x_dms) + y_estimate = np.interp(wave_iterated, wave_grid, y_dms) + + # Project that back to pixel coordinates. + x_iterated = x_estimate + (y_grid - y_estimate) * np.tan(np.deg2rad(tilt_tmp)) + + # Measure error between requested and iterated position. + delta_x = delta_x + (x_iterated - x_grid) + + # If the desired precision has been reached end iterations. + if np.all(np.abs(x_iterated - x_grid) < dtol): + break + + # Evaluate the final wavelength map, this time setting out-of-bounds values to NaN. + wave_map_2d = np.interp(x_grid - delta_x, x_dms[::-1], wave_grid[::-1], left=np.nan, right=np.nan) + + # Extend to full detector size. + tmp = np.full((os * (dimx + 2 * xpad), os * (dimx + 2 * xpad)), fill_value=np.nan) + tmp[-os * (dimy + 2 * ypad):] = wave_map_2d + wave_map_2d = tmp + + return wave_map_2d + + +def get_soss_wavemaps(pwcpos=PWCPOS_CMD, subarray='SUBSTRIP256', padding=False, padsize=20, spectraces=False): + """ + Generate order 1 and 2 2D wavemaps from the rotated SOSS trace positions + + Parameters + ---------- + pwcpos : float + The pupil wheel position + subarray: str + The subarray name, ['FULL', 'SUBSTRIP256', 'SUBSTRIP96'] + padding : bool + Include padding on map edges (only needed for reference files) + padsize: int + The size of the padding to include on each side + spectraces : bool + Return the interpolated spectraces as well + + Returns + ------- + Array, Array + The 2D wavemaps and corresponding 1D spectraces + """ + traces_order1, traces_order2 = get_soss_traces(pwcpos=pwcpos, order='12', subarray=subarray, interp=True) + + # Make wavemap from trace center wavelengths, padding to shape (296, 2088) + wavemin = 0.5 + wavemax = 5.5 + nwave = 5001 + wave_grid = np.linspace(wavemin, wavemax, nwave) + + # Extrapolate wavelengths for order 1 trace + xtrace_order1 = extrapolate_to_wavegrid(wave_grid, traces_order1.wavelength, traces_order1.x) + ytrace_order1 = extrapolate_to_wavegrid(wave_grid, traces_order1.wavelength, traces_order1.y) + spectrace_1 = np.array([xtrace_order1, ytrace_order1, wave_grid]) + + # Set cutoff for order 2 where it runs off the detector + o2_cutoff = 1783 + w_o2_tmp = traces_order2.wavelength[:o2_cutoff] + w_o2 = np.zeros(2040) * np.nan + w_o2[:o2_cutoff] = w_o2_tmp + y_o2_tmp = traces_order2.y[:o2_cutoff] + y_o2 = np.zeros(2040) * np.nan + y_o2[:o2_cutoff] = y_o2_tmp + x_o2 = np.copy(traces_order1.x) + + # Fill for column > 1400 with linear extrapolation + m = w_o2[o2_cutoff - 1] - w_o2[o2_cutoff - 2] + dx = np.arange(2040 - o2_cutoff) + 1 + w_o2[o2_cutoff:] = w_o2[o2_cutoff - 1] + m * dx + m = y_o2[o2_cutoff - 1] - y_o2[o2_cutoff - 2] + dx = np.arange(2040 - o2_cutoff) + 1 + y_o2[o2_cutoff:] = y_o2[o2_cutoff - 1] + m * dx + + # Extrapolate wavelengths for order 2 trace + xtrace_order2 = extrapolate_to_wavegrid(wave_grid, w_o2, x_o2) + ytrace_order2 = extrapolate_to_wavegrid(wave_grid, w_o2, y_o2) + spectrace_2 = np.array([xtrace_order2, ytrace_order2, wave_grid]) + + # Make wavemap from wavelength solution for order 1 + wavemap_1 = calc_2d_wave_map(wave_grid, xtrace_order1, ytrace_order1, np.zeros_like(xtrace_order1), oversample=1, padding=padsize) + + # Make wavemap from wavelength solution for order 2 + wavemap_2 = calc_2d_wave_map(wave_grid, xtrace_order2, ytrace_order2, np.zeros_like(xtrace_order2), oversample=1, padding=padsize) + + # Extrapolate wavemap to FULL frame + wavemap_1[:-256 - padsize, :] = wavemap_1[-256 - padsize] + wavemap_2[:-256 - padsize, :] = wavemap_2[-256 - padsize] + + # Trim to subarray + if subarray == 'SUBSTRIP256': + wavemap_1 = wavemap_1[1792 - padsize:2048 + padsize, :] + wavemap_2 = wavemap_2[1792 - padsize:2048 + padsize, :] + if subarray == 'SUBSTRIP96': + wavemap_1 = wavemap_1[1792 - padsize:1792 + 96 + padsize, :] + wavemap_2 = wavemap_2[1792 - padsize:1792 + 96 + padsize, :] + + # Remove padding if necessary + if not padding: + wavemap_1 = wavemap_1[padsize:-padsize, padsize:-padsize] + wavemap_2 = wavemap_2[padsize:-padsize, padsize:-padsize] + + if spectraces: + return np.array([wavemap_1, wavemap_2]), np.array([spectrace_1, spectrace_2]) + + else: + return np.array([wavemap_1, wavemap_2]) diff --git a/jwst/extract_1d/soss_extract/wavecal.py b/jwst/extract_1d/soss_extract/wavecal.py new file mode 100644 index 0000000000..f885c6f28f --- /dev/null +++ b/jwst/extract_1d/soss_extract/wavecal.py @@ -0,0 +1,290 @@ +# Wavecal module for handling the wavelength calibration model. For now we +# will use the reference json files that has all the model information. Ideally +# we would like to use just the train model for production but for now this is +# this is fine. + +import json +from dataclasses import dataclass +from functools import partial +from typing import Any, Dict + +import numpy as np +import numpy.typing as npt +from pkg_resources import resource_filename + +# explicitly defining the commanded position here as well (this is temp) +PWCPOS_CMD = 245.76 + +# # TODO: order 3 currently unsupport ATM. Will be support in the future: TBD +# REFERENCE_WAVECAL_MODELS = { +# "order1": resource_filename( +# __name__, "data/jwst_niriss_gr700xd_wavelength_model_order1.json" +# ), +# "order2": resource_filename( +# __name__, "data/jwst_niriss_gr700xd_wavelength_model_order2_002.json" +# ), +# } + + +# dataclass object to store the metadata from a wavecal model +@dataclass +class NIRISS_GR700XD_WAVECAL_META: + """ + Datamodel object container for the wavecal meta data in the json reference + file. + """ + + order: str + coefficients: npt.NDArray[np.float64] + intercept: npt.NDArray[np.float64] + scaler_data_min_: float + scaler_data_max_: float + poly_degree: int + + +def load_wavecal_model(filename: str) -> Dict[str, Any]: + """ + Load a wavecalibration model from a JSON file. + + Parameters + ---------- + filename : str + The path to the JSON file containing the wavecalibration model. + + Returns + ------- + dict + A dictionary representing the loaded wavecalibration model. + + Notes + ----- + This function reads the specified JSON file and returns its contents as a + dictionary, which typically contains information about a wavecalibration + model used in data analysis. + """ + with open(filename, "r") as file: + wavecal_model = json.load(file) + return wavecal_model + + +def get_wavecal_meta_for_spectral_order( + order: str, +) -> NIRISS_GR700XD_WAVECAL_META: + """ + Retrieve wavecalibration model metadata for a specific spectral order. + + Parameters + ---------- + order : str + The spectral order for which wavecalibration metadata is requested. + Valid options are 'order 1', 'order 2', or 'order 3'. + + Returns + ------- + NIRISS_GR700XD_WAVECAL_META + An object containing wavecalibration metadata, including order, model + coefficients, intercept, and scaler data bounds. + + Raises + ------ + ValueError + If the provided 'order' is not one of the valid options. + + Notes + ----- + This function retrieves wavecalibration metadata for the specified spectral + order. It first checks if the 'order' is valid and then loads the + corresponding wavecalibration model. The model's coefficients, intercept, + and scaler data bounds are extracted and returned as part of the metadata + object. + """ + # get the reference wavecal file name + valid_orders = ['order 1', 'order 2'] + if order not in valid_orders: + raise ValueError(f"valid orders are: {valid_orders}.") + + # get the appropiate reference file name given the order + reference_filename = REFERENCE_WAVECAL_MODELS[order] + + # load in the model + wavecal_model = load_wavecal_model(reference_filename) + + # model coefficients + poly_degree = wavecal_model["model"]["poly_deg"] + coefficients = wavecal_model["model"]["coef"] + intercept = wavecal_model["model"]["intercept"] + + # info for scaling inputs + scaler_data_min_ = wavecal_model["model"]["scaler"]["data_min_"] + scaler_data_max_ = wavecal_model["model"]["scaler"]["data_max_"] + + return NIRISS_GR700XD_WAVECAL_META( + order, coefficients, intercept, scaler_data_min_, scaler_data_max_, poly_degree + ) + + +def get_wavelengths( + x: np.ndarray, pwcpos: float, wavecal_meta: NIRISS_GR700XD_WAVECAL_META +) -> np.ndarray: + """Get the associated wavelength values for a given spectral order""" + if wavecal_meta.order == "order1": + wavelengths = wavecal_model_order1_poly(x, pwcpos, wavecal_meta) + elif wavecal_meta.order == "order2": + # raise NotImplementedError("Order 2 not implemented") + wavelengths = wavecal_model_order2_poly(x, pwcpos, wavecal_meta) + elif wavecal_meta.order == "order3": + raise ValueError("Order 3 not supported at this time") + else: + raise ValueError("not a valid order") + + return wavelengths + + +def min_max_scaler(x, x_min, x_max): + """ + Apply min-max scaling to input values. + + Parameters + ---------- + x : float or numpy.ndarray + The input value(s) to be scaled. + x_min : float + The minimum value in the range to which 'x' will be scaled. + x_max : float + The maximum value in the range to which 'x' will be scaled. + + Returns + ------- + float or numpy.ndarray + The scaled value(s) in the range [0, 1]. + + Notes + ----- + Min-max scaling is a data normalization technique that scales input values + 'x' to the range [0, 1] based on the provided minimum and maximum values, + 'x_min' and 'x_max'. This function is applicable to both individual values + and arrays of values. This function will use the min/max values from the + training data of the wavecal model. + """ + # scaling the input x values + x_scaled = (x - x_min) / (x_max - x_min) + return x_scaled + + +def wavecal_model_order1_poly(x, pwcpos, wavecal_meta: NIRISS_GR700XD_WAVECAL_META): + """compute order 1 wavelengths""" + x_scaler = partial( + min_max_scaler, + **{ + "x_min": wavecal_meta.scaler_data_min_[0], + "x_max": wavecal_meta.scaler_data_max_[0], + }, + ) + + pwcpos_offset_scaler = partial( + min_max_scaler, + **{ + "x_min": wavecal_meta.scaler_data_min_[1], + "x_max": wavecal_meta.scaler_data_max_[1], + }, + ) + + def get_poly_features(x: np.array, offset: np.array) -> np.ndarray: + """polynomial features for the order 1 wavecal model""" + poly_features = np.array( + [ + x, + offset, + x**2, + x * offset, + offset**2, + x**3, + x**2 * offset, + x * offset**2, + offset**3, + x**4, + x**3 * offset, + x**2 * offset**2, + x * offset**3, + offset**4, + x**5, + x**4 * offset, + x**3 * offset**2, + x**2 * offset**3, + x * offset**4, + offset**5, + ] + ) + return poly_features + + # extract model weights and intercept + coef = wavecal_meta.coefficients + intercept = wavecal_meta.intercept + + # get pixel columns and then scaled + x_scaled = x_scaler(x) + + # offset + offset = np.ones_like(x) * (pwcpos - PWCPOS_CMD) + offset_scaled = pwcpos_offset_scaler(offset) + + # polynomial features + poly_features = get_poly_features(x_scaled, offset_scaled) + wavelengths = coef @ poly_features + intercept + + return wavelengths + + +def wavecal_model_order2_poly(x, pwcpos, wavecal_meta: NIRISS_GR700XD_WAVECAL_META): + """compute order 2 wavelengths""" + x_scaler = partial( + min_max_scaler, + **{ + "x_min": wavecal_meta.scaler_data_min_[0], + "x_max": wavecal_meta.scaler_data_max_[0], + }, + ) + + pwcpos_offset_scaler = partial( + min_max_scaler, + **{ + "x_min": wavecal_meta.scaler_data_min_[1], + "x_max": wavecal_meta.scaler_data_max_[1], + }, + ) + + def get_poly_features(x: np.array, offset: np.array) -> np.ndarray: + """Polynomial features for the order 2 wavecal model""" + poly_features = np.array( + [ + x, + offset, + x**2, + x * offset, + offset**2, + x**3, + x**2 * offset, + x * offset**2, + offset**3, + ] + ) + return poly_features + + # coef and intercept + coef = wavecal_meta.coefficients + intercept = wavecal_meta.intercept + + # get pixel columns and then scaled + x_scaled = x_scaler(x) + + # offset + # offset = np.ones_like(x) * (pwcpos - PWCPOS_CMD) + # this will need to get changed later... + offset = np.ones_like(x) * pwcpos + offset_scaled = pwcpos_offset_scaler(offset) + + # polynomial features + poly_features = get_poly_features(x_scaled, offset_scaled) + wavelengths = coef @ poly_features + intercept + + return wavelengths