Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-15832 Fix UpliftDRF MOJO API, add docs, tests #15838

Merged
merged 5 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions h2o-algos/src/main/java/hex/generic/Generic.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class Generic extends ModelBuilder<GenericModel, GenericModelParameters,
allowedAlgos.add("coxph");
allowedAlgos.add("rulefit");
allowedAlgos.add("gam");
allowedAlgos.add("upliftdrf");

ALLOWED_MOJO_ALGOS = Collections.unmodifiableSet(allowedAlgos);
}
Expand Down
16 changes: 14 additions & 2 deletions h2o-algos/src/main/java/hex/generic/GenericModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
return new ModelMetricsRegressionCoxPH.MetricBuilderRegressionCoxPH("start", "stop", false, new String[0]);
case AnomalyDetection:
return new ModelMetricsAnomaly.MetricBuilderAnomaly();
case BinomialUplift:
return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain, null);
default:
throw H2O.unimpl();
}
Expand Down Expand Up @@ -181,8 +183,15 @@ private void predict(EasyPredictModelWrapper wrapper, AdaptFrameParameters adapt
final String offsetColumn = adaptFrameParameters.getOffsetColumn();
final String weightsColumn = adaptFrameParameters.getWeightsColumn();
final String responseColumn = adaptFrameParameters.getResponseColumn();
final String treatmentColumn = adaptFrameParameters.getTreatmentColumn();
final boolean isClassifier = wrapper.getModel().isClassifier();
final float[] yact = new float[1];
final boolean isUplift = treatmentColumn != null;
final float[] yact;
if (isUplift) {
yact = new float[2];
} else {
yact = new float[1];
}
for (int row = 0; row < cs[0]._len; row++) {
RowData rowData = new RowData();
RowDataUtils.extractChunkRow(cs, _fr._names, types, row, rowData);
Expand All @@ -206,6 +215,9 @@ private void predict(EasyPredictModelWrapper wrapper, AdaptFrameParameters adapt
yact[0] = (float) idx;
} else
yact[0] = ((Number) response).floatValue();
if (isUplift){
yact[1] = (float) rowData.get(treatmentColumn);
}
_mb.perRow(result, yact, weight, offset, GenericModel.this);
}
}
Expand Down Expand Up @@ -285,7 +297,7 @@ public String getResponseColumn() {
return genModel.isSupervised() ? genModel.getResponseName() : null;
}
@Override
public String getTreatmentColumn() {return null;}
public String getTreatmentColumn() {return descriptor != null ? descriptor.treatmentColumn() : null;}
@Override
public double missingColumnsType() {
return Double.NaN;
Expand Down
87 changes: 84 additions & 3 deletions h2o-algos/src/main/java/hex/generic/GenericModelOutput.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import hex.genmodel.attributes.metrics.*;
import hex.genmodel.descriptor.ModelDescriptor;
import hex.tree.isofor.ModelMetricsAnomaly;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.TwoDimTable;

Expand All @@ -27,6 +28,7 @@ public GenericModelOutput(final ModelDescriptor modelDescriptor) {
_hasOffset = modelDescriptor.offsetColumn() != null;
_hasWeights = modelDescriptor.weightsColumn() != null;
_hasFold = modelDescriptor.foldColumn() != null;
_hasTreatment = modelDescriptor.treatmentColumn() != null;
_modelClassDist = modelDescriptor.modelClassDist();
_priorClassDist = modelDescriptor.priorClassDist();
_names = modelDescriptor.columnNames();
Expand All @@ -36,6 +38,7 @@ public GenericModelOutput(final ModelDescriptor modelDescriptor) {
_defaultThreshold = modelDescriptor.defaultThreshold();
_original_model_identifier = modelDescriptor.algoName();
_original_model_full_name = modelDescriptor.algoFullName();

}

public GenericModelOutput(final ModelDescriptor modelDescriptor, final ModelAttributes modelAttributes,
Expand Down Expand Up @@ -190,6 +193,15 @@ ordinalMetrics._logloss, customMetric(ordinalMetrics),
metricsCoxPH._sigma, metricsCoxPH._mae, metricsCoxPH._root_mean_squared_log_error, metricsCoxPH._mean_residual_deviance,
customMetric(mojoMetrics),
metricsCoxPH._concordance, metricsCoxPH._concordant, metricsCoxPH._discordant, metricsCoxPH._tied_y);
case BinomialUplift:
assert mojoMetrics instanceof MojoModelMetricsBinomialUplift;
MojoModelMetricsBinomialUplift metricsUplift = (MojoModelMetricsBinomialUplift) mojoMetrics;
AUUC.AUUCType auucType = AUUC.AUUCType.valueOf((String) modelAttributes.getParameterValueByName("auuc_type"));
AUUC auuc = createAUUC(auucType, metricsUplift._thresholds_and_metric_scores, metricsUplift._auuc_table, metricsUplift._aecu_table);
return new ModelMetricsBinomialUpliftGeneric(null, null, metricsUplift._nobs, _domains[_domains.length - 1],
metricsUplift._ate, metricsUplift._att, metricsUplift._atc, metricsUplift._sigma, auuc, customMetric(metricsUplift),
convertTable(metricsUplift._thresholds_and_metric_scores), convertTable(metricsUplift._auuc_table),
convertTable(metricsUplift._aecu_table), metricsUplift._description);
case Unknown:
case Clustering:
case AutoEncoder:
Expand Down Expand Up @@ -283,7 +295,7 @@ private static TwoDimTable[] convertTables(final Table[] inputTables) {
}
return tables;
}

private static TwoDimTable convertTable(final Table convertedTable){
if(convertedTable == null) return null;
final TwoDimTable table = new TwoDimTable(convertedTable.getTableHeader(), convertedTable.getTableDescription(),
Expand All @@ -295,8 +307,77 @@ private static TwoDimTable convertTable(final Table convertedTable){
table.set(j, i, convertedTable.getCell(i,j));
}
}

return table;
}


private static AUUC createAUUC(AUUC.AUUCType auucType, Table thresholds_and_metric_scores, Table auuc_table, Table aecu_table){
int nbins = thresholds_and_metric_scores.rows();
double[] ths = new double[nbins];
long[] freq = new long[nbins];
AUUC.AUUCType[] auucTypes = AUUC.AUUCType.values();
double[][] uplift = new double[auucTypes.length][nbins];
double[][] upliftNorm = new double[auucTypes.length][nbins];
double[][] upliftRand = new double[auucTypes.length][nbins];
double[] auuc = new double[auucTypes.length];
double[] auucNorm = new double[auucTypes.length];
double[] auucRand = new double[auucTypes.length];
double[] aecu = new double[auucTypes.length];

String[] thrHeader = thresholds_and_metric_scores.getColHeaders();
// threshold column index
int thrIndex = ArrayUtils.find(thrHeader, "thresholds");
int freqIndex = ArrayUtils.find(thrHeader, "n");

// uplift type indices
int[] upliftIndices = new int[auucTypes.length];
int[] upliftNormIndices = new int[auucTypes.length];
int[] upliftRandIndices = new int[auucTypes.length];
for (int i = 1; i < auucTypes.length; i++) {
String auucTypeName = auucTypes[i].name();
upliftIndices[i] = ArrayUtils.find(thrHeader, auucTypeName);
upliftNormIndices[i] = ArrayUtils.find(thrHeader, auucTypeName+"_normalized");
upliftRandIndices[i] = ArrayUtils.find(thrHeader, auucTypeName+"_random");
// AUTO setting
if(auucTypeName.equals(AUUC.AUUCType.nameAuto())){
upliftIndices[0] = upliftIndices[i];
upliftNormIndices[0] = upliftNormIndices[i];
upliftRandIndices[0] = upliftRandIndices[i];
}
}
// fill thresholds and uplift values from table
for (int i = 0; i < thresholds_and_metric_scores.rows(); i++) {
ths[i] = (double) thresholds_and_metric_scores.getCell(thrIndex, i);
freq[i] = (long) thresholds_and_metric_scores.getCell(freqIndex, i);
for (int j = 0; j < auucTypes.length; j++) {
uplift[j][i] = (double) thresholds_and_metric_scores.getCell(upliftIndices[j], i);
upliftNorm[j][i] = (double) thresholds_and_metric_scores.getCell(upliftNormIndices[j], i);
upliftRand[j][i] = (double) thresholds_and_metric_scores.getCell(upliftRandIndices[j], i);
}
}
// fill auuc values and aecu values
String[] auucHeader = auuc_table.getColHeaders();
String[] aecuHeader = aecu_table.getColHeaders();
for (int i = 1; i < auucTypes.length; i++) {
AUUC.AUUCType type = auucTypes[i];
String auucTypeName = type.name();
int colIndex = ArrayUtils.find(auucHeader, auucTypeName);
auuc[i] = (double) auuc_table.getCell(colIndex, 0);
auucNorm[i] = (double) auuc_table.getCell(colIndex, 1);
auucRand[i] = (double) auuc_table.getCell(colIndex, 2);
colIndex = ArrayUtils.find(aecuHeader, auucTypeName);
aecu[i] = (double) aecu_table.getCell(colIndex, 0);
if(auucTypeName.equals(AUUC.AUUCType.nameAuto())){
auuc[0] = auuc[i];
auucNorm[0] = auucNorm[i];
auucRand[0] = auucRand[i];
aecu[0] = aecu[i];
}
}
return new AUUC(ths, freq, auuc, auucNorm, auucRand, aecu, auucType, uplift, upliftNorm, upliftRand);
}

@Override
public boolean hasTreatment() {
return super.hasTreatment();
}
}
36 changes: 24 additions & 12 deletions h2o-core/src/main/java/hex/AUUC.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

import java.util.Arrays;
import java.util.Iterator;
Expand Down Expand Up @@ -79,7 +78,6 @@ public AUUC(Vec probs, Vec y, Vec uplift, AUUCType auucType, int nbins) {
public AUUC(AUUCBuilder bldr, AUUCType auucType) {
this(bldr, true, auucType);
}


public AUUC(double[] customThresholds, Vec probs, Vec y, Vec uplift, AUUCType auucType) {
this(new AUUCImpl(customThresholds).doAll(probs, y, uplift)._bldr, auucType);
Expand Down Expand Up @@ -205,6 +203,25 @@ public AUUC() {
_upliftNormalized = new double[AUUCType.values().length][];
_upliftRandom = new double[AUUCType.values().length][];
}

public AUUC(double[] ths, long[] freq, double[] auuc, double[] auucNorm, double[] auucRand, double[] aecu,
AUUCType auucType, double[][] uplift, double[][] upliftNorm, double[][] upliftRand) {
_nBins = ths.length;
_n = freq[freq.length-1];
_ths = ths;
_frequencyCumsum = freq;
_treatment = _control = _yTreatment = _yControl = _frequency = new long[0];
_auucs = auuc;
_auucsNormalized = auucNorm;
_auucsRandom = auucRand;
_aecu = aecu;
_maxIdx = -1;
_auucType = auucType;
_auucTypeIndx = getIndexByAUUCType(_auucType);
_uplift = uplift;
_upliftNormalized = upliftNorm;
_upliftRandom = upliftRand;
}

public static double[] calculateQuantileThresholds(int groups, Vec preds) {
Frame fr = null;
Expand Down Expand Up @@ -443,19 +460,14 @@ public enum AUUCType {
* @return metric value */
abstract double exec(long treatment, long control, long yTreatment, long yControl );
public double exec(AUUC auc, int idx) { return exec(auc.treatment(idx),auc.control(idx),auc.yTreatment(idx),auc.yControl(idx)); }

public static final AUUCType[] VALUES = values();

public static AUUCType fromString(String strRepr) {
for (AUUCType tc : AUUCType.values()) {
if (tc.toString().equalsIgnoreCase(strRepr)) {
return tc;
}
}
return null;
}
public static final AUUCType[] VALUES_WITHOUT_AUTO = ArrayUtils.remove(values().clone(), ArrayUtils.find(AUUCType.values(), AUTO));

public double maxCriterion(AUUC auuc) { return exec(auuc, maxCriterionIdx(auuc)); }
public static String nameAuto(){
return qini.name();
}

/** Convert a criterion into a threshold index that maximizes the criterion
* @return Threshold index that maximizes the criterion
Expand Down
3 changes: 3 additions & 0 deletions h2o-core/src/main/java/hex/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,7 @@ public String[] features() {
public String weightsName () { return _hasWeights ?_names[weightsIdx()]:null;}
public String offsetName () { return _hasOffset ?_names[offsetIdx()]:null;}
public String foldName () { return _hasFold ?_names[foldIdx()]:null;}
public String treatmentName() { return _hasTreatment ? _names[treatmentIdx()]: null;}
public InteractionBuilder interactionBuilder() { return null; }
// Vec layout is [c1,c2,...,cn, w?, o?, f?, u?, r]
// cn are predictor cols, r is response, w is weights, o is offset, f is fold and t is treatment - these are optional
Expand Down Expand Up @@ -3469,6 +3470,8 @@ protected class H2OModelDescriptor implements ModelDescriptor {
@Override
public String weightsColumn() { return _output.weightsName(); }
@Override
public String treatmentColumn() { return _output.treatmentName(); }
@Override
public String foldColumn() { return _output.foldName(); }
@Override
public ModelCategory getModelCategory() { return _output.getModelCategory(); }
Expand Down
1 change: 1 addition & 0 deletions h2o-core/src/main/java/hex/ModelMetrics.java
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ public double[] perRow(double ds[], float yact[], double weight, double offset,
assert(weight == 1 && offset == 0);
return perRow(ds, yact, m);
}

public void reduce( T mb ) {
_sumsqe += mb._sumsqe;
_count += mb._count;
Expand Down
20 changes: 20 additions & 0 deletions h2o-core/src/main/java/hex/ModelMetricsBinomialUpliftGeneric.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package hex;

import water.fvec.Frame;
import water.util.TwoDimTable;

public class ModelMetricsBinomialUpliftGeneric extends ModelMetricsBinomialUplift {


public final TwoDimTable _thresholds_and_metric_scores;
public final TwoDimTable _auuc_table;
public final TwoDimTable _aecu_table;

public ModelMetricsBinomialUpliftGeneric(Model model, Frame frame, long nobs, String[] domain, double ate, double att, double atc, double sigma, AUUC auuc, CustomMetric customMetric, TwoDimTable thresholds_and_metric_scores, TwoDimTable auuc_table, TwoDimTable aecu_table, final String description) {
super(model, frame, nobs, domain, ate, att, atc, sigma, auuc, customMetric);
_thresholds_and_metric_scores = thresholds_and_metric_scores;
_auuc_table = auuc_table;
_aecu_table = aecu_table;
_description = description;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package water.api.schemas3;

import hex.ModelMetricsBinomialUpliftGeneric;

public class ModelMetricsBinomialUpliftGenericV3<I extends ModelMetricsBinomialUpliftGeneric, S extends ModelMetricsBinomialUpliftGenericV3<I, S>>
extends ModelMetricsBinomialUpliftV3<I, S> {

@Override
public S fillFromImpl(ModelMetricsBinomialUpliftGeneric modelMetrics) {
super.fillFromImpl(modelMetrics);
this.AUUC = modelMetrics._auuc.auuc();
this.auuc_normalized = modelMetrics._auuc.auucNormalized();
this.ate = modelMetrics.ate();
this.att = modelMetrics.att();
this.atc = modelMetrics.atc();
this.qini = modelMetrics.qini();

if (modelMetrics._auuc_table != null) { // Possibly overwrites whatever has been set in the ModelMetricsBinomialV3
this.auuc_table = new TwoDimTableV3().fillFromImpl(modelMetrics._auuc_table);
}
if (modelMetrics._aecu_table != null) { // Possibly overwrites whatever has been set in the ModelMetricsBinomialV3
this.aecu_table = new TwoDimTableV3().fillFromImpl(modelMetrics._aecu_table);
}
if (modelMetrics._thresholds_and_metric_scores != null) { // Possibly overwrites whatever has been set in the ModelMetricsBinomialV3
this.thresholds_and_metric_scores = new TwoDimTableV3().fillFromImpl(modelMetrics._thresholds_and_metric_scores);
}
return (S) this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ public S fillFromImpl(ModelMetricsBinomialUplift modelMetrics) {
qini = auuc.qini();
// Fill TwoDimTable
String[] thresholds = new String[auuc._nBins];
AUUCType metrics[] = AUUCType.VALUES;
metrics = ArrayUtils.remove(metrics, Arrays.asList(metrics).indexOf(AUUCType.AUTO));
AUUCType metrics[] = AUUCType.VALUES_WITHOUT_AUTO;
int metricsLength = metrics.length;
long[] n = new long[auuc._nBins];
double[][] uplift = new double[metricsLength][];
Expand Down Expand Up @@ -89,7 +88,7 @@ public S fillFromImpl(ModelMetricsBinomialUplift modelMetrics) {
types [i + 1 + 2 * metricsLength] = "double";
formats [i + 1 + 2 * metricsLength] = "%f";
}
colHeaders[i + 1 + 2 * metricsLength] = "n"; types[i + 1 + 2 * metricsLength] = "int"; formats[i + 1 + 2 * metricsLength] = "%d";
colHeaders[i + 1 + 2 * metricsLength] = "n"; types[i + 1 + 2 * metricsLength] = "long"; formats[i + 1 + 2 * metricsLength] = "%d";
colHeaders[i + 2 + 2 * metricsLength] = "idx"; types[i + 2 + 2 * metricsLength] = "int"; formats[i + 2 + 2 * metricsLength] = "%d";
TwoDimTable thresholdsByMetrics = new TwoDimTable("Metrics for Thresholds", "Cumulative Uplift metrics for a given percentile", new String[auuc._nBins], colHeaders, types, formats, null );
for (i = 0; i < auuc._nBins; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ water.api.schemas3.ModelMetricsBinomialGLMGenericV3
water.api.schemas3.ModelMetricsBinomialV3
water.api.schemas3.ModelMetricsBinomialGenericV3
water.api.schemas3.ModelMetricsBinomialUpliftV3
water.api.schemas3.ModelMetricsBinomialUpliftGenericV3
water.api.schemas3.ModelMetricsClusteringV3
water.api.schemas3.ModelMetricsHGLMV3
water.api.schemas3.ModelMetricsHGLMGenericV3
Expand Down
3 changes: 2 additions & 1 deletion h2o-docs/src/product/data-science/upliftdrf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ The current version of Uplift DRF is based on the implementation of DRF because
MOJO Support
''''''''''''

Uplift DRF currently doesn't support `MOJOs <../save-and-load-model.html#supported-mojos>`__.
Uplift DRF supports importing and exporting `MOJOs <../save-and-load-model.html#supported-mojos>`__.


Uplift DRF demo
~~~~~~~~~~~~~~~
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ public String foldColumn() {
return null;
}

@Override
public String treatmentColumn() { return null; }

@Override
public ModelCategory getModelCategory() {
return _finalModel._category;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ private static MojoModelMetrics determineModelMetricsType(final MojoModel mojoMo
} else return new MojoModelMetricsOrdinal();
case CoxPH:
return new MojoModelMetricsRegressionCoxPH();
case BinomialUplift:
return new MojoModelMetricsBinomialUplift();
case Unknown:
case Clustering:
case AutoEncoder:
Expand Down
Loading