Skip to content

Commit

Permalink
extra tests for the Relative Difference Prior
Browse files Browse the repository at this point in the history
split the tests for the different priors in different classes, such
that we can have specific tests for each. Currently, only added some for the RDP,
including a Lipschitz test and limit for large values.
  • Loading branch information
KrisThielemans committed Apr 27, 2024
1 parent 26393f4 commit d669022
Showing 1 changed file with 50 additions and 3 deletions.
53 changes: 50 additions & 3 deletions src/recon_test/test_priors.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "stir/Verbosity.h"
#include "stir/Succeeded.h"
#include "stir/num_threads.h"
#include "stir/numerics/norm.h"
#include <iostream>
#include <memory>
#include <boost/random/uniform_01.hpp>
Expand Down Expand Up @@ -475,9 +476,53 @@ class RelativeDifferencePriorTests : public GeneralisedPriorTests
{
public:
using GeneralisedPriorTests::GeneralisedPriorTests;
void run_specific_tests(const std::string& test_name,
RelativeDifferencePrior<float>& rdp,
const shared_ptr<target_type>& target_sptr);
void run_tests() override;
};

void
RelativeDifferencePriorTests::run_specific_tests(const std::string& test_name,
RelativeDifferencePrior<float>& rdp,
const shared_ptr<DiscretisedDensity<3, float>>& target_sptr)
{
std::cerr << "----- test " << test_name << " --> RDP gradient limit tests\n";
shared_ptr<target_type> grad_sptr(target_sptr->get_empty_copy());
const Array<3, float> weights = rdp.get_weights() * rdp.get_penalisation_factor();
const bool do_kappa = rdp.get_kappa_sptr() != 0;
// strictly speaking, we should be checking product of the kappas in a neighbourhood, but they usually very smoothly. In any
// case, this will give an upper-bound
const double kappa2_max = do_kappa ? square(rdp.get_kappa_sptr()->find_max()) : 1.;
const auto weights_sum = weights.sum() * kappa2_max;

if (rdp.get_epsilon() > 0)
{
const double grad_Lipschitz = 4 * weights_sum * kappa2_max / rdp.get_epsilon();

rdp.compute_gradient(*grad_sptr, *target_sptr);
check_if_less(norm(grad_sptr->begin_all(), grad_sptr->end_all()),
grad_Lipschitz * norm(target_sptr->begin_all(), target_sptr->end_all()) * 1.001F,
"gradient Lipschitz with x = input_image, y = 0");
}

shared_ptr<target_type> delta_sptr(target_sptr->get_empty_copy());
delta_sptr->fill(0.F);

auto idx = make_coordinate(1, 1, 1);
(*delta_sptr)[idx] = 1E10F * rdp.get_epsilon();
rdp.compute_gradient(*grad_sptr, *delta_sptr);
check_if_less((*grad_sptr)[idx], weights_sum / (1 + rdp.get_gamma()), "RDP gradient large limit");
(*delta_sptr)[idx] = 1E20F * rdp.get_epsilon();
rdp.compute_gradient(*grad_sptr, *delta_sptr);
check_if_less((*grad_sptr)[idx], weights_sum / (1 + rdp.get_gamma()), "RDP gradient very large limit");
// check at boundary
idx = make_coordinate(0, 0, 0);
(*delta_sptr)[idx] = 1E10F * rdp.get_epsilon();
rdp.compute_gradient(*grad_sptr, *delta_sptr);
check_if_less((*grad_sptr)[idx], weights_sum / (1 + rdp.get_gamma()), "RDP gradient large limit at boundary");
}

void
RelativeDifferencePriorTests::run_tests()
{
Expand All @@ -488,15 +533,18 @@ RelativeDifferencePriorTests::run_tests()
{
// gamma is default and epsilon is 0.0
RelativeDifferencePrior<float> objective_function(false, 1.F, 2.F, 0.F);
this->configure_prior_tests(true, true, false); // RDP, with epsilon = 0.0, will fail the numerical Hessian test
this->configure_prior_tests(
true, true, false); // RDP, with epsilon = 0.0, will fail the numerical Hessian test (it can become infinity)
this->run_tests_for_objective_function("RDP_no_kappa_no_eps", objective_function, density_sptr);
this->run_specific_tests("RDP_specific_no_kappa_no_eps", objective_function, density_sptr);
}
std::cerr << "\n\nTests for Relative Difference Prior with epsilon = 0.1\n";
{
// gamma is default and epsilon is "small"
RelativeDifferencePrior<float> objective_function(false, 1.F, 2.F, 0.1F);
this->configure_prior_tests(true, true, true); // With a large enough epsilon the RDP Hessian numerical test will pass
this->run_tests_for_objective_function("RDP_no_kappa_with_eps", objective_function, density_sptr);
this->run_specific_tests("RDP_specific_no_kappa_with_eps", objective_function, density_sptr);
}
}

Expand Down Expand Up @@ -584,11 +632,10 @@ main(int argc, char** argv)
everything_ok = everything_ok && tests.is_everything_ok();
}
{
PLSPriorTests tests(argc > 1 ? argv[1] : nullptr);
LogCoshPriorTests tests(argc > 1 ? argv[1] : nullptr);
tests.run_tests();
everything_ok = everything_ok && tests.is_everything_ok();
}


return everything_ok ? EXIT_SUCCESS : EXIT_FAILURE;
}

0 comments on commit d669022

Please sign in to comment.