Skip to content

Commit

Permalink
Add GBM as a weak learner
Browse files Browse the repository at this point in the history
  • Loading branch information
valenad1 committed Sep 18, 2023
1 parent a73217f commit 79eaef1
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 1 deletion.
17 changes: 17 additions & 0 deletions h2o-algos/src/main/java/hex/adaboost/AdaBoost.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import hex.glm.GLMModel;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import org.apache.log4j.Logger;
import water.*;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
Expand Down Expand Up @@ -166,6 +168,8 @@ private ModelBuilder chooseWeakLearner(Frame frame) {
switch (_parms._weak_learner) {
case GLM:
return getGLMWeakLearner(frame);
case GBM:
return getGBMWeakLearner(frame);
default:
case DRF:
return getDRFWeakLearner(frame);
Expand Down Expand Up @@ -196,6 +200,19 @@ private GLM getGLMWeakLearner(Frame frame) {
return new GLM(parms);
}

private GBM getGBMWeakLearner(Frame frame) {
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = frame._key;
parms._response_column = _parms._response_column;
parms._weights_column = _weightsName;
parms._min_rows = 1;
parms._ntrees = 1;
parms._sample_rate = 1;
parms._max_depth = 1;
parms._seed = _parms._seed;
return new GBM(parms);
}

public TwoDimTable createModelSummaryTable() {
List<String> colHeaders = new ArrayList<>();
List<String> colTypes = new ArrayList<>();
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/adaboost/AdaBoostModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
public class AdaBoostModel extends Model<AdaBoostModel, AdaBoostModel.AdaBoostParameters, AdaBoostModel.AdaBoostOutput> {
private static final Logger LOG = Logger.getLogger(AdaBoostModel.class);

public enum Algorithm {DRF, GLM, AUTO}
public enum Algorithm {DRF, GLM, GBM, AUTO}

public AdaBoostModel(Key<AdaBoostModel> selfKey, AdaBoostParameters parms,
AdaBoostOutput output) {
Expand Down
7 changes: 7 additions & 0 deletions h2o-algos/src/main/java/hex/tree/gbm/GBMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -373,4 +373,11 @@ public void map(Chunk[] chk, NewChunk[] nchk) {
}.withPostMapAction(JobUpdatePostMap.forJob(j)).doAll(types, vs).outputFrame(destination_key, names, domains);
}

@Override
public double score(double[] data) {
double[] pred = score0(data, new double[_output.nclasses() + 1], 0, _output._ntrees);
score0PostProcessSupervised(pred, data);
return pred[0];
}

}
27 changes: 27 additions & 0 deletions h2o-algos/src/test/java/hex/adaboost/AdaBoostTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -488,4 +488,31 @@ public void testBasicTrainAndScoreGLM() {
Scope.exit();
}
}

@Test
public void testBasicTrainAndScoreGBM() {
try {
Scope.enter();
Frame train = Scope.track(parseTestFile("smalldata/prostate/prostate.csv"));
Frame test = Scope.track(parseTestFile("smalldata/prostate/prostate.csv"));
String response = "CAPSULE";
train.toCategoricalCol(response);
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._n_estimators = 50;
p._weak_learner = AdaBoostModel.Algorithm.GBM;
p._response_column = response;

AdaBoost adaBoost = new AdaBoost(p);
AdaBoostModel adaBoostModel = adaBoost.trainModel().get();
Scope.track_generic(adaBoostModel);
assertNotNull(adaBoostModel);

Frame score = adaBoostModel.score(test);
Scope.track(score);
} finally {
Scope.exit();
}
}
}

0 comments on commit 79eaef1

Please sign in to comment.