Skip to content

Commit

Permalink
Refactoring according to suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
valenad1 committed Sep 26, 2023
1 parent 4ec9229 commit e52c287
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
16 changes: 8 additions & 8 deletions h2o-algos/src/main/java/hex/adaboost/AdaBoost.java
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,16 @@ private void buildAdaboost() {
DKV.put(model);
Scope.untrack(model._key);
_model._output.models[n] = model._key;
Frame score = model.score(_trainWithWeights);
Scope.track(score);
Frame predictions = model.score(_trainWithWeights);
Scope.track(predictions);

CountWeTask countWe = new CountWeTask().doAll(_trainWithWeights.vec(_weightsName), _trainWithWeights.vec(_parms._response_column), score.vec("predict"));
double e_m = countWe.We / countWe.W;
double alpha_m = _parms._learn_rate * Math.log((1 - e_m) / e_m);
_model._output.alphas[n] = alpha_m;
CountWeTask countWe = new CountWeTask().doAll(_trainWithWeights.vec(_weightsName), _trainWithWeights.vec(_parms._response_column), predictions.vec("predict"));
double eM = countWe.We / countWe.W;
double alphaM = _parms._learn_rate * Math.log((1 - eM) / eM);
_model._output.alphas[n] = alphaM;

UpdateWeightsTask updateWeightsTask = new UpdateWeightsTask(alpha_m);
updateWeightsTask.doAll(_trainWithWeights.vec(_weightsName), _trainWithWeights.vec(_parms._response_column), score.vec("predict"));
UpdateWeightsTask updateWeightsTask = new UpdateWeightsTask(alphaM);
updateWeightsTask.doAll(_trainWithWeights.vec(_weightsName), _trainWithWeights.vec(_parms._response_column), predictions.vec("predict"));
_job.update(1);
_model.update(_job);
LOG.info((n + 1) + ". estimator was built in " + timer.toString());
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/adaboost/AdaBoostModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ protected double[] score0(double[] data, double[] preds) {
linearCombination += _output.alphas[i]*-1;
alphas0 += _output.alphas[i];
} else {
linearCombination += _output.alphas[i]*1;
linearCombination += _output.alphas[i];
alphas1 += _output.alphas[i];
}
}
Expand Down
14 changes: 7 additions & 7 deletions h2o-algos/src/main/java/hex/adaboost/UpdateWeightsTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@
* Update weights according to AdaBoost algorithm
*/
class UpdateWeightsTask extends MRTask<UpdateWeightsTask> {
double exp_am;
double exp_am_inverse;
double expAm;
double expAmInverse;

public UpdateWeightsTask(double alpha_m) {
exp_am = Math.exp(alpha_m);
exp_am_inverse = Math.exp(-alpha_m);
public UpdateWeightsTask(double alphaM) {
expAm = Math.exp(alphaM);
expAmInverse = Math.exp(-alphaM);
}

@Override
public void map(Chunk weights, Chunk response, Chunk predict) {
for (int row = 0; row < weights._len; row++) {
double weight = weights.atd(row);
if (response.at8(row) != predict.at8(row)) {
weights.set(row, weight * exp_am);
weights.set(row, weight * expAm);
} else {
weights.set(row, weight * exp_am_inverse);
weights.set(row, weight * expAmInverse);
}
}
}
Expand Down

0 comments on commit e52c287

Please sign in to comment.