diff --git a/h2o-algos/src/main/java/hex/knn/KNN.java b/h2o-algos/src/main/java/hex/knn/KNN.java index dcfb2cf46dc7..ae2ede5def64 100644 --- a/h2o-algos/src/main/java/hex/knn/KNN.java +++ b/h2o-algos/src/main/java/hex/knn/KNN.java @@ -2,6 +2,7 @@ import hex.*; import water.DKV; +import water.Key; import water.Scope; import water.fvec.Chunk; import water.fvec.Frame; @@ -47,7 +48,7 @@ class KNNDriver extends Driver { @Override public void computeImpl() { KNNModel model = null; - Frame result = null; + Frame result = new Frame(Key.make("KNN_distances")); Frame tmpResult = null; try { init(true); // Initialize parameters @@ -64,6 +65,7 @@ public void computeImpl() { int responseColumnIndex = train.find(responseColumn); int nChunks = train.anyVec().nChunks(); int nCols = train.numCols(); + // split data into chunks to calculate distances in parallel task for (int i = 0; i < nChunks; i++) { Chunk[] query = new Chunk[nCols]; for (int j = 0; j < nCols; j++) { @@ -71,13 +73,9 @@ public void computeImpl() { } KNNDistanceTask task = new KNNDistanceTask(_parms._k, query, _parms._distance, idColumnIndex, idColumn, idType, responseColumnIndex, responseColumn); tmpResult = task.doAll(train).outputFrame(); - if (result == null) { - result = tmpResult; - } else { - result = result.add(tmpResult); - } + // merge result from a chunk + result = result.add(tmpResult); } - result = result.deepCopy("KNN_distances"); DKV.put(result._key, result); model._output.setDistancesKey(result._key); Scope.untrack(result); diff --git a/h2o-algos/src/main/java/hex/knn/KNNDistanceTask.java b/h2o-algos/src/main/java/hex/knn/KNNDistanceTask.java index 13bc78ee215f..31bca772ede3 100644 --- a/h2o-algos/src/main/java/hex/knn/KNNDistanceTask.java +++ b/h2o-algos/src/main/java/hex/knn/KNNDistanceTask.java @@ -21,7 +21,7 @@ public class KNNDistanceTask extends MRTask { public byte _idColumnType; /** - * + * Calculate distances dor a particular chunk */ public KNNDistanceTask(int k, Chunk[] query, KNNDistance distance, int idIndex, String idColumn, byte idType, int responseIndex, String responseColumn){ this._k = k; @@ -70,9 +70,13 @@ public void reduce(KNNDistanceTask mrt) { this._topNNeighboursMaps.reduce(inputMap); } + /** + * Get data from maps to Frame + * @param vecs + * @return filled array of vecs with calculated data + */ public Vec[] fillVecs(Vec[] vecs){ for (int i = 0; i < vecs[0].length(); i++) { - // id is on 0 index in vecs String id = _idColumnType == Vec.T_STR ? vecs[0].stringAt(i) : String.valueOf(vecs[0].at8(i)); TopNTreeMap topNMap = _topNNeighboursMaps.get(id); Iterator distances = topNMap.keySet().stream().iterator(); @@ -92,6 +96,10 @@ public Vec[] fillVecs(Vec[] vecs){ return vecs; } + /** + * Generate output frame with calculated distances. + * @return + */ public Frame outputFrame() { int newVecsSize = _k*3+1; Vec[] vecs = new Vec[newVecsSize]; @@ -108,9 +116,10 @@ public Frame outputFrame() { vecs[0] = id; names[0] = _idColumn; for (int i = 1; i < _k+1; i++) { - names[i] = "dist_"+i; - names[_k+i] = _idColumn+"_"+i; - names[2*_k+i] = _responseColumn+"_"+i; + // names of columns + names[i] = "dist_"+i; // this could be customized + names[_k+i] = _idColumn+"_"+i; // this could be customized + names[2*_k+i] = _responseColumn+"_"+i; // this could be customized vecs[i] = id.makeZero(); vecs[i] = vecs[i].toNumericVec(); vecs[_k+i] = id.makeZero(); diff --git a/h2o-algos/src/main/java/hex/knn/KNNScoringTask.java b/h2o-algos/src/main/java/hex/knn/KNNScoringTask.java index 7758dd6c2234..8838e95c288c 100644 --- a/h2o-algos/src/main/java/hex/knn/KNNScoringTask.java +++ b/h2o-algos/src/main/java/hex/knn/KNNScoringTask.java @@ -18,7 +18,7 @@ public class KNNScoringTask extends MRTask { public int _domainSize; /** - * + * Go through the whole input frame to find the k near distances and score based on them. */ public KNNScoringTask(double[] query, int k, int domainSize, KNNDistance distance, int idIndex, byte idType, int responseIndex){ this._k = k; @@ -62,11 +62,13 @@ public double[] score(){ for (int value: _distancesMap.values()){ scores[value+1]++; } + // normalize the result score by _k for (int i = 1; i < _domainSize+1; i++) { if(scores[i] != 0) { scores[i] = scores[i]/_k; } } + // decide the class by the max score scores[0] = ArrayUtils.maxIndex(scores)-1; return scores; }