diff --git a/maps/CMakeLists.txt b/maps/CMakeLists.txt index 1edd89d9..65344584 100644 --- a/maps/CMakeLists.txt +++ b/maps/CMakeLists.txt @@ -32,3 +32,4 @@ add_spt3g_test(mask_operators) add_spt3g_test(healpix_maps) add_spt3g_test(fitsio) add_spt3g_test(map_modules) +add_spt3g_test(mask_serialization_test) diff --git a/maps/include/maps/G3SkyMapMask.h b/maps/include/maps/G3SkyMapMask.h index 7d56fe43..04fc1c2b 100644 --- a/maps/include/maps/G3SkyMapMask.h +++ b/maps/include/maps/G3SkyMapMask.h @@ -108,7 +108,8 @@ class G3SkyMapMask : public G3FrameObject { private: G3SkyMapMask() {} // Fake out for serialization - template void serialize(A &ar, const unsigned v); + template void load(A &ar, const unsigned v); + template void save(A &ar, const unsigned v) const; friend class cereal::access; std::vector data_; @@ -121,8 +122,12 @@ class G3SkyMapMask : public G3FrameObject { SET_LOGGER("G3SkyMapMask"); }; +namespace cereal { + template struct specialize {}; +} + G3_POINTERS(G3SkyMapMask); -G3_SERIALIZABLE(G3SkyMapMask, 1); +G3_SERIALIZABLE(G3SkyMapMask, 2); #endif diff --git a/maps/src/G3SkyMapMask.cxx b/maps/src/G3SkyMapMask.cxx index cf371d26..94cb6962 100644 --- a/maps/src/G3SkyMapMask.cxx +++ b/maps/src/G3SkyMapMask.cxx @@ -397,15 +397,56 @@ G3SkyMapMask::MakeBinaryMap() const } template -void G3SkyMapMask::serialize(A &ar, unsigned v) +void G3SkyMapMask::load(A &ar, unsigned v) { using namespace cereal; ar & make_nvp("G3FrameObject", base_class(this)); ar & make_nvp("parent", parent_); - ar & make_nvp("data", data_); + if (v < 2) { + ar & make_nvp("data", data_); + } else { + std::vector packed; + size_t nbits; + ar & make_nvp("data", packed); + data_.resize(packed.size()*8); + for (ssize_t i = 0; i < packed.size(); i++) + for (int j = 0; j < 8; j++) + data_[i*8 + j] = (packed[i] >> j) & 1; + ar & make_nvp("nbits", nbits); // In case not a multiple of 8 + data_.resize(nbits); + } +} + +template +void G3SkyMapMask::save(A &ar, unsigned v) const +{ + using namespace cereal; + ar & make_nvp("G3FrameObject", base_class(this)); + ar & make_nvp("parent", parent_); + + std::vector packed((data_.size() / 8) + + ((data_.size() % 8) ? 1 : 0)); + // Pack data in all bits up to a multiple of 8, then the remainder + // Two pieces so the compiler can unroll the inner loop in the first + // part + for (ssize_t i = 0; i < data_.size() / 8; i++) { + packed[i] = 0; + for (int j = 0; j < 8; j++) + packed[i] |= !!data_[i*8 + j] << j; + } + if (data_.size() % 8) { + const ssize_t i = packed.size() - 1; + packed[i] = 0; + for (int j = 0; j < data_.size() - i*8; j++) + packed[i] |= !!data_[i*8 + j] << j; + } + + ar & make_nvp("data", packed); + ar & make_nvp("nbits", data_.size()); } + -G3_SERIALIZABLE_CODE(G3SkyMapMask); +G3_SPLIT_SERIALIZABLE_CODE(G3SkyMapMask); static bool skymapmask_getitem(const G3SkyMapMask &m, boost::python::object index) diff --git a/maps/tests/mask_serialization_test.py b/maps/tests/mask_serialization_test.py new file mode 100755 index 00000000..5064a09d --- /dev/null +++ b/maps/tests/mask_serialization_test.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python + +from spt3g import core, maps + +import numpy + +m = maps.HealpixSkyMap(nside=256) + +data = numpy.random.uniform(0, 1, size=m.size) > 0.5 +mask = maps.G3SkyMapMask(m, data) + +fr = core.G3Frame(core.G3FrameType.Map) +fr['mask'] = mask + +w = core.G3Writer(filename='masktest.g3') +w(fr) +del w + +f = core.G3File('masktest.g3').next() +recovereddata = numpy.asarray(f['mask']) + +assert (recovereddata == data).all() +