diff --git a/src/rail/core/__init__.py b/src/rail/core/__init__.py index eaf530dd..569f1e77 100644 --- a/src/rail/core/__init__.py +++ b/src/rail/core/__init__.py @@ -20,3 +20,4 @@ def find_version(): #from .utilPhotometry import PhotormetryManipulator, HyperbolicSmoothing, HyperbolicMagnitudes from .utilStages import ColumnMapper, RowSelector, TableConverter from .introspection import RailEnv +from .point_estimation import PointEstimationMixin diff --git a/src/rail/core/point_estimation.py b/src/rail/core/point_estimation.py new file mode 100644 index 00000000..b9a98d1d --- /dev/null +++ b/src/rail/core/point_estimation.py @@ -0,0 +1,124 @@ +import numpy as np +from numpy.typing import NDArray + +class PointEstimationMixin(): + + def calculate_point_estimates(self, qp_dist, grid=None): + """This function drives the calculation of point estimates for qp.Ensembles. + It is defined here, and called from the `_process_chunk` method in the + `CatEstimator` child classes. + + Parameters + ---------- + qp_dist : qp.Ensemble + The qp Ensemble instance that contains posterior estimates. + grid : array-like, optional + The grid on which to evaluate the point estimate. Note that not all + point estimates require a grid to be provided, by default None. + + Returns + ------- + qp.Ensemble + The original `qp.Ensemble` with new ancillary point estimate data included. + The `Ensemble.ancil` keys are ['mean', 'mode', 'median']. + + Notes + ----- + If there are particularly efficient ways to calculate point estimates for + a given `CatEstimator` subclass, the subclass can implement any of the + `_calculate__point_estimate` for any of the point estimate calculator + methods, i.e.: + + - `_calculate_mode_point_estimate` + - `_calculate_mean_point_estimate` + - `_calculate_median_point_estimate` + """ + + ancil_dict = dict() + calculated_point_estimates = [] + if 'calculated_point_estimates' in self.config: + calculated_point_estimates = self.config['calculated_point_estimates'] + + if 'mode' in calculated_point_estimates: + mode_value = self._calculate_mode_point_estimate(qp_dist, grid) + ancil_dict.update(mode = mode_value) + + if 'mean' in calculated_point_estimates: + mean_value = self._calculate_mean_point_estimate(qp_dist) + ancil_dict.update(mean = mean_value) + + if 'median' in calculated_point_estimates: + median_value = self._calculate_median_point_estimate(qp_dist) + ancil_dict.update(median = median_value) + + if calculated_point_estimates: + qp_dist.set_ancil(ancil_dict) + + return qp_dist + + def _calculate_mode_point_estimate(self, qp_dist, grid=None) -> NDArray: + """Calculates and returns the mode values for a set of posterior estimates + in a qp.Ensemble instance. + + Parameters + ---------- + qp_dist : qp.Ensemble + The qp Ensemble instance that contains posterior estimates. + grid : array-like, optional + The grid on which to evaluate the `mode` point estimate, if a grid is + not provided, a default will be created at run time using `zmin`, `zmax`, + and `nzbins`, by default None + + Returns + ------- + NDArray + The mode value for each posterior in the qp.Ensemble + + Raises + ------ + KeyError + If `grid` is not provided, one will be created using stage_config + `zmin`, `zmax`, and `nzbins` keys. If any of those keys are missing, + we'll raise a KeyError. + """ + if grid is None: + for key in ['zmin', 'zmax', 'nzbins']: + if key not in self.config: + raise KeyError(f"Expected `{key}` to be defined in stage " \ + "configuration dictionary in order to caluclate mode.") + + grid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins) + + return qp_dist.mode(grid=grid) + + def _calculate_mean_point_estimate(self, qp_dist) -> NDArray: + """Calculates and returns the mean values for a set of posterior estimates + in a qp.Ensemble instance. + + Parameters + ---------- + qp_dist : qp.Ensemble + The qp Ensemble instance that contains posterior estimates. + + Returns + ------- + NDArray + The mean value for each posterior in the qp.Ensemble + """ + return qp_dist.mean() + + def _calculate_median_point_estimate(self, qp_dist) -> NDArray: + """Calculates and returns the median values for a set of posterior estimates + in a qp.Ensemble instance. + + Parameters + ---------- + qp_dist : qp.Ensemble + The qp Ensemble instance that contains posterior estimates. + + Returns + ------- + NDArray + The median value for each posterior in the qp.Ensemble + """ + return qp_dist.median() diff --git a/src/rail/estimation/estimator.py b/src/rail/estimation/estimator.py index 1f555550..a00c4523 100644 --- a/src/rail/estimation/estimator.py +++ b/src/rail/estimation/estimator.py @@ -2,18 +2,17 @@ Abstract base classes defining Estimators of individual galaxy redshift uncertainties """ import gc -import numpy as np -from numpy.typing import NDArray from rail.core.common_params import SHARED_PARAMS from rail.core.data import TableHandle, QPHandle, ModelHandle from rail.core.stage import RailStage from rail.estimation.informer import CatInformer +from rail.core.point_estimation import PointEstimationMixin # for backwards compatibility -class CatEstimator(RailStage): +class CatEstimator(RailStage, PointEstimationMixin): """The base class for making photo-z posterior estimates from catalog-like inputs (i.e., tables with fluxes in photometric bands among the set of columns) @@ -125,128 +124,9 @@ def _process_chunk(self, start, end, data, first): raise NotImplementedError(f"{self.name}._process_chunk is not implemented") # pragma: no cover def _do_chunk_output(self, qp_dstn, start, end, first): + qp_dstn = self.calculate_point_estimates(qp_dstn) if first: self._output_handle = self.add_handle('output', data=qp_dstn) self._output_handle.initialize_write(self._input_length, communicator=self.comm) self._output_handle.set_data(qp_dstn, partial=True) self._output_handle.write_chunk(start, end) - - def _calculate_point_estimates(self, qp_dist, grid=None): - """This function drives the calculation of point estimates for qp.Ensembles. - It is defined here, and called from the `_process_chunk` method in the - `CatEstimator` child classes. - - Parameters - ---------- - qp_dist : qp.Ensemble - The qp Ensemble instance that contains posterior estimates. - grid : array-like, optional - The grid on which to evaluate the point estimate. Note that not all - point estimates require a grid to be provided, by default None. - - Returns - ------- - qp.Ensemble - The original `qp.Ensemble` with new ancillary point estimate data included. - The `Ensemble.ancil` keys are ['mean', 'mode', 'median']. - - Notes - ----- - If there are particularly efficient ways to calculate point estimates for - a given `CatEstimator` subclass, the subclass can implement any of the - `_calculate__point_estimate` for any of the point estimate calculator - methods, i.e.: - - - `_calculate_mode_point_estimate` - - `_calculate_mean_point_estimate` - - `_calculate_median_point_estimate` - """ - - ancil_dict = dict() - calculated_point_estimates = [] - if 'calculated_point_estimates' in self.config: - calculated_point_estimates = self.config.calculated_point_estimates - - if 'mode' in calculated_point_estimates: - mode_value = self._calculate_mode_point_estimate(qp_dist, grid) - ancil_dict.update(mode = mode_value) - - if 'mean' in calculated_point_estimates: - mean_value = self._calculate_mean_point_estimate(qp_dist) - ancil_dict.update(mean = mean_value) - - if 'median' in calculated_point_estimates: - median_value = self._calculate_median_point_estimate(qp_dist) - ancil_dict.update(median = median_value) - - if calculated_point_estimates: - qp_dist.set_ancil(ancil_dict) - - return qp_dist - - def _calculate_mode_point_estimate(self, qp_dist, grid=None) -> NDArray: - """Calculates and returns the mode values for a set of posterior estimates - in a qp.Ensemble instance. - - Parameters - ---------- - qp_dist : qp.Ensemble - The qp Ensemble instance that contains posterior estimates. - grid : array-like, optional - The grid on which to evaluate the `mode` point estimate, if a grid is - not provided, a default will be created at run time using `zmin`, `zmax`, - and `nzbins`, by default None - - Returns - ------- - NDArray - The mode value for each posterior in the qp.Ensemble - - Raises - ------ - KeyError - If `grid` is not provided, one will be created using the config parameters - `zmin`, `zmax`, and `nzbins`. If any of those parameters are missing, - we'll raise a KeyError. - """ - if grid is None: - for key in ['zmin', 'zmax', 'nzbins']: - if key not in self.config: - raise KeyError(f"Expected `{key}` to be defined in stage " \ - "configuration dictionary in order to caluclate mode.") - - grid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins) - - return qp_dist.mode(grid=grid) - - def _calculate_mean_point_estimate(self, qp_dist) -> NDArray: - """Calculates and returns the mean values for a set of posterior estimates - in a qp.Ensemble instance. - - Parameters - ---------- - qp_dist : qp.Ensemble - The qp Ensemble instance that contains posterior estimates. - - Returns - ------- - NDArray - The mean value for each posterior in the qp.Ensemble - """ - return qp_dist.mean() - - def _calculate_median_point_estimate(self, qp_dist) -> NDArray: - """Calculates and returns the median values for a set of posterior estimates - in a qp.Ensemble instance. - - Parameters - ---------- - qp_dist : qp.Ensemble - The qp Ensemble instance that contains posterior estimates. - - Returns - ------- - NDArray - The median value for each posterior in the qp.Ensemble - """ - return qp_dist.median() diff --git a/tests/estimation/test_estimator.py b/tests/core/test_point_estimation.py similarity index 89% rename from tests/estimation/test_estimator.py rename to tests/core/test_point_estimation.py index c5d63e65..b09e63de 100644 --- a/tests/estimation/test_estimator.py +++ b/tests/core/test_point_estimation.py @@ -28,7 +28,7 @@ def _calculate_mode_point_estimate(self, qp_dist=None, grid=None): scales = 1 + 0.2*(np.random.uniform(size=(100,1))-0.5) test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales)) - result = test_estimator._calculate_point_estimates(test_ensemble, None) + result = test_estimator.calculate_point_estimates(test_ensemble) assert np.all(result.ancil['mode'] == MEANING_OF_LIFE) @@ -47,7 +47,7 @@ def test_basic_point_estimate(): locs = 2* (np.random.uniform(size=(100,1))-0.5) scales = 1 + 0.2*(np.random.uniform(size=(100,1))-0.5) test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales)) - result = test_estimator._calculate_point_estimates(test_ensemble, None) + result = test_estimator.calculate_point_estimates(test_ensemble, None) # note: we're not interested in testing the values of point estimates, # just that they were added to the ancillary data. @@ -63,7 +63,7 @@ def test_mode_no_grid(): test_estimator = CatEstimator.make_stage(name='test', **config_dict) with pytest.raises(KeyError) as excinfo: - _ = test_estimator._calculate_point_estimates(None, None) + _ = test_estimator.calculate_point_estimates(None, None) assert "to be defined in stage configuration" in str(excinfo.value) @@ -78,6 +78,6 @@ def test_mode_no_point_estimates(): scales = 1 + 0.2*(np.random.uniform(size=(100,1))-0.5) test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales)) - output_ensemble = test_estimator._calculate_point_estimates(test_ensemble, None) + output_ensemble = test_estimator.calculate_point_estimates(test_ensemble, None) assert output_ensemble.ancil is None