Skip to content

Commit

Permalink
[perf] reproj partial maps
Browse files Browse the repository at this point in the history
- option to skip map pixels in ``ReprojectMaps``
  • Loading branch information
LiYunyang committed Dec 2, 2024
1 parent 32d3871 commit c155fa8
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ doc/module-autodocs.rst
doc/moddoc_*
*.g3
.DS_Store
.idea/
48 changes: 44 additions & 4 deletions maps/python/map_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,15 +871,26 @@ class ReprojectMaps(object):
weighted : bool
If True (default), ensure that maps have had weights applied before
reprojection. Otherwise, reproject maps without checking the weights.
partial : bool=False
If True, the reproj will be performed on a partial map (of the output map),
defined by the mask. If the mask is not provided, it will be detrmined from
the non-zero pixels of the first reprojected map.
mask : G3SkyMapMask, G3SkyMap, or np.ndarray, Optional.
Mask to be used for partial reproject. This should be of the same size as the output map.
For numpy array, all zeros/inf/nan/hp.UNSEEN pixels are skipped.
"""

def __init__(self, map_stub=None, rebin=1, interp=False, weighted=True):
def __init__(self, map_stub=None, rebin=1, interp=False, weighted=True, partial=False, mask=None):
assert map_stub is not None, "map_stub argument required"
self.stub = map_stub.clone(False)
self.stub.pol_type = None
self.rebin = rebin
self.interp = interp
self.weighted = weighted
self._mask = None
self.partial = partial

self.mask = mask

def __call__(self, frame):
if isinstance(frame, core.G3Frame) and frame.type != core.G3FrameType.Map:
Expand All @@ -905,15 +916,44 @@ def __call__(self, frame):

if key in "TQUH":
mnew = self.stub.clone(False)
maps.reproj_map(m, mnew, rebin=self.rebin, interp=self.interp)
maps.reproj_map(m, mnew, rebin=self.rebin, interp=self.interp, mask=self.mask)

elif key in ["Wpol", "Wunpol"]:
mnew = maps.G3SkyMapWeights(self.stub)
for wkey in mnew.keys():
maps.reproj_map(
m[wkey], mnew[wkey], rebin=self.rebin, interp=self.interp
m[wkey], mnew[wkey], rebin=self.rebin, interp=self.interp, mask=self.mask
)

frame[key] = mnew

self.mask = mnew
return frame

@property
def mask(self):
return self._mask

@mask.setter
def mask(self, mask):
if mask is None:
return
if self._mask is None and self.partial:
if isinstance(mask, maps.G3SkyMapMask):
self._mask = mask
elif isinstance(mask, maps.G3SkyMap):
self._mask = maps.G3SkyMapMask(mask, use_data=True, zero_nans=True, zero_infs=True)
elif isinstance(mask, np.ndarray):
from healpy import UNSEEN
tmp = self.stub.clone(False)
mask_copy = np.ones(mask.shape, dtype=int)
bad = np.logical_or.reduce([
np.isnan(mask),
np.isinf(mask),
mask==0,
mask==UNSEEN
])
mask_copy[bad] = 0
tmp[:] = mask_copy
self._mask = maps.G3SkyMapMask(tmp, use_data=True)
else:
raise TypeError("Mask must be a G3SkyMapMask, G3SkyMap, or numpy array")
16 changes: 13 additions & 3 deletions maps/src/maputils.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ void FlattenPol(FlatSkyMapPtr Q, FlatSkyMapPtr U, G3SkyMapWeightsPtr W, double h
}


void ReprojMap(G3SkyMapConstPtr in_map, G3SkyMapPtr out_map, int rebin, bool interp)
void ReprojMap(G3SkyMapConstPtr in_map, G3SkyMapPtr out_map, int rebin, bool interp,
G3SkyMapMaskConstPtr out_map_mask)
{
bool rotate = false; // no transform
Quat q_rot; // quaternion for rotating from output to input coordinate system
Expand Down Expand Up @@ -310,8 +311,13 @@ void ReprojMap(G3SkyMapConstPtr in_map, G3SkyMapPtr out_map, int rebin, bool int
out_map->pol_conv = in_map->pol_conv;
}

size_t stop = out_map->size();
if (rebin > 1) {
for (size_t i = 0; i < out_map->size(); i++) {
for (size_t i = 0; i < stop; i++) {
if (!!out_map_mask && !out_map_mask->at(i)) {
(*out_map)[i] = 0;
continue;
}
double val = 0;
auto quats = out_map->GetRebinQuats(i, rebin);
if (rotate)
Expand All @@ -328,7 +334,11 @@ void ReprojMap(G3SkyMapConstPtr in_map, G3SkyMapPtr out_map, int rebin, bool int
}
}
} else {
for (size_t i = 0; i < out_map->size(); i++) {
for (size_t i = 0; i < stop; i++) {
if (!!out_map_mask && !out_map_mask->at(i)) {
(*out_map)[i] = 0;
continue;
}
double val = 0;
auto q = out_map->PixelToQuat(i);
if (rotate)
Expand Down

0 comments on commit c155fa8

Please sign in to comment.