Skip to content

Commit

Permalink
introduce templated LM_distributable function
Browse files Browse the repository at this point in the history
  • Loading branch information
KrisThielemans committed May 11, 2024
1 parent 927ab41 commit 24f369d
Showing 1 changed file with 98 additions and 71 deletions.
169 changes: 98 additions & 71 deletions src/recon_buildblock/distributable.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -595,75 +595,18 @@ distributable_computation(const shared_ptr<ForwardProjectorByBin>& forward_proje
% wall_clock_timer.value());
}

constexpr float max_quotient = 10000.F;

inline void
LM_gradient(DiscretisedDensity<3, float>& output_image,
const Bin& measured_bin,
const DiscretisedDensity<3, float>& input_image,
const ProjMatrixElemsForOneBin& row,
const float add_term)
{
Bin fwd_bin = measured_bin;
fwd_bin.set_bin_value(0.0f);
row.forward_project(fwd_bin, input_image);

if (add_term)
fwd_bin.set_bin_value(fwd_bin.get_bin_value() + add_term);

float measured_div_fwd = 0.F;
if (measured_bin.get_bin_value() <= max_quotient * fwd_bin.get_bin_value())
measured_div_fwd = measured_bin.get_bin_value() / fwd_bin.get_bin_value();
else
return;

fwd_bin.set_bin_value(measured_div_fwd);
row.back_project(output_image, fwd_bin);
}

/* Hessian
\sum_e A_e^t (y_e/(A_e lambda+ c)^2 A_e x)
*/
inline void
LM_Hessian(DiscretisedDensity<3, float>& output_image,
const Bin& measured_bin,
const DiscretisedDensity<3, float>& input_image,
const DiscretisedDensity<3, float>& rhs,
const ProjMatrixElemsForOneBin& row,
const float add_term)
{
Bin fwd_bin = measured_bin;
fwd_bin.set_bin_value(0.0f);
row.forward_project(fwd_bin, input_image);

if (add_term)
fwd_bin.set_bin_value(fwd_bin.get_bin_value() + add_term);

if (measured_bin.get_bin_value() > max_quotient * fwd_bin.get_bin_value())
return;
float measured_div_fwd2 = measured_bin.get_bin_value() / square(fwd_bin.get_bin_value());

fwd_bin.set_bin_value(0.0f);
row.forward_project(fwd_bin, rhs);

if (fwd_bin.get_bin_value() == 0)
return;

fwd_bin.set_bin_value(measured_div_fwd2 * fwd_bin.get_bin_value());
row.back_project(output_image, fwd_bin);
}

template <typename CallBackT>
void
LM_distributable_computation(const shared_ptr<ProjMatrixByBin> PM_sptr,
const shared_ptr<ProjDataInfo>& proj_data_info_sptr,
DiscretisedDensity<3, float>* output_image_ptr,
const DiscretisedDensity<3, float>* input_image_ptr,
const std::vector<BinAndCorr>& record_ptr,
const int subset_num,
const int num_subsets,
const bool has_add,
const bool accumulate)
LM_distributable_computation_template(const shared_ptr<ProjMatrixByBin> PM_sptr,
const shared_ptr<ProjDataInfo>& proj_data_info_sptr,
DiscretisedDensity<3, float>* output_image_ptr,
const DiscretisedDensity<3, float>* input_image_ptr,
const std::vector<BinAndCorr>& record_ptr,
const int subset_num,
const int num_subsets,
const bool has_add,
const bool accumulate,
CallBackT&& call_back)
{

CPUTimer CPU_timer;
Expand Down Expand Up @@ -724,7 +667,7 @@ LM_distributable_computation(const shared_ptr<ProjMatrixByBin> PM_sptr,
local_output_image_sptrs[thread_num].reset(output_image_ptr->get_empty_copy());
}

Bin measured_bin = record_ptr.at(ievent).my_bin;
const Bin& measured_bin = record_ptr.at(ievent).my_bin;

if (num_subsets > 1)
{
Expand All @@ -739,8 +682,11 @@ LM_distributable_computation(const shared_ptr<ProjMatrixByBin> PM_sptr,
}

PM_sptr->get_proj_matrix_elems_for_one_bin(local_row[thread_num], measured_bin);
LM_gradient(*local_output_image_sptrs[thread_num], measured_bin, *input_image_ptr,
local_row[thread_num], has_add ? record_ptr.at(ievent).my_corr : 0);
call_back(*local_output_image_sptrs[thread_num],
measured_bin,
*input_image_ptr,
local_row[thread_num],
has_add ? record_ptr.at(ievent).my_corr : 0);
}
}
#ifdef STIR_OPENMP
Expand All @@ -760,4 +706,85 @@ LM_distributable_computation(const shared_ptr<ProjMatrixByBin> PM_sptr,
% wall_clock_timer.value());
}

constexpr float max_quotient = 10000.F;

inline void
LM_gradient(DiscretisedDensity<3, float>& output_image,
const Bin& measured_bin,
const DiscretisedDensity<3, float>& input_image,
const ProjMatrixElemsForOneBin& row,
const float add_term)
{
Bin fwd_bin = measured_bin;
fwd_bin.set_bin_value(0.0f);
row.forward_project(fwd_bin, input_image);

if (add_term)
fwd_bin.set_bin_value(fwd_bin.get_bin_value() + add_term);

float measured_div_fwd = 0.F;
if (measured_bin.get_bin_value() <= max_quotient * fwd_bin.get_bin_value())
measured_div_fwd = measured_bin.get_bin_value() / fwd_bin.get_bin_value();
else
return;

fwd_bin.set_bin_value(measured_div_fwd);
row.back_project(output_image, fwd_bin);
}

/* Hessian
\sum_e A_e^t (y_e/(A_e lambda+ c)^2 A_e x)
*/
inline void
LM_Hessian(DiscretisedDensity<3, float>& output_image,
const Bin& measured_bin,
const DiscretisedDensity<3, float>& input_image,
const DiscretisedDensity<3, float>& rhs,
const ProjMatrixElemsForOneBin& row,
const float add_term)
{
Bin fwd_bin = measured_bin;
fwd_bin.set_bin_value(0.0f);
row.forward_project(fwd_bin, input_image);

if (add_term)
fwd_bin.set_bin_value(fwd_bin.get_bin_value() + add_term);

if (measured_bin.get_bin_value() > max_quotient * fwd_bin.get_bin_value())
return;
float measured_div_fwd2 = measured_bin.get_bin_value() / square(fwd_bin.get_bin_value());

fwd_bin.set_bin_value(0.0f);
row.forward_project(fwd_bin, rhs);

if (fwd_bin.get_bin_value() == 0)
return;

fwd_bin.set_bin_value(measured_div_fwd2 * fwd_bin.get_bin_value());
row.back_project(output_image, fwd_bin);
}

void
LM_distributable_computation(const shared_ptr<ProjMatrixByBin> PM_sptr,
const shared_ptr<ProjDataInfo>& proj_data_info_sptr,
DiscretisedDensity<3, float>* output_image_ptr,
const DiscretisedDensity<3, float>* input_image_ptr,
const std::vector<BinAndCorr>& record_ptr,
const int subset_num,
const int num_subsets,
const bool has_add,
const bool accumulate)
{
LM_distributable_computation_template(PM_sptr,
proj_data_info_sptr,
output_image_ptr,
input_image_ptr,
record_ptr,
subset_num,
num_subsets,
has_add,
accumulate,
LM_gradient);
}
END_NAMESPACE_STIR

0 comments on commit 24f369d

Please sign in to comment.