Skip to content

Commit

Permalink
Merge pull request #24 from LSSTDESC/user/aimalz/renaming
Browse files Browse the repository at this point in the history
naming consistency/clarity within src/rail/estimation
  • Loading branch information
aimalz authored Jul 14, 2023
2 parents 39959e4 + 267be43 commit 91adfec
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import qp


class NaiveStack(PZSummarizer):
class NaiveStackSummarizer(PZSummarizer):
"""Summarizer which simply histograms a point estimate
"""

name = 'NaiveStack'
name = 'NaiveStackSummarizer'
config_options = PZSummarizer.config_options.copy()
config_options.update(zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"),
zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import qp


class PointEstimateHist(PZSummarizer):
class PointEstHistSummarizer(PZSummarizer):
"""Summarizer which simply histograms a point estimate
"""

name = 'PointEstimateHist'
name = 'PointEstHistSummarizer'
config_options = PZSummarizer.config_options.copy()
config_options.update(zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"),
zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import qp


class RandomPZ(CatEstimator):
class RandomGaussEstimator(CatEstimator):
"""Random CatEstimator
"""

name = 'RandomPZ'
name = 'RandomGaussEstimator'
inputs = [('input', TableHandle)]
config_options = CatEstimator.config_options.copy()
config_options.update(rand_width=Param(float, 0.025, "ad hock width of PDF"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def __init__(self, zgrid, pdf, zmode):
self.zmode = zmode


class Inform_trainZ(CatInformer):
class TrainZInformer(CatInformer):
"""Train an Estimator which returns a global PDF for all galaxies
"""

name = 'Inform_trainZ'
name = 'TrainZInformer'
config_options = CatInformer.config_options.copy()
config_options.update(zmin=SHARED_PARAMS,
zmax=SHARED_PARAMS,
Expand Down Expand Up @@ -56,11 +56,11 @@ def run(self):
self.add_data('model', self.model)


class TrainZ(CatEstimator):
class TrainZEstimator(CatEstimator):
"""CatEstimator which returns a global PDF for all galaxies
"""

name = 'TrainZ'
name = 'TrainZEstimator'
config_options = CatEstimator.config_options.copy()
config_options.update(zmin=SHARED_PARAMS,
zmax=SHARED_PARAMS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TEENY = 1.e-15


class VarInferenceStack(PZSummarizer):
class VarInfStackSummarizer(PZSummarizer):
"""Variational inference summarizer based on notebook created by Markus Rau
The summzarizer is appropriate for the likelihoods returned by
template-based codes, for which the NaiveSummarizer are not appropriate.
Expand All @@ -32,7 +32,7 @@ class VarInferenceStack(PZSummarizer):
number of samples used in dirichlet to determind error bar
"""

name = 'VarInferenceStack'
name = 'VarInfStackSummarizer'
config_options = PZSummarizer.config_options.copy()
config_options.update(zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"),
zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"),
Expand Down
2 changes: 1 addition & 1 deletion src/rail/estimation/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def summarize(self, input_data):

class SZPZSummarizer(RailStage): #pragma: no cover
"""The base class for classes that use two sets of data: a photometry sample with
spec-z values, and a photometry sample with unknown redshifts, e.g. simpleSOM and
spec-z values, and a photometry sample with unknown redshifts, e.g. minisom_som and
outputs a QP Ensemble with bootstrap realization of the N(z) distribution
"""
name = 'SZPZtoNZSummarizer'
Expand Down
10 changes: 5 additions & 5 deletions src/rail/examples_data/goldenspike_data/goldenspike.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ stages:
- classname: QuantityCut
name: quantity_cut
nprocess: 1
- classname: Inform_trainZ
- classname: TrainZInformer
name: inform_trainZ
nprocess: 1
- classname: Estimator
name: estimate_bpz
nprocess: 1
- classname: TrainZ
- classname: TrainZEstimator
name: estimate_trainZ
nprocess: 1
- classname: RandomPZ
- classname: RandomGaussEstimator
name: estimate_randomZ
nprocess: 1
- classname: PointEstimateHist
- classname: PointEstHistSummarizer
name: point_estimate_test
nprocess: 1
- classname: NaiveStack
- classname: NaiveStackSummarizer
name: naive_stack_test
nprocess: 1
10 changes: 5 additions & 5 deletions src/rail/stages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from rail.estimation.estimator import *
from rail.estimation.summarizer import *
from rail.estimation.algos.naiveStack import *
from rail.estimation.algos.randomPZ import *
from rail.estimation.algos.pointEstimateHist import *
from rail.estimation.algos.trainZ import *
from rail.estimation.algos.varInference import *
from rail.estimation.algos.naive_stack import *
from rail.estimation.algos.random_gauss import *
from rail.estimation.algos.point_est_hist import *
from rail.estimation.algos.train_z import *
from rail.estimation.algos.var_inf import *

from rail.creation.degrader import *
#from rail.creation.degradation.spectroscopic_degraders import *
Expand Down
8 changes: 4 additions & 4 deletions tests/estimation/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from rail.core.algo_utils import one_algo
from rail.core.stage import RailStage
from rail.estimation.algos import randomPZ, trainZ
from rail.estimation.algos import random_gauss, train_z

sci_ver_str = scipy.__version__.split(".")

Expand All @@ -25,7 +25,7 @@ def test_random_pz():
}
# zb_expected = np.array([1.359, 0.013, 0.944, 1.831, 2.982, 1.565, 0.308, 0.157, 0.986, 1.679])
train_algo = None
pz_algo = randomPZ.RandomPZ
pz_algo = random_gauss.RandomGaussEstimator
results, rerun_results, rerun3_results = one_algo(
"RandomPZ", train_algo, pz_algo, train_config_dict, estim_config_dict
)
Expand All @@ -44,8 +44,8 @@ def test_train_pz():
zb_expected = np.repeat(0.1445183, 10)
pdf_expected = np.zeros(shape=(301,))
pdf_expected[10:16] = [7, 23, 8, 23, 26, 13]
train_algo = trainZ.Inform_trainZ
pz_algo = trainZ.TrainZ
train_algo = train_z.TrainZInformer
pz_algo = train_z.TrainZEstimator
results, rerun_results, rerun3_results = one_algo(
"TrainZ", train_algo, pz_algo, train_config_dict, estim_config_dict
)
Expand Down
8 changes: 4 additions & 4 deletions tests/estimation/test_summarizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rail.core.data import QPHandle
from rail.core.stage import RailStage
from rail.core.utils import RAILDIR
from rail.estimation.algos import naiveStack, pointEstimateHist, varInference
from rail.estimation.algos import naive_stack, point_est_hist, var_inf

testdata = os.path.join(RAILDIR, "rail/examples_data/testdata/output_BPZ_lite.fits")
DS = RailStage.data_store
Expand All @@ -29,17 +29,17 @@ def one_algo(key, summarizer_class, summary_kwargs):

def test_naive_stack():
summary_config_dict = {}
summarizer_class = naiveStack.NaiveStack
summarizer_class = naive_stack.NaiveStackSummarizer
results = one_algo("NaiveStack", summarizer_class, summary_config_dict)


def test_point_estimate_hist():
summary_config_dict = {}
summarizer_class = pointEstimateHist.PointEstimateHist
summarizer_class = point_est_hist.PointEstHistSummarizer
results = one_algo("PointEstimateHist", summarizer_class, summary_config_dict)


def test_var_inference_stack():
summary_config_dict = {}
summarizer_class = varInference.VarInferenceStack
summarizer_class = var_inf.VarInfStackSummarizer
results = one_algo("VariationalInference", summarizer_class, summary_config_dict)

0 comments on commit 91adfec

Please sign in to comment.