Skip to content

Commit

Permalink
Merge pull request #3167 from bmorris3/rampviz-nirspec-fix
Browse files Browse the repository at this point in the history
Fix for NIRSpec IRS2 readout mode
  • Loading branch information
bmorris3 authored Aug 29, 2024
2 parents 69e31b7 + 0e735d5 commit c46a0df
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 12 deletions.
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ New Features

- The standalone version of jdaviz now uses solara instead of voila, resulting in faster load times. [#2909]

- New configuration for ramp/Level 1 data products from Roman WFI and JWST [#3120, #3148]
- New configuration for ramp/Level 1 data products from Roman WFI and JWST [#3120, #3148, #3167]

Cubeviz
^^^^^^^
Expand Down
50 changes: 42 additions & 8 deletions jdaviz/configs/rampviz/plugins/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def parse_data(app, file_obj, data_type=None, data_label=None,

elif isinstance(file_obj, Level1bModel):
metadata = standardize_metadata({
key: value for key, value in file_obj.to_flat_dict().items()
key: value for key, value in file_obj.to_flat_dict(
include_arrays=False)
.items()
if key.startswith('meta')
})

Expand All @@ -152,7 +154,7 @@ def parse_data(app, file_obj, data_type=None, data_label=None,
raise NotImplementedError(f'Unsupported data format: {file_obj}')


def _swap_axes(x):
def move_group_axis_last(x):
# swap axes per the conventions of ramp cubes
# (group axis comes first) and the default in
# rampviz (group axis expected last)
Expand Down Expand Up @@ -189,23 +191,27 @@ def _roman_3d_to_glue_data(
ramp_diff_data_label = f"{data_label}[DIFF]"

# load these cubes into the cache:
app._jdaviz_helper.cube_cache[ramp_cube_data_label] = NDDataArray(_swap_axes(data))
app._jdaviz_helper.cube_cache[ramp_diff_data_label] = NDDataArray(_swap_axes(diff_data))
app._jdaviz_helper.cube_cache[ramp_cube_data_label] = NDDataArray(
move_group_axis_last(data)
)
app._jdaviz_helper.cube_cache[ramp_diff_data_label] = NDDataArray(
move_group_axis_last(diff_data)
)

if meta is not None:
meta = standardize_roman_metadata(file_obj)

# load these cubes into the app:
_parse_ndarray(
app,
file_obj=_swap_axes(data),
file_obj=move_group_axis_last(data),
data_label=ramp_cube_data_label,
viewer_reference_name=group_viewer_reference_name,
meta=meta
)
_parse_ndarray(
app,
file_obj=_swap_axes(diff_data),
file_obj=move_group_axis_last(diff_data),
data_label=ramp_diff_data_label,
viewer_reference_name=diff_viewer_reference_name,
meta=meta
Expand Down Expand Up @@ -263,6 +269,27 @@ def _parse_hdulist(
def _parse_ramp_cube(app, ramp_cube_data, flux_unit, file_name,
group_viewer_reference_name, diff_viewer_reference_name,
meta=None):

# Identify NIRSpec IRS2 detector mode, which needs special treatment.
# jdox: https://jwst-docs.stsci.edu/jwst-near-infrared-spectrograph/nirspec-instrumentation/
# nirspec-detectors/nirspec-detector-readout-modes-and-patterns/nirspec-irs2-detector-readout-mode
if 'meta.model_type' in meta:
# this is a Level1bModel, which has metadata in a Node rather
# than a dictionary:
from_jwst_nirspec_irs2 = (
meta.get('meta._primary_header.TELESCOP') == 'JWST' and
meta.get('meta._primary_header.INSTRUME') == 'NIRSPEC' and
'IRS2' in meta.get('meta._primary_header.READPATT', '')
)
else:
# assume this was parsed from FITS:
header = meta.get('_primary_header', {})
from_jwst_nirspec_irs2 = (
header.get('TELESCOP') == 'JWST' and
header.get('INSTRUME') == 'NIRSPEC' and
'IRS2' in header.get('READPATT', '')
)

# last axis is the group axis, first two are spatial axes:
diff_data = np.vstack([
# begin with a group of zeros, so
Expand All @@ -271,8 +298,15 @@ def _parse_ramp_cube(app, ramp_cube_data, flux_unit, file_name,
np.diff(ramp_cube_data, axis=0)
])

ramp_cube = NDDataArray(_swap_axes(ramp_cube_data), unit=flux_unit, meta=meta)
diff_cube = NDDataArray(_swap_axes(diff_data), unit=flux_unit, meta=meta)
if from_jwst_nirspec_irs2:
# JWST/NIRSpec in IRS2 readout needs an additional axis swap for x and y:
def move_axes(x):
return np.swapaxes(move_group_axis_last(x), 0, 1)
else:
move_axes = move_group_axis_last

ramp_cube = NDDataArray(move_axes(ramp_cube_data), unit=flux_unit, meta=meta)
diff_cube = NDDataArray(move_axes(diff_data), unit=flux_unit, meta=meta)

group_data_label = app.return_data_label(file_name, ext="DATA")
diff_data_label = app.return_data_label(file_name, ext="DIFF")
Expand Down
45 changes: 45 additions & 0 deletions jdaviz/configs/rampviz/tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@


def test_load_rectangular_ramp(rampviz_helper, jwst_level_1b_rectangular_ramp):
rampviz_helper.load_data(jwst_level_1b_rectangular_ramp)

# drop the integration axis
original_cube_shape = jwst_level_1b_rectangular_ramp.shape[1:]

# on ramp cube load (1), the parser loads a diff cube (2) and
# the ramp extraction plugin produces a default extraction (3):
assert len(rampviz_helper.app.data_collection) == 3

parsed_cube_shape = rampviz_helper.app.data_collection[0].shape
assert parsed_cube_shape == (
original_cube_shape[2], original_cube_shape[1], original_cube_shape[0]
)


def test_load_nirspec_irs2(rampviz_helper, jwst_level_1b_rectangular_ramp):
# update the Level1bModel to have the header cards that are
# expected for an exposure from NIRSpec in IRS2 readout mode
jwst_level_1b_rectangular_ramp.update(
{
'meta': {
'_primary_header': {
"TELESCOP": "JWST",
"INSTRUME": "NIRSPEC",
"READPATT": "NRSIRS2"
}
}
}
)
rampviz_helper.load_data(jwst_level_1b_rectangular_ramp)

# drop the integration axis
original_cube_shape = jwst_level_1b_rectangular_ramp.shape[1:]

# on ramp cube load (1), the parser loads a diff cube (2) and
# the ramp extraction plugin produces a default extraction (3):
assert len(rampviz_helper.app.data_collection) == 3

parsed_cube_shape = rampviz_helper.app.data_collection[0].shape
assert parsed_cube_shape == (
original_cube_shape[1], original_cube_shape[2], original_cube_shape[0]
)
14 changes: 11 additions & 3 deletions jdaviz/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,30 @@ def roman_level_1_ramp():
return data_model


@pytest.fixture
def jwst_level_1b_ramp():
def _make_jwst_ramp(shape=(1, 10, 25, 25)):
from stdatamodels.jwst.datamodels import Level1bModel

rng = np.random.default_rng(seed=42)

# JWST Level 1b ramp files have an additional preceding dimension
# compared with Roman. This dimension is the integration number
# in a sequence (if there's more than one in the visit).
shape = (1, 10, 25, 25)
data_model = Level1bModel(shape)
data_model.data = 100 + 3 * np.cumsum(rng.uniform(size=shape), axis=0)

return data_model


@pytest.fixture
def jwst_level_1b_ramp():
return _make_jwst_ramp()


@pytest.fixture
def jwst_level_1b_rectangular_ramp():
return _make_jwst_ramp(shape=(1, 10, 32, 25))


@pytest.fixture
def image_2d_wcs():
return WCS({'CTYPE1': 'RA---TAN', 'CUNIT1': 'deg', 'CDELT1': -0.0002777777778,
Expand Down

0 comments on commit c46a0df

Please sign in to comment.