diff --git a/h2o-algos/src/main/java/hex/adaboost/AdaBoost.java b/h2o-algos/src/main/java/hex/adaboost/AdaBoost.java index 0d467989da11..1e461fea578f 100644 --- a/h2o-algos/src/main/java/hex/adaboost/AdaBoost.java +++ b/h2o-algos/src/main/java/hex/adaboost/AdaBoost.java @@ -7,6 +7,8 @@ import hex.glm.GLMModel; import hex.tree.drf.DRF; import hex.tree.drf.DRFModel; +import hex.tree.gbm.GBM; +import hex.tree.gbm.GBMModel; import org.apache.log4j.Logger; import water.*; import water.exceptions.H2OModelBuilderIllegalArgumentException; @@ -166,6 +168,8 @@ private ModelBuilder chooseWeakLearner(Frame frame) { switch (_parms._weak_learner) { case GLM: return getGLMWeakLearner(frame); + case GBM: + return getGBMWeakLearner(frame); default: case DRF: return getDRFWeakLearner(frame); @@ -196,6 +200,19 @@ private GLM getGLMWeakLearner(Frame frame) { return new GLM(parms); } + private GBM getGBMWeakLearner(Frame frame) { + GBMModel.GBMParameters parms = new GBMModel.GBMParameters(); + parms._train = frame._key; + parms._response_column = _parms._response_column; + parms._weights_column = _weightsName; + parms._min_rows = 1; + parms._ntrees = 1; + parms._sample_rate = 1; + parms._max_depth = 1; + parms._seed = _parms._seed; + return new GBM(parms); + } + public TwoDimTable createModelSummaryTable() { List colHeaders = new ArrayList<>(); List colTypes = new ArrayList<>(); diff --git a/h2o-algos/src/main/java/hex/adaboost/AdaBoostModel.java b/h2o-algos/src/main/java/hex/adaboost/AdaBoostModel.java index fb22947c7f16..e769779d376c 100644 --- a/h2o-algos/src/main/java/hex/adaboost/AdaBoostModel.java +++ b/h2o-algos/src/main/java/hex/adaboost/AdaBoostModel.java @@ -11,7 +11,7 @@ public class AdaBoostModel extends Model { private static final Logger LOG = Logger.getLogger(AdaBoostModel.class); - public enum Algorithm {DRF, GLM, AUTO} + public enum Algorithm {DRF, GLM, GBM, AUTO} public AdaBoostModel(Key selfKey, AdaBoostParameters parms, AdaBoostOutput output) { diff --git a/h2o-algos/src/main/java/hex/tree/gbm/GBMModel.java b/h2o-algos/src/main/java/hex/tree/gbm/GBMModel.java index 662f5ef0af6b..900ec18f9e53 100755 --- a/h2o-algos/src/main/java/hex/tree/gbm/GBMModel.java +++ b/h2o-algos/src/main/java/hex/tree/gbm/GBMModel.java @@ -373,4 +373,11 @@ public void map(Chunk[] chk, NewChunk[] nchk) { }.withPostMapAction(JobUpdatePostMap.forJob(j)).doAll(types, vs).outputFrame(destination_key, names, domains); } + @Override + public double score(double[] data) { + double[] pred = score0(data, new double[_output.nclasses() + 1], 0, _output._ntrees); + score0PostProcessSupervised(pred, data); + return pred[0]; + } + } diff --git a/h2o-algos/src/test/java/hex/adaboost/AdaBoostTest.java b/h2o-algos/src/test/java/hex/adaboost/AdaBoostTest.java index 37b1644d5d6f..65de33ee8830 100644 --- a/h2o-algos/src/test/java/hex/adaboost/AdaBoostTest.java +++ b/h2o-algos/src/test/java/hex/adaboost/AdaBoostTest.java @@ -488,4 +488,31 @@ public void testBasicTrainAndScoreGLM() { Scope.exit(); } } + + @Test + public void testBasicTrainAndScoreGBM() { + try { + Scope.enter(); + Frame train = Scope.track(parseTestFile("smalldata/prostate/prostate.csv")); + Frame test = Scope.track(parseTestFile("smalldata/prostate/prostate.csv")); + String response = "CAPSULE"; + train.toCategoricalCol(response); + AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters(); + p._train = train._key; + p._seed = 0xDECAF; + p._n_estimators = 50; + p._weak_learner = AdaBoostModel.Algorithm.GBM; + p._response_column = response; + + AdaBoost adaBoost = new AdaBoost(p); + AdaBoostModel adaBoostModel = adaBoost.trainModel().get(); + Scope.track_generic(adaBoostModel); + assertNotNull(adaBoostModel); + + Frame score = adaBoostModel.score(test); + Scope.track(score); + } finally { + Scope.exit(); + } + } }