-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add random forest classifier for tomographer #1
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
d9bac43
add random forest classifier for tomographer
hangqianjun d86b13a
Add unit test and change some variable names
hangqianjun 08b08bc
Change according to naming convention
hangqianjun d9992aa
resolve conflict in import
hangqianjun 85d617e
add classifier algo into sklearn/init
hangqianjun 1440f49
fixed typo
hangqianjun d1ccee8
Fix typo and tests
hangqianjun 88dd06c
distinguish bands and class_bands
hangqianjun 840f0e1
Added a test to cover ID column
hangqianjun 3211540
Cleared outputs in tests
hangqianjun 55ffa19
resolved conflicts in test_algo.py
hangqianjun fc2098d
KNearNeighEstimator.pkl name accidentally changed, reverted
hangqianjun 61997a7
Changing names for informer and classifier according to review
hangqianjun 61f004d
trivial commit to force action
eacharles File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
""" | ||
An example classifier that uses catalogue information to | ||
classify objects into tomoragphic bins using random forest. | ||
This is the base method in TXPipe, adapted from TXpipe/binning/random_forest.py | ||
Note: extra dependence on sklearn and input training file. | ||
""" | ||
|
||
from collections import OrderedDict | ||
import numpy as np | ||
from ceci.config import StageParameter as Param | ||
from rail.estimation.classifier import CatClassifier | ||
from rail.estimation.informer import CatInformer | ||
from rail.core.data import TableHandle, ModelHandle, Hdf5Handle | ||
from sklearn.ensemble import RandomForestClassifier as skl_RandomForestClassifier | ||
|
||
from rail.core.common_params import SHARED_PARAMS | ||
|
||
|
||
class randomForestmodel: | ||
""" | ||
Temporary class to store the trained model. | ||
""" | ||
def __init__(self, skl_classifier, features): | ||
self.skl_classifier = skl_classifier | ||
self.features = features | ||
|
||
|
||
class RandomForestInformer(CatInformer): | ||
"""Train the random forest classifier""" | ||
|
||
name = 'Inform_randomForestClassifier' | ||
config_options = CatInformer.config_options.copy() | ||
config_options.update( | ||
class_bands=Param(tuple, ["r","i","z"], msg="Which bands to use for classification"), | ||
bands=Param(dict, {"r":"mag_r_lsst", "i":"mag_i_lsst", "z":"mag_z_lsst"}, msg="column names for the the bands"), | ||
redshift_col=Param(str, "sz", msg="Redshift column names"), | ||
bin_edges=Param(tuple, [0,0.5,1.0], msg="Binning for training data"), | ||
random_seed=Param(int, msg="random seed"), | ||
no_assign=Param(int, -99, msg="Value for no assignment flag"),) | ||
outputs = [('model', ModelHandle)] | ||
|
||
def __init__(self, args, comm=None): | ||
CatInformer.__init__(self, args, comm=comm) | ||
|
||
def run(self): | ||
# Load the training data | ||
if self.config.hdf5_groupname: | ||
training_data_table = self.get_data('input')[self.config.hdf5_groupname] | ||
else: # pragma: no cover | ||
training_data_table = self.get_data('input') | ||
|
||
# Pull out the appropriate columns and combinations of the data | ||
print(f"Using these bands to train the tomography selector: {self.config.bands}") | ||
|
||
# Generate the training data that we will use | ||
# We record both the name of the column and the data itself | ||
features = [] | ||
training_data = [] | ||
for b1 in self.config.class_bands[:]: | ||
b1_cat=self.config.bands[b1] | ||
# First we use the magnitudes themselves | ||
features.append(b1) | ||
training_data.append(training_data_table[b1_cat]) | ||
# We also use the colours as training data, even the redundant ones | ||
for b2 in self.config.class_bands[:]: | ||
b2_cat=self.config.bands[b2] | ||
if b1 < b2: | ||
features.append(f"{b1}-{b2}") | ||
training_data.append(training_data_table[b1_cat] - training_data_table[b2_cat]) | ||
training_data = np.array(training_data).T | ||
|
||
print("Training data for bin classifier has shape ", training_data.shape) | ||
|
||
# Now put the training data into redshift bins | ||
# We use -99 to indicate that we are outside the desired ranges | ||
z = training_data_table[self.config.redshift_col] | ||
training_bin = np.repeat(self.config.no_assign, len(z)) | ||
print("Using these bin edges:", self.config.bin_edges) | ||
for i, zmin in enumerate(self.config.bin_edges[:-1]): | ||
zmax = self.config.bin_edges[i + 1] | ||
training_bin[(z > zmin) & (z < zmax)] = i | ||
ntrain_bin = ((z > zmin) & (z < zmax)).sum() | ||
print(f"Training set: {ntrain_bin} objects in tomographic bin {i}") | ||
|
||
# Can be replaced with any classifier | ||
skl_classifier = skl_RandomForestClassifier( | ||
max_depth=10, | ||
max_features=None, | ||
n_estimators=20, | ||
random_state=self.config.random_seed, | ||
) | ||
skl_classifier.fit(training_data, training_bin) | ||
|
||
#return classifier, features | ||
self.model = randomForestmodel(skl_classifier, features) | ||
self.add_data('model', self.model) | ||
|
||
|
||
class RandomForestClassifier(CatClassifier): | ||
"""Classifier that assigns tomographic | ||
bins based on random forest method""" | ||
|
||
name = 'randomForestClassifier' | ||
config_options = CatClassifier.config_options.copy() | ||
config_options.update( | ||
id_name=Param(str, "", msg="Column name for the object ID in the input data, if empty the row index is used as the ID."), | ||
class_bands=Param(tuple, ["r","i","z"], msg="Which bands to use for classification"), | ||
bands=Param(dict, {"r":"mag_r_lsst", "i":"mag_i_lsst", "z":"mag_z_lsst"}, msg="column names for the the bands"),) | ||
outputs = [('output', Hdf5Handle)] | ||
|
||
def __init__(self, args, comm=None): | ||
CatClassifier.__init__(self, args, comm=comm) | ||
|
||
|
||
def open_model(self, **kwargs): | ||
CatClassifier.open_model(self, **kwargs) | ||
if self.model is None: # pragma: no cover | ||
return | ||
self.skl_classifier = self.model.skl_classifier | ||
self.features = self.model.features | ||
|
||
|
||
def run(self): | ||
"""Apply the classifier to the measured magnitudes""" | ||
|
||
if self.config.hdf5_groupname: | ||
test_data = self.get_data('input')[self.config.hdf5_groupname] | ||
else: # pragma: no cover | ||
test_data = self.get_data('input') | ||
|
||
data = [] | ||
for f in self.features: | ||
hangqianjun marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# may be a single band | ||
if len(f) == 1: | ||
f_cat=self.config.bands[f] | ||
col = test_data[f_cat] | ||
# or a colour | ||
else: | ||
b1, b2 = f.split("-") | ||
b1_cat=self.config.bands[b1] | ||
b2_cat=self.config.bands[b2] | ||
col = (test_data[b1_cat] - test_data[b2_cat]) | ||
if np.all(~np.isfinite(col)): | ||
# entire column is NaN. Hopefully this will get deselected elsewhere | ||
col[:] = 30.0 | ||
else: | ||
ok = np.isfinite(col) | ||
col[~ok] = col[ok].max() | ||
data.append(col) | ||
data = np.array(data).T | ||
|
||
# Run the random forest on this data chunk | ||
bin_index = self.skl_classifier.predict(data) | ||
|
||
if self.config.id_name != "": | ||
# below is commented out and replaced by a redundant line | ||
# because the data doesn't have ID yet | ||
obj_id = test_data[self.config.id_name] | ||
elif self.config.id_name == "": | ||
# ID set to row index | ||
b=self.config.bands[self.config.class_bands[0]] | ||
obj_id = np.arange(len(test_data[b])) | ||
self.config.id_name="row_index" | ||
|
||
class_id = dict(data=OrderedDict([(self.config.id_name, obj_id), ("class_id", bin_index)])) | ||
self.add_data('output', class_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing needs to be changed here, but I just want to note re: LSSTDESC/rail_base#39 that this will be the first instance of IDs in the RAIL-iverse so should be considered in decision-making about how to consistently reference these throughout
src/rail
code.