Skip to content

Commit

Permalink
Merge pull request #3148 from bmorris3/rampviz-jwst-parser
Browse files Browse the repository at this point in the history
JWST L1 ramp parser for Rampviz
  • Loading branch information
bmorris3 authored Aug 26, 2024
2 parents 2160823 + 22546d9 commit a9b559b
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 105 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ New Features

- The standalone version of jdaviz now uses solara instead of voila, resulting in faster load times. [#2909]

- New configuration for ramp/Level 1 data products from Roman WFI and JWST [#3120, #3148]

Cubeviz
^^^^^^^

Expand Down
156 changes: 82 additions & 74 deletions jdaviz/configs/rampviz/plugins/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import astropy.units as u
from astropy.io import fits
from astropy.nddata import NDData, NDDataArray
from astropy.time import Time
from stdatamodels.jwst.datamodels import Level1bModel

from jdaviz.core.registries import data_parser_registry
from jdaviz.configs.cubeviz.plugins.parsers import _get_data_type_by_hdu
from jdaviz.utils import (
standardize_metadata, download_uri_to_path,
PRIHDR_KEY, standardize_roman_metadata
Expand All @@ -25,7 +24,8 @@

@data_parser_registry("ramp-data-parser")
def parse_data(app, file_obj, data_type=None, data_label=None,
parent=None, cache=None, local_path=None, timeout=None):
parent=None, cache=None, local_path=None, timeout=None,
integration=0):
"""
Attempts to parse a data file and auto-populate available viewers in
rampviz.
Expand Down Expand Up @@ -53,6 +53,10 @@ def parse_data(app, file_obj, data_type=None, data_label=None,
remote requests in seconds (passed to
`~astropy.utils.data.download_file` or
`~astroquery.mast.Conf.timeout`).
integration : int, optional
JWST Level 1b products bundle multiple integrations in a time-series into the
same ramp file. If this keyword is specified and the observations
are JWST Level 1b products, this integration in the time series will be selected.
"""

group_viewer_reference_name = app._jdaviz_helper._default_group_viewer_reference_name
Expand Down Expand Up @@ -101,6 +105,7 @@ def parse_data(app, file_obj, data_type=None, data_label=None,
with fits.open(file_obj) as hdulist:
_parse_hdulist(
app, hdulist, file_name=data_label or file_name,
integration=integration,
group_viewer_reference_name=group_viewer_reference_name,
diff_viewer_reference_name=diff_viewer_reference_name,
)
Expand All @@ -121,6 +126,20 @@ def parse_data(app, file_obj, data_type=None, data_label=None,
meta=getattr(file_obj, 'meta')
)

elif isinstance(file_obj, Level1bModel):
metadata = standardize_metadata({
key: value for key, value in file_obj.to_flat_dict().items()
if key.startswith('meta')
})

_parse_ramp_cube(
app, file_obj.data[integration], u.DN,
data_label or file_obj.__class__.__name__,
group_viewer_reference_name,
diff_viewer_reference_name,
meta=metadata
)

elif HAS_ROMAN_DATAMODELS and isinstance(file_obj, rdd.DataModel):
with rdd.open(file_obj) as pf:
_roman_3d_to_glue_data(
Expand All @@ -133,6 +152,13 @@ def parse_data(app, file_obj, data_type=None, data_label=None,
raise NotImplementedError(f'Unsupported data format: {file_obj}')


def _swap_axes(x):
# swap axes per the conventions of ramp cubes
# (group axis comes first) and the default in
# rampviz (group axis expected last)
return np.swapaxes(x, 0, -1)


def _roman_3d_to_glue_data(
app, file_obj, data_label,
group_viewer_reference_name=None,
Expand All @@ -143,12 +169,6 @@ def _roman_3d_to_glue_data(
Parse a Roman 3D ramp cube file (Level 1),
usually with suffix '_uncal.asdf'.
"""
def _swap_axes(x):
# swap axes per the conventions of Roman cubes
# (group axis comes first) and the default in
# Cubeviz (wavelength axis expected last)
return np.swapaxes(x, 0, -1)

# update viewer reference names for Roman ramp cubes:
# app._update_viewer_reference_name()

Expand Down Expand Up @@ -199,88 +219,76 @@ def _swap_axes(x):

def _parse_hdulist(
app, hdulist, file_name=None,
viewer_reference_name=None
integration=None,
group_viewer_reference_name=None,
diff_viewer_reference_name=None,
):
if file_name is None and hasattr(hdulist, 'file_name'):
file_name = hdulist.file_name
else:
file_name = file_name or "Unknown HDU object"

is_loaded = []

# TODO: This needs refactoring to be more robust.
# Current logic fails if there are multiple EXTVER.
for hdu in hdulist:
if hdu.data is None or not hdu.is_image or hdu.data.ndim != 3:
continue

data_type = _get_data_type_by_hdu(hdu)
if not data_type:
continue

# Only load each type once.
if data_type in is_loaded:
continue

is_loaded.append(data_type)
data_label = app.return_data_label(file_name, hdu.name)
hdu = hdulist[1] # extension containing the ramp
if hdu.header['NAXIS'] != 4:
raise ValueError(f"Expected a ramp with NAXIS=4 (with axes:"
f"integrations, groups, x, y), but got "
f"NAXIS={hdu.header['NAXIS']}.")

if 'BUNIT' in hdu.header:
try:
flux_unit = u.Unit(hdu.header['BUNIT'])
except Exception:
logging.warning("Invalid BUNIT, using DN as data unit")
flux_unit = u.DN
else:
if 'BUNIT' in hdu.header:
try:
flux_unit = u.Unit(hdu.header['BUNIT'])
except Exception:
logging.warning("Invalid BUNIT, using DN as data unit")
flux_unit = u.DN
else:
logging.warning("Invalid BUNIT, using DN as data unit")
flux_unit = u.DN

flux = hdu.data << flux_unit
metadata = standardize_metadata(hdu.header)
if hdu.name != 'PRIMARY' and 'PRIMARY' in hdulist:
metadata[PRIHDR_KEY] = standardize_metadata(hdulist['PRIMARY'].header)
# index the ramp array by the integration to load. returns all groups and pixels.
# cast from uint16 to integers:
ramp_cube = hdu.data[integration].astype(int)

app.add_data(flux, data_label)
app.data_collection[data_label].get_component("data").units = flux_unit
app.add_data_to_viewer(viewer_reference_name, data_label)
app._jdaviz_helper._loaded_flux_cube = app.data_collection[data_label]
metadata = standardize_metadata(hdu.header)
if hdu.name != 'PRIMARY' and 'PRIMARY' in hdulist:
metadata[PRIHDR_KEY] = standardize_metadata(hdulist['PRIMARY'].header)

_parse_ramp_cube(
app, ramp_cube, flux_unit, file_name,
group_viewer_reference_name,
diff_viewer_reference_name,
meta=metadata
)


def _parse_jwst_level1(
app, hdulist, data_label, ext='SCI',
viewer_name=None,
):
hdu = hdulist[ext]
data_type = _get_data_type_by_hdu(hdu)

# Manually inject MJD-OBS until we can support GWCS, see
# https://github.com/spacetelescope/jdaviz/issues/690 and
# https://github.com/glue-viz/glue-astronomy/issues/59
if ext == 'SCI' and 'MJD-OBS' not in hdu.header:
for key in ('MJD-BEG', 'DATE-OBS'): # Possible alternatives
if key in hdu.header:
if key.startswith('MJD'):
hdu.header['MJD-OBS'] = hdu.header[key]
break
else:
t = Time(hdu.header[key])
hdu.header['MJD-OBS'] = t.mjd
break

unit = u.Unit(hdu.header.get('BUNIT', 'count'))
flux = hdu.data << unit
def _parse_ramp_cube(app, ramp_cube_data, flux_unit, file_name,
group_viewer_reference_name, diff_viewer_reference_name,
meta=None):
# last axis is the group axis, first two are spatial axes:
diff_data = np.vstack([
# begin with a group of zeros, so
# that `diff_data.ndim == data.ndim`
np.zeros((1, *ramp_cube_data[0].shape)),
np.diff(ramp_cube_data, axis=0)
])

metadata = standardize_metadata(hdu.header)
app.data_collection[data_label] = NDData(data=flux, meta=metadata)
ramp_cube = NDDataArray(_swap_axes(ramp_cube_data), unit=flux_unit, meta=meta)
diff_cube = NDDataArray(_swap_axes(diff_data), unit=flux_unit, meta=meta)

group_data_label = app.return_data_label(file_name, ext="DATA")
diff_data_label = app.return_data_label(file_name, ext="DIFF")

if data_type == 'flux':
app.data_collection[-1].get_component("data").units = flux.unit
for data_entry, data_label, viewer_ref in zip(
(ramp_cube, diff_cube),
(group_data_label, diff_data_label),
(group_viewer_reference_name, diff_viewer_reference_name)
):
app.add_data(data_entry, data_label)
app.add_data_to_viewer(viewer_ref, data_label)

if viewer_name is not None:
app.add_data_to_viewer(viewer_name, data_label)
# load these cubes into the cache:
app._jdaviz_helper.cube_cache[data_label] = data_entry

if data_type == 'flux':
app._jdaviz_helper._loaded_flux_cube = app.data_collection[data_label]
app._jdaviz_helper._loaded_flux_cube = app.data_collection[group_data_label]


def _parse_ndarray(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def _on_subset_update(self, msg={}):
return

# glue region has transposed coords relative to cached cube:
region_mask = region.to_mask().to_image(self.cube.shape[:-1]).astype(bool).T
region_mask = region.to_mask().to_image(
self.cube.shape[:-1][::-1]
).astype(bool).T
cube_subset = self.cube[region_mask] # shape: (N pixels extracted, M groups)

n_pixels_in_extraction = cube_subset.shape[0]
Expand Down Expand Up @@ -292,7 +294,7 @@ def _update_aperture_method_on_function_change(self, *args):

@property
def cube(self):
return self.app._jdaviz_helper.cube_cache[self.dataset.selected]
return self.app._jdaviz_helper.cube_cache.get(self.dataset.selected)

@property
def slice_display_unit(self):
Expand Down
18 changes: 17 additions & 1 deletion jdaviz/configs/rampviz/tests/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


@pytest.mark.skipif(not HAS_ROMAN_DATAMODELS, reason="roman_datamodels is not installed")
def test_load_data(rampviz_helper, roman_level_1_ramp):
def test_load_data_roman(rampviz_helper, roman_level_1_ramp):
rampviz_helper.load_data(roman_level_1_ramp)

# on ramp cube load (1), the parser loads a diff cube (2) and
Expand All @@ -17,3 +17,19 @@ def test_load_data(rampviz_helper, roman_level_1_ramp):

assert viewer.axis_x.label == 'Group'
assert viewer.axis_y.label == 'DN'


def test_load_data_jwst(rampviz_helper, jwst_level_1b_ramp):
rampviz_helper.load_data(jwst_level_1b_ramp)

# on ramp cube load (1), the parser loads a diff cube (2) and
# the ramp extraction plugin produces a default extraction (3):
assert len(rampviz_helper.app.data_collection) == 3

# each viewer should have one loaded data entry:
for refname in 'group-viewer, diff-viewer, integration-viewer'.split(', '):
viewer = rampviz_helper.app.get_viewer(refname)
assert len(viewer.state.layers) == 1

assert viewer.axis_x.label == 'Group'
assert viewer.axis_y.label == 'DN'
22 changes: 15 additions & 7 deletions jdaviz/configs/rampviz/tests/test_ramp_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,30 @@


@pytest.mark.skipif(not HAS_ROMAN_DATAMODELS, reason="roman_datamodels is not installed")
def test_previews(rampviz_helper, roman_level_1_ramp):
rampviz_helper.load_data(roman_level_1_ramp)
def test_previews_roman(rampviz_helper, roman_level_1_ramp):
_ramp_extraction_previews(rampviz_helper, roman_level_1_ramp)


def test_previews_jwst(rampviz_helper, jwst_level_1b_ramp):
_ramp_extraction_previews(rampviz_helper, jwst_level_1b_ramp)


def _ramp_extraction_previews(_rampviz_helper, _ramp_file):
_rampviz_helper.load_data(_ramp_file)

# add subset:
region = CirclePixelRegion(center=PixCoord(12.5, 15.5), radius=2)
rampviz_helper.load_regions(region)
ramp_extr = rampviz_helper.plugins['Ramp Extraction']._obj
_rampviz_helper.load_regions(region)
ramp_extr = _rampviz_helper.plugins['Ramp Extraction']._obj

subsets = rampviz_helper.app.get_subsets()
ramp_cube = rampviz_helper.app.data_collection[0]
subsets = _rampviz_helper.app.get_subsets()
ramp_cube = _rampviz_helper.app.data_collection[0]
n_groups = ramp_cube.shape[-1]

assert len(subsets) == 1
assert 'Subset 1' in subsets

integration_viewer = rampviz_helper.app.get_viewer('integration-viewer')
integration_viewer = _rampviz_helper.app.get_viewer('integration-viewer')

# contains a layer for the default ramp extraction and the subset:
assert len(integration_viewer.layers) == 2
Expand Down
Loading

0 comments on commit a9b559b

Please sign in to comment.