From 0834268753a293a4cf6f15db370b13dee538450a Mon Sep 17 00:00:00 2001 From: Vincent Jacques Date: Mon, 16 Oct 2023 12:53:49 +0000 Subject: [PATCH] Fix alternative names with '--max-imbalance' --- lincs/liblincs/generation.cpp | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/lincs/liblincs/generation.cpp b/lincs/liblincs/generation.cpp index 6b4b2075..2bb9b4fb 100644 --- a/lincs/liblincs/generation.cpp +++ b/lincs/liblincs/generation.cpp @@ -203,11 +203,7 @@ Alternatives generate_uniform_classified_alternatives( criteria_values[criterion_index] = values_distributions[criterion_index](gen); } - alternatives.push_back(Alternative{ - "Alternative " + std::to_string(alt_index + 1), - criteria_values, - std::nullopt, - }); + alternatives.push_back(Alternative{"", criteria_values, std::nullopt}); } Alternatives alts{problem, alternatives}; @@ -295,8 +291,6 @@ Alternatives generate_balanced_classified_alternatives( while (min_size > 0) { ++iterations_with_no_effect; - // @todo(Bug, soon) Fix naming of generated alternatives: - // index starts at zero in each call to 'generate_uniform_classified_alternatives' so names are duplicated Alternatives candidates = generate_uniform_classified_alternatives(problem, model, multiplier * alternatives_count, gen); for (const auto& candidate : candidates.alternatives) { @@ -364,11 +358,15 @@ Alternatives generate_classified_alternatives( std::mt19937 gen(random_seed); - if (max_imbalance) { - return generate_balanced_classified_alternatives(problem, model, alternatives_count, *max_imbalance, gen); - } else { - return generate_uniform_classified_alternatives(problem, model, alternatives_count, gen); + Alternatives alternatives = max_imbalance ? + generate_balanced_classified_alternatives(problem, model, alternatives_count, *max_imbalance, gen) : + generate_uniform_classified_alternatives(problem, model, alternatives_count, gen); + + for (unsigned alternative_index = 0; alternative_index != alternatives_count; ++alternative_index) { + alternatives.alternatives[alternative_index].name = "Alternative " + std::to_string(alternative_index + 1); } + + return alternatives; } void check_histogram(const Problem& problem, const Model& model, const std::optional max_imbalance, const unsigned a, const unsigned b) { @@ -395,6 +393,14 @@ TEST_CASE("Generate balanced classified alternatives") { check_histogram(problem, model, 0.0, 50, 50); } +TEST_CASE("Generate balanced classified alternatives - names are correct") { + Problem problem = generate_classification_problem(3, 2, 42); + Model model = generate_mrsort_classification_model(problem, 42, 2); + Alternatives alternatives = generate_classified_alternatives(problem, model, 100, 42, 0.); + + CHECK(alternatives.alternatives[99].name == "Alternative 100"); +} + TEST_CASE("Generate balanced classified alternatives - many seeds") { // Assert that we can generate a balanced learning set for all generated models