Skip to content

Commit

Permalink
Add test ensuring profiles are initialized respecting their ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
jacquev6 committed Oct 16, 2023
1 parent 78cfc81 commit fec0731
Showing 1 changed file with 44 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@
#include "probabilistic-maximal-discrimination-power-per-criterion.hpp"

#include "../../../chrones.hpp"
#include "../../../generation.hpp" // Only for tests

#include "../../../vendored/doctest.h" // Keep last because it defines really common names like CHECK that we don't want injected into other headers


namespace lincs {

InitializeProfilesForProbabilisticMaximalDiscriminationPowerPerCriterion::InitializeProfilesForProbabilisticMaximalDiscriminationPowerPerCriterion(LearningData& learning_data_) : learning_data(learning_data_) {
CHRONE();

generators.reserve(learning_data.criteria_count);
for (unsigned criterion_index = 0; criterion_index != learning_data.criteria_count; ++criterion_index) {
generators.emplace_back();
auto& generator = generators.emplace_back();
generator.reserve(learning_data.categories_count - 1);
for (unsigned profile_index = 0; profile_index != learning_data.categories_count - 1; ++profile_index) {
generators.back().emplace_back(get_candidate_probabilities(criterion_index, profile_index));
generator.emplace_back(get_candidate_probabilities(criterion_index, profile_index));
}
}
}
Expand Down Expand Up @@ -88,26 +93,16 @@ void InitializeProfilesForProbabilisticMaximalDiscriminationPowerPerCriterion::i
const unsigned profile_index = category_index - 1;
float value = generators[criterion_index][profile_index](learning_data.urbgs[model_index]);

// Enforce profiles ordering constraint
if (criterion.category_correlation == Criterion::CategoryCorrelation::growing) {
if (profile_index != learning_data.categories_count - 2) {
value = std::min(value, learning_data.profiles[criterion_index][profile_index + 1][model_index]);
}
// @todo(Project management, soon) Add a unit test that triggers the following assertion
// (This will require removing the code to enforce the order of profiles above)
// Then restore the code to enforce the order of profiles
// Note, this assertion does not protect us from initializing a model with two identical profiles.
// Is it really that bad?
assert(
profile_index == learning_data.categories_count - 2
|| learning_data.profiles[criterion_index][profile_index + 1][model_index] >= value);
} else {
assert(criterion.category_correlation == Criterion::CategoryCorrelation::decreasing);
if (profile_index != learning_data.categories_count - 2) {
value = std::max(value, learning_data.profiles[criterion_index][profile_index + 1][model_index]);
}
assert(
profile_index == learning_data.categories_count - 2
|| learning_data.profiles[criterion_index][profile_index + 1][model_index] <= value);
}

learning_data.profiles[criterion_index][profile_index][model_index] = value;
Expand All @@ -116,4 +111,40 @@ void InitializeProfilesForProbabilisticMaximalDiscriminationPowerPerCriterion::i
}
}

TEST_CASE("Initialize profiles - respect ordering") {
Problem problem{
{
Criterion(
"Criterion 1",
Criterion::ValueType::real,
Criterion::CategoryCorrelation::growing,
0.0, 1.0
),
Criterion(
"Criterion 2",
Criterion::ValueType::real,
Criterion::CategoryCorrelation::decreasing,
0.0, 1.0
)
},
{
Category("Category 1"),
Category("Category 2"),
Category("Category 3"),
}
};
Model model = generate_mrsort_classification_model(problem, 42);
auto learning_set = generate_classified_alternatives(problem, model, 1000, 42, 0.1);
auto learning_data = LearnMrsortByWeightsProfilesBreed::LearningData::make(problem, learning_set, 1, 42);
InitializeProfilesForProbabilisticMaximalDiscriminationPowerPerCriterion initializer(learning_data);

for (unsigned iteration = 0; iteration != 10; ++iteration) {
initializer.initialize_profiles(0, 1);
// Both CHECKs fail at least once when the 'Enforce profiles ordering constraint' code is removed
CHECK(learning_data.profiles[0][0][0] <= learning_data.profiles[0][1][0]);
CHECK(learning_data.profiles[1][0][0] >= learning_data.profiles[1][1][0]);
}
}


} // namespace lincs

0 comments on commit fec0731

Please sign in to comment.