From 2b090a823f3b5a02aeaa36c4b7a641ffd50b146b Mon Sep 17 00:00:00 2001 From: Jose Luis Blanco-Claraco Date: Mon, 27 Nov 2023 23:23:38 +0100 Subject: [PATCH] Fix RKNN search and add more robust unit tests --- include/nanoflann.hpp | 9 ++-- tests/test_main.cpp | 114 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 112 insertions(+), 11 deletions(-) diff --git a/include/nanoflann.hpp b/include/nanoflann.hpp index 0d207c5..64b7795 100644 --- a/include/nanoflann.hpp +++ b/include/nanoflann.hpp @@ -193,8 +193,7 @@ class KNNResultSet CountType size() const { return count; } bool empty() const { return count == 0; } - - bool full() const { return count == capacity; } + bool full() const { return count == capacity; } /** * Called during search to add an element matching the criteria. @@ -278,8 +277,8 @@ class RKNNResultSet } CountType size() const { return count; } - - bool full() const { return count == capacity; } + bool empty() const { return count == 0; } + bool full() const { return count == capacity; } /** * Called during search to add an element matching the criteria. @@ -288,7 +287,7 @@ class RKNNResultSet */ bool addPoint(DistanceType dist, IndexType index) { - if (dist > maximumSearchDistanceSquared) return false; // stop search + if (dist > maximumSearchDistanceSquared) return true; CountType i; for (i = count; i > 0; --i) diff --git a/tests/test_main.cpp b/tests/test_main.cpp index 316f4d5..66303d3 100644 --- a/tests/test_main.cpp +++ b/tests/test_main.cpp @@ -151,7 +151,7 @@ void L2_vs_bruteforce_test( const auto nFound = resultSet.size(); // Brute force neighbors: - std::multimap bf_nn; + std::multimap bf_nn; { for (size_t i = 0; i < nSamples; i++) { @@ -163,18 +163,99 @@ void L2_vs_bruteforce_test( } } - // Keep knn closest ones: - while (bf_nn.size() > num_results) bf_nn.erase((++bf_nn.rbegin()).base()); + // Keep bruteforce solutions indexed by idx instead of distances, + // to handle correctly almost or exactly coindicing distances for >=2 NN: + std::map bf_idx2dist; + for (const auto& kv : bf_nn) bf_idx2dist[kv.second] = kv.first; // Compare: - EXPECT_EQ(bf_nn.size(), nFound); if (!bf_nn.empty()) { auto it = bf_nn.begin(); for (size_t i = 0; i < nFound; ++i, ++it) { - EXPECT_EQ(it->second, ret_indexes[i]); + // Distances must be in exact order: EXPECT_NEAR(it->first, out_dists_sqr[i], 1e-3); + + // indices may be not in the (rare) case of a tie: + EXPECT_NEAR(bf_idx2dist.at(ret_indexes[i]), out_dists_sqr[i], 1e-3) + << "For: numToSearch=" << numToSearch + << " out_dists_sqr[i]=" << out_dists_sqr[i] << "\n"; + } + } +} + +template +void rknn_L2_vs_bruteforce_test( + const size_t nSamples, const size_t DIM, const size_t numToSearch, + const NUM maxRadiusSqr) +{ + std::vector> samples; + + const NUM max_range = NUM(20.0); + + // Generate points: + generateRandomPointCloud(samples, nSamples, DIM, max_range); + + // Query point: + std::vector query_pt(DIM); + for (size_t d = 0; d < DIM; d++) + query_pt[d] = static_cast(max_range * (rand() % 1000) / (1000.0)); + + // construct a kd-tree index: + // Dimensionality set at run-time (default: L2) + // ------------------------------------------------------------ + typedef KDTreeVectorOfVectorsAdaptor>, NUM> + my_kd_tree_t; + + my_kd_tree_t mat_index(DIM /*dim*/, samples, 10 /* max leaf */); + + // do a knn search + const size_t num_results = numToSearch; + std::vector ret_indexes(num_results); + std::vector out_dists_sqr(num_results); + + nanoflann::RKNNResultSet resultSet(num_results, maxRadiusSqr); + + resultSet.init(&ret_indexes[0], &out_dists_sqr[0]); + mat_index.index->findNeighbors(resultSet, &query_pt[0]); + + const auto nFound = resultSet.size(); + + // Brute force neighbors: + std::multimap bf_nn; + { + for (size_t i = 0; i < nSamples; i++) + { + double dist = 0.0; + for (size_t d = 0; d < DIM; d++) + dist += (query_pt[d] - samples[i][d]) * + (query_pt[d] - samples[i][d]); + + if (dist <= maxRadiusSqr) bf_nn.emplace(dist, i); + } + } + + // Keep bruteforce solutions indexed by idx instead of distances, + // to handle correctly almost or exactly coindicing distances for >=2 NN: + std::map bf_idx2dist; + for (const auto& kv : bf_nn) bf_idx2dist[kv.second] = kv.first; + + // Compare: + if (!bf_nn.empty()) + { + auto it = bf_nn.begin(); + for (size_t i = 0; i < nFound; ++i, ++it) + { + // Distances must be in exact order: + EXPECT_NEAR(it->first, out_dists_sqr[i], 1e-3) + << "For: numToSearch=" << numToSearch + << " out_dists_sqr[i]=" << out_dists_sqr[i] << "\n"; + + // indices may be not in the (rare) case of a tie: + EXPECT_NEAR(bf_idx2dist.at(ret_indexes[i]), out_dists_sqr[i], 1e-3) + << "For: numToSearch=" << numToSearch + << " out_dists_sqr[i]=" << out_dists_sqr[i] << "\n"; } } } @@ -524,7 +605,7 @@ TEST(kdtree, L2_vs_bruteforce) srand(static_cast(time(nullptr))); for (int knn = 1; knn < 20; knn += 3) { - for (int i = 0; i < 10; i++) + for (int i = 0; i < 500; i++) { L2_vs_bruteforce_test(100, 2, knn); L2_vs_bruteforce_test(100, 3, knn); @@ -537,6 +618,27 @@ TEST(kdtree, L2_vs_bruteforce) } } +TEST(kdtree, L2_vs_bruteforce_rknn) +{ + srand(static_cast(time(nullptr))); + for (int knn = 1; knn < 20; knn += 3) + { + for (int r = 1; r < 5; r++) + { + for (int i = 0; i < 100; i++) + { + rknn_L2_vs_bruteforce_test(100, 2, knn, 9.0f * r * r); + rknn_L2_vs_bruteforce_test(100, 3, knn, 9.0f * r * r); + rknn_L2_vs_bruteforce_test(100, 7, knn, 9.0f * r * r); + + rknn_L2_vs_bruteforce_test(100, 2, knn, 9.0 * r * r); + rknn_L2_vs_bruteforce_test(100, 3, knn, 9.0 * r * r); + rknn_L2_vs_bruteforce_test(100, 7, knn, 9.0 * r * r); + } + } + } +} + TEST(kdtree, SO3_vs_bruteforce) { srand(static_cast(time(nullptr)));