diff --git a/glue/core/state.py b/glue/core/state.py index 333f9ffc8..47f9f1d68 100644 --- a/glue/core/state.py +++ b/glue/core/state.py @@ -65,7 +65,7 @@ def load(rec, context) import numpy as np from matplotlib.colors import Colormap from matplotlib import cm -from astropy.units import NamedUnit, Unit +from astropy.units import UnitBase, Unit from astropy.wcs import WCS import shapely @@ -622,14 +622,15 @@ def _load_slice(rec, context): return slice(rec['start'], rec['stop'], rec['step']) -@saver(NamedUnit) -def _save_named_unit(unit, context): - return dict(named_unit=unit.to_string()) +@saver(UnitBase) +def _save_unit_base(unit, context): + return dict(unit_base=unit.to_string()) -@loader(NamedUnit) -def _load_named_unit(rec, context): - return Unit(rec["named_unit"]) +@loader(UnitBase) +def _load_unit_base(rec, context): + return Unit(rec["unit_base"]) + @saver(WCS) diff --git a/glue/core/tests/test_state.py b/glue/core/tests/test_state.py index c2be7e0b0..3656f83b0 100644 --- a/glue/core/tests/test_state.py +++ b/glue/core/tests/test_state.py @@ -312,6 +312,21 @@ def test_astropy_units(): unit2 = clone(unit) assert unit2 is unit + unit = u.km + unit2 = clone(unit) + assert unit2 is unit + + +@requires_astropy +def test_astropy_compound_units(): + import astropy.units as u + unit = u.m / u.s + unit2 = clone(unit) + assert unit2 == unit + unit = u.W / u.m**2 / u.nm + unit2 = clone(unit) + assert unit2 == unit + class DummyClass(object): pass