Skip to content

Commit

Permalink
GH-15832 Fix UpliftDRF MOJO API, add docs, tests (#15838)
Browse files Browse the repository at this point in the history
* Fix Uplift MOJO API, add tests

* Add tests, fix predict

* Implement Generic logic behind Python/R uplift MOJO API
  • Loading branch information
maurever authored Nov 1, 2023
1 parent c4a4c86 commit c9e1ff2
Show file tree
Hide file tree
Showing 21 changed files with 539 additions and 22 deletions.
1 change: 1 addition & 0 deletions h2o-algos/src/main/java/hex/generic/Generic.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class Generic extends ModelBuilder<GenericModel, GenericModelParameters,
allowedAlgos.add("coxph");
allowedAlgos.add("rulefit");
allowedAlgos.add("gam");
allowedAlgos.add("upliftdrf");

ALLOWED_MOJO_ALGOS = Collections.unmodifiableSet(allowedAlgos);
}
Expand Down
16 changes: 14 additions & 2 deletions h2o-algos/src/main/java/hex/generic/GenericModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
return new ModelMetricsRegressionCoxPH.MetricBuilderRegressionCoxPH("start", "stop", false, new String[0]);
case AnomalyDetection:
return new ModelMetricsAnomaly.MetricBuilderAnomaly();
case BinomialUplift:
return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain, null);
default:
throw H2O.unimpl();
}
Expand Down Expand Up @@ -181,8 +183,15 @@ private void predict(EasyPredictModelWrapper wrapper, AdaptFrameParameters adapt
final String offsetColumn = adaptFrameParameters.getOffsetColumn();
final String weightsColumn = adaptFrameParameters.getWeightsColumn();
final String responseColumn = adaptFrameParameters.getResponseColumn();
final String treatmentColumn = adaptFrameParameters.getTreatmentColumn();
final boolean isClassifier = wrapper.getModel().isClassifier();
final float[] yact = new float[1];
final boolean isUplift = treatmentColumn != null;
final float[] yact;
if (isUplift) {
yact = new float[2];
} else {
yact = new float[1];
}
for (int row = 0; row < cs[0]._len; row++) {
RowData rowData = new RowData();
RowDataUtils.extractChunkRow(cs, _fr._names, types, row, rowData);
Expand All @@ -206,6 +215,9 @@ private void predict(EasyPredictModelWrapper wrapper, AdaptFrameParameters adapt
yact[0] = (float) idx;
} else
yact[0] = ((Number) response).floatValue();
if (isUplift){
yact[1] = (float) rowData.get(treatmentColumn);
}
_mb.perRow(result, yact, weight, offset, GenericModel.this);
}
}
Expand Down Expand Up @@ -285,7 +297,7 @@ public String getResponseColumn() {
return genModel.isSupervised() ? genModel.getResponseName() : null;
}
@Override
public String getTreatmentColumn() {return null;}
public String getTreatmentColumn() {return descriptor != null ? descriptor.treatmentColumn() : null;}
@Override
public double missingColumnsType() {
return Double.NaN;
Expand Down
87 changes: 84 additions & 3 deletions h2o-algos/src/main/java/hex/generic/GenericModelOutput.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import hex.genmodel.attributes.metrics.*;
import hex.genmodel.descriptor.ModelDescriptor;
import hex.tree.isofor.ModelMetricsAnomaly;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.TwoDimTable;

Expand All @@ -27,6 +28,7 @@ public GenericModelOutput(final ModelDescriptor modelDescriptor) {
_hasOffset = modelDescriptor.offsetColumn() != null;
_hasWeights = modelDescriptor.weightsColumn() != null;
_hasFold = modelDescriptor.foldColumn() != null;
_hasTreatment = modelDescriptor.treatmentColumn() != null;
_modelClassDist = modelDescriptor.modelClassDist();
_priorClassDist = modelDescriptor.priorClassDist();
_names = modelDescriptor.columnNames();
Expand All @@ -36,6 +38,7 @@ public GenericModelOutput(final ModelDescriptor modelDescriptor) {
_defaultThreshold = modelDescriptor.defaultThreshold();
_original_model_identifier = modelDescriptor.algoName();
_original_model_full_name = modelDescriptor.algoFullName();

}

public GenericModelOutput(final ModelDescriptor modelDescriptor, final ModelAttributes modelAttributes,
Expand Down Expand Up @@ -190,6 +193,15 @@ ordinalMetrics._logloss, customMetric(ordinalMetrics),
metricsCoxPH._sigma, metricsCoxPH._mae, metricsCoxPH._root_mean_squared_log_error, metricsCoxPH._mean_residual_deviance,
customMetric(mojoMetrics),
metricsCoxPH._concordance, metricsCoxPH._concordant, metricsCoxPH._discordant, metricsCoxPH._tied_y);
case BinomialUplift:
assert mojoMetrics instanceof MojoModelMetricsBinomialUplift;
MojoModelMetricsBinomialUplift metricsUplift = (MojoModelMetricsBinomialUplift) mojoMetrics;
AUUC.AUUCType auucType = AUUC.AUUCType.valueOf((String) modelAttributes.getParameterValueByName("auuc_type"));
AUUC auuc = createAUUC(auucType, metricsUplift._thresholds_and_metric_scores, metricsUplift._auuc_table, metricsUplift._aecu_table);
return new ModelMetricsBinomialUpliftGeneric(null, null, metricsUplift._nobs, _domains[_domains.length - 1],
metricsUplift._ate, metricsUplift._att, metricsUplift._atc, metricsUplift._sigma, auuc, customMetric(metricsUplift),
convertTable(metricsUplift._thresholds_and_metric_scores), convertTable(metricsUplift._auuc_table),
convertTable(metricsUplift._aecu_table), metricsUplift._description);
case Unknown:
case Clustering:
case AutoEncoder:
Expand Down Expand Up @@ -283,7 +295,7 @@ private static TwoDimTable[] convertTables(final Table[] inputTables) {
}
return tables;
}

private static TwoDimTable convertTable(final Table convertedTable){
if(convertedTable == null) return null;
final TwoDimTable table = new TwoDimTable(convertedTable.getTableHeader(), convertedTable.getTableDescription(),
Expand All @@ -295,8 +307,77 @@ private static TwoDimTable convertTable(final Table convertedTable){
table.set(j, i, convertedTable.getCell(i,j));
}
}

return table;
}


private static AUUC createAUUC(AUUC.AUUCType auucType, Table thresholds_and_metric_scores, Table auuc_table, Table aecu_table){
int nbins = thresholds_and_metric_scores.rows();
double[] ths = new double[nbins];
long[] freq = new long[nbins];
AUUC.AUUCType[] auucTypes = AUUC.AUUCType.values();
double[][] uplift = new double[auucTypes.length][nbins];
double[][] upliftNorm = new double[auucTypes.length][nbins];
double[][] upliftRand = new double[auucTypes.length][nbins];
double[] auuc = new double[auucTypes.length];
double[] auucNorm = new double[auucTypes.length];
double[] auucRand = new double[auucTypes.length];
double[] aecu = new double[auucTypes.length];

String[] thrHeader = thresholds_and_metric_scores.getColHeaders();
// threshold column index
int thrIndex = ArrayUtils.find(thrHeader, "thresholds");
int freqIndex = ArrayUtils.find(thrHeader, "n");

// uplift type indices
int[] upliftIndices = new int[auucTypes.length];
int[] upliftNormIndices = new int[auucTypes.length];
int[] upliftRandIndices = new int[auucTypes.length];
for (int i = 1; i < auucTypes.length; i++) {
String auucTypeName = auucTypes[i].name();
upliftIndices[i] = ArrayUtils.find(thrHeader, auucTypeName);
upliftNormIndices[i] = ArrayUtils.find(thrHeader, auucTypeName+"_normalized");
upliftRandIndices[i] = ArrayUtils.find(thrHeader, auucTypeName+"_random");
// AUTO setting
if(auucTypeName.equals(AUUC.AUUCType.nameAuto())){
upliftIndices[0] = upliftIndices[i];
upliftNormIndices[0] = upliftNormIndices[i];
upliftRandIndices[0] = upliftRandIndices[i];
}
}
// fill thresholds and uplift values from table
for (int i = 0; i < thresholds_and_metric_scores.rows(); i++) {
ths[i] = (double) thresholds_and_metric_scores.getCell(thrIndex, i);
freq[i] = (long) thresholds_and_metric_scores.getCell(freqIndex, i);
for (int j = 0; j < auucTypes.length; j++) {
uplift[j][i] = (double) thresholds_and_metric_scores.getCell(upliftIndices[j], i);
upliftNorm[j][i] = (double) thresholds_and_metric_scores.getCell(upliftNormIndices[j], i);
upliftRand[j][i] = (double) thresholds_and_metric_scores.getCell(upliftRandIndices[j], i);
}
}
// fill auuc values and aecu values
String[] auucHeader = auuc_table.getColHeaders();
String[] aecuHeader = aecu_table.getColHeaders();
for (int i = 1; i < auucTypes.length; i++) {
AUUC.AUUCType type = auucTypes[i];
String auucTypeName = type.name();
int colIndex = ArrayUtils.find(auucHeader, auucTypeName);
auuc[i] = (double) auuc_table.getCell(colIndex, 0);
auucNorm[i] = (double) auuc_table.getCell(colIndex, 1);
auucRand[i] = (double) auuc_table.getCell(colIndex, 2);
colIndex = ArrayUtils.find(aecuHeader, auucTypeName);
aecu[i] = (double) aecu_table.getCell(colIndex, 0);
if(auucTypeName.equals(AUUC.AUUCType.nameAuto())){
auuc[0] = auuc[i];
auucNorm[0] = auucNorm[i];
auucRand[0] = auucRand[i];
aecu[0] = aecu[i];
}
}
return new AUUC(ths, freq, auuc, auucNorm, auucRand, aecu, auucType, uplift, upliftNorm, upliftRand);
}

@Override
public boolean hasTreatment() {
return super.hasTreatment();
}
}
36 changes: 24 additions & 12 deletions h2o-core/src/main/java/hex/AUUC.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

import java.util.Arrays;
import java.util.Iterator;
Expand Down Expand Up @@ -79,7 +78,6 @@ public AUUC(Vec probs, Vec y, Vec uplift, AUUCType auucType, int nbins) {
public AUUC(AUUCBuilder bldr, AUUCType auucType) {
this(bldr, true, auucType);
}


public AUUC(double[] customThresholds, Vec probs, Vec y, Vec uplift, AUUCType auucType) {
this(new AUUCImpl(customThresholds).doAll(probs, y, uplift)._bldr, auucType);
Expand Down Expand Up @@ -205,6 +203,25 @@ public AUUC() {
_upliftNormalized = new double[AUUCType.values().length][];
_upliftRandom = new double[AUUCType.values().length][];
}

public AUUC(double[] ths, long[] freq, double[] auuc, double[] auucNorm, double[] auucRand, double[] aecu,
AUUCType auucType, double[][] uplift, double[][] upliftNorm, double[][] upliftRand) {
_nBins = ths.length;
_n = freq[freq.length-1];
_ths = ths;
_frequencyCumsum = freq;
_treatment = _control = _yTreatment = _yControl = _frequency = new long[0];
_auucs = auuc;
_auucsNormalized = auucNorm;
_auucsRandom = auucRand;
_aecu = aecu;
_maxIdx = -1;
_auucType = auucType;
_auucTypeIndx = getIndexByAUUCType(_auucType);
_uplift = uplift;
_upliftNormalized = upliftNorm;
_upliftRandom = upliftRand;
}

public static double[] calculateQuantileThresholds(int groups, Vec preds) {
Frame fr = null;
Expand Down Expand Up @@ -443,19 +460,14 @@ public enum AUUCType {
* @return metric value */
abstract double exec(long treatment, long control, long yTreatment, long yControl );
public double exec(AUUC auc, int idx) { return exec(auc.treatment(idx),auc.control(idx),auc.yTreatment(idx),auc.yControl(idx)); }

public static final AUUCType[] VALUES = values();

public static AUUCType fromString(String strRepr) {
for (AUUCType tc : AUUCType.values()) {
if (tc.toString().equalsIgnoreCase(strRepr)) {
return tc;
}
}
return null;
}
public static final AUUCType[] VALUES_WITHOUT_AUTO = ArrayUtils.remove(values().clone(), ArrayUtils.find(AUUCType.values(), AUTO));

public double maxCriterion(AUUC auuc) { return exec(auuc, maxCriterionIdx(auuc)); }
public static String nameAuto(){
return qini.name();
}

/** Convert a criterion into a threshold index that maximizes the criterion
* @return Threshold index that maximizes the criterion
Expand Down
3 changes: 3 additions & 0 deletions h2o-core/src/main/java/hex/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,7 @@ public String[] features() {
public String weightsName () { return _hasWeights ?_names[weightsIdx()]:null;}
public String offsetName () { return _hasOffset ?_names[offsetIdx()]:null;}
public String foldName () { return _hasFold ?_names[foldIdx()]:null;}
public String treatmentName() { return _hasTreatment ? _names[treatmentIdx()]: null;}
public InteractionBuilder interactionBuilder() { return null; }
// Vec layout is [c1,c2,...,cn, w?, o?, f?, u?, r]
// cn are predictor cols, r is response, w is weights, o is offset, f is fold and t is treatment - these are optional
Expand Down Expand Up @@ -3469,6 +3470,8 @@ protected class H2OModelDescriptor implements ModelDescriptor {
@Override
public String weightsColumn() { return _output.weightsName(); }
@Override
public String treatmentColumn() { return _output.treatmentName(); }
@Override
public String foldColumn() { return _output.foldName(); }
@Override
public ModelCategory getModelCategory() { return _output.getModelCategory(); }
Expand Down
1 change: 1 addition & 0 deletions h2o-core/src/main/java/hex/ModelMetrics.java
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ public double[] perRow(double ds[], float yact[], double weight, double offset,
assert(weight == 1 && offset == 0);
return perRow(ds, yact, m);
}

public void reduce( T mb ) {
_sumsqe += mb._sumsqe;
_count += mb._count;
Expand Down
20 changes: 20 additions & 0 deletions h2o-core/src/main/java/hex/ModelMetricsBinomialUpliftGeneric.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package hex;

import water.fvec.Frame;
import water.util.TwoDimTable;

public class ModelMetricsBinomialUpliftGeneric extends ModelMetricsBinomialUplift {


public final TwoDimTable _thresholds_and_metric_scores;
public final TwoDimTable _auuc_table;
public final TwoDimTable _aecu_table;

public ModelMetricsBinomialUpliftGeneric(Model model, Frame frame, long nobs, String[] domain, double ate, double att, double atc, double sigma, AUUC auuc, CustomMetric customMetric, TwoDimTable thresholds_and_metric_scores, TwoDimTable auuc_table, TwoDimTable aecu_table, final String description) {
super(model, frame, nobs, domain, ate, att, atc, sigma, auuc, customMetric);
_thresholds_and_metric_scores = thresholds_and_metric_scores;
_auuc_table = auuc_table;
_aecu_table = aecu_table;
_description = description;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package water.api.schemas3;

import hex.ModelMetricsBinomialUpliftGeneric;

public class ModelMetricsBinomialUpliftGenericV3<I extends ModelMetricsBinomialUpliftGeneric, S extends ModelMetricsBinomialUpliftGenericV3<I, S>>
extends ModelMetricsBinomialUpliftV3<I, S> {

@Override
public S fillFromImpl(ModelMetricsBinomialUpliftGeneric modelMetrics) {
super.fillFromImpl(modelMetrics);
this.AUUC = modelMetrics._auuc.auuc();
this.auuc_normalized = modelMetrics._auuc.auucNormalized();
this.ate = modelMetrics.ate();
this.att = modelMetrics.att();
this.atc = modelMetrics.atc();
this.qini = modelMetrics.qini();

if (modelMetrics._auuc_table != null) { // Possibly overwrites whatever has been set in the ModelMetricsBinomialV3
this.auuc_table = new TwoDimTableV3().fillFromImpl(modelMetrics._auuc_table);
}
if (modelMetrics._aecu_table != null) { // Possibly overwrites whatever has been set in the ModelMetricsBinomialV3
this.aecu_table = new TwoDimTableV3().fillFromImpl(modelMetrics._aecu_table);
}
if (modelMetrics._thresholds_and_metric_scores != null) { // Possibly overwrites whatever has been set in the ModelMetricsBinomialV3
this.thresholds_and_metric_scores = new TwoDimTableV3().fillFromImpl(modelMetrics._thresholds_and_metric_scores);
}
return (S) this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ public S fillFromImpl(ModelMetricsBinomialUplift modelMetrics) {
qini = auuc.qini();
// Fill TwoDimTable
String[] thresholds = new String[auuc._nBins];
AUUCType metrics[] = AUUCType.VALUES;
metrics = ArrayUtils.remove(metrics, Arrays.asList(metrics).indexOf(AUUCType.AUTO));
AUUCType metrics[] = AUUCType.VALUES_WITHOUT_AUTO;
int metricsLength = metrics.length;
long[] n = new long[auuc._nBins];
double[][] uplift = new double[metricsLength][];
Expand Down Expand Up @@ -89,7 +88,7 @@ public S fillFromImpl(ModelMetricsBinomialUplift modelMetrics) {
types [i + 1 + 2 * metricsLength] = "double";
formats [i + 1 + 2 * metricsLength] = "%f";
}
colHeaders[i + 1 + 2 * metricsLength] = "n"; types[i + 1 + 2 * metricsLength] = "int"; formats[i + 1 + 2 * metricsLength] = "%d";
colHeaders[i + 1 + 2 * metricsLength] = "n"; types[i + 1 + 2 * metricsLength] = "long"; formats[i + 1 + 2 * metricsLength] = "%d";
colHeaders[i + 2 + 2 * metricsLength] = "idx"; types[i + 2 + 2 * metricsLength] = "int"; formats[i + 2 + 2 * metricsLength] = "%d";
TwoDimTable thresholdsByMetrics = new TwoDimTable("Metrics for Thresholds", "Cumulative Uplift metrics for a given percentile", new String[auuc._nBins], colHeaders, types, formats, null );
for (i = 0; i < auuc._nBins; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ water.api.schemas3.ModelMetricsBinomialGLMGenericV3
water.api.schemas3.ModelMetricsBinomialV3
water.api.schemas3.ModelMetricsBinomialGenericV3
water.api.schemas3.ModelMetricsBinomialUpliftV3
water.api.schemas3.ModelMetricsBinomialUpliftGenericV3
water.api.schemas3.ModelMetricsClusteringV3
water.api.schemas3.ModelMetricsHGLMV3
water.api.schemas3.ModelMetricsHGLMGenericV3
Expand Down
3 changes: 2 additions & 1 deletion h2o-docs/src/product/data-science/upliftdrf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ The current version of Uplift DRF is based on the implementation of DRF because
MOJO Support
''''''''''''

Uplift DRF currently doesn't support `MOJOs <../save-and-load-model.html#supported-mojos>`__.
Uplift DRF supports importing and exporting `MOJOs <../save-and-load-model.html#supported-mojos>`__.


Uplift DRF demo
~~~~~~~~~~~~~~~
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ public String foldColumn() {
return null;
}

@Override
public String treatmentColumn() { return null; }

@Override
public ModelCategory getModelCategory() {
return _finalModel._category;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ private static MojoModelMetrics determineModelMetricsType(final MojoModel mojoMo
} else return new MojoModelMetricsOrdinal();
case CoxPH:
return new MojoModelMetricsRegressionCoxPH();
case BinomialUplift:
return new MojoModelMetricsBinomialUplift();
case Unknown:
case Clustering:
case AutoEncoder:
Expand Down
Loading

0 comments on commit c9e1ff2

Please sign in to comment.