diff --git a/h2o-algos/src/main/java/hex/ensemble/StackedEnsembleModel.java b/h2o-algos/src/main/java/hex/ensemble/StackedEnsembleModel.java index d02d5733404a..fbad0f0d25bd 100644 --- a/h2o-algos/src/main/java/hex/ensemble/StackedEnsembleModel.java +++ b/h2o-algos/src/main/java/hex/ensemble/StackedEnsembleModel.java @@ -10,11 +10,10 @@ import water.fvec.NewChunk; import water.fvec.Vec; import water.udf.CFuncRef; -import water.util.FrameUtils; import water.util.Log; import water.util.MRUtils; import water.util.TwoDimTable; -import water.util.fp.Function; +import water.util.fp.Function2; import java.util.*; import java.util.stream.Stream; @@ -297,7 +296,7 @@ public Frame scoreContributions(Frame frame, Key destination_key, Job fun = subFrame -> { + Function2 fun = (subFrame, resultIsFinalFrame) -> { String[] columns = null; String[] colsWithBiasTerm = null; Frame indivContribs = baseLineContributions(subFrame, Key.make(destination_key + "_individual_contribs_for_subframe_"+subFrame._key), j, options, backgroundFrame); @@ -309,7 +308,10 @@ public Frame scoreContributions(Frame frame, Key destination_key, Job one result with the destination key + : Key.make(destination_key + "_for_subframe_"+subFrame._key), + colsWithBiasTerm, null); } finally { indivContribs.delete(true); } @@ -317,10 +319,9 @@ public Frame scoreContributions(Frame frame, Key destination_key, Job H2O.CLOUD._memary.length || // could be map-reduced over the bg frame !ContributionsWithBackgroundFrameTask.enoughMinMemory(numOfUsefulBaseModels() * ContributionsWithBackgroundFrameTask.estimatePerNodeMinimalMemory(frame.numCols(), frame, backgroundFrame))) // or we have no other choice due to memory - return SplitToChunksApplyCombine.splitApplyCombine(frame, fun, destination_key); + return SplitToChunksApplyCombine.splitApplyCombine(frame, (fr -> fun.apply(fr, false)), destination_key); else { - Frame result = fun.apply(frame); - result._key = destination_key; + Frame result = fun.apply(frame, true); DKV.put(result); return result; } diff --git a/h2o-algos/src/test/java/hex/deeplearning/DeepLearningSHAPTest.java b/h2o-algos/src/test/java/hex/deeplearning/DeepLearningSHAPTest.java index 200e5fa9df35..8249d7912c11 100644 --- a/h2o-algos/src/test/java/hex/deeplearning/DeepLearningSHAPTest.java +++ b/h2o-algos/src/test/java/hex/deeplearning/DeepLearningSHAPTest.java @@ -66,7 +66,7 @@ public void testClassificationCompactSHAP() { Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)"); assertTrue(val instanceof ValFrame); res = val.getFrame(); - assertColsEquals(scored, res, 2, 0, 1e-6); + assertColsEquals(scored, res, 2, 0, 1e-4); } finally { fr.delete(); bgFr.delete(); @@ -109,7 +109,7 @@ public void testClassificationOriginalSHAP() { Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)"); assertTrue(val instanceof ValFrame); res = val.getFrame(); - assertColsEquals(scored, res, 2, 0, 1e-6); + assertColsEquals(scored, res, 2, 0, 1e-4); } finally { fr.delete(); bgFr.delete(); @@ -152,7 +152,7 @@ public void testRegressionCompactSHAP() { Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)"); assertTrue(val instanceof ValFrame); res = val.getFrame(); - assertColsEquals(scored, res, 0, 0, 1e-5); + assertColsEquals(scored, res, 0, 0, 1e-4); } finally { fr.delete(); bgFr.delete(); @@ -194,7 +194,7 @@ public void testRegressionOriginalSHAP() { Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)"); assertTrue(val instanceof ValFrame); res = val.getFrame(); - assertColsEquals(scored, res, 0, 0, 1e-5); + assertColsEquals(scored, res, 0, 0, 1e-4); } finally { fr.delete(); bgFr.delete(); @@ -332,7 +332,7 @@ public void testRegressionReLUDeepSHAPComparison() { Frame scored = null; Frame contribs = null; Frame res = null; - double eps = 1e-5; + double eps = 1e-4; try { // Launch Deep Learning DeepLearningParameters params = new DeepLearningParameters(); @@ -776,7 +776,7 @@ public void testClassificationReLUDeepSHAPComparison() { Frame scored = null; Frame contribs = null; Frame res = null; - double eps = 1e-5; + double eps = 1e-4; try { // Launch Deep Learning DeepLearningParameters params = new DeepLearningParameters(); diff --git a/h2o-algos/src/test/java/hex/ensemble/StackedEnsembleSHAPTest.java b/h2o-algos/src/test/java/hex/ensemble/StackedEnsembleSHAPTest.java index 681b9d8e0bcc..dc884013b825 100644 --- a/h2o-algos/src/test/java/hex/ensemble/StackedEnsembleSHAPTest.java +++ b/h2o-algos/src/test/java/hex/ensemble/StackedEnsembleSHAPTest.java @@ -127,7 +127,7 @@ public void testClassificationCompactSHAP() { DKV.remove(Key.make("expContrSum")); assertTrue(val instanceof ValFrame); res = val.getFrame(); - assertColsEquals(scored, res, 2, 0, 1e-6); + assertColsEquals(scored, res, 2, 0, 1e-4); } finally { fr.delete(); bgFr.delete(); @@ -200,7 +200,7 @@ public void testClassificationCompactOutputSpaceSHAP() { Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)"); assertTrue(val instanceof ValFrame); res = val.getFrame(); - assertColsEquals(scored, res, 2, 0, 1e-6); + assertColsEquals(scored, res, 2, 0, 1e-4); } finally { fr.delete(); bgFr.delete(); @@ -237,7 +237,7 @@ public void testRegressionCompactSHAP() { Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)"); assertTrue(val instanceof ValFrame); res = val.getFrame(); - assertColsEquals(scored, res, 0, 0, 1e-5); + assertColsEquals(scored, res, 0, 0, 1e-4); } finally { fr.delete(); bgFr.delete(); @@ -274,7 +274,7 @@ public void testRegressionOriginalSHAP() { Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)"); assertTrue(val instanceof ValFrame); res = val.getFrame(); - assertColsEquals(scored, res, 0, 0, 1e-5); + assertColsEquals(scored, res, 0, 0, 1e-4); } finally { fr.delete(); bgFr.delete();