Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract calculate_point_estimate up one more level #48

Merged
merged 1 commit into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/rail/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ def find_version():
#from .utilPhotometry import PhotormetryManipulator, HyperbolicSmoothing, HyperbolicMagnitudes
from .utilStages import ColumnMapper, RowSelector, TableConverter
from .introspection import RailEnv
from .point_estimation import PointEstimationMixin
124 changes: 124 additions & 0 deletions src/rail/core/point_estimation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import numpy as np
from numpy.typing import NDArray

class PointEstimationMixin():

def calculate_point_estimates(self, qp_dist, grid=None):
"""This function drives the calculation of point estimates for qp.Ensembles.
It is defined here, and called from the `_process_chunk` method in the
`CatEstimator` child classes.

Parameters
----------
qp_dist : qp.Ensemble
The qp Ensemble instance that contains posterior estimates.
grid : array-like, optional
The grid on which to evaluate the point estimate. Note that not all
point estimates require a grid to be provided, by default None.

Returns
-------
qp.Ensemble
The original `qp.Ensemble` with new ancillary point estimate data included.
The `Ensemble.ancil` keys are ['mean', 'mode', 'median'].

Notes
-----
If there are particularly efficient ways to calculate point estimates for
a given `CatEstimator` subclass, the subclass can implement any of the
`_calculate_<foo>_point_estimate` for any of the point estimate calculator
methods, i.e.:

- `_calculate_mode_point_estimate`
- `_calculate_mean_point_estimate`
- `_calculate_median_point_estimate`
"""

ancil_dict = dict()
calculated_point_estimates = []
if 'calculated_point_estimates' in self.config:
calculated_point_estimates = self.config['calculated_point_estimates']

if 'mode' in calculated_point_estimates:
mode_value = self._calculate_mode_point_estimate(qp_dist, grid)
ancil_dict.update(mode = mode_value)

if 'mean' in calculated_point_estimates:
mean_value = self._calculate_mean_point_estimate(qp_dist)
ancil_dict.update(mean = mean_value)

if 'median' in calculated_point_estimates:
median_value = self._calculate_median_point_estimate(qp_dist)
ancil_dict.update(median = median_value)

if calculated_point_estimates:
qp_dist.set_ancil(ancil_dict)

return qp_dist

def _calculate_mode_point_estimate(self, qp_dist, grid=None) -> NDArray:
"""Calculates and returns the mode values for a set of posterior estimates
in a qp.Ensemble instance.

Parameters
----------
qp_dist : qp.Ensemble
The qp Ensemble instance that contains posterior estimates.
grid : array-like, optional
The grid on which to evaluate the `mode` point estimate, if a grid is
not provided, a default will be created at run time using `zmin`, `zmax`,
and `nzbins`, by default None

Returns
-------
NDArray
The mode value for each posterior in the qp.Ensemble

Raises
------
KeyError
If `grid` is not provided, one will be created using stage_config
`zmin`, `zmax`, and `nzbins` keys. If any of those keys are missing,
we'll raise a KeyError.
"""
if grid is None:
for key in ['zmin', 'zmax', 'nzbins']:
if key not in self.config:
raise KeyError(f"Expected `{key}` to be defined in stage " \
"configuration dictionary in order to caluclate mode.")

grid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins)

return qp_dist.mode(grid=grid)

def _calculate_mean_point_estimate(self, qp_dist) -> NDArray:
"""Calculates and returns the mean values for a set of posterior estimates
in a qp.Ensemble instance.

Parameters
----------
qp_dist : qp.Ensemble
The qp Ensemble instance that contains posterior estimates.

Returns
-------
NDArray
The mean value for each posterior in the qp.Ensemble
"""
return qp_dist.mean()

def _calculate_median_point_estimate(self, qp_dist) -> NDArray:
"""Calculates and returns the median values for a set of posterior estimates
in a qp.Ensemble instance.

Parameters
----------
qp_dist : qp.Ensemble
The qp Ensemble instance that contains posterior estimates.

Returns
-------
NDArray
The median value for each posterior in the qp.Ensemble
"""
return qp_dist.median()
126 changes: 3 additions & 123 deletions src/rail/estimation/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
Abstract base classes defining Estimators of individual galaxy redshift uncertainties
"""
import gc
import numpy as np
from numpy.typing import NDArray

from rail.core.common_params import SHARED_PARAMS
from rail.core.data import TableHandle, QPHandle, ModelHandle
from rail.core.stage import RailStage

from rail.estimation.informer import CatInformer
from rail.core.point_estimation import PointEstimationMixin
# for backwards compatibility


class CatEstimator(RailStage):
class CatEstimator(RailStage, PointEstimationMixin):
"""The base class for making photo-z posterior estimates from catalog-like inputs
(i.e., tables with fluxes in photometric bands among the set of columns)

Expand Down Expand Up @@ -125,128 +124,9 @@ def _process_chunk(self, start, end, data, first):
raise NotImplementedError(f"{self.name}._process_chunk is not implemented") # pragma: no cover

def _do_chunk_output(self, qp_dstn, start, end, first):
qp_dstn = self.calculate_point_estimates(qp_dstn)
if first:
self._output_handle = self.add_handle('output', data=qp_dstn)
self._output_handle.initialize_write(self._input_length, communicator=self.comm)
self._output_handle.set_data(qp_dstn, partial=True)
self._output_handle.write_chunk(start, end)

def _calculate_point_estimates(self, qp_dist, grid=None):
"""This function drives the calculation of point estimates for qp.Ensembles.
It is defined here, and called from the `_process_chunk` method in the
`CatEstimator` child classes.

Parameters
----------
qp_dist : qp.Ensemble
The qp Ensemble instance that contains posterior estimates.
grid : array-like, optional
The grid on which to evaluate the point estimate. Note that not all
point estimates require a grid to be provided, by default None.

Returns
-------
qp.Ensemble
The original `qp.Ensemble` with new ancillary point estimate data included.
The `Ensemble.ancil` keys are ['mean', 'mode', 'median'].

Notes
-----
If there are particularly efficient ways to calculate point estimates for
a given `CatEstimator` subclass, the subclass can implement any of the
`_calculate_<foo>_point_estimate` for any of the point estimate calculator
methods, i.e.:

- `_calculate_mode_point_estimate`
- `_calculate_mean_point_estimate`
- `_calculate_median_point_estimate`
"""

ancil_dict = dict()
calculated_point_estimates = []
if 'calculated_point_estimates' in self.config:
calculated_point_estimates = self.config.calculated_point_estimates

if 'mode' in calculated_point_estimates:
mode_value = self._calculate_mode_point_estimate(qp_dist, grid)
ancil_dict.update(mode = mode_value)

if 'mean' in calculated_point_estimates:
mean_value = self._calculate_mean_point_estimate(qp_dist)
ancil_dict.update(mean = mean_value)

if 'median' in calculated_point_estimates:
median_value = self._calculate_median_point_estimate(qp_dist)
ancil_dict.update(median = median_value)

if calculated_point_estimates:
qp_dist.set_ancil(ancil_dict)

return qp_dist

def _calculate_mode_point_estimate(self, qp_dist, grid=None) -> NDArray:
"""Calculates and returns the mode values for a set of posterior estimates
in a qp.Ensemble instance.

Parameters
----------
qp_dist : qp.Ensemble
The qp Ensemble instance that contains posterior estimates.
grid : array-like, optional
The grid on which to evaluate the `mode` point estimate, if a grid is
not provided, a default will be created at run time using `zmin`, `zmax`,
and `nzbins`, by default None

Returns
-------
NDArray
The mode value for each posterior in the qp.Ensemble

Raises
------
KeyError
If `grid` is not provided, one will be created using the config parameters
`zmin`, `zmax`, and `nzbins`. If any of those parameters are missing,
we'll raise a KeyError.
"""
if grid is None:
for key in ['zmin', 'zmax', 'nzbins']:
if key not in self.config:
raise KeyError(f"Expected `{key}` to be defined in stage " \
"configuration dictionary in order to caluclate mode.")

grid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins)

return qp_dist.mode(grid=grid)

def _calculate_mean_point_estimate(self, qp_dist) -> NDArray:
"""Calculates and returns the mean values for a set of posterior estimates
in a qp.Ensemble instance.

Parameters
----------
qp_dist : qp.Ensemble
The qp Ensemble instance that contains posterior estimates.

Returns
-------
NDArray
The mean value for each posterior in the qp.Ensemble
"""
return qp_dist.mean()

def _calculate_median_point_estimate(self, qp_dist) -> NDArray:
"""Calculates and returns the median values for a set of posterior estimates
in a qp.Ensemble instance.

Parameters
----------
qp_dist : qp.Ensemble
The qp Ensemble instance that contains posterior estimates.

Returns
-------
NDArray
The median value for each posterior in the qp.Ensemble
"""
return qp_dist.median()
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _calculate_mode_point_estimate(self, qp_dist=None, grid=None):
scales = 1 + 0.2*(np.random.uniform(size=(100,1))-0.5)
test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales))

result = test_estimator._calculate_point_estimates(test_ensemble, None)
result = test_estimator.calculate_point_estimates(test_ensemble)

assert np.all(result.ancil['mode'] == MEANING_OF_LIFE)

Expand All @@ -47,7 +47,7 @@ def test_basic_point_estimate():
locs = 2* (np.random.uniform(size=(100,1))-0.5)
scales = 1 + 0.2*(np.random.uniform(size=(100,1))-0.5)
test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales))
result = test_estimator._calculate_point_estimates(test_ensemble, None)
result = test_estimator.calculate_point_estimates(test_ensemble, None)

# note: we're not interested in testing the values of point estimates,
# just that they were added to the ancillary data.
Expand All @@ -63,7 +63,7 @@ def test_mode_no_grid():
test_estimator = CatEstimator.make_stage(name='test', **config_dict)

with pytest.raises(KeyError) as excinfo:
_ = test_estimator._calculate_point_estimates(None, None)
_ = test_estimator.calculate_point_estimates(None, None)

assert "to be defined in stage configuration" in str(excinfo.value)

Expand All @@ -78,6 +78,6 @@ def test_mode_no_point_estimates():
scales = 1 + 0.2*(np.random.uniform(size=(100,1))-0.5)
test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales))

output_ensemble = test_estimator._calculate_point_estimates(test_ensemble, None)
output_ensemble = test_estimator.calculate_point_estimates(test_ensemble, None)

assert output_ensemble.ancil is None
Loading