Skip to content

Commit

Permalink
Merge pull request #18 from LSSTDESC/ceci2
Browse files Browse the repository at this point in the history
Update for ceci version 2
  • Loading branch information
BStoelzner authored Jul 17, 2024
2 parents 6b55d57 + 997fa3e commit a1da1ee
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"pz-rail-base",
"pz-rail-base>=1.0.3",
"scikit-learn",
]

Expand Down
8 changes: 4 additions & 4 deletions src/rail/estimation/algos/k_nearneigh.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ class KNearNeighInformer(CatInformer):
nneigh_min=Param(int, 3, msg="int, min number of near neighbors to use for PDF fit"),
nneigh_max=Param(int, 7, msg="int, max number of near neighbors to use ofr PDF fit"))

def __init__(self, args, comm=None):
def __init__(self, args, **kwargs):
""" Constructor
Do CatInformer specific initialization, then check on bands """
CatInformer.__init__(self, args, comm=comm)
super().__init__(args, **kwargs)

usecols = self.config.bands.copy()
usecols.append(self.config.redshift_col)
Expand Down Expand Up @@ -150,15 +150,15 @@ class KNearNeighEstimator(CatEstimator):
mag_limits=SHARED_PARAMS,
redshift_col=SHARED_PARAMS)

def __init__(self, args, comm=None):
def __init__(self, args, **kwargs):
""" Constructor:
Do Estimator specific initialization """
self.sigma = None
self.numneigh = None
self.model = None
self.trainszs = None
self.zgrid = None
CatEstimator.__init__(self, args, comm=comm)
super().__init__(args, **kwargs)
usecols = self.config.bands.copy()
usecols.append(self.config.redshift_col)
self.usecols = usecols
Expand Down
9 changes: 2 additions & 7 deletions src/rail/estimation/algos/nz_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ class NZDirInformer(CatInformer):
distance_delta=Param(float, 1.e-6, msg="padding for distance calculation"),
hdf5_groupname=Param(str, "photometry", msg="name of hdf5 group for data, if None, then set to ''"))

def __init__(self, args, comm=None):
""" Constructor:
Do Informer specific initialization """
CatInformer.__init__(self, args, comm=comm)

def run(self):
from sklearn.neighbors import NearestNeighbors

Expand Down Expand Up @@ -118,15 +113,15 @@ class NZDirSummarizer(CatEstimator):
outputs = [('output', QPHandle),
('single_NZ', QPHandle)]

def __init__(self, args, comm=None):
def __init__(self, args, **kwargs):
self.zgrid = None
self.model = None
self.distances = None
self.szusecols = None
self.szweights = None
self.sz_mag_data = None
self.bincents = None
CatEstimator.__init__(self, args, comm=comm)
super().__init__(args, **kwargs)

def open_model(self, **kwargs):
CatEstimator.open_model(self, **kwargs)
Expand Down
7 changes: 0 additions & 7 deletions src/rail/estimation/algos/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ class RandomForestInformer(CatInformer):
random_seed=Param(int, msg="random seed"),
no_assign=Param(int, -99, msg="Value for no assignment flag"),)
outputs = [('model', ModelHandle)]

def __init__(self, args, comm=None):
CatInformer.__init__(self, args, comm=comm)

def run(self):
# Load the training data
Expand Down Expand Up @@ -107,10 +104,6 @@ class RandomForestClassifier(CatClassifier):
class_bands=Param(tuple, ["r","i","z"], msg="Which bands to use for classification"),
bands=Param(dict, {"r":"mag_r_lsst", "i":"mag_i_lsst", "z":"mag_z_lsst"}, msg="column names for the the bands"),)
outputs = [('output', Hdf5Handle)]

def __init__(self, args, comm=None):
CatClassifier.__init__(self, args, comm=comm)


def open_model(self, **kwargs):
CatClassifier.open_model(self, **kwargs)
Expand Down
8 changes: 4 additions & 4 deletions src/rail/estimation/algos/sklearn_neurnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ class SklNeurNetInformer(CatInformer):
"not optimally)"))


def __init__(self, args, comm=None):
def __init__(self, args, **kwargs):
""" Constructor:
Do CatInformer specific initialization """
CatInformer.__init__(self, args, comm=comm)
super().__init__(args, **kwargs)
if self.config.ref_band not in self.config.bands:
raise ValueError("ref_band not present in bands list! ")

Expand Down Expand Up @@ -127,10 +127,10 @@ class SklNeurNetEstimator(CatEstimator):
nondetect_val=SHARED_PARAMS,
bands=SHARED_PARAMS)

def __init__(self, args, comm=None):
def __init__(self, args, **kwargs):
""" Constructor:
Do CatEstimator specific initialization """
CatEstimator.__init__(self, args, comm=comm)
super().__init__(args, **kwargs)
if self.config.ref_band not in self.config.bands:
raise ValueError("ref_band is not in list of bands!")

Expand Down

0 comments on commit a1da1ee

Please sign in to comment.