From ec037a1585796e6c17b6b471e0b05d1f395b2394 Mon Sep 17 00:00:00 2001 From: Veronika Maurerova Date: Fri, 6 Dec 2024 15:00:47 +0100 Subject: [PATCH] Add cosine distance test, fix cosine distance calculation, fix scoring task --- .../src/main/java/hex/knn/CosineDistance.java | 2 +- .../src/main/java/hex/knn/KNNScoringTask.java | 16 ++-- h2o-algos/src/test/java/hex/knn/KNNTest.java | 96 +++++++++++++++++-- 3 files changed, 99 insertions(+), 15 deletions(-) diff --git a/h2o-algos/src/main/java/hex/knn/CosineDistance.java b/h2o-algos/src/main/java/hex/knn/CosineDistance.java index a6aea8bf792b..a324cfb10e81 100644 --- a/h2o-algos/src/main/java/hex/knn/CosineDistance.java +++ b/h2o-algos/src/main/java/hex/knn/CosineDistance.java @@ -23,6 +23,6 @@ public double[] calculateValues(double v1, double v2, double[] values) { @Override public double result(double[] values) { assert values.length == valuesLength; - return 1 - (values[0] / (Math.sqrt(values[1] * values[2]))); + return 1 - (values[0] / (Math.sqrt(values[1]) * Math.sqrt(values[2]))); } } diff --git a/h2o-algos/src/main/java/hex/knn/KNNScoringTask.java b/h2o-algos/src/main/java/hex/knn/KNNScoringTask.java index 06eb66cadf2d..49c9f9418535 100644 --- a/h2o-algos/src/main/java/hex/knn/KNNScoringTask.java +++ b/h2o-algos/src/main/java/hex/knn/KNNScoringTask.java @@ -33,15 +33,17 @@ public KNNScoringTask(double[] query, int k, int domainSize, KNNDistance distanc @Override public void map(Chunk[] cs) { - int inputColNum = cs.length-2; + int inputColNum = cs.length; int inputRowNum = cs[0]._len; - for (int j = 0; j < inputRowNum; j++) { // go over all input data rows - String inputDataId = _idColumnType == Vec.T_STR ? cs[_idIndex].stringAt(j) : String.valueOf(cs[_idIndex].at8(j)); - int inputDataCategory = (int) cs[_responseIndex].at8(j); + for (int i = 0; i < inputRowNum; i++) { // go over all input data rows + String inputDataId = _idColumnType == Vec.T_STR ? cs[_idIndex].stringAt(i) : String.valueOf(cs[_idIndex].at8(i)); + int inputDataCategory = (int) cs[_responseIndex].at8(i); double[] distValues = _distance.initializeValues(); + int j = 0; for (int k = 0; k < inputColNum; k++) { // go over all columns - double queryColData = _queryData[k]; - double inputColData = cs[k].atd(j); + if(k == _idIndex || k == _responseIndex) continue; + double queryColData = _queryData[j++]; + double inputColData = cs[k].atd(i); distValues = _distance.calculateValues(queryColData, inputColData, distValues); } double dist = _distance.result(distValues); @@ -65,7 +67,7 @@ public double[] score(){ scores[i] = scores[i]/_k; } } - scores[0] = ArrayUtils.maxIndex(scores); + scores[0] = ArrayUtils.maxIndex(scores)-1; return scores; } } diff --git a/h2o-algos/src/test/java/hex/knn/KNNTest.java b/h2o-algos/src/test/java/hex/knn/KNNTest.java index c01050715504..68a80e91fdab 100644 --- a/h2o-algos/src/test/java/hex/knn/KNNTest.java +++ b/h2o-algos/src/test/java/hex/knn/KNNTest.java @@ -126,14 +126,12 @@ public void testSimpleFrameEuclidean() { Assert.assertEquals(preds.vec(2).at(0), 1.0, 0); Assert.assertEquals(preds.vec(0).at8(3), 0); - Assert.assertEquals(preds.vec(1).at(3), 0.5, 0); - Assert.assertEquals(preds.vec(2).at(3), 0.5, 0); + Assert.assertEquals(preds.vec(1).at(3), 1.0, 0); + Assert.assertEquals(preds.vec(2).at(3), 0.0, 0); ModelMetricsBinomial mm = (ModelMetricsBinomial) ModelMetrics.getFromDKV(knn, parms.train()); Assert.assertNotNull(mm); - Assert.assertEquals(mm.auc(), 0.75, 0); - - + Assert.assertEquals(mm.auc(), 1.0, 0); } finally { if (knn != null){ knn.delete(); @@ -150,7 +148,6 @@ public void testSimpleFrameEuclidean() { } } - @Test public void testSimpleFrameManhattan() { KNNModel knn = null; @@ -209,8 +206,82 @@ public void testSimpleFrameManhattan() { ModelMetricsBinomial mm = (ModelMetricsBinomial) ModelMetrics.getFromDKV(knn, parms.train()); Assert.assertNotNull(mm); - Assert.assertEquals(mm.auc(), 0.75, 0); + Assert.assertEquals(mm.auc(), 1.0, 0); + } finally { + if (knn != null){ + knn.delete(); + } + if (distances != null){ + distances.delete(); + } + if(fr != null) { + fr.delete(); + } + if(preds != null){ + preds.delete(); + } + } + } + + @Test + public void testSimpleFrameCosine() { + KNNModel knn = null; + Frame fr = null; + Frame preds = null; + Frame distances = null; + try { + fr = generateSimpleFrameForCosine(); + + String idColumn = "id"; + String response = "class"; + + DKV.put(fr); + KNNModel.KNNParameters parms = new KNNModel.KNNParameters(); + parms._train = fr._key; + parms._k = 2; + parms._distance = new CosineDistance(); + parms._response_column = response; + parms._id_column = idColumn; + parms._auc_type = MultinomialAucType.MACRO_OVR; + + parms._seed = 42; + KNN job = new KNN(parms); + knn = job.trainModel().get(); + Assert.assertNotNull(knn); + + distances = knn._output.getDistances(); + Assert.assertNotNull(distances); + Assert.assertEquals(distances.vec(0).at8(0), 1); + Assert.assertEquals(distances.vec(1).at(0), 0.0, 10e-5); + Assert.assertEquals(distances.vec(2).at(0), 1.0, 10e-5); + Assert.assertEquals(distances.vec(3).at8(0), 1); + Assert.assertEquals(distances.vec(4).at8(0), 3); + Assert.assertEquals(distances.vec(5).at8(0), 1); + Assert.assertEquals(distances.vec(6).at8(0),0); + + Assert.assertEquals(distances.vec(0).at8(1), 2); + Assert.assertEquals(distances.vec(1).at(1), 0.0, 10e-5); + Assert.assertEquals(distances.vec(2).at(1), 0.105573, 10e-5); + Assert.assertEquals(distances.vec(3).at8(1), 2); + Assert.assertEquals(distances.vec(4).at8(1), 4); + Assert.assertEquals(distances.vec(5).at8(1), 1); + Assert.assertEquals(distances.vec(6).at8(1), 0); + + preds = knn.score(fr); + Assert.assertNotNull(preds); + + Assert.assertEquals(preds.vec(0).at8(0), 1); + Assert.assertEquals(preds.vec(1).at(0), 0.5, 0); + Assert.assertEquals(preds.vec(2).at(0), 0.5, 0); + + Assert.assertEquals(preds.vec(0).at8(3), 0); + Assert.assertEquals(preds.vec(1).at(3), 1.0, 0); + Assert.assertEquals(preds.vec(2).at(3), 0.0, 0); + + ModelMetricsBinomial mm = (ModelMetricsBinomial) ModelMetrics.getFromDKV(knn, parms.train()); + Assert.assertNotNull(mm); + Assert.assertEquals(mm.auc(), 1.0, 0); } finally { if (knn != null){ knn.delete(); @@ -237,4 +308,15 @@ private Frame generateSimpleFrame(){ .withDataForCol(3, ar("1", "1", "0", "0")) .build(); } + + private Frame generateSimpleFrameForCosine(){ + return new TestFrameBuilder() + .withColNames("id", "C0", "C1", "class") + .withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM, Vec.T_CAT) + .withDataForCol(0, ari(1, 2, 3, 4)) + .withDataForCol(1, ard(0.0, 1.0, 2.0, 3.0)) + .withDataForCol(2, ard(-1.0, 1.0, 0.0, 1.0)) + .withDataForCol(3, ar("1", "1", "0", "0")) + .build(); + } }