diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000000..7e5a9ce091
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,250 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000..105ce2da2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/jwst.iml b/.idea/jwst.iml
new file mode 100644
index 0000000000..524fd84d4c
--- /dev/null
+++ b/.idea/jwst.iml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000000..558e5ddba0
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000000..79fc4059f9
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000000..94a25f7f4c
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
new file mode 100644
index 0000000000..880d93f233
--- /dev/null
+++ b/.idea/workspace.xml
@@ -0,0 +1,37 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1681931910213
+
+
+ 1681931910213
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/jwst/extract_1d/extract_1d_step.py b/jwst/extract_1d/extract_1d_step.py
index 44c8f71ca2..ac6b99a7dc 100644
--- a/jwst/extract_1d/extract_1d_step.py
+++ b/jwst/extract_1d/extract_1d_step.py
@@ -426,8 +426,7 @@ def process(self, input):
return input_model
# Load reference files.
- spectrace_ref_name = self.get_reference_file(input_model, 'spectrace')
- wavemap_ref_name = self.get_reference_file(input_model, 'wavemap')
+ pastasoss_ref_name = self.get_reference_file(input_model, 'pastasoss') # Add support for "pastasoss" reffile
specprofile_ref_name = self.get_reference_file(input_model, 'specprofile')
speckernel_ref_name = self.get_reference_file(input_model, 'speckernel')
@@ -452,8 +451,7 @@ def process(self, input):
# Run the extraction.
result, ref_outputs, atoca_outputs = soss_extract.run_extract1d(
input_model,
- spectrace_ref_name,
- wavemap_ref_name,
+ pastasoss_ref_name,
specprofile_ref_name,
speckernel_ref_name,
subarray,
diff --git a/jwst/extract_1d/soss_extract/soss_extract.py b/jwst/extract_1d/soss_extract/soss_extract.py
index 1a23994197..127b6a53e7 100644
--- a/jwst/extract_1d/soss_extract/soss_extract.py
+++ b/jwst/extract_1d/soss_extract/soss_extract.py
@@ -13,6 +13,7 @@
from .soss_syscor import make_background_mask, soss_background
from .soss_solver import solve_transform, transform_wavemap, transform_profile, transform_coords
from .atoca import ExtractionEngine, MaskOverlapError
+from .soss_wavemaps import get_soss_wavemaps
from .atoca_utils import (ThroughputSOSS, WebbKernel, grid_from_map, mask_bad_dispersion_direction,
make_combined_adaptive_grid, get_wave_p_or_m, oversample_grid)
from .soss_boxextract import get_box_weights, box_extract, estim_error_nearest_data
@@ -21,15 +22,12 @@
log.setLevel(logging.DEBUG)
-def get_ref_file_args(ref_files, transform):
+def get_ref_file_args(ref_files):
"""Prepare the reference files for the extraction engine.
Parameters
----------
ref_files : dict
A dictionary of the reference file DataModels.
- transform : array or list
- A 3-element array describing the rotation and translation to apply
- to the reference files in order to match the observation.
Returns
-------
@@ -39,32 +37,18 @@ def get_ref_file_args(ref_files, transform):
"""
# The wavelength maps for order 1 and 2.
- wavemap_ref = ref_files['wavemap']
+ pastasoss_ref = ref_files['pastasoss']
+ pwcpos = pastasoss_ref.map[0].pwcpos # Make sure the pastasoss_ref object has these attrs
+ subarray = pastasoss_ref.map[0].subarray # Make sure the pastasoss_ref object has these attrs
+ pad = pastasoss_ref.map[0].padding # Make sure the pastasoss_ref object has these attrs
- ovs = wavemap_ref.map[0].oversampling
- pad = wavemap_ref.map[0].padding
+ # Use pastasoss to generate the appropriate wavemap for the given PWCPOS
+ (wavemap_o1, wavemap_o2), (spectrace_o1, spectrace_o2) = get_soss_wavemaps(pwcpos=pwcpos, subarray=subarray, padding=True, padsize=pad)
- wavemap_o1 = transform_wavemap(transform, wavemap_ref.map[0].data, ovs, pad)
- wavemap_o2 = transform_wavemap(transform, wavemap_ref.map[1].data, ovs, pad)
-
- # Make sure all pixels follow the expected direction of the dispersion
- wavemap_o1, flag_o1 = mask_bad_dispersion_direction(wavemap_o1)
- wavemap_o2, flag_o2 = mask_bad_dispersion_direction(wavemap_o2)
-
- # Warn if not all pixels were corrected
- msg_warning = 'Some pixels in order {} do not follow the expected dispersion axis'
- if not flag_o1:
- log.warning(msg_warning.format(1))
- if not flag_o2:
- log.warning(msg_warning.format(2))
-
- # The spectral profiles for order 1 and 2.
+ # The spectral profiles for order 1 and 2 (no transform necessary using pastasoss)
specprofile_ref = ref_files['specprofile']
- ovs = specprofile_ref.profile[0].oversampling
- pad = specprofile_ref.profile[0].padding
-
- specprofile_o1 = transform_profile(transform, specprofile_ref.profile[0].data, ovs, pad, norm=False)
- specprofile_o2 = transform_profile(transform, specprofile_ref.profile[1].data, ovs, pad, norm=False)
+ specprofile_o1 = specprofile_ref.profile[0].data
+ specprofile_o2 = specprofile_ref.profile[1].data
# The throughput curves for order 1 and 2.
spectrace_ref = ref_files['spectrace']
@@ -106,18 +90,18 @@ def get_ref_file_args(ref_files, transform):
return [wavemap_o1, wavemap_o2], [specprofile_o1, specprofile_o2], [throughput_o1, throughput_o2], [kernels_o1, kernels_o2]
-def get_trace_1d(ref_files, transform, order, cols=None):
+def get_trace_1d(ref_files, order, transform=None, cols=None):
"""Get the x, y, wavelength of the trace after applying the transform.
Parameters
----------
ref_files : dict
A dictionary of the reference file DataModels.
- transform : array or list
+ order : int
+ The spectral order for which to return the trace parameters.
+ transform : array or list, optional
A 3-element list or array describing the rotation and translation
to apply to the reference files in order to match the
observation.
- order : int
- The spectral order for which to return the trace parameters.
cols : array[int], optional
The columns on the detector for which to compute the trace
parameters. If not given, all columns will be computed.
@@ -133,21 +117,30 @@ def get_trace_1d(ref_files, transform, order, cols=None):
else:
xtrace = cols
- spectrace_ref = ref_files['spectrace']
+ pastasoss_ref = ref_files['pastasoss']
# Read x, y, wavelength for the relevant order.
- xref = spectrace_ref.trace[order - 1].data['X']
- yref = spectrace_ref.trace[order - 1].data['Y']
- waveref = spectrace_ref.trace[order - 1].data['WAVELENGTH']
+ xref = pastasoss_ref.trace[order - 1].data['X']
+ yref = pastasoss_ref.trace[order - 1].data['Y']
+ waveref = pastasoss_ref.trace[order - 1].data['WAVELENGTH']
+
+ # No transform necessary is using pastasoss
+ if transform is None:
+ xtrace = xref
+ ytrace = yref
+ wavetrace = waveref
+
+ # Transform the trace if using a static PWCPOS
+ else:
- # Rotate and shift the positions based on transform.
- angle, xshift, yshift = transform
- xrot, yrot = transform_coords(angle, xshift, yshift, xref, yref)
+ # Rotate and shift the positions based on transform.
+ angle, xshift, yshift = transform
+ xrot, yrot = transform_coords(angle, xshift, yshift, xref, yref)
- # Interpolate y and wavelength to the requested columns.
- sort = np.argsort(xrot)
- ytrace = np.interp(xtrace, xrot[sort], yrot[sort])
- wavetrace = np.interp(xtrace, xrot[sort], waveref[sort])
+ # Interpolate y and wavelength to the requested columns.
+ sort = np.argsort(xrot)
+ ytrace = np.interp(xtrace, xrot[sort], yrot[sort])
+ wavetrace = np.interp(xtrace, xrot[sort], waveref[sort])
return xtrace, ytrace, wavetrace
@@ -1015,7 +1008,7 @@ def extract_image(decontaminated_data, scierr, scimask, box_weights, bad_pix='mo
return fluxes, fluxerrs, npixels
-def run_extract1d(input_model, spectrace_ref_name, wavemap_ref_name,
+def run_extract1d(input_model, pastasoss_ref_name,
specprofile_ref_name, speckernel_ref_name, subarray,
soss_filter, soss_kwargs):
"""Run the spectral extraction on NIRISS SOSS data.
@@ -1023,10 +1016,8 @@ def run_extract1d(input_model, spectrace_ref_name, wavemap_ref_name,
----------
input_model : DataModel
The input DataModel.
- spectrace_ref_name : str
- Name of the spectrace reference file.
- wavemap_ref_name : str
- Name of the wavemap reference file.
+ pastasoss_ref_name : str
+ Name of the pastasoss reference file.
specprofile_ref_name : str
Name of the specprofile reference file.
speckernel_ref_name : str
@@ -1052,14 +1043,12 @@ def run_extract1d(input_model, spectrace_ref_name, wavemap_ref_name,
order_str_2_int = {f'Order {order}': order for order in [1, 2, 3]}
# Read the reference files.
- spectrace_ref = datamodels.SpecTraceModel(spectrace_ref_name)
- wavemap_ref = datamodels.WaveMapModel(wavemap_ref_name)
+ pastasoss_ref = datamodels.PastasossModel(pastasoss_ref_name)
specprofile_ref = datamodels.SpecProfileModel(specprofile_ref_name)
speckernel_ref = datamodels.SpecKernelModel(speckernel_ref_name)
ref_files = dict()
- ref_files['spectrace'] = spectrace_ref
- ref_files['wavemap'] = wavemap_ref
+ ref_files['pastasoss'] = pastasoss_ref
ref_files['specprofile'] = specprofile_ref
ref_files['speckernel'] = speckernel_ref
diff --git a/jwst/extract_1d/soss_extract/soss_traces.py b/jwst/extract_1d/soss_extract/soss_traces.py
new file mode 100644
index 0000000000..d1f17f10f4
--- /dev/null
+++ b/jwst/extract_1d/soss_extract/soss_traces.py
@@ -0,0 +1,219 @@
+# Module to predict SOSS trace positions for a given spectral order(s)
+# given the a pupil wheel position angle taken from the "PWCPOS" fits
+# header keyword.
+
+from dataclasses import dataclass
+from typing import Tuple
+
+import numpy as np
+from scipy.interpolate import interp1d
+from pkg_resources import resource_filename
+
+from pastasoss.wavecal import get_wavecal_meta_for_spectral_order
+from pastasoss.wavecal import get_wavelengths
+
+PWCPOS_CMD = 245.7600 # Commanded PWCPOS for the GR700XD
+
+
+# TODO: order 3 currently unsupport ATM. Will be support in the future: TBD
+REFERENCE_TRACE_FILES = {
+ "order1": "jwst_niriss_gr700xd_order1_trace_refmodel.txt",
+ "order2": "jwst_niriss_gr700xd_order2_trace_refmodel_002.txt",
+ # order 3 currently unsupport ATM. Will be support in the future: TBD
+}
+
+REFERENCE_WAVECAL_MODELS = {
+ "order1": resource_filename(
+ __name__, "data/jwst_niriss_gr700xd_wavelength_model_order1.json"
+ ),
+ "order2": resource_filename(
+ __name__, "data/jwst_niriss_gr700xd_wavelength_model_order2_002.json"
+ ),
+}
+
+
+@dataclass
+class TraceModel:
+ order: str
+ x: np.ndarray
+ y: np.ndarray
+ wavelength: np.ndarray
+
+
+def rotate(
+ x: np.ndarray,
+ y: np.ndarray,
+ angle: float,
+ origin: Tuple[float, float] = (0, 0),
+ interp: bool = True,
+) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Applies a rotation transformation to a set of 2D points.
+
+ Parameters
+ ----------
+ x : np.ndarray
+ The x-coordinates of the points to be transformed.
+ y : np.ndarray
+ The y-coordinates of the points to be transformed.
+ angle : float
+ The angle (in degrees) by which to rotate the points.
+ origin : Tuple[float, float], optional
+ The point about which to rotate the points. Default is (0, 0).
+ interp : bool, optional
+ Whether to interpolate the rotated positions onto the original x-pixel
+ column values. Default is True.
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray]
+ The x and y coordinates of the rotated points.
+
+ Examples
+ --------
+ >>> x = np.array([0, 1, 2, 3])
+ >>> y = np.array([0, 1, 2, 3])
+ >>> x_rot, y_rot = rotate(x, y, 90)
+ """
+
+ # shift to rotate about center
+ xy_center = np.atleast_2d(origin).T
+ xy = np.vstack([x, y])
+
+ # Rotation transform matrix
+ radians = np.radians(angle)
+ c, s = np.cos(radians), np.sin(radians)
+ R = np.array([[c, -s], [s, c]])
+
+ # apply transformation
+ x_new, y_new = R @ (xy - xy_center) + xy_center
+
+ # interpolate rotated positions onto x-pixel column values (default)
+ if interp:
+ # interpolate new coordinates onto original x values and mask values
+ # outside of the domain of the image 0<=x<=2047 and 0<=y<=255.
+ y_new = interp1d(x_new, y_new, fill_value="extrapolate")(x)
+ mask = np.where(y_new <= 255.0)
+ x = x[mask]
+ y_new = y_new[mask]
+ return x, y_new
+
+ return x_new, y_new
+
+
+def get_reference_trace(
+ file: str,
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Load in the reference trace positions given a file associated with a given
+ spectral order.
+
+ Parameters
+ ----------
+ file : str
+ The path to the file containing the reference trace positions.
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray, np.ndarray]
+ A tuple containing the x, y positions and origin of the reference
+ traces.
+
+ Examples
+ --------
+ >>> x, y, origin = get_reference_traces_positions('ref_filename.txt')
+ """
+ filepath = resource_filename(__name__, f"data/{file}")
+ traces = np.loadtxt(filepath)
+ origin = traces[0]
+ x = traces[1:, 0]
+ y = traces[1:, 1]
+ return x, y, origin
+
+
+def get_soss_traces(
+ pwcpos: float, order: str = "123", subarray: str = 'SUBSTRIP256', interp: bool = True
+) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ This is the primary method for generate the gr700xd trace position given a
+ pupil wheel positions angle provided in the FITS header under keyword
+ PWCPOS. The traces for a given spectral order are found by perform a
+ rotation transformation using the refence trace positions at the commanded
+ PWCPOS=245.76 degrees. This methods yield sub-pixel performance and will be
+ improved upon in later interations as more NIRISS/SOSS observations become
+ available.
+
+ Parameters
+ ----------
+ pwcpos : float
+ The pupil wheel positions angle provided in the FITS header under
+ keyword PWCPOS.
+ order : str, optional
+ The spectral order to compute the new traces for. Default is '123'.
+ Support for order 3 will be added at a later date.
+ subarray : str
+ The subarray being used, ['SUBSTRIP96', 'SUBSTRIP256']
+ interp : bool, optional
+ Whether to interpolate the rotated positions onto the original x-pixel
+ column values. Default is True.
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray]]
+ If `order` is '1', a tuple of the x and y coordinates of the rotated
+ points for the first spectral order.
+ If `order` is '2', a tuple of the x and y coordinates of the rotated
+ points for the second spectral order.
+ If `order` is '3' or a combination of '1', '2', and '3', a list of
+ tuples of the x and y coordinates of the rotated points for each
+ spectral order.
+
+ Raises
+ ------
+ ValueError
+ If `order` is not '1', '2', '3', or a combination of '1', '2', and '3'.
+
+ Examples
+ --------
+ >>> x_new, y_new = get_trace_from_reference_transform(2.3)
+ """
+
+ norders = len(order)
+
+ if norders > 3:
+ raise ValueError("order must be: 1,2, 3.")
+ if norders > 1:
+ # recursively compute the new traces for each order
+ return [get_soss_traces(pwcpos, m, subarray) for m in order]
+
+ # This might be an alternative way of writing this
+ # if 'order'+order in REFERENCE_TRACE_FILES.keys():
+ # ref_file = REFERENCE_TRACE_FILES["order"+order]
+ # This section can definite be refactored in a later version
+ elif order == "1":
+ ref_trace_file = REFERENCE_TRACE_FILES["order1"]
+ wave_cal_model_meta = get_wavecal_meta_for_spectral_order("order1")
+
+ elif order == "2":
+ ref_trace_file = REFERENCE_TRACE_FILES["order2"]
+ wave_cal_model_meta = get_wavecal_meta_for_spectral_order("order2")
+
+ elif order == "3":
+ print("The software currently does not support order 3 at this time.")
+ return None
+
+ # reference trace data
+ x, y, origin = get_reference_trace(ref_trace_file)
+
+ # Offset for SUBSTRIP96
+ if subarray == 'SUBSTRIP96':
+ y -= 10
+
+ # rotated reference trace
+ x_new, y_new = rotate(x, y, pwcpos - PWCPOS_CMD, origin, interp=interp)
+
+ # wavelength associated to trace at given pwcpos value
+ wavelengths = get_wavelengths(x_new, pwcpos, wave_cal_model_meta)
+
+ # return x_new, y_new, wavelengths
+ return TraceModel(order, x_new, y_new, wavelengths)
diff --git a/jwst/extract_1d/soss_extract/soss_wavemaps.py b/jwst/extract_1d/soss_extract/soss_wavemaps.py
new file mode 100644
index 0000000000..0429b97c34
--- /dev/null
+++ b/jwst/extract_1d/soss_extract/soss_wavemaps.py
@@ -0,0 +1,220 @@
+"""
+Module to generate 2D SOSS wavemap arrays from 1D wavelength solutions
+
+Author: Joe Filippazzo
+Date: 02/28/2024
+Usage:
+import pastasoss
+pwcpos = 245.8
+wavemaps = pastasoss.get_soss_wavemaps(pwcpos)
+"""
+
+import numpy as np
+
+from pastasoss.soss_traces import get_soss_traces, PWCPOS_CMD
+
+
+def extrapolate_to_wavegrid(w_grid, wavelength, quantity):
+ """
+ Extrapolates quantities on the right and the left of a given array of quantity
+
+ Parameters
+ ----------
+ w_grid : sequence
+ The wavelength grid to interpolate onto
+ wavelength : sequence
+ The native wavelength values of the data
+ quantity : sequence
+ The data to interpolate
+
+ Returns
+ -------
+ Array
+ The interpolated quantities
+ """
+ sorted = np.argsort(wavelength)
+ q = quantity[sorted]
+ w = wavelength[sorted]
+
+ # Determine the slope on the right of the array
+ slope_right = (q[-1] - q[-2]) / (w[-1] - w[-2])
+ # extrapolate at wavelengths larger than the max on the right
+ indright = np.where(w_grid > w[-1])[0]
+ q_right = q[-1] + (w_grid[indright] - w[-1]) * slope_right
+ # Determine the slope on the left of the array
+ slope_left = (q[1] - q[0]) / (w[1] - w[0])
+ # extrapolate at wavelengths smaller than the min on the left
+ indleft = np.where(w_grid < w[0])[0]
+ q_left = q[0] + (w_grid[indleft] - w[0]) * slope_left
+ # Construct and extrapolated array of the quantity
+ w = np.concatenate((w_grid[indleft], w, w_grid[indright]))
+ q = np.concatenate((q_left, q, q_right))
+
+ # resample at the w_grid everywhere
+ q_grid = np.interp(w_grid, w, q)
+
+ return q_grid
+
+
+def calc_2d_wave_map(wave_grid, x_dms, y_dms, tilt, oversample=2, padding=10, maxiter=5, dtol=1e-2):
+ """Compute the 2D wavelength map on the detector.
+
+ Parameters
+ ----------
+ wave_grid : sequence
+ The wavelength corresponding to the x_dms, y_dms, and tilt values.
+ x_dms : sequence
+ The trace x position on the detector in DMS coordinates.
+ y_dms : sequence
+ The trace y position on the detector in DMS coordinates.
+ tilt : sequence
+ The trace tilt angle in degrees.
+ oversample : int
+ The oversampling factor of the input coordinates.
+ padding : int
+ The native pixel padding around the edge of the detector.
+ maxiter : int
+ The maximum number of iterations used when solving for the wavelength at each pixel.
+ dtol : float
+ The tolerance of the iterative solution in pixels.
+
+ Returns
+ -------
+ Array
+ An array containing the wavelength at each pixel on the detector.
+ """
+ os = np.copy(oversample)
+ xpad = np.copy(padding)
+ ypad = np.copy(padding)
+
+ # No need to compute wavelengths across the entire detector, slightly larger than SUBSTRIP256 will do.
+ dimx, dimy = 2048, 300
+ y_dms = y_dms + (dimy - 2048) # Adjust y-coordinate to area of interest.
+
+ # Generate the oversampled grid of pixel coordinates.
+ x_vec = np.arange((dimx + 2 * xpad) * os) / os - (os - 1) / (2 * os) - xpad
+ y_vec = np.arange((dimy + 2 * ypad) * os) / os - (os - 1) / (2 * os) - ypad
+ x_grid, y_grid = np.meshgrid(x_vec, y_vec)
+
+ # Iteratively compute the wavelength at each pixel.
+ delta_x = 0.0 # A shift in x represents a shift in wavelength.
+ for niter in range(maxiter):
+
+ # Assume all y have same wavelength.
+ wave_iterated = np.interp(x_grid - delta_x, x_dms[::-1],
+ wave_grid[::-1]) # Invert arrays to get increasing x.
+
+ # Compute the tilt angle at the wavelengths.
+ tilt_tmp = np.interp(wave_iterated, wave_grid, tilt)
+
+ # Compute the trace position at the wavelengths.
+ x_estimate = np.interp(wave_iterated, wave_grid, x_dms)
+ y_estimate = np.interp(wave_iterated, wave_grid, y_dms)
+
+ # Project that back to pixel coordinates.
+ x_iterated = x_estimate + (y_grid - y_estimate) * np.tan(np.deg2rad(tilt_tmp))
+
+ # Measure error between requested and iterated position.
+ delta_x = delta_x + (x_iterated - x_grid)
+
+ # If the desired precision has been reached end iterations.
+ if np.all(np.abs(x_iterated - x_grid) < dtol):
+ break
+
+ # Evaluate the final wavelength map, this time setting out-of-bounds values to NaN.
+ wave_map_2d = np.interp(x_grid - delta_x, x_dms[::-1], wave_grid[::-1], left=np.nan, right=np.nan)
+
+ # Extend to full detector size.
+ tmp = np.full((os * (dimx + 2 * xpad), os * (dimx + 2 * xpad)), fill_value=np.nan)
+ tmp[-os * (dimy + 2 * ypad):] = wave_map_2d
+ wave_map_2d = tmp
+
+ return wave_map_2d
+
+
+def get_soss_wavemaps(pwcpos=PWCPOS_CMD, subarray='SUBSTRIP256', padding=False, padsize=20, spectraces=False):
+ """
+ Generate order 1 and 2 2D wavemaps from the rotated SOSS trace positions
+
+ Parameters
+ ----------
+ pwcpos : float
+ The pupil wheel position
+ subarray: str
+ The subarray name, ['FULL', 'SUBSTRIP256', 'SUBSTRIP96']
+ padding : bool
+ Include padding on map edges (only needed for reference files)
+ padsize: int
+ The size of the padding to include on each side
+ spectraces : bool
+ Return the interpolated spectraces as well
+
+ Returns
+ -------
+ Array, Array
+ The 2D wavemaps and corresponding 1D spectraces
+ """
+ traces_order1, traces_order2 = get_soss_traces(pwcpos=pwcpos, order='12', subarray=subarray, interp=True)
+
+ # Make wavemap from trace center wavelengths, padding to shape (296, 2088)
+ wavemin = 0.5
+ wavemax = 5.5
+ nwave = 5001
+ wave_grid = np.linspace(wavemin, wavemax, nwave)
+
+ # Extrapolate wavelengths for order 1 trace
+ xtrace_order1 = extrapolate_to_wavegrid(wave_grid, traces_order1.wavelength, traces_order1.x)
+ ytrace_order1 = extrapolate_to_wavegrid(wave_grid, traces_order1.wavelength, traces_order1.y)
+ spectrace_1 = np.array([xtrace_order1, ytrace_order1, wave_grid])
+
+ # Set cutoff for order 2 where it runs off the detector
+ o2_cutoff = 1783
+ w_o2_tmp = traces_order2.wavelength[:o2_cutoff]
+ w_o2 = np.zeros(2040) * np.nan
+ w_o2[:o2_cutoff] = w_o2_tmp
+ y_o2_tmp = traces_order2.y[:o2_cutoff]
+ y_o2 = np.zeros(2040) * np.nan
+ y_o2[:o2_cutoff] = y_o2_tmp
+ x_o2 = np.copy(traces_order1.x)
+
+ # Fill for column > 1400 with linear extrapolation
+ m = w_o2[o2_cutoff - 1] - w_o2[o2_cutoff - 2]
+ dx = np.arange(2040 - o2_cutoff) + 1
+ w_o2[o2_cutoff:] = w_o2[o2_cutoff - 1] + m * dx
+ m = y_o2[o2_cutoff - 1] - y_o2[o2_cutoff - 2]
+ dx = np.arange(2040 - o2_cutoff) + 1
+ y_o2[o2_cutoff:] = y_o2[o2_cutoff - 1] + m * dx
+
+ # Extrapolate wavelengths for order 2 trace
+ xtrace_order2 = extrapolate_to_wavegrid(wave_grid, w_o2, x_o2)
+ ytrace_order2 = extrapolate_to_wavegrid(wave_grid, w_o2, y_o2)
+ spectrace_2 = np.array([xtrace_order2, ytrace_order2, wave_grid])
+
+ # Make wavemap from wavelength solution for order 1
+ wavemap_1 = calc_2d_wave_map(wave_grid, xtrace_order1, ytrace_order1, np.zeros_like(xtrace_order1), oversample=1, padding=padsize)
+
+ # Make wavemap from wavelength solution for order 2
+ wavemap_2 = calc_2d_wave_map(wave_grid, xtrace_order2, ytrace_order2, np.zeros_like(xtrace_order2), oversample=1, padding=padsize)
+
+ # Extrapolate wavemap to FULL frame
+ wavemap_1[:-256 - padsize, :] = wavemap_1[-256 - padsize]
+ wavemap_2[:-256 - padsize, :] = wavemap_2[-256 - padsize]
+
+ # Trim to subarray
+ if subarray == 'SUBSTRIP256':
+ wavemap_1 = wavemap_1[1792 - padsize:2048 + padsize, :]
+ wavemap_2 = wavemap_2[1792 - padsize:2048 + padsize, :]
+ if subarray == 'SUBSTRIP96':
+ wavemap_1 = wavemap_1[1792 - padsize:1792 + 96 + padsize, :]
+ wavemap_2 = wavemap_2[1792 - padsize:1792 + 96 + padsize, :]
+
+ # Remove padding if necessary
+ if not padding:
+ wavemap_1 = wavemap_1[padsize:-padsize, padsize:-padsize]
+ wavemap_2 = wavemap_2[padsize:-padsize, padsize:-padsize]
+
+ if spectraces:
+ return np.array([wavemap_1, wavemap_2]), np.array([spectrace_1, spectrace_2])
+
+ else:
+ return np.array([wavemap_1, wavemap_2])
diff --git a/jwst/extract_1d/soss_extract/wavecal.py b/jwst/extract_1d/soss_extract/wavecal.py
new file mode 100644
index 0000000000..f885c6f28f
--- /dev/null
+++ b/jwst/extract_1d/soss_extract/wavecal.py
@@ -0,0 +1,290 @@
+# Wavecal module for handling the wavelength calibration model. For now we
+# will use the reference json files that has all the model information. Ideally
+# we would like to use just the train model for production but for now this is
+# this is fine.
+
+import json
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Dict
+
+import numpy as np
+import numpy.typing as npt
+from pkg_resources import resource_filename
+
+# explicitly defining the commanded position here as well (this is temp)
+PWCPOS_CMD = 245.76
+
+# # TODO: order 3 currently unsupport ATM. Will be support in the future: TBD
+# REFERENCE_WAVECAL_MODELS = {
+# "order1": resource_filename(
+# __name__, "data/jwst_niriss_gr700xd_wavelength_model_order1.json"
+# ),
+# "order2": resource_filename(
+# __name__, "data/jwst_niriss_gr700xd_wavelength_model_order2_002.json"
+# ),
+# }
+
+
+# dataclass object to store the metadata from a wavecal model
+@dataclass
+class NIRISS_GR700XD_WAVECAL_META:
+ """
+ Datamodel object container for the wavecal meta data in the json reference
+ file.
+ """
+
+ order: str
+ coefficients: npt.NDArray[np.float64]
+ intercept: npt.NDArray[np.float64]
+ scaler_data_min_: float
+ scaler_data_max_: float
+ poly_degree: int
+
+
+def load_wavecal_model(filename: str) -> Dict[str, Any]:
+ """
+ Load a wavecalibration model from a JSON file.
+
+ Parameters
+ ----------
+ filename : str
+ The path to the JSON file containing the wavecalibration model.
+
+ Returns
+ -------
+ dict
+ A dictionary representing the loaded wavecalibration model.
+
+ Notes
+ -----
+ This function reads the specified JSON file and returns its contents as a
+ dictionary, which typically contains information about a wavecalibration
+ model used in data analysis.
+ """
+ with open(filename, "r") as file:
+ wavecal_model = json.load(file)
+ return wavecal_model
+
+
+def get_wavecal_meta_for_spectral_order(
+ order: str,
+) -> NIRISS_GR700XD_WAVECAL_META:
+ """
+ Retrieve wavecalibration model metadata for a specific spectral order.
+
+ Parameters
+ ----------
+ order : str
+ The spectral order for which wavecalibration metadata is requested.
+ Valid options are 'order 1', 'order 2', or 'order 3'.
+
+ Returns
+ -------
+ NIRISS_GR700XD_WAVECAL_META
+ An object containing wavecalibration metadata, including order, model
+ coefficients, intercept, and scaler data bounds.
+
+ Raises
+ ------
+ ValueError
+ If the provided 'order' is not one of the valid options.
+
+ Notes
+ -----
+ This function retrieves wavecalibration metadata for the specified spectral
+ order. It first checks if the 'order' is valid and then loads the
+ corresponding wavecalibration model. The model's coefficients, intercept,
+ and scaler data bounds are extracted and returned as part of the metadata
+ object.
+ """
+ # get the reference wavecal file name
+ valid_orders = ['order 1', 'order 2']
+ if order not in valid_orders:
+ raise ValueError(f"valid orders are: {valid_orders}.")
+
+ # get the appropiate reference file name given the order
+ reference_filename = REFERENCE_WAVECAL_MODELS[order]
+
+ # load in the model
+ wavecal_model = load_wavecal_model(reference_filename)
+
+ # model coefficients
+ poly_degree = wavecal_model["model"]["poly_deg"]
+ coefficients = wavecal_model["model"]["coef"]
+ intercept = wavecal_model["model"]["intercept"]
+
+ # info for scaling inputs
+ scaler_data_min_ = wavecal_model["model"]["scaler"]["data_min_"]
+ scaler_data_max_ = wavecal_model["model"]["scaler"]["data_max_"]
+
+ return NIRISS_GR700XD_WAVECAL_META(
+ order, coefficients, intercept, scaler_data_min_, scaler_data_max_, poly_degree
+ )
+
+
+def get_wavelengths(
+ x: np.ndarray, pwcpos: float, wavecal_meta: NIRISS_GR700XD_WAVECAL_META
+) -> np.ndarray:
+ """Get the associated wavelength values for a given spectral order"""
+ if wavecal_meta.order == "order1":
+ wavelengths = wavecal_model_order1_poly(x, pwcpos, wavecal_meta)
+ elif wavecal_meta.order == "order2":
+ # raise NotImplementedError("Order 2 not implemented")
+ wavelengths = wavecal_model_order2_poly(x, pwcpos, wavecal_meta)
+ elif wavecal_meta.order == "order3":
+ raise ValueError("Order 3 not supported at this time")
+ else:
+ raise ValueError("not a valid order")
+
+ return wavelengths
+
+
+def min_max_scaler(x, x_min, x_max):
+ """
+ Apply min-max scaling to input values.
+
+ Parameters
+ ----------
+ x : float or numpy.ndarray
+ The input value(s) to be scaled.
+ x_min : float
+ The minimum value in the range to which 'x' will be scaled.
+ x_max : float
+ The maximum value in the range to which 'x' will be scaled.
+
+ Returns
+ -------
+ float or numpy.ndarray
+ The scaled value(s) in the range [0, 1].
+
+ Notes
+ -----
+ Min-max scaling is a data normalization technique that scales input values
+ 'x' to the range [0, 1] based on the provided minimum and maximum values,
+ 'x_min' and 'x_max'. This function is applicable to both individual values
+ and arrays of values. This function will use the min/max values from the
+ training data of the wavecal model.
+ """
+ # scaling the input x values
+ x_scaled = (x - x_min) / (x_max - x_min)
+ return x_scaled
+
+
+def wavecal_model_order1_poly(x, pwcpos, wavecal_meta: NIRISS_GR700XD_WAVECAL_META):
+ """compute order 1 wavelengths"""
+ x_scaler = partial(
+ min_max_scaler,
+ **{
+ "x_min": wavecal_meta.scaler_data_min_[0],
+ "x_max": wavecal_meta.scaler_data_max_[0],
+ },
+ )
+
+ pwcpos_offset_scaler = partial(
+ min_max_scaler,
+ **{
+ "x_min": wavecal_meta.scaler_data_min_[1],
+ "x_max": wavecal_meta.scaler_data_max_[1],
+ },
+ )
+
+ def get_poly_features(x: np.array, offset: np.array) -> np.ndarray:
+ """polynomial features for the order 1 wavecal model"""
+ poly_features = np.array(
+ [
+ x,
+ offset,
+ x**2,
+ x * offset,
+ offset**2,
+ x**3,
+ x**2 * offset,
+ x * offset**2,
+ offset**3,
+ x**4,
+ x**3 * offset,
+ x**2 * offset**2,
+ x * offset**3,
+ offset**4,
+ x**5,
+ x**4 * offset,
+ x**3 * offset**2,
+ x**2 * offset**3,
+ x * offset**4,
+ offset**5,
+ ]
+ )
+ return poly_features
+
+ # extract model weights and intercept
+ coef = wavecal_meta.coefficients
+ intercept = wavecal_meta.intercept
+
+ # get pixel columns and then scaled
+ x_scaled = x_scaler(x)
+
+ # offset
+ offset = np.ones_like(x) * (pwcpos - PWCPOS_CMD)
+ offset_scaled = pwcpos_offset_scaler(offset)
+
+ # polynomial features
+ poly_features = get_poly_features(x_scaled, offset_scaled)
+ wavelengths = coef @ poly_features + intercept
+
+ return wavelengths
+
+
+def wavecal_model_order2_poly(x, pwcpos, wavecal_meta: NIRISS_GR700XD_WAVECAL_META):
+ """compute order 2 wavelengths"""
+ x_scaler = partial(
+ min_max_scaler,
+ **{
+ "x_min": wavecal_meta.scaler_data_min_[0],
+ "x_max": wavecal_meta.scaler_data_max_[0],
+ },
+ )
+
+ pwcpos_offset_scaler = partial(
+ min_max_scaler,
+ **{
+ "x_min": wavecal_meta.scaler_data_min_[1],
+ "x_max": wavecal_meta.scaler_data_max_[1],
+ },
+ )
+
+ def get_poly_features(x: np.array, offset: np.array) -> np.ndarray:
+ """Polynomial features for the order 2 wavecal model"""
+ poly_features = np.array(
+ [
+ x,
+ offset,
+ x**2,
+ x * offset,
+ offset**2,
+ x**3,
+ x**2 * offset,
+ x * offset**2,
+ offset**3,
+ ]
+ )
+ return poly_features
+
+ # coef and intercept
+ coef = wavecal_meta.coefficients
+ intercept = wavecal_meta.intercept
+
+ # get pixel columns and then scaled
+ x_scaled = x_scaler(x)
+
+ # offset
+ # offset = np.ones_like(x) * (pwcpos - PWCPOS_CMD)
+ # this will need to get changed later...
+ offset = np.ones_like(x) * pwcpos
+ offset_scaled = pwcpos_offset_scaler(offset)
+
+ # polynomial features
+ poly_features = get_poly_features(x_scaled, offset_scaled)
+ wavelengths = coef @ poly_features + intercept
+
+ return wavelengths