Skip to content

Commit

Permalink
Merge pull request #19 from nmearl/fix/use-specutils-uncerts
Browse files Browse the repository at this point in the history
Use uncertainties and mask from `Spectrum1D`
  • Loading branch information
eteq authored Aug 13, 2020
2 parents f27f484 + 0325316 commit b518b0d
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 23 deletions.
86 changes: 66 additions & 20 deletions glue_astronomy/translators/spectrum1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
from astropy.wcs import WCS
from astropy import units as u
from astropy.wcs import WCSSUB_SPECTRAL
from astropy.nddata import StdDevUncertainty, InverseVariance, VarianceUncertainty

from glue_astronomy.spectral_coordinates import SpectralCoordinates

from specutils import Spectrum1D

UNCERT_REF = {'std': StdDevUncertainty,
'var': VarianceUncertainty,
'ivar': InverseVariance}


@data_translator(Spectrum1D)
class Specutils1DHandler:
Expand All @@ -20,7 +25,19 @@ def to_data(self, obj):
data = Data(coords=coords)
data['flux'] = obj.flux
data.get_component('flux').units = str(obj.unit)

# Include uncertainties if they exist
if obj.uncertainty is not None:
data['uncertainty'] = obj.uncertainty.quantity
data.get_component('uncertainty').units = str(obj.unit)
data.meta.update({'uncertainty_type': obj.uncertainty.uncertainty_type})

# Include mask if it exists
if obj.mask is not None:
data['mask'] = obj.mask

data.meta.update(obj.meta)

return data

def to_object(self, data_or_subset, attribute=None, statistic='mean'):
Expand Down Expand Up @@ -70,32 +87,61 @@ def to_object(self, data_or_subset, attribute=None, statistic='mean'):
elif attribute is None:
if len(data.main_components) == 1:
attribute = data.main_components[0]
# If no specific attribute is defined, attempt to retrieve
# both the flux and uncertainties
elif any([x.label in ('flux', 'uncertainty') for x in data.components]):
attribute = [data.find_component_id('flux'),
data.find_component_id('uncertainty')]
else:
raise ValueError("Data object has more than one attribute, so "
"you will need to specify which one to use as "
"the flux for the spectrum using the "
"attribute= keyword argument.")

component = data.get_component(attribute)
def parse_attributes(attributes):
data_kwargs = {}

# Get mask if there is one defined, or if this is a subset
if subset_state is None:
mask = None
else:
mask = data.get_mask(subset_state=subset_state)
mask = ~mask

# Collapse values and mask to profile
if data.ndim > 1:
# Get units and attach to value
values = data.compute_statistic(statistic, attribute, axis=axes,
subset_state=subset_state)
if mask is not None:
collapse_axes = tuple([x for x in range(1, data.ndim)])
mask = np.all(mask, collapse_axes)
else:
values = data.get_data(attribute)
for attribute in attributes:
component = data.get_component(attribute)

# Get mask if there is one defined, or if this is a subset
if subset_state is None:
mask = None
else:
mask = data.get_mask(subset_state=subset_state)
mask = ~mask

# Collapse values and mask to profile
if data.ndim > 1:
# Get units and attach to value
values = data.compute_statistic(statistic, attribute, axis=axes,
subset_state=subset_state)
if mask is not None:
collapse_axes = tuple([x for x in range(1, data.ndim)])
mask = np.all(mask, collapse_axes)
else:
values = data.get_data(attribute)

attribute_label = attribute.label

if attribute_label not in ('flux', 'uncertainty'):
attribute_label = 'flux'

values = values * u.Unit(component.units)

# If the attribute is uncertainty, we must coerce it to a
# specific uncertainty type. If no value exists in the glue
# object meta dictionary, use standard deviation.
if attribute_label == 'uncertainty':
values = UNCERT_REF[
data.meta.get('uncertainty_type', 'std')](values)

data_kwargs.update({attribute_label: values,
'mask': mask})

return data_kwargs

values = values * u.Unit(component.units)
data_kwargs = parse_attributes(
[attribute] if not hasattr(attribute, '__len__') else attribute)

return Spectrum1D(values, mask=mask, **kwargs)
return Spectrum1D(**data_kwargs, **kwargs)
26 changes: 23 additions & 3 deletions glue_astronomy/translators/tests/test_spectrum1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from astropy import units as u
from astropy.wcs import WCS
from astropy.tests.helper import assert_quantity_allclose
from astropy.nddata import StdDevUncertainty

from glue.core import Data, DataCollection
from glue.core.component import Component
Expand Down Expand Up @@ -134,7 +135,11 @@ def test_from_spectrum1d(mode):

kwargs = {'spectral_axis': [1, 2, 3, 4] * u.Hz}

spec = Spectrum1D([2, 3, 4, 5] * u.Jy, **kwargs)
spec = Spectrum1D([2, 3, 4, 5] * u.Jy,
uncertainty=StdDevUncertainty(
[0.1, 0.1, 0.1, 0.1] * u.Jy),
mask=[False, False, False, False],
**kwargs)

data_collection = DataCollection()

Expand All @@ -143,14 +148,29 @@ def test_from_spectrum1d(mode):
data = data_collection['spectrum']

assert isinstance(data, Data)
assert len(data.main_components) == 1
assert len(data.main_components) == 3
assert data.main_components[0].label == 'flux'
assert_allclose(data['flux'], [2, 3, 4, 5])
component = data.get_component('flux')
assert component.units == 'Jy'

# Check round-tripping
# Check uncertainty parsing within glue data object
assert data.main_components[1].label == 'uncertainty'
assert_allclose(data['uncertainty'], [0.1, 0.1, 0.1, 0.1])
component = data.get_component('uncertainty')
assert component.units == 'Jy'

# Check round-tripping via single attribute reference
spec_new = data.get_object(attribute='flux')
assert isinstance(spec_new, Spectrum1D)
assert_quantity_allclose(spec_new.spectral_axis, [1, 2, 3, 4] * u.Hz)
assert_quantity_allclose(spec_new.flux, [2, 3, 4, 5] * u.Jy)
assert spec_new.uncertainty is None

# Check complete round-tripping, including uncertainties
spec_new = data.get_object()
assert isinstance(spec_new, Spectrum1D)
assert_quantity_allclose(spec_new.spectral_axis, [1, 2, 3, 4] * u.Hz)
assert_quantity_allclose(spec_new.flux, [2, 3, 4, 5] * u.Jy)
assert spec_new.uncertainty is not None
assert_quantity_allclose(spec_new.uncertainty.quantity, [0.1, 0.1, 0.1, 0.1] * u.Jy)

0 comments on commit b518b0d

Please sign in to comment.