Skip to content

Commit

Permalink
Add cosine distance test, fix cosine distance calculation, fix scorin…
Browse files Browse the repository at this point in the history
…g task
  • Loading branch information
maurever committed Dec 6, 2024
1 parent f5117be commit ec037a1
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 15 deletions.
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/knn/CosineDistance.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ public double[] calculateValues(double v1, double v2, double[] values) {
@Override
public double result(double[] values) {
assert values.length == valuesLength;
return 1 - (values[0] / (Math.sqrt(values[1] * values[2])));
return 1 - (values[0] / (Math.sqrt(values[1]) * Math.sqrt(values[2])));
}
}
16 changes: 9 additions & 7 deletions h2o-algos/src/main/java/hex/knn/KNNScoringTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,17 @@ public KNNScoringTask(double[] query, int k, int domainSize, KNNDistance distanc

@Override
public void map(Chunk[] cs) {
int inputColNum = cs.length-2;
int inputColNum = cs.length;
int inputRowNum = cs[0]._len;
for (int j = 0; j < inputRowNum; j++) { // go over all input data rows
String inputDataId = _idColumnType == Vec.T_STR ? cs[_idIndex].stringAt(j) : String.valueOf(cs[_idIndex].at8(j));
int inputDataCategory = (int) cs[_responseIndex].at8(j);
for (int i = 0; i < inputRowNum; i++) { // go over all input data rows
String inputDataId = _idColumnType == Vec.T_STR ? cs[_idIndex].stringAt(i) : String.valueOf(cs[_idIndex].at8(i));
int inputDataCategory = (int) cs[_responseIndex].at8(i);
double[] distValues = _distance.initializeValues();
int j = 0;
for (int k = 0; k < inputColNum; k++) { // go over all columns
double queryColData = _queryData[k];
double inputColData = cs[k].atd(j);
if(k == _idIndex || k == _responseIndex) continue;
double queryColData = _queryData[j++];
double inputColData = cs[k].atd(i);
distValues = _distance.calculateValues(queryColData, inputColData, distValues);
}
double dist = _distance.result(distValues);
Expand All @@ -65,7 +67,7 @@ public double[] score(){
scores[i] = scores[i]/_k;
}
}
scores[0] = ArrayUtils.maxIndex(scores);
scores[0] = ArrayUtils.maxIndex(scores)-1;
return scores;
}
}
96 changes: 89 additions & 7 deletions h2o-algos/src/test/java/hex/knn/KNNTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,12 @@ public void testSimpleFrameEuclidean() {
Assert.assertEquals(preds.vec(2).at(0), 1.0, 0);

Assert.assertEquals(preds.vec(0).at8(3), 0);
Assert.assertEquals(preds.vec(1).at(3), 0.5, 0);
Assert.assertEquals(preds.vec(2).at(3), 0.5, 0);
Assert.assertEquals(preds.vec(1).at(3), 1.0, 0);
Assert.assertEquals(preds.vec(2).at(3), 0.0, 0);

ModelMetricsBinomial mm = (ModelMetricsBinomial) ModelMetrics.getFromDKV(knn, parms.train());
Assert.assertNotNull(mm);
Assert.assertEquals(mm.auc(), 0.75, 0);


Assert.assertEquals(mm.auc(), 1.0, 0);
} finally {
if (knn != null){
knn.delete();
Expand All @@ -150,7 +148,6 @@ public void testSimpleFrameEuclidean() {
}
}


@Test
public void testSimpleFrameManhattan() {
KNNModel knn = null;
Expand Down Expand Up @@ -209,8 +206,82 @@ public void testSimpleFrameManhattan() {

ModelMetricsBinomial mm = (ModelMetricsBinomial) ModelMetrics.getFromDKV(knn, parms.train());
Assert.assertNotNull(mm);
Assert.assertEquals(mm.auc(), 0.75, 0);
Assert.assertEquals(mm.auc(), 1.0, 0);
} finally {
if (knn != null){
knn.delete();
}
if (distances != null){
distances.delete();
}
if(fr != null) {
fr.delete();
}
if(preds != null){
preds.delete();
}
}
}

@Test
public void testSimpleFrameCosine() {
KNNModel knn = null;
Frame fr = null;
Frame preds = null;
Frame distances = null;
try {
fr = generateSimpleFrameForCosine();

String idColumn = "id";
String response = "class";

DKV.put(fr);
KNNModel.KNNParameters parms = new KNNModel.KNNParameters();
parms._train = fr._key;
parms._k = 2;
parms._distance = new CosineDistance();
parms._response_column = response;
parms._id_column = idColumn;
parms._auc_type = MultinomialAucType.MACRO_OVR;

parms._seed = 42;
KNN job = new KNN(parms);
knn = job.trainModel().get();
Assert.assertNotNull(knn);

distances = knn._output.getDistances();
Assert.assertNotNull(distances);

Assert.assertEquals(distances.vec(0).at8(0), 1);
Assert.assertEquals(distances.vec(1).at(0), 0.0, 10e-5);
Assert.assertEquals(distances.vec(2).at(0), 1.0, 10e-5);
Assert.assertEquals(distances.vec(3).at8(0), 1);
Assert.assertEquals(distances.vec(4).at8(0), 3);
Assert.assertEquals(distances.vec(5).at8(0), 1);
Assert.assertEquals(distances.vec(6).at8(0),0);

Assert.assertEquals(distances.vec(0).at8(1), 2);
Assert.assertEquals(distances.vec(1).at(1), 0.0, 10e-5);
Assert.assertEquals(distances.vec(2).at(1), 0.105573, 10e-5);
Assert.assertEquals(distances.vec(3).at8(1), 2);
Assert.assertEquals(distances.vec(4).at8(1), 4);
Assert.assertEquals(distances.vec(5).at8(1), 1);
Assert.assertEquals(distances.vec(6).at8(1), 0);

preds = knn.score(fr);
Assert.assertNotNull(preds);

Assert.assertEquals(preds.vec(0).at8(0), 1);
Assert.assertEquals(preds.vec(1).at(0), 0.5, 0);
Assert.assertEquals(preds.vec(2).at(0), 0.5, 0);

Assert.assertEquals(preds.vec(0).at8(3), 0);
Assert.assertEquals(preds.vec(1).at(3), 1.0, 0);
Assert.assertEquals(preds.vec(2).at(3), 0.0, 0);

ModelMetricsBinomial mm = (ModelMetricsBinomial) ModelMetrics.getFromDKV(knn, parms.train());
Assert.assertNotNull(mm);
Assert.assertEquals(mm.auc(), 1.0, 0);
} finally {
if (knn != null){
knn.delete();
Expand All @@ -237,4 +308,15 @@ private Frame generateSimpleFrame(){
.withDataForCol(3, ar("1", "1", "0", "0"))
.build();
}

private Frame generateSimpleFrameForCosine(){
return new TestFrameBuilder()
.withColNames("id", "C0", "C1", "class")
.withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM, Vec.T_CAT)
.withDataForCol(0, ari(1, 2, 3, 4))
.withDataForCol(1, ard(0.0, 1.0, 2.0, 3.0))
.withDataForCol(2, ard(-1.0, 1.0, 0.0, 1.0))
.withDataForCol(3, ar("1", "1", "0", "0"))
.build();
}
}

0 comments on commit ec037a1

Please sign in to comment.