Skip to content

Commit

Permalink
Larger coverage in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jlblancoc committed Nov 27, 2023
1 parent 4065aae commit 78dac97
Showing 1 changed file with 39 additions and 27 deletions.
66 changes: 39 additions & 27 deletions tests/test_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <cmath> // for abs()
#include <cstdlib>
#include <iostream>
#include <map>
#include <nanoflann.hpp>

#include "../examples/utils.h"
Expand Down Expand Up @@ -114,7 +115,8 @@ void generateRandomPointCloud(
}

template <typename NUM>
void L2_vs_bruteforce_test(const size_t nSamples, const size_t DIM)
void L2_vs_bruteforce_test(
const size_t nSamples, const size_t DIM, const size_t numToSearch)
{
std::vector<std::vector<NUM>> samples;

Expand All @@ -137,7 +139,7 @@ void L2_vs_bruteforce_test(const size_t nSamples, const size_t DIM)
my_kd_tree_t mat_index(DIM /*dim*/, samples, 10 /* max leaf */);

// do a knn search
const size_t num_results = 1;
const size_t num_results = numToSearch;
std::vector<size_t> ret_indexes(num_results);
std::vector<NUM> out_dists_sqr(num_results);

Expand All @@ -146,28 +148,35 @@ void L2_vs_bruteforce_test(const size_t nSamples, const size_t DIM)
resultSet.init(&ret_indexes[0], &out_dists_sqr[0]);
mat_index.index->findNeighbors(resultSet, &query_pt[0]);

// Brute force:
double min_dist_L2 = std::numeric_limits<double>::max();
size_t min_idx = std::numeric_limits<size_t>::max();
const auto nFound = resultSet.size();

// Brute force neighbors:
std::multimap<double /*dist*/, size_t /*idx*/> bf_nn;
{
for (size_t i = 0; i < nSamples; i++)
{
double dist = 0.0;
for (size_t d = 0; d < DIM; d++)
dist += (query_pt[d] - samples[i][d]) *
(query_pt[d] - samples[i][d]);
if (dist < min_dist_L2)
{
min_dist_L2 = dist;
min_idx = i;
}
bf_nn.emplace(dist, i);
}
ASSERT_TRUE(min_idx != std::numeric_limits<size_t>::max());
}

// Keep knn closest ones:
while (bf_nn.size() > num_results) bf_nn.erase((++bf_nn.rbegin()).base());

// Compare:
EXPECT_EQ(min_idx, ret_indexes[0]);
EXPECT_NEAR(min_dist_L2, out_dists_sqr[0], 1e-3);
EXPECT_EQ(bf_nn.size(), nFound);
if (!bf_nn.empty())
{
auto it = bf_nn.begin();
for (size_t i = 0; i < nFound; ++i, ++it)
{
EXPECT_EQ(it->second, ret_indexes[i]);
EXPECT_NEAR(it->first, out_dists_sqr[i], 1e-3);
}
}
}

template <typename NUM>
Expand Down Expand Up @@ -385,8 +394,8 @@ void L2_dynamic_vs_bruteforce_test(const size_t nSamples)
}

template <typename NUM>
void L2_concurrent_build_vs_bruteforce_test(const size_t nSamples,
const size_t DIM)
void L2_concurrent_build_vs_bruteforce_test(
const size_t nSamples, const size_t DIM)
{
std::vector<std::vector<NUM>> samples;

Expand All @@ -406,7 +415,8 @@ void L2_concurrent_build_vs_bruteforce_test(const size_t nSamples,
typedef KDTreeVectorOfVectorsAdaptor<std::vector<std::vector<NUM>>, NUM>
my_kd_tree_t;

my_kd_tree_t mat_index(DIM /*dim*/, samples, 10 /* max leaf */, 0 /* concurrent build */);
my_kd_tree_t mat_index(
DIM /*dim*/, samples, 10 /* max leaf */, 0 /* concurrent build */);

// do a knn search
const size_t num_results = 1;
Expand Down Expand Up @@ -443,8 +453,7 @@ void L2_concurrent_build_vs_bruteforce_test(const size_t nSamples,
}

template <typename NUM>
void L2_concurrent_build_vs_L2_test(const size_t nSamples,
const size_t DIM)
void L2_concurrent_build_vs_L2_test(const size_t nSamples, const size_t DIM)
{
std::vector<std::vector<NUM>> samples;

Expand All @@ -464,8 +473,8 @@ void L2_concurrent_build_vs_L2_test(const size_t nSamples,
typedef KDTreeVectorOfVectorsAdaptor<std::vector<std::vector<NUM>>, NUM>
my_kd_tree_t;

my_kd_tree_t mat_index_concurrent_build(DIM /*dim*/, samples,
10 /* max leaf */, 0 /* concurrent build */);
my_kd_tree_t mat_index_concurrent_build(
DIM /*dim*/, samples, 10 /* max leaf */, 0 /* concurrent build */);
my_kd_tree_t mat_index(DIM /*dim*/, samples, 10 /* max leaf */);

// Compare:
Expand Down Expand Up @@ -513,15 +522,18 @@ TEST(kdtree, robust_empty_tree)
TEST(kdtree, L2_vs_bruteforce)
{
srand(static_cast<unsigned int>(time(nullptr)));
for (int i = 0; i < 10; i++)
for (int knn = 1; knn < 20; knn += 3)
{
L2_vs_bruteforce_test<float>(100, 2);
L2_vs_bruteforce_test<float>(100, 3);
L2_vs_bruteforce_test<float>(100, 7);
for (int i = 0; i < 10; i++)
{
L2_vs_bruteforce_test<float>(100, 2, knn);
L2_vs_bruteforce_test<float>(100, 3, knn);
L2_vs_bruteforce_test<float>(100, 7, knn);

L2_vs_bruteforce_test<double>(100, 2);
L2_vs_bruteforce_test<double>(100, 3);
L2_vs_bruteforce_test<double>(100, 7);
L2_vs_bruteforce_test<double>(100, 2, knn);
L2_vs_bruteforce_test<double>(100, 3, knn);
L2_vs_bruteforce_test<double>(100, 7, knn);
}
}
}

Expand Down

0 comments on commit 78dac97

Please sign in to comment.