Skip to content

Commit

Permalink
Merge pull request #54 from LSSTDESC/issue/53/nocolumns
Browse files Browse the repository at this point in the history
remove dependencies on .columns file #53
  • Loading branch information
sschmidt23 authored Sep 20, 2024
2 parents d3306df + f497d1c commit deefe81
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 27 deletions.
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ build-backend = "setuptools.build_meta"
[tool.setuptools_scm]
write_to = "src/rail/bpz/_version.py"

[tool.setuptools.package-data]
"rail.examples.estimation.configs" = ["*.columns"]

[tool.pytest.ini_options]
testpaths = [
"tests",
Expand Down
18 changes: 10 additions & 8 deletions src/rail/estimation/algos/bpz_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@
from ceci.config import StageParameter as Param
from rail.estimation.estimator import CatEstimator, CatInformer
from rail.utils.path_utils import RAILDIR
from rail.bpz.utils import RAIL_BPZ_DIR
from rail.core.common_params import SHARED_PARAMS


default_filter_list = ["DC2LSST_u", "DC2LSST_g", "DC2LSST_r",
"DC2LSST_i", "DC2LSST_z", "DC2LSST_y"]


def nzfunc(z, z0, alpha, km, m, m0): # pragma: no cover
zm = z0 + (km * (m - m0))
return np.power(z, alpha) * np.exp(-1. * np.power((z / zm), alpha))
Expand Down Expand Up @@ -76,8 +79,6 @@ class BPZliteInformer(CatInformer):
"SED, FILTER, and AB directories. If left to "
"default `None` it will use the install "
"directory for rail + rail/examples_data/estimation_data/data"),
columns_file=Param(str, os.path.join(RAIL_BPZ_DIR, "rail/examples_data/estimation_data/configs/test_bpz.columns"),
msg="name of the file specifying the columns"),
spectra_file=Param(str, "CWWSB4.list",
msg="name of the file specifying the list of SEDs to use"),
m0=Param(float, 20.0, msg="reference apparent mag, used in prior param"),
Expand Down Expand Up @@ -266,8 +267,9 @@ class BPZliteEstimator(CatEstimator):
"SED, FILTER, and AB directories. If left to "
"default `None` it will use the install "
"directory for rail + ../examples_data/estimation_data/data"),
columns_file=Param(str, os.path.join(RAIL_BPZ_DIR, "rail/examples_data/estimation_data/configs/test_bpz.columns"),
msg="name of the file specifying the columns"),
filter_list=Param(list, default_filter_list,
msg="list of filter files names (with no '.sed' suffix). Filters must be"
"in FILTER dir. MUST BE IN SAME ORDER as 'bands'"),
spectra_file=Param(str, "CWWSB4.list",
msg="name of the file specifying the list of SEDs to use"),
madau_flag=Param(str, "no",
Expand Down Expand Up @@ -309,6 +311,8 @@ def __init__(self, args, **kwargs):
raise ValueError("Number of bands specified in bands must be equal to number of mag errors specified in err_bands!")
if self.config.ref_band not in self.config.bands: # pragma: no cover
raise ValueError(f"reference band not found in bands specified in bands: {str(self.config.bands)}")
if len(self.config.bands) != len(self.config.err_bands) or len(self.config.bands) != len(self.config.filter_list):
raise ValueError(f"length of bands, err_bands, and filter_list are not the same!")

def _initialize_run(self):
super()._initialize_run()
Expand Down Expand Up @@ -346,9 +350,7 @@ def _load_templates(self):
z = self.zgrid

data_path = self.data_path
columns_file = self.config.columns_file
ignore_rows = ["M_0", "OTHER", "ID", "Z_S"]
filters = [f for f in get_str(columns_file, 0) if f not in ignore_rows]
filters = self.config.filter_list

spectra_file = os.path.join(data_path, "SED", self.config.spectra_file)
spectra = [s[:-4] for s in get_str(spectra_file)]
Expand Down
14 changes: 0 additions & 14 deletions src/rail/examples_data/estimation_data/configs/test_bpz.columns

This file was deleted.

25 changes: 23 additions & 2 deletions tests/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,31 @@ def test_bpz_lite_wkernel_flatprior():
'zp_errors': np.array([0.01, 0.01, 0.01, 0.01, 0.01, 0.01]),
'mag_err_min': 0.005,
'hdf5_groupname': 'photometry'}
# zb_expected = np.array([0.18, 2.88, 0.12, 0.15, 2.97, 2.78, 0.11, 0.19,
# 2.98, 2.92])
train_algo = None
pz_algo = bpz_lite.BPZliteEstimator
results, rerun_results, rerun3_results = one_algo("BPZ_lite", train_algo, pz_algo, train_config_dict, estim_config_dict)
# assert np.isclose(results.ancil['zmode'], zb_expected).all()
assert np.isclose(results.ancil['zmode'], rerun_results.ancil['zmode']).all()


def test_wrong_number_of_filters():
train_config_dict = {}
estim_config_dict = {'zmin': 0.0, 'zmax': 3.0,
'dz': 0.01,
'nzbins': 301,
'data_path': None,
'columns_file': os.path.join(RAIL_BPZ_DIR, "rail/examples_data/estimation_data/configs/test_bpz.columns"),
'spectra_file': "CWWSB4.list",
'madau_flag': 'no',
'ref_band': 'mag_i_lsst',
'prior_file': 'flat',
'p_min': 0.005,
'gauss_kernel': 0.1,
'zp_errors': np.array([0.01, 0.01, 0.01, 0.01, 0.01, 0.01]),
'mag_err_min': 0.005,
'filter_list': ['DC2LSST_u', 'DC2LSST_g'],
'hdf5_groupname': 'photometry'}
train_algo = None
with pytest.raises(ValueError):
pz_algo = bpz_lite.BPZliteEstimator
_, _, _ = one_algo("BPZ_lite", train_algo, pz_algo, train_config_dict, estim_config_dict)

0 comments on commit deefe81

Please sign in to comment.