diff --git a/Libs/Optimize/Optimize.cpp b/Libs/Optimize/Optimize.cpp index fcadbb32fa..dbcec64210 100644 --- a/Libs/Optimize/Optimize.cpp +++ b/Libs/Optimize/Optimize.cpp @@ -1667,6 +1667,12 @@ void Optimize::SetTimePtsPerSubject(int time_pts_per_subject) { this->m_timepts_ //--------------------------------------------------------------------------- int Optimize::GetTimePtsPerSubject() { return this->m_timepts_per_subject; } +//--------------------------------------------------------------------------- +void Optimize::SetExplanatoryVariables(std::vector val) { this->m_explanatory_variables = val; } + +//--------------------------------------------------------------------------- +std::vector Optimize::GetExplanatoryVariables() { return this->m_explanatory_variables; } + //--------------------------------------------------------------------------- void Optimize::SetOptimizationIterations(int optimization_iterations) { this->m_optimization_iterations = optimization_iterations; diff --git a/Libs/Optimize/Optimize.h b/Libs/Optimize/Optimize.h index 3f86a2f059..2744bd5842 100644 --- a/Libs/Optimize/Optimize.h +++ b/Libs/Optimize/Optimize.h @@ -158,10 +158,14 @@ class Optimize { m_mesh_ffc_mode = mesh_ffc_mode; m_sampler->SetMeshFFCMode(mesh_ffc_mode); } - //! Set the number of time points per subject (TODO: details) + //! Set the number of time points per subject used for Linear Regression or Mixed Effects Model optimization void SetTimePtsPerSubject(int time_pts_per_subject); - //! Get the number of time points per subject (TODO: details) + //! Get the number of time points per subject used for Linear Regression or Mixed Effects Model optimization int GetTimePtsPerSubject(); + //! Set Explanatory Variable for | used for Linear Regression or Mixed Effects Model optimization + void SetExplanatoryVariables(std::vector vals); + //! Get the number of time points per subject used for Linear Regression or Mixed Effects Model optimization + std::vector GetExplanatoryVariables(); //! Set the number of optimization iterations void SetOptimizationIterations(int optimization_iterations); //! Set the number of optimization iterations already completed (TODO: details) @@ -388,6 +392,7 @@ class Optimize { bool m_mesh_ffc_mode = 0; unsigned int m_timepts_per_subject = 1; + std::vector m_explanatory_variables; int m_optimization_iterations = 2000; int m_optimization_iterations_completed = 0; int m_iterations_per_split = 1000; diff --git a/Libs/Optimize/OptimizeParameters.cpp b/Libs/Optimize/OptimizeParameters.cpp index d7f8786002..ae6e97a883 100644 --- a/Libs/Optimize/OptimizeParameters.cpp +++ b/Libs/Optimize/OptimizeParameters.cpp @@ -45,6 +45,8 @@ const std::string checkpointing_interval = "checkpointing_interval"; const std::string save_init_splits = "save_init_splits"; const std::string keep_checkpoints = "keep_checkpoints"; const std::string use_disentangled_ssm = "use_disentangled_ssm"; +const std::string use_linear_regression = "use_linear_regression"; +const std::string time_points_per_subject = "time_points_per_subject"; const std::string field_attributes = "field_attributes"; const std::string field_attribute_weights = "field_attribute_weights"; const std::string use_geodesics_to_landmarks = "use_geodesics_to_landmarks"; @@ -53,6 +55,7 @@ const std::string particle_format = "particle_format"; const std::string geodesic_remesh_percent = "geodesic_remesh_percent"; const std::string shared_boundary = "shared_boundary"; const std::string shared_boundary_weight = "shared_boundary_weight"; +const std::string output_prefix = "output_prefix"; } // namespace Keys //--------------------------------------------------------------------------- @@ -93,8 +96,11 @@ OptimizeParameters::OptimizeParameters(ProjectHandle project) { Keys::geodesics_to_landmarks_weight, Keys::keep_checkpoints, Keys::use_disentangled_ssm, + Keys::use_linear_regression, + Keys::time_points_per_subject, Keys::particle_format, Keys::geodesic_remesh_percent, + Keys::output_prefix, Keys::shared_boundary, Keys::shared_boundary_weight}; @@ -201,6 +207,18 @@ bool OptimizeParameters::get_use_disentangled_ssm() { return params_.get(Keys::u //--------------------------------------------------------------------------- void OptimizeParameters::set_use_disentangled_ssm(bool value) { params_.set(Keys::use_disentangled_ssm, value); } +//--------------------------------------------------------------------------- +bool OptimizeParameters::get_use_linear_regression() { return params_.get(Keys::use_linear_regression, false); } + +//--------------------------------------------------------------------------- +void OptimizeParameters::set_use_linear_regression(bool value) { params_.set(Keys::use_linear_regression, value); } + +//--------------------------------------------------------------------------- +int OptimizeParameters::get_time_points_per_subject() { return params_.get(Keys::time_points_per_subject, 1); } + +//--------------------------------------------------------------------------- +void OptimizeParameters::set_time_points_per_subject(int value) { params_.set(Keys::time_points_per_subject, value); } + //--------------------------------------------------------------------------- bool OptimizeParameters::get_use_procrustes_scaling() { return params_.get(Keys::procrustes_scaling, false); } @@ -433,6 +451,9 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) { optimize->SetMeshFFCMode(get_mesh_ffc_mode()); optimize->SetUseDisentangledSpatiotemporalSSM(get_use_disentangled_ssm()); optimize->set_particle_format(get_particle_format()); + optimize->SetTimePtsPerSubject(get_time_points_per_subject()); + optimize->SetUseRegression(get_use_linear_regression()); + optimize->SetUseMixedEffects(get_time_points_per_subject() > 1 ? true : false); optimize->SetSharedBoundaryEnabled(get_shared_boundary()); optimize->SetSharedBoundaryWeight(get_shared_boundary_weight()); @@ -628,6 +649,21 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) { } } + // get explanatory variables for subjects if used for regression + if (get_use_linear_regression()) + { + std::vector exp_vars; + for (const auto& s : subjects) { + exp_vars.push_back(s->get_explanatory_variable()); + } + dynamic_cast( + optimize->GetSampler()->GetEnsembleRegressionEntropyFunction()->GetShapeMatrix()) + ->SetExplanatory(exp_vars); + dynamic_cast( + optimize->GetSampler()->GetEnsembleMixedEffectsEntropyFunction()->GetShapeMatrix()) + ->SetExplanatory(exp_vars); + } + std::vector filenames; int count = 0; domain_count = 0; @@ -853,6 +889,9 @@ void OptimizeParameters::set_geodesic_remesh_percent(double value) { params_.set(Keys::geodesic_remesh_percent, value); } +//--------------------------------------------------------------------------- +void OptimizeParameters::set_output_prefix(std::string value) { params_.set(Keys::output_prefix, value); } + //--------------------------------------------------------------------------- bool OptimizeParameters::get_shared_boundary() { return params_.get(Keys::shared_boundary, false); } diff --git a/Libs/Optimize/OptimizeParameters.h b/Libs/Optimize/OptimizeParameters.h index 6e37f063b7..99727771e6 100644 --- a/Libs/Optimize/OptimizeParameters.h +++ b/Libs/Optimize/OptimizeParameters.h @@ -60,6 +60,12 @@ class OptimizeParameters { bool get_use_disentangled_ssm(); void set_use_disentangled_ssm(bool value); + bool get_use_linear_regression(); + void set_use_linear_regression(bool value); + + int get_time_points_per_subject(); + void set_time_points_per_subject(int value); + bool get_use_procrustes(); void set_use_procrustes(bool value); @@ -137,6 +143,7 @@ class OptimizeParameters { double get_shared_boundary_weight(); void set_shared_boundary_weight(double value); + void set_output_prefix(std::string value); private: std::string get_output_prefix(); diff --git a/Libs/Particles/ParticleShapeStatistics.cpp b/Libs/Particles/ParticleShapeStatistics.cpp index 288eae4a44..f415567e3c 100644 --- a/Libs/Particles/ParticleShapeStatistics.cpp +++ b/Libs/Particles/ParticleShapeStatistics.cpp @@ -492,6 +492,25 @@ ParticleShapeStatistics::ParticleShapeStatistics(std::shared_ptr projec groups.push_back(1); } import_points(points, groups); + // TODO: importing regression params doesn't make sense here. take a look again later. +} + +//--------------------------------------------------------------------------- +Eigen::VectorXd ParticleShapeStatistics::compute_regression_mean( + const std::vector& explanatory_variables) const { + std::cout << "Computing mean for regression" << std::endl; + Eigen::VectorXd t = Eigen::Map( + explanatory_variables.data(), explanatory_variables.size()); + + // Ensure slope and intercept are initialized + if (slope_.size() == 0 || intercept_.size() == 0) { + throw std::runtime_error("Slope and Intercept not initialized yet!"); + } + + if (t.size() == 1) + return slope_ + intercept_ * t[0]; + else + return slope_ + intercept_.cwiseProduct(t); } //--------------------------------------------------------------------------- @@ -694,4 +713,6 @@ Eigen::MatrixXd ParticleShapeStatistics::get_group1_matrix() const { return grou //--------------------------------------------------------------------------- Eigen::MatrixXd ParticleShapeStatistics::get_group2_matrix() const { return group2_matrix_; } + + } // namespace shapeworks diff --git a/Libs/Particles/ParticleShapeStatistics.h b/Libs/Particles/ParticleShapeStatistics.h index 28b0972676..2a1602b806 100644 --- a/Libs/Particles/ParticleShapeStatistics.h +++ b/Libs/Particles/ParticleShapeStatistics.h @@ -88,6 +88,8 @@ class ParticleShapeStatistics { //! Returns the mean shape. const Eigen::VectorXd& get_mean() const { return mean_; } + + Eigen::VectorXd compute_regression_mean(const std::vector& explanatory_variables) const; const Eigen::VectorXd& get_group1_mean() const { return mean1_; } const Eigen::VectorXd& get_group2_mean() const { return mean2_; } @@ -135,6 +137,9 @@ class ParticleShapeStatistics { //! Set the meshes for each sample (used for some evaluation metrics) void set_meshes(const std::vector& meshes) { meshes_ = meshes; } + // import estimated parameters for regression + inline bool import_regression_parameters(Eigen::VectorXd slope, Eigen::VectorXd intercept) { slope_ = slope; intercept_ = intercept; return true;}; + private: unsigned int num_samples_group1_; unsigned int num_samples_group2_; @@ -150,6 +155,10 @@ class ParticleShapeStatistics { Eigen::VectorXd mean2_; Eigen::MatrixXd points_minus_mean_; Eigen::MatrixXd shapes_; + + // for regression tasks + Eigen::VectorXd slope_; + Eigen::VectorXd intercept_; std::vector percent_variance_by_mode_; Eigen::MatrixXd principals_; diff --git a/Libs/Project/ProjectReader.cpp b/Libs/Project/ProjectReader.cpp index d7b4142cd8..9029174296 100644 --- a/Libs/Project/ProjectReader.cpp +++ b/Libs/Project/ProjectReader.cpp @@ -68,6 +68,10 @@ void ProjectReader::load_subjects(StringMapList list) { if (contains(item, "excluded")) { subject->set_excluded(Variant(item["excluded"])); } + if (contains(item, "explanatory_variable")) { + subject->set_explanatory_variable(Variant(item["explanatory_variable"])); + } + if (name.empty()) { if (!subject->get_original_filenames().empty()) { name = StringUtils::getBaseFilenameWithoutExtension(subject->get_original_filenames()[0]); diff --git a/Libs/Project/Subject.cpp b/Libs/Project/Subject.cpp index 43b1bda204..57c1eb5d50 100644 --- a/Libs/Project/Subject.cpp +++ b/Libs/Project/Subject.cpp @@ -100,6 +100,12 @@ void Subject::set_fixed(bool fixed) { fixed_ = fixed; } //--------------------------------------------------------------------------- bool Subject::is_excluded() { return excluded_; } +//--------------------------------------------------------------------------- +void Subject::set_explanatory_variable(double val) { explanatory_variable_ = val; } + +//--------------------------------------------------------------------------- +double Subject::get_explanatory_variable() { return explanatory_variable_; } + //--------------------------------------------------------------------------- void Subject::set_excluded(bool excluded) { excluded_ = excluded; } diff --git a/Libs/Project/Subject.h b/Libs/Project/Subject.h index 3b3f176634..585cd07590 100644 --- a/Libs/Project/Subject.h +++ b/Libs/Project/Subject.h @@ -107,6 +107,11 @@ class Subject { //! Set if this subject is excluded or not void set_excluded(bool excluded); + //! Get the explanatory variable defined for the subject, used for Linear Regression and Mixed Effects Model for optimization + double get_explanatory_variable(); + //! Set the explanatory variable defined for the subject, used for Linear Regression and Mixed Effects Model for optimization + void set_explanatory_variable(double val); + //! Get the notes std::string get_notes(); //! Set the notes @@ -118,6 +123,7 @@ class Subject { std::string display_name_; bool fixed_ = false; bool excluded_ = false; + double explanatory_variable_ = std::numeric_limits::lowest(); StringList original_filenames_; StringList groomed_filenames_; StringList local_particle_filenames_; diff --git a/Studio/Analysis/AnalysisTool.cpp b/Studio/Analysis/AnalysisTool.cpp index 0a4d9b668d..e3a3127bfd 100644 --- a/Studio/Analysis/AnalysisTool.cpp +++ b/Studio/Analysis/AnalysisTool.cpp @@ -269,6 +269,11 @@ void AnalysisTool::on_reconstructionButton_clicked() { //--------------------------------------------------------------------------- int AnalysisTool::get_pca_mode() { return ui_->pcaModeSpinBox->value() - 1; } +//--------------------------------------------------------------------------- +bool AnalysisTool::get_regression_analysis_status() { + return ui_->enableRegressionCheckBox->isChecked(); +} + //--------------------------------------------------------------------------- double AnalysisTool::get_group_ratio() { double group_slider_value = ui_->group_slider->value(); @@ -497,6 +502,36 @@ void AnalysisTool::network_analysis_clicked() { app_->get_py_worker()->run_job(network_analysis_job_); } +//--------------------------------------------------------------------------- +Eigen::VectorXd load_regression_parameters(std::string filepath) { + std::ifstream infile(slope_file_path); + if (!infile.good()) { + throw std::runtime_error("Unable to open regression parameter file: \"" + + filepath + "\" for reading"); + } + try { + std::vector temp_values; + double value; + while (infile >> value) { + temp_values.push_back(value); + } + if (temp_values.empty()) { + std::cerr << "Error: No data found in file " << slope_file_path + << std::endl; + return Eigen::VectorXd(); + } + Eigen::VectorXd param_vector(temp_values.size()); + for (std::size_t i = 0; i < temp_values.size(); ++i) { + param_vector[i] = temp_values[i]; + } + return param_vector; + + } catch (json::exception& e) { + throw std::runtime_error("Unabled to parse regression parameter file " + + filepath + " : " + e.what()); + } +} + //----------------------------------------------------------------------------- bool AnalysisTool::compute_stats() { if (stats_ready_) { @@ -619,6 +654,24 @@ bool AnalysisTool::compute_stats() { compute_shape_evaluations(); } + can_run_regression_ = check_explanatory_variable_limits(); + std::cout << "can run regression set to " << can_run_regression_ << std::endl; + if (can_run_regression_) { + auto slope = load_regression_parameters( + session_->get_regression_param_file("slope")); // dM vector + auto intercept = load_regression_parameters( + session_->get_regression_param_file("intercept")); // dM vector + stats_.import_regression_parameters(slope, intercept); // set slope and intercept in stats object + ui_->regression_groupbox->setVisible(true); + ui_->explanatoryVariableSlider->setVisible(true); + ui_->enableRegressionCheckBox->setVisible(true); + } + else { + ui_->regression_groupbox->setVisible(false); + ui_->explanatoryVariableSlider->setVisible(false); + ui_->enableRegressionCheckBox->setVisible(false); + } + stats_ready_ = true; /// Set this to true to export long format sample data (e.g. for import into R) @@ -659,6 +712,20 @@ bool AnalysisTool::compute_stats() { return true; } +//--------------------------------------------------------------------------- +bool check_explanatory_variable_limits() { + auto subjects = session_->get_project()->get_subjects(); + explanatory_variable_limits_.resize(2, 0.0); + explanatory_variable_limits_[0] = std::numeric_limits::max(); + explanatory_variable_limits_[1] = std::numeric_limits::lowest(); + for (auto sub : subjects) { + double exp_val = sub->get_explanatory_variable(); + if (exp_val == std::numeric_limits::lowest()) return false; + explanatory_variable_limits_[0] = std::min(explanatory_variable_limits_[0], exp_val); + explanatory_variable_limits_[1] = std::max(explanatory_variable_limits_[1], exp_val); + } + return true; +} //----------------------------------------------------------------------------- Particles AnalysisTool::get_mean_shape_points() { if (!compute_stats()) { @@ -721,8 +788,8 @@ Particles AnalysisTool::get_shape_points(int mode, double value) { ui_->explained_variance->setText(""); ui_->cumulative_explained_variance->setText(""); } - - temp_shape_ = stats_.get_mean() + (e * (value * lambda)); + auto mean = !get_regression_analysis_status() ? stats_.get_mean() : stats_.compute_regression_mean(get_explanatory_variable_value()); + temp_shape_ = mean + (e * (value * lambda)); auto positions = temp_shape_; @@ -861,6 +928,8 @@ void AnalysisTool::store_settings() { params.set("network_pvalue_of_interest", ui_->network_pvalue_of_interest->text().toStdString()); params.set("network_pvalue_threshold", ui_->network_pvalue_threshold->text().toStdString()); + // params.set("regression_slope", session->) + session_->get_project()->set_parameters(Parameters::ANALYSIS_PARAMS, params); } @@ -985,6 +1054,13 @@ void AnalysisTool::on_pcaSlider_valueChanged() { Q_EMIT pca_update(); } +//--------------------------------------------------------------------------- +void AnalysisTool::on_explanatoryVariableSlider_valueChanged() { + // this will make the slider handle redraw making the UI appear more responsive + QCoreApplication::processEvents(); + Q_EMIT pca_update(); +} + //--------------------------------------------------------------------------- void AnalysisTool::on_group_slider_valueChanged() { // this will make the slider handle redraw making the UI appear more responsive @@ -1085,6 +1161,14 @@ double AnalysisTool::get_pca_value() { return value; } + +std::vector AnalysisTool::get_explanatory_variable_value() { + int slider_value = ui_->explanatoryVariableSlider->value(); + // return {t_min + (static_cast(slider_value) / 100.0) * (t_max - t_min)}; + return {explanatory_variable_limits_[0] + (static_cast(slider_value) / 100.0) * (explanatory_variable_limits_[1] - explanatory_variable_limits_[0])}; + +} + //--------------------------------------------------------------------------- void AnalysisTool::pca_labels_changed(QString value, QString eigen, QString lambda) { set_labels(QString("pca"), value); @@ -1103,6 +1187,7 @@ void AnalysisTool::update_slider() { void AnalysisTool::reset_stats() { stats_ready_ = false; evals_ready_ = false; + can_run_regression_ = false; ui_->tabWidget->setCurrentWidget(ui_->mean_tab); ui_->allSamplesRadio->setChecked(true); diff --git a/Studio/Analysis/AnalysisTool.h b/Studio/Analysis/AnalysisTool.h index d6198e972a..ae35bb5347 100644 --- a/Studio/Analysis/AnalysisTool.h +++ b/Studio/Analysis/AnalysisTool.h @@ -77,9 +77,15 @@ class AnalysisTool : public QWidget { double get_pca_value(); + std::vector get_explanatory_variable_value(); + bool pca_animate(); McaMode get_mca_level() const; + bool get_regression_analysis_status(); + + bool check_explanatory_variable_limits(); + int get_sample_number(); bool compute_stats(); @@ -135,6 +141,7 @@ class AnalysisTool : public QWidget { // PCA void on_pcaSlider_valueChanged(); + void on_explanatoryVariableSlider_valueChanged(); void on_group_slider_valueChanged(); void on_pcaModeSpinBox_valueChanged(int i); @@ -236,6 +243,8 @@ class AnalysisTool : public QWidget { void update_difference_particles(); Eigen::VectorXd get_mean_shape_particles(); + + Eigen::VectorXd load_regression_parameters(std::string filepath); ShapeHandle create_shape_from_points(Particles points); @@ -275,6 +284,9 @@ class AnalysisTool : public QWidget { std::string feature_map_; + std::vector explanatory_variable_limits_; + bool can_run_regression_; // decide if necessary variables are present to run regression in analysis + std::vector current_group_names_; std::vector current_group_values_; diff --git a/Studio/Analysis/AnalysisTool.ui b/Studio/Analysis/AnalysisTool.ui index da73800f75..a2316e5b90 100644 --- a/Studio/Analysis/AnalysisTool.ui +++ b/Studio/Analysis/AnalysisTool.ui @@ -1546,6 +1546,69 @@ QWidget#particles_panel { + + + + Regression Analysis + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + Enable regression analysis if explanatory variables are provided + + + Enable Regression Analysis + + + + + + + Exp.Var. + + + + + + + false + + + Explanatory variable slider + + + -20 + + + 20 + + + Qt::Horizontal + + + QSlider::TicksBelow + + + 1 + + + + + + diff --git a/Studio/Data/Session.cpp b/Studio/Data/Session.cpp index 8513b99a1d..2feedec085 100644 --- a/Studio/Data/Session.cpp +++ b/Studio/Data/Session.cpp @@ -837,6 +837,8 @@ void Session::new_plane_point(PickResult result) { //--------------------------------------------------------------------------- QString Session::get_filename() { return filename_; } +// QString Session::get_parent_dir() + //--------------------------------------------------------------------------- int Session::get_num_shapes() { return shapes_.size(); } @@ -1416,4 +1418,19 @@ void Session::recompute_surfaces() { } Q_EMIT update_display(); } + +std::string get_regression_param_file(std::string param_name) { + QFileInfo fileInfo(filename_); + QString baseName = fileInfo.completeBaseName(); + + QDir projectDir = fileInfo.absoluteDir(); + QString particlesDir = baseName + "_particles"; + QString paramFilePath = projectDir.filePath(particlesDir); + paramFilePath = QDir(paramFilePath).filePath(param_name); + + if (!QFile::exists(paramFilePath)) { + return ""; + } + return paramFilePath.toStdString(); +} } // namespace shapeworks diff --git a/Studio/Data/Session.h b/Studio/Data/Session.h index a7d9f7a3f7..a445ec8b55 100644 --- a/Studio/Data/Session.h +++ b/Studio/Data/Session.h @@ -339,6 +339,8 @@ class Session : public QObject, public QEnableSharedFromThis { void new_plane_point(PickResult result); + std::string get_regression_param_file(std::string param_name = "slope"); + QWidget* parent_{nullptr}; Preferences& preferences_; diff --git a/Studio/Optimize/OptimizeTool.cpp b/Studio/Optimize/OptimizeTool.cpp index c5225f3639..705c0771c3 100644 --- a/Studio/Optimize/OptimizeTool.cpp +++ b/Studio/Optimize/OptimizeTool.cpp @@ -33,6 +33,7 @@ OptimizeTool::OptimizeTool(Preferences& prefs, Telemetry& telemetry) : preferenc connect(ui_->use_normals, &QCheckBox::toggled, this, &OptimizeTool::update_ui_elements); connect(ui_->procrustes, &QCheckBox::toggled, this, &OptimizeTool::update_ui_elements); connect(ui_->multiscale, &QCheckBox::toggled, this, &OptimizeTool::update_ui_elements); + connect(ui_->use_linear_regression, &QCheckBox::toggled, this, &OptimizeTool::update_ui_elements); connect(ui_->use_geodesics_from_landmarks, &QCheckBox::toggled, this, &OptimizeTool::update_ui_elements); connect(ui_->use_geodesic_distance, &QCheckBox::toggled, this, &OptimizeTool::update_ui_elements); @@ -63,7 +64,10 @@ OptimizeTool::OptimizeTool(Preferences& prefs, Telemetry& telemetry) : preferenc "It has no effect on the optimization"); ui_->shared_boundary->setToolTip("Use shared boundary optimization"); ui_->shared_boundary_weight->setToolTip("Weight of shared boundary optimization"); - ui_->use_disentangled_ssm->setToolTip("Use disentangled Optimization technique to build spatiotemporal SSM."); + ui_->use_disentangled_ssm->setToolTip("Use the disentangled optimization technique to build spatiotemporal SSM."); + ui_->use_linear_regression->setToolTip("Use the linear regression optimization technique, where correspondence particle optimization is performed by regressing shape against an explanatory variable. Ensure that the explanatory variable is specified in the data tab of the project file."); + ui_->time_points_per_subject->setToolTip("Number of timepoints/explanatory variables defined for each subject in the data tab. Note: More than 1 timepoint uses a mixed effects model; 1 timepoint uses linear regression. "); + // hidden for 6.5 release ui_->disentangled_label->hide(); @@ -71,6 +75,8 @@ OptimizeTool::OptimizeTool(Preferences& prefs, Telemetry& telemetry) : preferenc QIntValidator* above_zero = new QIntValidator(1, std::numeric_limits::max(), this); QIntValidator* zero_and_up = new QIntValidator(0, std::numeric_limits::max(), this); + QIntValidator* one_and_up = new QIntValidator(1, std::numeric_limits::max(), this); + QDoubleValidator* double_validator = new QDoubleValidator(0, std::numeric_limits::max(), 1000, this); @@ -86,6 +92,7 @@ OptimizeTool::OptimizeTool(Preferences& prefs, Telemetry& telemetry) : preferenc ui_->multiscale_particles->setValidator(above_zero); ui_->narrow_band->setValidator(double_validator); ui_->geodesics_to_landmarks_weight->setValidator(double_validator); + ui_->time_points_per_subject->setValidator(one_and_up); ui_->shared_boundary_weight->setValidator(double_validator); line_edits_.push_back(ui_->number_of_particles); @@ -100,7 +107,7 @@ OptimizeTool::OptimizeTool(Preferences& prefs, Telemetry& telemetry) : preferenc line_edits_.push_back(ui_->multiscale_particles); line_edits_.push_back(ui_->geodesics_to_landmarks_weight); line_edits_.push_back(ui_->narrow_band); - line_edits_.push_back(ui_->shared_boundary_weight); + for (QLineEdit* line_edit : line_edits_) { connect(line_edit, &QLineEdit::textChanged, this, &OptimizeTool::update_run_button); @@ -287,6 +294,9 @@ void OptimizeTool::load_params() { ui_->geodesics_to_landmarks_weight->setText(QString::number(params.get_geodesic_to_landmarks_weight())); ui_->use_disentangled_ssm->setChecked(params.get_use_disentangled_ssm()); + ui_->use_linear_regression->setChecked(params.get_use_linear_regression()); + ui_->time_points_per_subject->setText(QString::number(params.get_time_points_per_subject())); + ui_->procrustes->setChecked(params.get_use_procrustes()); ui_->procrustes_scaling->setChecked(params.get_use_procrustes_scaling()); ui_->procrustes_rotation_translation->setChecked(params.get_use_procrustes_rotation_translation()); @@ -334,6 +344,8 @@ void OptimizeTool::store_params() { params.set_use_geodesics_to_landmarks(ui_->use_geodesics_from_landmarks->isChecked()); params.set_geodesic_to_landmarks_weight(ui_->geodesics_to_landmarks_weight->text().toDouble()); params.set_use_disentangled_ssm(ui_->use_disentangled_ssm->isChecked()); + params.set_use_linear_regression(ui_->use_linear_regression->isChecked()); + params.set_time_points_per_subject(ui_->time_points_per_subject->text().toDouble()); params.set_use_procrustes(ui_->procrustes->isChecked()); params.set_use_procrustes_scaling(ui_->procrustes_scaling->isChecked()); @@ -382,6 +394,7 @@ void OptimizeTool::update_ui_elements() { ui_->procrustes_rotation_translation->setEnabled(ui_->procrustes->isChecked()); ui_->procrustes_interval->setEnabled(ui_->procrustes->isChecked()); ui_->multiscale_particles->setEnabled(ui_->multiscale->isChecked()); + ui_->time_points_per_subject->setEnabled(ui_->use_linear_regression->isChecked()); ui_->geodesics_to_landmarks_weight->setEnabled(ui_->use_geodesics_from_landmarks->isChecked()); ui_->geodesic_remesh_percent->setEnabled(ui_->use_geodesic_distance->isChecked()); diff --git a/Studio/Optimize/OptimizeTool.h b/Studio/Optimize/OptimizeTool.h index e62b58a896..30da28d10c 100644 --- a/Studio/Optimize/OptimizeTool.h +++ b/Studio/Optimize/OptimizeTool.h @@ -37,6 +37,8 @@ Q_OBJECT; //! Load params from project void load_params(); //! Store params to project + + void store_params(); //! Enable action buttons @@ -94,5 +96,6 @@ public Q_SLOTS: QElapsedTimer elapsed_timer_; Ui_OptimizeTool* ui_; + }; } diff --git a/Studio/Optimize/OptimizeTool.ui b/Studio/Optimize/OptimizeTool.ui index 2f1bc5d173..f741c6c05e 100644 --- a/Studio/Optimize/OptimizeTool.ui +++ b/Studio/Optimize/OptimizeTool.ui @@ -302,6 +302,13 @@ QWidget#optimize_panel { + + + + Use Linear Regression + + + @@ -428,6 +435,13 @@ QWidget#optimize_panel { + + + + Timepoints defined per subject + + + @@ -551,6 +565,16 @@ QWidget#optimize_panel { + + + + 1 + + + Qt::AlignCenter + + + @@ -867,6 +891,67 @@ QWidget#optimize_panel { + + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + + + + 100 + + + Qt::AlignCenter + + + @@ -1223,6 +1308,9 @@ QWidget#optimize_panel { use_landmarks narrow_band use_disentangled_ssm + use_linear_regression + time_points_per_subject +