Skip to content

Commit

Permalink
WIP: always use LM_distributable_computation for gradient
Browse files Browse the repository at this point in the history
currently still failing
  • Loading branch information
KrisThielemans committed May 9, 2024
1 parent 2270c16 commit b168163
Showing 1 changed file with 23 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ PoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProjMatrixByBin<Tar
this->cache_lm_file = true;
return cache_listmode_file();
}
else
{
this->cache_lm_file = false;
this->cache_size = 4000; // 1000000;
}

return Succeeded::yes;
}
Expand Down Expand Up @@ -721,145 +726,29 @@ PoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProjMatrixByBin<
"actual_compute_subset_gradient_without_penalty(): cannot subtract subset sensitivity because "
"use_subset_sensitivities is false. This will result in an error in the gradient computation.");

if (this->cache_lm_file)
unsigned int icache = 0;
bool stop = false;
while (!stop)
{
for (unsigned int icache = 0; icache < this->num_cache_files; ++icache)
if (this->cache_lm_file)
{
load_listmode_cache_file(icache);
LM_distributable_computation(this->PM_sptr,
this->proj_data_info_sptr,
&gradient,
&current_estimate,
record_cache,
subset_num,
this->num_subsets,
this->has_add,
/* accumulate = */ icache != 0);
stop = (icache + 1) == this->num_cache_files;
this->load_listmode_cache_file(icache);
}
}
else
{
// list_mode_data_sptr->set_get_position(start_time);
// TODO implement function that will do this for a random time

this->list_mode_data_sptr->reset();

const double start_time = this->frame_defs.get_start_time(this->current_frame_num);
const double end_time = this->frame_defs.get_end_time(this->current_frame_num);

long num_used_events = 0;
const float max_quotient = 10000.F;

double current_time = 0.;

// need get_bin_value(), so currently need to cast to either of the 2 below (ugly. TODO)
shared_ptr<ProjDataFromStream> add_from_stream_sptr;
shared_ptr<ProjDataInMemory> add_in_mem_sptr;

if (!is_null_ptr(this->additive_proj_data_sptr))
else
{
add_in_mem_sptr = std::dynamic_pointer_cast<ProjDataInMemory>(this->additive_proj_data_sptr);
if (is_null_ptr(add_in_mem_sptr))
{
add_from_stream_sptr = std::dynamic_pointer_cast<ProjDataFromStream>(this->additive_proj_data_sptr);
// TODO could create a ProjDataInMemory instead, but for now we give up.
if (is_null_ptr(add_from_stream_sptr))
error("Additive projection data is in unsupported file format. You need to create an Interfile copy. sorry.");
}
}

ProjMatrixElemsForOneBin proj_matrix_row;
gradient.fill(0);
shared_ptr<ListRecord> record_sptr = this->list_mode_data_sptr->get_empty_record_sptr();
ListRecord& record = *record_sptr;

VectorWithOffset<ListModeData::SavedPosition> frame_start_positions(1, static_cast<int>(this->frame_defs.get_num_frames()));

while (true)
{

if (this->list_mode_data_sptr->get_next_record(record) == Succeeded::no)
{
info("End of listmode file!", 2);
break; // get out of while loop
}

if (record.is_time())
{
current_time = record.time().get_time_in_secs();
if (this->do_time_frame && current_time >= end_time)
{
break; // get out of while loop
}
}

if (current_time < start_time)
continue;

if (record.is_event() && record.event().is_prompt())
{
Bin measured_bin;
measured_bin.set_bin_value(1.0f);
record.event().get_bin(measured_bin, *this->proj_data_info_sptr);

if (measured_bin.get_bin_value() != 1.0f
|| measured_bin.segment_num() < this->proj_data_info_sptr->get_min_segment_num()
|| measured_bin.segment_num() > this->proj_data_info_sptr->get_max_segment_num()
|| measured_bin.tangential_pos_num() < this->proj_data_info_sptr->get_min_tangential_pos_num()
|| measured_bin.tangential_pos_num() > this->proj_data_info_sptr->get_max_tangential_pos_num()
|| measured_bin.axial_pos_num() < this->proj_data_info_sptr->get_min_axial_pos_num(measured_bin.segment_num())
|| measured_bin.axial_pos_num() > this->proj_data_info_sptr->get_max_axial_pos_num(measured_bin.segment_num())
|| measured_bin.timing_pos_num() < this->proj_data_info_sptr->get_min_tof_pos_num()
|| measured_bin.timing_pos_num() > this->proj_data_info_sptr->get_max_tof_pos_num())
{
continue;
}

measured_bin.set_bin_value(1.0f);
// If more than 1 subsets, check if the current bin belongs to the current.
bool in_subset = true;
if (this->num_subsets > 1)
{
Bin basic_bin = measured_bin;
this->PM_sptr->get_symmetries_ptr()->find_basic_bin(basic_bin);
in_subset = (subset_num == static_cast<int>(basic_bin.view_num() % this->num_subsets));
}
if (in_subset)
{
this->PM_sptr->get_proj_matrix_elems_for_one_bin(proj_matrix_row, measured_bin);
Bin fwd_bin;
fwd_bin.set_bin_value(0.0f);
proj_matrix_row.forward_project(fwd_bin, current_estimate);
// additive sinogram
if (!is_null_ptr(this->additive_proj_data_sptr))
{
// TODO simplify once we don't need the casting for get_bin_value() anymore
const float add_value = add_in_mem_sptr ? add_in_mem_sptr->get_bin_value(measured_bin)
: add_from_stream_sptr->get_bin_value(measured_bin);
const float value = fwd_bin.get_bin_value() + add_value;
fwd_bin.set_bin_value(value);
}

if (measured_bin.get_bin_value() <= max_quotient * fwd_bin.get_bin_value())
{
const float measured_div_fwd = 1.0f / fwd_bin.get_bin_value();
measured_bin.set_bin_value(measured_div_fwd);
proj_matrix_row.back_project(gradient, measured_bin);
}
}

++num_used_events;

if (num_used_events % 200000L == 0)
info(boost::format("Used Events: %1% ") % num_used_events);

// if we use event-count-based processing, see if we need to stop
if (this->num_events_to_use > 0)
if (num_used_events >= this->num_events_to_use)
break;
}
stop = this->load_listmode_batch();
}
info(boost::format("Number of used events (for all subsets): %1%") % num_used_events);
LM_distributable_computation(this->PM_sptr,
this->proj_data_info_sptr,
&gradient,
&current_estimate,
record_cache,
subset_num,
this->num_subsets,
this->has_add,
/* accumulate = */ icache != 0);
++icache;
}

if (!add_sensitivity)
Expand Down

0 comments on commit b168163

Please sign in to comment.