diff --git a/brahmap/interfaces/linearoperators.py b/brahmap/interfaces/linearoperators.py index b87a4bc..62feb10 100644 --- a/brahmap/interfaces/linearoperators.py +++ b/brahmap/interfaces/linearoperators.py @@ -346,6 +346,7 @@ def __init__(self, diag: np.ndarray, dtype=None): if dtype is not None: self.diag = np.asarray(diag, dtype=dtype) elif isinstance(diag, np.ndarray): + self.diag = diag dtype = self.diag.dtype else: dtype = np.float64 diff --git a/brahmap/mapmakers/GLS.py b/brahmap/mapmakers/GLS.py index ee95ba6..5ce3fdc 100644 --- a/brahmap/mapmakers/GLS.py +++ b/brahmap/mapmakers/GLS.py @@ -45,7 +45,7 @@ def compute_GLS_maps( inv_noise_cov_operator: ( ToeplitzLO | BlockLO | DiagonalOperator | InvNoiseCovLO_Uncorrelated ) = None, - threshold_cond: float = 1.0e3, + threshold: float = 1.0e-5, dtype_float=None, update_pointings_inplace: bool = True, GLSParameters: GLSParameters = GLSParameters(), @@ -78,7 +78,7 @@ def compute_GLS_maps( solver_type=GLSParameters.solver_type, pol_angles=pol_angles, noise_weights=inv_noise_cov_operator.diag, - threshold_cond=threshold_cond, + threshold=threshold, dtype_float=dtype_float, update_pointings_inplace=update_pointings_inplace, ) diff --git a/brahmap/mapmakers/__init__.py b/brahmap/mapmakers/__init__.py index 8c6f2f5..025bd49 100644 --- a/brahmap/mapmakers/__init__.py +++ b/brahmap/mapmakers/__init__.py @@ -7,10 +7,18 @@ from .lbsim_mapmakers import ( LBSimGLSParameters, LBSimGLSResult, + LBSim_InvNoiseCovLO_UnCorr, + LBSim_compute_GLS_maps_from_obs, LBSim_compute_GLS_maps, ) - __all__ = ["LBSimGLSParameters", "LBSimGLSResult", "LBSim_compute_GLS_maps"] + __all__ = [ + "LBSimGLSParameters", + "LBSimGLSResult", + "LBSim_InvNoiseCovLO_UnCorr", + "LBSim_compute_GLS_maps_from_obs", + "LBSim_compute_GLS_maps", + ] else: __all__ = [] diff --git a/brahmap/mapmakers/lbsim_mapmakers.py b/brahmap/mapmakers/lbsim_mapmakers.py index ed18360..f416bab 100644 --- a/brahmap/mapmakers/lbsim_mapmakers.py +++ b/brahmap/mapmakers/lbsim_mapmakers.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, asdict import healpy as hp import litebird_sim as lbs +from typing import List from brahmap.mapmakers import GLSParameters, GLSResult, compute_GLS_maps from brahmap.linop import DiagonalOperator @@ -20,6 +21,136 @@ class LBSimGLSResult(GLSResult): coordinate_system: lbs.CoordinateSystem = lbs.CoordinateSystem.Galactic +def sample_counts(obs_list): + """Returns an array of size `n_obs x n_dets` containing the size of TOD for each detector in each observation. + + Args: + obs_list (_type_): List of observation + + Returns: + np.ndarray: sample count array + """ + count_arr = np.zeros((len(obs_list), obs_list[0].n_detectors), dtype=int) + for obs_idx, obs in enumerate(obs_list): + for det_idx in obs.det_idx: + count_arr[obs_idx, det_idx] = obs.tod[det_idx].shape[0] + return count_arr + + +def start_idx(arr, obs_idx, det_idx): + """Returns the starting index of TOD of detector `det_idx` in observation `obs_idx` in flattened array of TODs concatenated over observations. That is, the starting index of TOD of given det and obs index in a flat array given by `np.concatenate([getattr(obs, component) for obs in obs_list], axis=None)`. `arr` must be the output of `sample_count()` function.""" + idx = 0 + for i in range(obs_idx): + idx += sum(arr[i][:]) + idx += sum(arr[obs_idx][:det_idx]) + return idx + + +def end_idx(arr, obs_idx, det_idx): + """Similar to `start_idx()` function, but this one returns the ending index""" + idx = start_idx(arr, obs_idx, det_idx) + idx += arr[obs_idx, det_idx] + return idx + + +def det_sample_count_idx(arr, obs_idx, det_idx): + """ + This function returns the starting index of the TOD of a given detector for a given `obs_idx` + """ + return sum(arr[:, det_idx][:obs_idx]) + + +class LBSim_InvNoiseCovLO_UnCorr(InvNoiseCovLO_Uncorrelated): + """Here the `noise_variance` must be a dictionary of noise variance associated with detector names. This operator will arrange the blocks of noise variance in the same way as tods in the obs_list are distributed.""" + + def __init__( + self, + obs: lbs.Observation | List[lbs.Observation], + noise_variance: dict | None = None, + dtype=None, + ): + if isinstance(obs, lbs.Observation): + obs_list = [obs] + else: + obs_list = obs + + tod_counts = sample_counts(obs_list) + diag_len = tod_counts.sum() + tod_len = tod_counts.sum(axis=0) + det_list = list(obs_list[0].name) + + if noise_variance is None: + diagonal = np.ones(diag_len, dtype=(np.float64 if dtype is None else dtype)) + else: + noise_dict_keys = list(noise_variance.keys()) + if dtype is None: + dtype = noise_variance[noise_dict_keys[0]].dtype + for detector in det_list: + if detector not in noise_dict_keys: + idx = det_list.index(detector) + noise_variance[detector] = np.ones(tod_len[idx]) + if len(noise_variance[detector]) != tod_len[det_list.index(detector)]: + raise ValueError( + f"Incorrect length of noise variance for detector {detector}" + ) + + diagonal = np.empty(diag_len, dtype=dtype) + for obs_idx, obs in enumerate(obs_list): + for det_idx in obs.det_idx: + stdiagidx = start_idx(tod_counts, obs_idx, det_idx) + endiagidx = end_idx(tod_counts, obs_idx, det_idx) + sttodidx = det_sample_count_idx(tod_counts, obs_idx, det_idx) + entodidx = ( + det_sample_count_idx(tod_counts, obs_idx, det_idx) + + tod_counts[obs_idx, det_idx] + ) + + diagonal[stdiagidx:endiagidx] = noise_variance[det_list[det_idx]][ + sttodidx:entodidx + ] + + super(LBSim_InvNoiseCovLO_UnCorr, self).__init__(diag=diagonal, dtype=dtype) + + +def LBSim_compute_GLS_maps_from_obs( + nside: int, + obs: lbs.Observation | List[lbs.Observation], + pointings_flag: np.ndarray | List[np.ndarray] = None, + inv_noise_cov_diagonal: ( + LBSim_InvNoiseCovLO_UnCorr | InvNoiseCovLO_Uncorrelated | None + ) = None, + threshold: float = 1.0e-5, + dtype_float=None, + LBSimGLSParameters: LBSimGLSParameters = LBSimGLSParameters(), + component: str = "tod", +) -> LBSimGLSResult | tuple[ProcessTimeSamples, LBSimGLSResult]: + if isinstance(obs, lbs.Observation): + obs_list = [obs] + else: + obs_list = obs + + pointings = np.concatenate([ob.pointings for ob in obs_list]).reshape((-1, 2)) + pol_angles = np.concatenate( + [ob.psi for ob in obs_list], axis=None + ) # `axis=None` returns flattened arrays + tod = np.concatenate([getattr(obs, component) for obs in obs_list], axis=None) + + lbsim_gls_result = LBSim_compute_GLS_maps( + nside=nside, + pointings=pointings, + tod=tod, + pointings_flag=pointings_flag, + pol_angles=pol_angles, + inv_noise_cov_operator=inv_noise_cov_diagonal, + threshold=threshold, + dtype_float=dtype_float, + update_pointings_inplace=True, + LBSimGLSParameters=LBSimGLSParameters, + ) + + return lbsim_gls_result + + def LBSim_compute_GLS_maps( nside: int, pointings: np.ndarray, @@ -29,7 +160,7 @@ def LBSim_compute_GLS_maps( inv_noise_cov_operator: ( ToeplitzLO | BlockLO | DiagonalOperator | InvNoiseCovLO_Uncorrelated ) = None, - threshold_cond: float = 1.0e3, + threshold: float = 1.0e-5, dtype_float=None, update_pointings_inplace: bool = True, LBSimGLSParameters: LBSimGLSParameters = LBSimGLSParameters(), @@ -50,7 +181,7 @@ def LBSim_compute_GLS_maps( pointings_flag=pointings_flag, pol_angles=pol_angles, inv_noise_cov_operator=inv_noise_cov_operator, - threshold_cond=threshold_cond, + threshold=threshold, dtype_float=dtype_float, update_pointings_inplace=update_pointings_inplace, GLSParameters=LBSimGLSParameters,