diff --git a/lincs/liblincs/learning/mrsort-by-weights-profiles-breed/initialize-profiles/probabilistic-maximal-discrimination-power-per-criterion.cpp b/lincs/liblincs/learning/mrsort-by-weights-profiles-breed/initialize-profiles/probabilistic-maximal-discrimination-power-per-criterion.cpp index 3c8f287a..ccf94d63 100644 --- a/lincs/liblincs/learning/mrsort-by-weights-profiles-breed/initialize-profiles/probabilistic-maximal-discrimination-power-per-criterion.cpp +++ b/lincs/liblincs/learning/mrsort-by-weights-profiles-breed/initialize-profiles/probabilistic-maximal-discrimination-power-per-criterion.cpp @@ -3,6 +3,9 @@ #include "probabilistic-maximal-discrimination-power-per-criterion.hpp" #include "../../../chrones.hpp" +#include "../../../generation.hpp" // Only for tests + +#include "../../../vendored/doctest.h" // Keep last because it defines really common names like CHECK that we don't want injected into other headers namespace lincs { @@ -10,10 +13,12 @@ namespace lincs { InitializeProfilesForProbabilisticMaximalDiscriminationPowerPerCriterion::InitializeProfilesForProbabilisticMaximalDiscriminationPowerPerCriterion(LearningData& learning_data_) : learning_data(learning_data_) { CHRONE(); + generators.reserve(learning_data.criteria_count); for (unsigned criterion_index = 0; criterion_index != learning_data.criteria_count; ++criterion_index) { - generators.emplace_back(); + auto& generator = generators.emplace_back(); + generator.reserve(learning_data.categories_count - 1); for (unsigned profile_index = 0; profile_index != learning_data.categories_count - 1; ++profile_index) { - generators.back().emplace_back(get_candidate_probabilities(criterion_index, profile_index)); + generator.emplace_back(get_candidate_probabilities(criterion_index, profile_index)); } } } @@ -88,26 +93,16 @@ void InitializeProfilesForProbabilisticMaximalDiscriminationPowerPerCriterion::i const unsigned profile_index = category_index - 1; float value = generators[criterion_index][profile_index](learning_data.urbgs[model_index]); + // Enforce profiles ordering constraint if (criterion.category_correlation == Criterion::CategoryCorrelation::growing) { if (profile_index != learning_data.categories_count - 2) { value = std::min(value, learning_data.profiles[criterion_index][profile_index + 1][model_index]); } - // @todo(Project management, soon) Add a unit test that triggers the following assertion - // (This will require removing the code to enforce the order of profiles above) - // Then restore the code to enforce the order of profiles - // Note, this assertion does not protect us from initializing a model with two identical profiles. - // Is it really that bad? - assert( - profile_index == learning_data.categories_count - 2 - || learning_data.profiles[criterion_index][profile_index + 1][model_index] >= value); } else { assert(criterion.category_correlation == Criterion::CategoryCorrelation::decreasing); if (profile_index != learning_data.categories_count - 2) { value = std::max(value, learning_data.profiles[criterion_index][profile_index + 1][model_index]); } - assert( - profile_index == learning_data.categories_count - 2 - || learning_data.profiles[criterion_index][profile_index + 1][model_index] <= value); } learning_data.profiles[criterion_index][profile_index][model_index] = value; @@ -116,4 +111,40 @@ void InitializeProfilesForProbabilisticMaximalDiscriminationPowerPerCriterion::i } } +TEST_CASE("Initialize profiles - respect ordering") { + Problem problem{ + { + Criterion( + "Criterion 1", + Criterion::ValueType::real, + Criterion::CategoryCorrelation::growing, + 0.0, 1.0 + ), + Criterion( + "Criterion 2", + Criterion::ValueType::real, + Criterion::CategoryCorrelation::decreasing, + 0.0, 1.0 + ) + }, + { + Category("Category 1"), + Category("Category 2"), + Category("Category 3"), + } + }; + Model model = generate_mrsort_classification_model(problem, 42); + auto learning_set = generate_classified_alternatives(problem, model, 1000, 42, 0.1); + auto learning_data = LearnMrsortByWeightsProfilesBreed::LearningData::make(problem, learning_set, 1, 42); + InitializeProfilesForProbabilisticMaximalDiscriminationPowerPerCriterion initializer(learning_data); + + for (unsigned iteration = 0; iteration != 10; ++iteration) { + initializer.initialize_profiles(0, 1); + // Both CHECKs fail at least once when the 'Enforce profiles ordering constraint' code is removed + CHECK(learning_data.profiles[0][0][0] <= learning_data.profiles[0][1][0]); + CHECK(learning_data.profiles[1][0][0] >= learning_data.profiles[1][1][0]); + } +} + + } // namespace lincs