Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regression Modeling - [WIP] #2268

Draft
wants to merge 22 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Libs/Optimize/Optimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1667,6 +1667,12 @@ void Optimize::SetTimePtsPerSubject(int time_pts_per_subject) { this->m_timepts_
//---------------------------------------------------------------------------
int Optimize::GetTimePtsPerSubject() { return this->m_timepts_per_subject; }

//---------------------------------------------------------------------------
void Optimize::SetExplanatoryVariables(std::vector<double> val) { this->m_explanatory_variables = val; }

//---------------------------------------------------------------------------
std::vector<double> Optimize::GetExplanatoryVariables() { return this->m_explanatory_variables; }

//---------------------------------------------------------------------------
void Optimize::SetOptimizationIterations(int optimization_iterations) {
this->m_optimization_iterations = optimization_iterations;
Expand Down
9 changes: 7 additions & 2 deletions Libs/Optimize/Optimize.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,14 @@ class Optimize {
m_mesh_ffc_mode = mesh_ffc_mode;
m_sampler->SetMeshFFCMode(mesh_ffc_mode);
}
//! Set the number of time points per subject (TODO: details)
//! Set the number of time points per subject used for Linear Regression or Mixed Effects Model optimization
void SetTimePtsPerSubject(int time_pts_per_subject);
//! Get the number of time points per subject (TODO: details)
//! Get the number of time points per subject used for Linear Regression or Mixed Effects Model optimization
int GetTimePtsPerSubject();
//! Set Explanatory Variable for | used for Linear Regression or Mixed Effects Model optimization
void SetExplanatoryVariables(std::vector<double> vals);
//! Get the number of time points per subject used for Linear Regression or Mixed Effects Model optimization
std::vector<double> GetExplanatoryVariables();
//! Set the number of optimization iterations
void SetOptimizationIterations(int optimization_iterations);
//! Set the number of optimization iterations already completed (TODO: details)
Expand Down Expand Up @@ -388,6 +392,7 @@ class Optimize {
bool m_mesh_ffc_mode = 0;

unsigned int m_timepts_per_subject = 1;
std::vector<double> m_explanatory_variables;
int m_optimization_iterations = 2000;
int m_optimization_iterations_completed = 0;
int m_iterations_per_split = 1000;
Expand Down
39 changes: 39 additions & 0 deletions Libs/Optimize/OptimizeParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ const std::string checkpointing_interval = "checkpointing_interval";
const std::string save_init_splits = "save_init_splits";
const std::string keep_checkpoints = "keep_checkpoints";
const std::string use_disentangled_ssm = "use_disentangled_ssm";
const std::string use_linear_regression = "use_linear_regression";
const std::string time_points_per_subject = "time_points_per_subject";
const std::string field_attributes = "field_attributes";
const std::string field_attribute_weights = "field_attribute_weights";
const std::string use_geodesics_to_landmarks = "use_geodesics_to_landmarks";
Expand All @@ -53,6 +55,7 @@ const std::string particle_format = "particle_format";
const std::string geodesic_remesh_percent = "geodesic_remesh_percent";
const std::string shared_boundary = "shared_boundary";
const std::string shared_boundary_weight = "shared_boundary_weight";
const std::string output_prefix = "output_prefix";
} // namespace Keys

//---------------------------------------------------------------------------
Expand Down Expand Up @@ -93,8 +96,11 @@ OptimizeParameters::OptimizeParameters(ProjectHandle project) {
Keys::geodesics_to_landmarks_weight,
Keys::keep_checkpoints,
Keys::use_disentangled_ssm,
Keys::use_linear_regression,
Keys::time_points_per_subject,
Keys::particle_format,
Keys::geodesic_remesh_percent,
Keys::output_prefix,
Keys::shared_boundary,
Keys::shared_boundary_weight};

Expand Down Expand Up @@ -201,6 +207,18 @@ bool OptimizeParameters::get_use_disentangled_ssm() { return params_.get(Keys::u
//---------------------------------------------------------------------------
void OptimizeParameters::set_use_disentangled_ssm(bool value) { params_.set(Keys::use_disentangled_ssm, value); }

//---------------------------------------------------------------------------
bool OptimizeParameters::get_use_linear_regression() { return params_.get(Keys::use_linear_regression, false); }

//---------------------------------------------------------------------------
void OptimizeParameters::set_use_linear_regression(bool value) { params_.set(Keys::use_linear_regression, value); }

//---------------------------------------------------------------------------
int OptimizeParameters::get_time_points_per_subject() { return params_.get(Keys::time_points_per_subject, 1); }

//---------------------------------------------------------------------------
void OptimizeParameters::set_time_points_per_subject(int value) { params_.set(Keys::time_points_per_subject, value); }

//---------------------------------------------------------------------------
bool OptimizeParameters::get_use_procrustes_scaling() { return params_.get(Keys::procrustes_scaling, false); }

Expand Down Expand Up @@ -433,6 +451,9 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) {
optimize->SetMeshFFCMode(get_mesh_ffc_mode());
optimize->SetUseDisentangledSpatiotemporalSSM(get_use_disentangled_ssm());
optimize->set_particle_format(get_particle_format());
optimize->SetTimePtsPerSubject(get_time_points_per_subject());
optimize->SetUseRegression(get_use_linear_regression());
optimize->SetUseMixedEffects(get_time_points_per_subject() > 1 ? true : false);
optimize->SetSharedBoundaryEnabled(get_shared_boundary());
optimize->SetSharedBoundaryWeight(get_shared_boundary_weight());

Expand Down Expand Up @@ -628,6 +649,21 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) {
}
}

// get explanatory variables for subjects if used for regression
if (get_use_linear_regression())
{
std::vector<double> exp_vars;
for (const auto& s : subjects) {
exp_vars.push_back(s->get_explanatory_variable());
}
dynamic_cast<LinearRegressionShapeMatrix*>(
optimize->GetSampler()->GetEnsembleRegressionEntropyFunction()->GetShapeMatrix())
->SetExplanatory(exp_vars);
dynamic_cast<MixedEffectsShapeMatrix*>(
optimize->GetSampler()->GetEnsembleMixedEffectsEntropyFunction()->GetShapeMatrix())
->SetExplanatory(exp_vars);
}

std::vector<std::string> filenames;
int count = 0;
domain_count = 0;
Expand Down Expand Up @@ -853,6 +889,9 @@ void OptimizeParameters::set_geodesic_remesh_percent(double value) {
params_.set(Keys::geodesic_remesh_percent, value);
}

//---------------------------------------------------------------------------
void OptimizeParameters::set_output_prefix(std::string value) { params_.set(Keys::output_prefix, value); }

//---------------------------------------------------------------------------
bool OptimizeParameters::get_shared_boundary() { return params_.get(Keys::shared_boundary, false); }

Expand Down
7 changes: 7 additions & 0 deletions Libs/Optimize/OptimizeParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class OptimizeParameters {
bool get_use_disentangled_ssm();
void set_use_disentangled_ssm(bool value);

bool get_use_linear_regression();
void set_use_linear_regression(bool value);

int get_time_points_per_subject();
void set_time_points_per_subject(int value);

bool get_use_procrustes();
void set_use_procrustes(bool value);

Expand Down Expand Up @@ -137,6 +143,7 @@ class OptimizeParameters {

double get_shared_boundary_weight();
void set_shared_boundary_weight(double value);
void set_output_prefix(std::string value);

private:
std::string get_output_prefix();
Expand Down
21 changes: 21 additions & 0 deletions Libs/Particles/ParticleShapeStatistics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,25 @@ ParticleShapeStatistics::ParticleShapeStatistics(std::shared_ptr<Project> projec
groups.push_back(1);
}
import_points(points, groups);
// TODO: importing regression params doesn't make sense here. take a look again later.
}

//---------------------------------------------------------------------------
Eigen::VectorXd ParticleShapeStatistics::compute_regression_mean(
const std::vector<double>& explanatory_variables) const {
std::cout << "Computing mean for regression" << std::endl;
Eigen::VectorXd t = Eigen::Map<const Eigen::VectorXd>(
explanatory_variables.data(), explanatory_variables.size());

// Ensure slope and intercept are initialized
if (slope_.size() == 0 || intercept_.size() == 0) {
throw std::runtime_error("Slope and Intercept not initialized yet!");
}

if (t.size() == 1)
return slope_ + intercept_ * t[0];
else
return slope_ + intercept_.cwiseProduct(t);
}

//---------------------------------------------------------------------------
Expand Down Expand Up @@ -694,4 +713,6 @@ Eigen::MatrixXd ParticleShapeStatistics::get_group1_matrix() const { return grou
//---------------------------------------------------------------------------
Eigen::MatrixXd ParticleShapeStatistics::get_group2_matrix() const { return group2_matrix_; }



} // namespace shapeworks
9 changes: 9 additions & 0 deletions Libs/Particles/ParticleShapeStatistics.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class ParticleShapeStatistics {

//! Returns the mean shape.
const Eigen::VectorXd& get_mean() const { return mean_; }

Eigen::VectorXd compute_regression_mean(const std::vector<double>& explanatory_variables) const;
const Eigen::VectorXd& get_group1_mean() const { return mean1_; }
const Eigen::VectorXd& get_group2_mean() const { return mean2_; }

Expand Down Expand Up @@ -135,6 +137,9 @@ class ParticleShapeStatistics {
//! Set the meshes for each sample (used for some evaluation metrics)
void set_meshes(const std::vector<Mesh>& meshes) { meshes_ = meshes; }

// import estimated parameters for regression
inline bool import_regression_parameters(Eigen::VectorXd slope, Eigen::VectorXd intercept) { slope_ = slope; intercept_ = intercept; return true;};

private:
unsigned int num_samples_group1_;
unsigned int num_samples_group2_;
Expand All @@ -150,6 +155,10 @@ class ParticleShapeStatistics {
Eigen::VectorXd mean2_;
Eigen::MatrixXd points_minus_mean_;
Eigen::MatrixXd shapes_;

// for regression tasks
Eigen::VectorXd slope_;
Eigen::VectorXd intercept_;

std::vector<double> percent_variance_by_mode_;
Eigen::MatrixXd principals_;
Expand Down
4 changes: 4 additions & 0 deletions Libs/Project/ProjectReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ void ProjectReader::load_subjects(StringMapList list) {
if (contains(item, "excluded")) {
subject->set_excluded(Variant(item["excluded"]));
}
if (contains(item, "explanatory_variable")) {
subject->set_explanatory_variable(Variant(item["explanatory_variable"]));
}

if (name.empty()) {
if (!subject->get_original_filenames().empty()) {
name = StringUtils::getBaseFilenameWithoutExtension(subject->get_original_filenames()[0]);
Expand Down
6 changes: 6 additions & 0 deletions Libs/Project/Subject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ void Subject::set_fixed(bool fixed) { fixed_ = fixed; }
//---------------------------------------------------------------------------
bool Subject::is_excluded() { return excluded_; }

//---------------------------------------------------------------------------
void Subject::set_explanatory_variable(double val) { explanatory_variable_ = val; }

//---------------------------------------------------------------------------
double Subject::get_explanatory_variable() { return explanatory_variable_; }

//---------------------------------------------------------------------------
void Subject::set_excluded(bool excluded) { excluded_ = excluded; }

Expand Down
6 changes: 6 additions & 0 deletions Libs/Project/Subject.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ class Subject {
//! Set if this subject is excluded or not
void set_excluded(bool excluded);

//! Get the explanatory variable defined for the subject, used for Linear Regression and Mixed Effects Model for optimization
double get_explanatory_variable();
//! Set the explanatory variable defined for the subject, used for Linear Regression and Mixed Effects Model for optimization
void set_explanatory_variable(double val);

//! Get the notes
std::string get_notes();
//! Set the notes
Expand All @@ -118,6 +123,7 @@ class Subject {
std::string display_name_;
bool fixed_ = false;
bool excluded_ = false;
double explanatory_variable_ = std::numeric_limits<double>::lowest();
StringList original_filenames_;
StringList groomed_filenames_;
StringList local_particle_filenames_;
Expand Down
Loading
Loading