Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-6783 Add custom metric function to UpliftDRF #15592

Merged
merged 19 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions h2o-algos/src/main/java/hex/schemas/UpliftDRFV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public static final class UpliftDRFParametersV3 extends SharedTreeV3.SharedTreeP
"categorical_encoding",
"distribution",
"check_constant_response",
"custom_metric_func",
"treatment_column",
"uplift_metric",
"auuc_type",
Expand Down
1 change: 1 addition & 0 deletions h2o-algos/src/main/java/hex/tree/SharedTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ protected final boolean doScoringAndSaveModel(boolean finalScoring, boolean oob,
ModelMetrics mmv = scv.scoreAndMakeModelMetrics(_model, _parms.valid(), v, build_tree_one_node);
_lastScoredTree = _model._output._ntrees;
out._validation_metrics = mmv;
out._validation_metrics._description = "Validation metrics";
if (_model._output._ntrees>0 || scoreZeroTrees()) //don't score the 0-tree model - the error is too large
out._scored_valid[out._ntrees].fillFrom(mmv);
}
Expand Down
14 changes: 12 additions & 2 deletions h2o-algos/src/main/java/hex/tree/uplift/UpliftDRF.java
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ public boolean providesVarImp() {
error("_treatment_column", "The treatment column has to be defined.");
if (_parms._custom_distribution_func != null)
error("_custom_distribution_func", "The custom distribution is not yet supported for Uplift DRF.");
if (_parms._custom_metric_func != null)
error("_custom_metric_func", "The custom metric is not yet supported for Uplift DRF.");
if (_parms._stopping_metric != ScoreKeeper.StoppingMetric.AUTO)
error("_stopping_metric", "The early stopping is not yet supported for Uplift DRF.");
if (_parms._stopping_rounds != 0)
Expand Down Expand Up @@ -404,6 +402,9 @@ static TwoDimTable createUpliftScoringHistoryTable(Model.Output _output,
colHeaders.add("Timestamp"); colTypes.add("string"); colFormat.add("%s");
colHeaders.add("Duration"); colTypes.add("string"); colFormat.add("%s");
colHeaders.add("Number of Trees"); colTypes.add("long"); colFormat.add("%d");
colHeaders.add("Training ATE"); colTypes.add("double"); colFormat.add("%d");
colHeaders.add("Training ATT"); colTypes.add("double"); colFormat.add("%d");
colHeaders.add("Training ATC"); colTypes.add("double"); colFormat.add("%d");
colHeaders.add("Training AUUC nbins"); colTypes.add("int"); colFormat.add("%d");
colHeaders.add("Training AUUC"); colTypes.add("double"); colFormat.add("%.5f");
colHeaders.add("Training AUUC normalized"); colTypes.add("double"); colFormat.add("%.5f");
Expand All @@ -413,6 +414,9 @@ static TwoDimTable createUpliftScoringHistoryTable(Model.Output _output,
}

if (_output._validation_metrics != null) {
colHeaders.add("Validation ATE"); colTypes.add("double"); colFormat.add("%d");
colHeaders.add("Validation ATT"); colTypes.add("double"); colFormat.add("%d");
colHeaders.add("Validation ATC"); colTypes.add("double"); colFormat.add("%d");
colHeaders.add("Validation AUUC nbins"); colTypes.add("int"); colFormat.add("%d");
colHeaders.add("Validation AUUC"); colTypes.add("double"); colFormat.add("%.5f");
colHeaders.add("Validation AUUC normalized"); colTypes.add("double"); colFormat.add("%.5f");
Expand Down Expand Up @@ -443,6 +447,9 @@ static TwoDimTable createUpliftScoringHistoryTable(Model.Output _output,
table.set(row, col++, PrettyPrint.msecs(_training_time_ms[i] - job.start_time(), true));
table.set(row, col++, i);
ScoreKeeper st = _scored_train[i];
table.set(row, col++, st._ate);
table.set(row, col++, st._att);
table.set(row, col++, st._atc);
table.set(row, col++, st._auuc_nbins);
table.set(row, col++, st._AUUC);
table.set(row, col++, st._auuc_normalized);
Expand All @@ -451,6 +458,9 @@ static TwoDimTable createUpliftScoringHistoryTable(Model.Output _output,

if (_output._validation_metrics != null) {
st = _scored_valid[i];
table.set(row, col++, st._ate);
table.set(row, col++, st._att);
table.set(row, col++, st._atc);
table.set(row, col++, st._auuc_nbins);
table.set(row, col++, st._AUUC);
table.set(row, col++, st._auuc_normalized);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package hex.util;

import hex.AUUC;
import hex.Model;
import hex.ScoreKeeper;
import hex.genmodel.utils.DistributionFamily;
Expand Down
10 changes: 9 additions & 1 deletion h2o-core/src/main/java/hex/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -2222,10 +2222,15 @@ protected void setupLocal() {
Chunk weightsChunk = _hasWeights && _computeMetrics ? chks[_output.weightsIdx()] : null;
Chunk offsetChunk = _output.hasOffset() ? chks[_output.offsetIdx()] : null;
Chunk responseChunk = null;
Chunk treatmentChunk = null;
float [] actual = null;
_mb = Model.this.makeMetricBuilder(_domain);
if (_computeMetrics) {
if (_output.hasResponse()) {
if (_output.hasTreatment()){
actual = new float[2];
responseChunk = chks[_output.responseIdx()];
treatmentChunk = chks[_output.treatmentIdx()];
} else if (_output.hasResponse()) {
actual = new float[1];
responseChunk = chks[_output.responseIdx()];
} else
Expand All @@ -2252,6 +2257,9 @@ protected void setupLocal() {
for (int i = 0; i < actual.length; ++i)
actual[i] = (float) data(chks, row, i);
}
if (treatmentChunk != null) {
actual[1] = (float) treatmentChunk.atd(row);
}
_mb.perRow(preds, actual, weight, offset, Model.this);
// Handle custom metric
customMetricPerRow(preds, actual, weight, offset, Model.this);
Expand Down
48 changes: 39 additions & 9 deletions h2o-core/src/main/java/hex/ModelMetricsBinomialUplift.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@

public class ModelMetricsBinomialUplift extends ModelMetricsSupervised {
public final AUUC _auuc;
public double _ate;
public double _att;
public double _atc;

public ModelMetricsBinomialUplift(Model model, Frame frame, long nobs, String[] domain,
double sigma, AUUC auuc,
public ModelMetricsBinomialUplift(Model model, Frame frame, long nobs, String[] domain,
double ate, double att, double atc, double sigma, AUUC auuc,
CustomMetric customMetric) {
super(model, frame, nobs, 0, domain, sigma, customMetric);
_ate = ate;
_att = att;
_atc = atc;
_auuc = auuc;
}

Expand All @@ -30,6 +36,9 @@ public static ModelMetricsBinomialUplift getFromDKV(Model model, Frame frame) {
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(super.toString());
sb.append("ATE:" ).append((float) _ate).append("\n");
sb.append("ATT:" ).append((float) _att).append("\n");
sb.append("ATC:" ).append((float) _atc).append("\n");
if(_auuc != null){
sb.append("Default AUUC: ").append((float) _auuc.auuc()).append("\n");
sb.append("Qini AUUC: ").append((float) _auuc.auucByType(AUUC.AUUCType.qini)).append("\n");
Expand All @@ -50,6 +59,12 @@ public String toString() {
public double auucNormalized(){return _auuc.auucNormalized();}

public int nbins(){return _auuc._nBins;}

public double ate() {return _ate;}

public double att() {return _att;}

public double atc() {return _atc;}

@Override
protected StringBuilder appendToStringMetrics(StringBuilder sb) {
Expand Down Expand Up @@ -127,13 +142,13 @@ public UpliftBinomialMetrics(String[] domain, double[] thresholds) {
_mb = new MetricBuilderBinomialUplift(domain, thresholds);
Chunk uplift = chks[0];
Chunk actuals = chks[1];
Chunk treatment =chks[2];
Chunk treatment = chks[2];
double[] ds = new double[1];
float[] acts = new float[2];
for (int i=0; i<chks[0]._len;++i) {
ds[0] = uplift.atd(i);
acts[0] = (float) actuals.atd(i);
acts[1] = (float )treatment.atd(i);
acts[1] = (float) treatment.atd(i);
_mb.perRow(ds, acts, 1, 0, null);
}
}
Expand All @@ -143,7 +158,10 @@ public UpliftBinomialMetrics(String[] domain, double[] thresholds) {
public static class MetricBuilderBinomialUplift extends MetricBuilderSupervised<MetricBuilderBinomialUplift> {

protected AUUC.AUUCBuilder _auuc;

public double _sumTE;
public double _sumTETreatment;
public long _treatmentCount;

public MetricBuilderBinomialUplift( String[] domain, double[] thresholds) {
super(2,domain);
if(thresholds != null) {
Expand All @@ -163,17 +181,20 @@ public MetricBuilderBinomialUplift( String[] domain) {
public double[] perRow(double[] ds, float[] yact, double weight, double offset, Model m) {
assert _auuc == null || yact.length == 2 : "Treatment must be included in `yact` when calculating AUUC";
if(Float .isNaN(yact[0])) return ds; // No errors if actual is missing
if(ArrayUtils.hasNaNs(ds)) return ds; // No errors if prediction has missing values (can happen for GLM)
if(weight == 0 || Double.isNaN(weight)) return ds;
int y = (int)yact[0];
if (y != 0 && y != 1) return ds; // The actual is effectively a NaN
_wY += weight * y;
_wYY += weight * y * y;
_count++;
_wcount += weight;
int treatmentGroup = (int)yact[1]; // treatment = 1, control = 0
double treatmentEffect = ds[0] * weight;
_sumTE += treatmentEffect; // result prediction
_sumTETreatment += treatmentGroup * treatmentEffect;
_treatmentCount += treatmentGroup * weight;
if (_auuc != null) {
float treatment = yact[1];
_auuc.perRow(ds[0], weight, y, treatment);
_auuc.perRow(treatmentEffect, weight, y, treatmentGroup);
}
return ds;
}
Expand All @@ -183,6 +204,9 @@ public double[] perRow(double[] ds, float[] yact, double weight, double offset,
if(_auuc != null) {
_auuc.reduce(mb._auuc);
}
_sumTE += mb._sumTE;
_sumTETreatment += mb._sumTETreatment;
_treatmentCount += mb._treatmentCount;
}

/**
Expand Down Expand Up @@ -231,15 +255,21 @@ private ModelMetrics makeModelMetrics(final Model m, final Frame f, final Frame

private ModelMetrics makeModelMetrics(Model m, Frame f, AUUC auuc) {
double sigma = Double.NaN;
double ate = Double.NaN;
double atc = Double.NaN;
double att = Double.NaN;
if(_wcount > 0) {
if (auuc == null) {
sigma = weightedSigma();
auuc = new AUUC(_auuc, m._parms._auuc_type);
}
ate = _sumTE/_wcount;
att = _sumTETreatment/_treatmentCount;
atc = (_sumTE-_sumTETreatment)/(_wcount-_treatmentCount);
} else {
auuc = new AUUC();
}
ModelMetricsBinomialUplift mm = new ModelMetricsBinomialUplift(m, f, _count, _domain, sigma, auuc, _customMetric);
ModelMetricsBinomialUplift mm = new ModelMetricsBinomialUplift(m, f, _count, _domain, ate, att, atc, sigma, auuc, _customMetric);
if (m!=null) m.addModelMetrics(mm);
return mm;
}
Expand Down
6 changes: 6 additions & 0 deletions h2o-core/src/main/java/hex/ScoreKeeper.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public class ScoreKeeper extends Iced {
public double _auuc_normalized = Double.NaN;
public double _qini = Double.NaN;
public int _auuc_nbins = 0;
public double _ate = Double.NaN;
public double _att = Double.NaN;
public double _atc = Double.NaN;

public ScoreKeeper() {}

Expand Down Expand Up @@ -125,6 +128,9 @@ else if (m instanceof ModelMetricsMultinomial) {
_auuc_normalized = ((ModelMetricsBinomialUplift)m).auucNormalized();
_qini = ((ModelMetricsBinomialUplift)m).qini();
_auuc_nbins = ((ModelMetricsBinomialUplift)m).nbins();
_ate = ((ModelMetricsBinomialUplift)m).ate();
_att = ((ModelMetricsBinomialUplift)m).att();
_atc = ((ModelMetricsBinomialUplift)m).atc();
}
if (customMetric != null ) {
_custom_metric = customMetric.value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@
public class ModelMetricsBinomialUpliftV3<I extends ModelMetricsBinomialUplift, S extends water.api.schemas3.ModelMetricsBinomialUpliftV3<I, S>>
extends ModelMetricsBaseV3<I,S> {

@API(help="Average Treatment Effect.", direction=API.Direction.OUTPUT)
public double ate;

@API(help="Average Treatment Effect on the Treated.", direction=API.Direction.OUTPUT)
public double att;

@API(help="Average Treatment Effect on the Control.", direction=API.Direction.OUTPUT)
public double atc;

@API(help="The default AUUC for this scoring run.", direction=API.Direction.OUTPUT)
public double AUUC;

Expand Down Expand Up @@ -40,6 +49,9 @@ public S fillFromImpl(ModelMetricsBinomialUplift modelMetrics) {

AUUC auuc = modelMetrics._auuc;
if (null != auuc) {
ate = modelMetrics.ate();
att = modelMetrics.att();
atc = modelMetrics.atc();
AUUC = auuc.auuc();
auuc_normalized = auuc.auucNormalized();
qini = auuc.qini();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
``upload_custom_metric``
------------------------

- Available in: GBM, DRF, Deeplearning
- Available in: GBM, DRF, Deeplearning, UpliftDRF
- Hyperparameter: no

Description
Expand Down
Loading