diff --git a/h2o-algos/src/test/java/hex/hglm/HGLMUtilTest.java b/h2o-algos/src/test/java/hex/hglm/HGLMUtilTest.java index 5315c33272c0..a94900ed922c 100644 --- a/h2o-algos/src/test/java/hex/hglm/HGLMUtilTest.java +++ b/h2o-algos/src/test/java/hex/hglm/HGLMUtilTest.java @@ -9,6 +9,7 @@ import static hex.hglm.HGLMTask.ComputationEngineTask.sumAfjAfjAfjTYj; import static hex.hglm.HGLMUtils.*; +import static hex.hglm.MetricBuilderHGLM.calHGLMllg; import static org.junit.Assert.assertEquals; @RunWith(H2ORunner.class) @@ -196,4 +197,35 @@ public void checkCumSum(int numLevel2, int numFixedLength) { TestUtil.checkDoubleArrays(sumAfjTAfj, sumAfjTAfjMat.getArray(), 1e-12); TestUtil.checkArrays(sumAfjTYj, sumAfjTYjMat.transpose().getArray()[0], 1e-12); } + + @Test + public void testCalLlg() { + checkCalLlg(1, 2, 10, 1); + checkCalLlg(10, 20, 100, 3); + checkCalLlg(18, 8, 20, 1); + checkCalLlg(8, 8, 10, 1); + } + + public void checkCalLlg(int numRandomCoeff, int numLevel2, int nobs, int multiplier) { + double varResidual = Math.abs(genRandomMatrix(1, 1, 123)[0][0]); + double yMinsXFixSquare = Math.abs(genRandomMatrix(1,1, 124)[0][0]); + double[][] tmat = genSymPsdMatrix(numRandomCoeff, 124, multiplier); + Matrix tmatInv = new Matrix(tmat).inverse(); + double oneOVar = 1.0/varResidual; + double oneOVarSq = oneOVar*oneOVar; + double llgManual = nobs*Math.log(2*Math.PI)+oneOVar*yMinsXFixSquare; + double[][] yMinusXFixTimesZ = genRandomMatrix(numLevel2, numRandomCoeff, 126); + double[][][] zjTimesZj = new double[numLevel2][][]; + + for (int ind2=0; ind2