Skip to content

Commit

Permalink
Add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
maurever committed Dec 11, 2024
1 parent 02a7c3b commit a8d2ae9
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
12 changes: 5 additions & 7 deletions h2o-algos/src/main/java/hex/knn/KNN.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import hex.*;
import water.DKV;
import water.Key;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;
Expand Down Expand Up @@ -47,7 +48,7 @@ class KNNDriver extends Driver {
@Override
public void computeImpl() {
KNNModel model = null;
Frame result = null;
Frame result = new Frame(Key.make("KNN_distances"));
Frame tmpResult = null;
try {
init(true); // Initialize parameters
Expand All @@ -64,20 +65,17 @@ public void computeImpl() {
int responseColumnIndex = train.find(responseColumn);
int nChunks = train.anyVec().nChunks();
int nCols = train.numCols();
// split data into chunks to calculate distances in parallel task
for (int i = 0; i < nChunks; i++) {
Chunk[] query = new Chunk[nCols];
for (int j = 0; j < nCols; j++) {
query[j] = train.vec(j).chunkForChunkIdx(i).deepCopy();
}
KNNDistanceTask task = new KNNDistanceTask(_parms._k, query, _parms._distance, idColumnIndex, idColumn, idType, responseColumnIndex, responseColumn);
tmpResult = task.doAll(train).outputFrame();
if (result == null) {
result = tmpResult;
} else {
result = result.add(tmpResult);
}
// merge result from a chunk
result = result.add(tmpResult);
}
result = result.deepCopy("KNN_distances");
DKV.put(result._key, result);
model._output.setDistancesKey(result._key);
Scope.untrack(result);
Expand Down
19 changes: 14 additions & 5 deletions h2o-algos/src/main/java/hex/knn/KNNDistanceTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class KNNDistanceTask extends MRTask<KNNDistanceTask> {
public byte _idColumnType;

/**
*
* Calculate distances dor a particular chunk
*/
public KNNDistanceTask(int k, Chunk[] query, KNNDistance distance, int idIndex, String idColumn, byte idType, int responseIndex, String responseColumn){
this._k = k;
Expand Down Expand Up @@ -70,9 +70,13 @@ public void reduce(KNNDistanceTask mrt) {
this._topNNeighboursMaps.reduce(inputMap);
}

/**
* Get data from maps to Frame
* @param vecs
* @return filled array of vecs with calculated data
*/
public Vec[] fillVecs(Vec[] vecs){
for (int i = 0; i < vecs[0].length(); i++) {
// id is on 0 index in vecs
String id = _idColumnType == Vec.T_STR ? vecs[0].stringAt(i) : String.valueOf(vecs[0].at8(i));
TopNTreeMap<KNNKey, Object> topNMap = _topNNeighboursMaps.get(id);
Iterator<KNNKey> distances = topNMap.keySet().stream().iterator();
Expand All @@ -92,6 +96,10 @@ public Vec[] fillVecs(Vec[] vecs){
return vecs;
}

/**
* Generate output frame with calculated distances.
* @return
*/
public Frame outputFrame() {
int newVecsSize = _k*3+1;
Vec[] vecs = new Vec[newVecsSize];
Expand All @@ -108,9 +116,10 @@ public Frame outputFrame() {
vecs[0] = id;
names[0] = _idColumn;
for (int i = 1; i < _k+1; i++) {
names[i] = "dist_"+i;
names[_k+i] = _idColumn+"_"+i;
names[2*_k+i] = _responseColumn+"_"+i;
// names of columns
names[i] = "dist_"+i; // this could be customized
names[_k+i] = _idColumn+"_"+i; // this could be customized
names[2*_k+i] = _responseColumn+"_"+i; // this could be customized
vecs[i] = id.makeZero();
vecs[i] = vecs[i].toNumericVec();
vecs[_k+i] = id.makeZero();
Expand Down
4 changes: 3 additions & 1 deletion h2o-algos/src/main/java/hex/knn/KNNScoringTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class KNNScoringTask extends MRTask<KNNScoringTask> {
public int _domainSize;

/**
*
* Go through the whole input frame to find the k near distances and score based on them.
*/
public KNNScoringTask(double[] query, int k, int domainSize, KNNDistance distance, int idIndex, byte idType, int responseIndex){
this._k = k;
Expand Down Expand Up @@ -62,11 +62,13 @@ public double[] score(){
for (int value: _distancesMap.values()){
scores[value+1]++;
}
// normalize the result score by _k
for (int i = 1; i < _domainSize+1; i++) {
if(scores[i] != 0) {
scores[i] = scores[i]/_k;
}
}
// decide the class by the max score
scores[0] = ArrayUtils.maxIndex(scores)-1;
return scores;
}
Expand Down

0 comments on commit a8d2ae9

Please sign in to comment.