Skip to content

Commit

Permalink
Fix RKNN search and add more robust unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jlblancoc committed Nov 27, 2023
1 parent 78dac97 commit 2b090a8
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 11 deletions.
9 changes: 4 additions & 5 deletions include/nanoflann.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ class KNNResultSet

CountType size() const { return count; }
bool empty() const { return count == 0; }

bool full() const { return count == capacity; }
bool full() const { return count == capacity; }

/**
* Called during search to add an element matching the criteria.
Expand Down Expand Up @@ -278,8 +277,8 @@ class RKNNResultSet
}

CountType size() const { return count; }

bool full() const { return count == capacity; }
bool empty() const { return count == 0; }
bool full() const { return count == capacity; }

/**
* Called during search to add an element matching the criteria.
Expand All @@ -288,7 +287,7 @@ class RKNNResultSet
*/
bool addPoint(DistanceType dist, IndexType index)
{
if (dist > maximumSearchDistanceSquared) return false; // stop search
if (dist > maximumSearchDistanceSquared) return true;

CountType i;
for (i = count; i > 0; --i)
Expand Down
114 changes: 108 additions & 6 deletions tests/test_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ void L2_vs_bruteforce_test(
const auto nFound = resultSet.size();

// Brute force neighbors:
std::multimap<double /*dist*/, size_t /*idx*/> bf_nn;
std::multimap<NUM /*dist*/, size_t /*idx*/> bf_nn;
{
for (size_t i = 0; i < nSamples; i++)
{
Expand All @@ -163,18 +163,99 @@ void L2_vs_bruteforce_test(
}
}

// Keep knn closest ones:
while (bf_nn.size() > num_results) bf_nn.erase((++bf_nn.rbegin()).base());
// Keep bruteforce solutions indexed by idx instead of distances,
// to handle correctly almost or exactly coindicing distances for >=2 NN:
std::map<size_t, NUM> bf_idx2dist;
for (const auto& kv : bf_nn) bf_idx2dist[kv.second] = kv.first;

// Compare:
EXPECT_EQ(bf_nn.size(), nFound);
if (!bf_nn.empty())
{
auto it = bf_nn.begin();
for (size_t i = 0; i < nFound; ++i, ++it)
{
EXPECT_EQ(it->second, ret_indexes[i]);
// Distances must be in exact order:
EXPECT_NEAR(it->first, out_dists_sqr[i], 1e-3);

// indices may be not in the (rare) case of a tie:
EXPECT_NEAR(bf_idx2dist.at(ret_indexes[i]), out_dists_sqr[i], 1e-3)
<< "For: numToSearch=" << numToSearch
<< " out_dists_sqr[i]=" << out_dists_sqr[i] << "\n";
}
}
}

template <typename NUM>
void rknn_L2_vs_bruteforce_test(
const size_t nSamples, const size_t DIM, const size_t numToSearch,
const NUM maxRadiusSqr)
{
std::vector<std::vector<NUM>> samples;

const NUM max_range = NUM(20.0);

// Generate points:
generateRandomPointCloud(samples, nSamples, DIM, max_range);

// Query point:
std::vector<NUM> query_pt(DIM);
for (size_t d = 0; d < DIM; d++)
query_pt[d] = static_cast<NUM>(max_range * (rand() % 1000) / (1000.0));

// construct a kd-tree index:
// Dimensionality set at run-time (default: L2)
// ------------------------------------------------------------
typedef KDTreeVectorOfVectorsAdaptor<std::vector<std::vector<NUM>>, NUM>
my_kd_tree_t;

my_kd_tree_t mat_index(DIM /*dim*/, samples, 10 /* max leaf */);

// do a knn search
const size_t num_results = numToSearch;
std::vector<size_t> ret_indexes(num_results);
std::vector<NUM> out_dists_sqr(num_results);

nanoflann::RKNNResultSet<NUM> resultSet(num_results, maxRadiusSqr);

resultSet.init(&ret_indexes[0], &out_dists_sqr[0]);
mat_index.index->findNeighbors(resultSet, &query_pt[0]);

const auto nFound = resultSet.size();

// Brute force neighbors:
std::multimap<NUM /*dist*/, size_t /*idx*/> bf_nn;
{
for (size_t i = 0; i < nSamples; i++)
{
double dist = 0.0;
for (size_t d = 0; d < DIM; d++)
dist += (query_pt[d] - samples[i][d]) *
(query_pt[d] - samples[i][d]);

if (dist <= maxRadiusSqr) bf_nn.emplace(dist, i);
}
}

// Keep bruteforce solutions indexed by idx instead of distances,
// to handle correctly almost or exactly coindicing distances for >=2 NN:
std::map<size_t, NUM> bf_idx2dist;
for (const auto& kv : bf_nn) bf_idx2dist[kv.second] = kv.first;

// Compare:
if (!bf_nn.empty())
{
auto it = bf_nn.begin();
for (size_t i = 0; i < nFound; ++i, ++it)
{
// Distances must be in exact order:
EXPECT_NEAR(it->first, out_dists_sqr[i], 1e-3)
<< "For: numToSearch=" << numToSearch
<< " out_dists_sqr[i]=" << out_dists_sqr[i] << "\n";

// indices may be not in the (rare) case of a tie:
EXPECT_NEAR(bf_idx2dist.at(ret_indexes[i]), out_dists_sqr[i], 1e-3)
<< "For: numToSearch=" << numToSearch
<< " out_dists_sqr[i]=" << out_dists_sqr[i] << "\n";
}
}
}
Expand Down Expand Up @@ -524,7 +605,7 @@ TEST(kdtree, L2_vs_bruteforce)
srand(static_cast<unsigned int>(time(nullptr)));
for (int knn = 1; knn < 20; knn += 3)
{
for (int i = 0; i < 10; i++)
for (int i = 0; i < 500; i++)
{
L2_vs_bruteforce_test<float>(100, 2, knn);
L2_vs_bruteforce_test<float>(100, 3, knn);
Expand All @@ -537,6 +618,27 @@ TEST(kdtree, L2_vs_bruteforce)
}
}

TEST(kdtree, L2_vs_bruteforce_rknn)
{
srand(static_cast<unsigned int>(time(nullptr)));
for (int knn = 1; knn < 20; knn += 3)
{
for (int r = 1; r < 5; r++)
{
for (int i = 0; i < 100; i++)
{
rknn_L2_vs_bruteforce_test<float>(100, 2, knn, 9.0f * r * r);
rknn_L2_vs_bruteforce_test<float>(100, 3, knn, 9.0f * r * r);
rknn_L2_vs_bruteforce_test<float>(100, 7, knn, 9.0f * r * r);

rknn_L2_vs_bruteforce_test<double>(100, 2, knn, 9.0 * r * r);
rknn_L2_vs_bruteforce_test<double>(100, 3, knn, 9.0 * r * r);
rknn_L2_vs_bruteforce_test<double>(100, 7, knn, 9.0 * r * r);
}
}
}
}

TEST(kdtree, SO3_vs_bruteforce)
{
srand(static_cast<unsigned int>(time(nullptr)));
Expand Down

0 comments on commit 2b090a8

Please sign in to comment.