Skip to content

Commit

Permalink
minor fix and reordering
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandr-Solovev committed Dec 3, 2024
1 parent 1bef482 commit 8d103c1
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 85 deletions.
23 changes: 19 additions & 4 deletions ...gorithms/k_nearest_neighbors/kdtree_knn_classification_predict_dense_default_batch_impl.i
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "src/algorithms/k_nearest_neighbors/kdtree_knn_classification_model_impl.h"
#include "src/algorithms/k_nearest_neighbors/kdtree_knn_impl.i"
#include "src/algorithms/k_nearest_neighbors/knn_heap.h"
#include <iostream>

namespace daal
{
Expand Down Expand Up @@ -145,11 +146,27 @@ Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, cpu>::compu

if (par3 == NULL) return Status(ErrorNullParameterNotSupported);

const Model * const model = static_cast<const Model *>(m);
const auto & kdTreeTable = *(model->impl()->getKDTreeTable());
const Model * const model = static_cast<const Model *>(m);
const KDTreeTable & kdTreeTable = *(model->impl()->getKDTreeTable());
const KDTreeNode * const nodes = static_cast<const KDTreeNode *>(kdTreeTable.getArray());
const size_t xRowCount = x->getNumberOfRows();

const algorithmFpType base = 2.0;
const algorithmFpType baseInPower = Math::sPowx(base, Math::sCeil(Math::sLog(base * xRowCount - 1) / Math::sLog(base)));
DAAL_ASSERT(baseInPower > 0)
const size_t maxKDTreeNodeCount = ((size_t)baseInPower * __KDTREE_MAX_NODE_COUNT_MULTIPLICATION_FACTOR) / __KDTREE_LEAF_BUCKET_SIZE + 1;
for (int index = 0; index < maxKDTreeNodeCount; index++)
{
const KDTreeNode & node = nodes[index];

std::cout << "Node Index: " << index << ", Dimension: " << node.dimension << ", Cut Point: " << node.cutPoint
<< ", Left Index: " << node.leftIndex << ", Right Index: " << node.rightIndex << std::endl;
}

const auto rootTreeNodeIndex = model->impl()->getRootNodeIndex();
const NumericTable & data = *(model->impl()->getData());
const NumericTable * labels = nullptr;

if (resultsToEvaluate != 0)
{
labels = model->impl()->getLabels().get();
Expand All @@ -164,8 +181,6 @@ Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, cpu>::compu
}
const size_t heapSize = (iSize / 16 + 1) * 16;

const size_t xRowCount = x->getNumberOfRows();
const algorithmFpType base = 2.0;
const size_t expectedMaxDepth = (Math::sLog(xRowCount) / Math::sLog(base) + 1) * __KDTREE_DEPTH_MULTIPLICATION_FACTOR;
const size_t stackSize = Math::sPowx(base, Math::sCeil(Math::sLog(expectedMaxDepth) / Math::sLog(base)));
struct Local
Expand Down
Loading

0 comments on commit 8d103c1

Please sign in to comment.