Skip to content

Commit

Permalink
Fix shap pt.4 (#15823)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfryda authored Oct 13, 2023
1 parent ca34e97 commit 5a720e5
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
15 changes: 8 additions & 7 deletions h2o-algos/src/main/java/hex/ensemble/StackedEnsembleModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.udf.CFuncRef;
import water.util.FrameUtils;
import water.util.Log;
import water.util.MRUtils;
import water.util.TwoDimTable;
import water.util.fp.Function;
import water.util.fp.Function2;

import java.util.*;
import java.util.stream.Stream;
Expand Down Expand Up @@ -297,7 +296,7 @@ public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Fra
if (options._outputPerReference)
return baseLineContributions(frame, destination_key, j , options, backgroundFrame);

Function<Frame, Frame> fun = subFrame -> {
Function2<Frame, Boolean, Frame> fun = (subFrame, resultIsFinalFrame) -> {
String[] columns = null;
String[] colsWithBiasTerm = null;
Frame indivContribs = baseLineContributions(subFrame, Key.make(destination_key + "_individual_contribs_for_subframe_"+subFrame._key), j, options, backgroundFrame);
Expand All @@ -309,18 +308,20 @@ public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Fra
return new ContributionsMeanAggregator(j,(int) subFrame.numRows(), columns.length + 1 /* (bias term) */, (int)backgroundFrame.numRows())
.withPostMapAction(JobUpdatePostMap.forJob(j))
.doAll(columns.length + 1, Vec.T_NUM, indivContribs)
.outputFrame(Key.make(destination_key + "_for_subframe_"+subFrame._key), colsWithBiasTerm, null);
.outputFrame(resultIsFinalFrame
? destination_key // no subframes -> one result with the destination key
: Key.make(destination_key + "_for_subframe_"+subFrame._key),
colsWithBiasTerm, null);
} finally {
indivContribs.delete(true);
}
};
if (backgroundFrame.anyVec().nChunks() > H2O.CLOUD._memary.length || // could be map-reduced over the bg frame
!ContributionsWithBackgroundFrameTask.enoughMinMemory(numOfUsefulBaseModels() *
ContributionsWithBackgroundFrameTask.estimatePerNodeMinimalMemory(frame.numCols(), frame, backgroundFrame))) // or we have no other choice due to memory
return SplitToChunksApplyCombine.splitApplyCombine(frame, fun, destination_key);
return SplitToChunksApplyCombine.splitApplyCombine(frame, (fr -> fun.apply(fr, false)), destination_key);
else {
Frame result = fun.apply(frame);
result._key = destination_key;
Frame result = fun.apply(frame, true);
DKV.put(result);
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void testClassificationCompactSHAP() {
Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
assertColsEquals(scored, res, 2, 0, 1e-6);
assertColsEquals(scored, res, 2, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
Expand Down Expand Up @@ -109,7 +109,7 @@ public void testClassificationOriginalSHAP() {
Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
assertColsEquals(scored, res, 2, 0, 1e-6);
assertColsEquals(scored, res, 2, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
Expand Down Expand Up @@ -152,7 +152,7 @@ public void testRegressionCompactSHAP() {
Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
assertColsEquals(scored, res, 0, 0, 1e-5);
assertColsEquals(scored, res, 0, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
Expand Down Expand Up @@ -194,7 +194,7 @@ public void testRegressionOriginalSHAP() {
Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
assertColsEquals(scored, res, 0, 0, 1e-5);
assertColsEquals(scored, res, 0, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
Expand Down Expand Up @@ -332,7 +332,7 @@ public void testRegressionReLUDeepSHAPComparison() {
Frame scored = null;
Frame contribs = null;
Frame res = null;
double eps = 1e-5;
double eps = 1e-4;
try {
// Launch Deep Learning
DeepLearningParameters params = new DeepLearningParameters();
Expand Down Expand Up @@ -776,7 +776,7 @@ public void testClassificationReLUDeepSHAPComparison() {
Frame scored = null;
Frame contribs = null;
Frame res = null;
double eps = 1e-5;
double eps = 1e-4;
try {
// Launch Deep Learning
DeepLearningParameters params = new DeepLearningParameters();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void testClassificationCompactSHAP() {
DKV.remove(Key.make("expContrSum"));
assertTrue(val instanceof ValFrame);
res = val.getFrame();
assertColsEquals(scored, res, 2, 0, 1e-6);
assertColsEquals(scored, res, 2, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
Expand Down Expand Up @@ -200,7 +200,7 @@ public void testClassificationCompactOutputSpaceSHAP() {
Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
assertColsEquals(scored, res, 2, 0, 1e-6);
assertColsEquals(scored, res, 2, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
Expand Down Expand Up @@ -237,7 +237,7 @@ public void testRegressionCompactSHAP() {
Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
assertColsEquals(scored, res, 0, 0, 1e-5);
assertColsEquals(scored, res, 0, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
Expand Down Expand Up @@ -274,7 +274,7 @@ public void testRegressionOriginalSHAP() {
Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
assertColsEquals(scored, res, 0, 0, 1e-5);
assertColsEquals(scored, res, 0, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
Expand Down

0 comments on commit 5a720e5

Please sign in to comment.