diff --git a/h2o-algos/src/main/java/hex/ensemble/StackedEnsembleModel.java b/h2o-algos/src/main/java/hex/ensemble/StackedEnsembleModel.java
index d02d5733404a..fbad0f0d25bd 100644
--- a/h2o-algos/src/main/java/hex/ensemble/StackedEnsembleModel.java
+++ b/h2o-algos/src/main/java/hex/ensemble/StackedEnsembleModel.java
@@ -10,11 +10,10 @@
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.udf.CFuncRef;
-import water.util.FrameUtils;
import water.util.Log;
import water.util.MRUtils;
import water.util.TwoDimTable;
-import water.util.fp.Function;
+import water.util.fp.Function2;
import java.util.*;
import java.util.stream.Stream;
@@ -297,7 +296,7 @@ public Frame scoreContributions(Frame frame, Key destination_key, Job fun = subFrame -> {
+ Function2 fun = (subFrame, resultIsFinalFrame) -> {
String[] columns = null;
String[] colsWithBiasTerm = null;
Frame indivContribs = baseLineContributions(subFrame, Key.make(destination_key + "_individual_contribs_for_subframe_"+subFrame._key), j, options, backgroundFrame);
@@ -309,7 +308,10 @@ public Frame scoreContributions(Frame frame, Key destination_key, Job one result with the destination key
+ : Key.make(destination_key + "_for_subframe_"+subFrame._key),
+ colsWithBiasTerm, null);
} finally {
indivContribs.delete(true);
}
@@ -317,10 +319,9 @@ public Frame scoreContributions(Frame frame, Key destination_key, Job H2O.CLOUD._memary.length || // could be map-reduced over the bg frame
!ContributionsWithBackgroundFrameTask.enoughMinMemory(numOfUsefulBaseModels() *
ContributionsWithBackgroundFrameTask.estimatePerNodeMinimalMemory(frame.numCols(), frame, backgroundFrame))) // or we have no other choice due to memory
- return SplitToChunksApplyCombine.splitApplyCombine(frame, fun, destination_key);
+ return SplitToChunksApplyCombine.splitApplyCombine(frame, (fr -> fun.apply(fr, false)), destination_key);
else {
- Frame result = fun.apply(frame);
- result._key = destination_key;
+ Frame result = fun.apply(frame, true);
DKV.put(result);
return result;
}
diff --git a/h2o-algos/src/test/java/hex/deeplearning/DeepLearningSHAPTest.java b/h2o-algos/src/test/java/hex/deeplearning/DeepLearningSHAPTest.java
index 200e5fa9df35..8249d7912c11 100644
--- a/h2o-algos/src/test/java/hex/deeplearning/DeepLearningSHAPTest.java
+++ b/h2o-algos/src/test/java/hex/deeplearning/DeepLearningSHAPTest.java
@@ -66,7 +66,7 @@ public void testClassificationCompactSHAP() {
Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
- assertColsEquals(scored, res, 2, 0, 1e-6);
+ assertColsEquals(scored, res, 2, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
@@ -109,7 +109,7 @@ public void testClassificationOriginalSHAP() {
Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
- assertColsEquals(scored, res, 2, 0, 1e-6);
+ assertColsEquals(scored, res, 2, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
@@ -152,7 +152,7 @@ public void testRegressionCompactSHAP() {
Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
- assertColsEquals(scored, res, 0, 0, 1e-5);
+ assertColsEquals(scored, res, 0, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
@@ -194,7 +194,7 @@ public void testRegressionOriginalSHAP() {
Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
- assertColsEquals(scored, res, 0, 0, 1e-5);
+ assertColsEquals(scored, res, 0, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
@@ -332,7 +332,7 @@ public void testRegressionReLUDeepSHAPComparison() {
Frame scored = null;
Frame contribs = null;
Frame res = null;
- double eps = 1e-5;
+ double eps = 1e-4;
try {
// Launch Deep Learning
DeepLearningParameters params = new DeepLearningParameters();
@@ -776,7 +776,7 @@ public void testClassificationReLUDeepSHAPComparison() {
Frame scored = null;
Frame contribs = null;
Frame res = null;
- double eps = 1e-5;
+ double eps = 1e-4;
try {
// Launch Deep Learning
DeepLearningParameters params = new DeepLearningParameters();
diff --git a/h2o-algos/src/test/java/hex/ensemble/StackedEnsembleSHAPTest.java b/h2o-algos/src/test/java/hex/ensemble/StackedEnsembleSHAPTest.java
index 681b9d8e0bcc..dc884013b825 100644
--- a/h2o-algos/src/test/java/hex/ensemble/StackedEnsembleSHAPTest.java
+++ b/h2o-algos/src/test/java/hex/ensemble/StackedEnsembleSHAPTest.java
@@ -127,7 +127,7 @@ public void testClassificationCompactSHAP() {
DKV.remove(Key.make("expContrSum"));
assertTrue(val instanceof ValFrame);
res = val.getFrame();
- assertColsEquals(scored, res, 2, 0, 1e-6);
+ assertColsEquals(scored, res, 2, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
@@ -200,7 +200,7 @@ public void testClassificationCompactOutputSpaceSHAP() {
Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
- assertColsEquals(scored, res, 2, 0, 1e-6);
+ assertColsEquals(scored, res, 2, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
@@ -237,7 +237,7 @@ public void testRegressionCompactSHAP() {
Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
- assertColsEquals(scored, res, 0, 0, 1e-5);
+ assertColsEquals(scored, res, 0, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();
@@ -274,7 +274,7 @@ public void testRegressionOriginalSHAP() {
Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
- assertColsEquals(scored, res, 0, 0, 1e-5);
+ assertColsEquals(scored, res, 0, 0, 1e-4);
} finally {
fr.delete();
bgFr.delete();