Skip to content

Commit

Permalink
Fix alternative names with '--max-imbalance'
Browse files Browse the repository at this point in the history
  • Loading branch information
jacquev6 committed Oct 16, 2023
1 parent b82dd5d commit 0834268
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions lincs/liblincs/generation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,7 @@ Alternatives generate_uniform_classified_alternatives(
criteria_values[criterion_index] = values_distributions[criterion_index](gen);
}

alternatives.push_back(Alternative{
"Alternative " + std::to_string(alt_index + 1),
criteria_values,
std::nullopt,
});
alternatives.push_back(Alternative{"", criteria_values, std::nullopt});
}

Alternatives alts{problem, alternatives};
Expand Down Expand Up @@ -295,8 +291,6 @@ Alternatives generate_balanced_classified_alternatives(
while (min_size > 0) {
++iterations_with_no_effect;

// @todo(Bug, soon) Fix naming of generated alternatives:
// index starts at zero in each call to 'generate_uniform_classified_alternatives' so names are duplicated
Alternatives candidates = generate_uniform_classified_alternatives(problem, model, multiplier * alternatives_count, gen);

for (const auto& candidate : candidates.alternatives) {
Expand Down Expand Up @@ -364,11 +358,15 @@ Alternatives generate_classified_alternatives(

std::mt19937 gen(random_seed);

if (max_imbalance) {
return generate_balanced_classified_alternatives(problem, model, alternatives_count, *max_imbalance, gen);
} else {
return generate_uniform_classified_alternatives(problem, model, alternatives_count, gen);
Alternatives alternatives = max_imbalance ?
generate_balanced_classified_alternatives(problem, model, alternatives_count, *max_imbalance, gen) :
generate_uniform_classified_alternatives(problem, model, alternatives_count, gen);

for (unsigned alternative_index = 0; alternative_index != alternatives_count; ++alternative_index) {
alternatives.alternatives[alternative_index].name = "Alternative " + std::to_string(alternative_index + 1);
}

return alternatives;
}

void check_histogram(const Problem& problem, const Model& model, const std::optional<float> max_imbalance, const unsigned a, const unsigned b) {
Expand All @@ -395,6 +393,14 @@ TEST_CASE("Generate balanced classified alternatives") {
check_histogram(problem, model, 0.0, 50, 50);
}

TEST_CASE("Generate balanced classified alternatives - names are correct") {
Problem problem = generate_classification_problem(3, 2, 42);
Model model = generate_mrsort_classification_model(problem, 42, 2);
Alternatives alternatives = generate_classified_alternatives(problem, model, 100, 42, 0.);

CHECK(alternatives.alternatives[99].name == "Alternative 100");
}

TEST_CASE("Generate balanced classified alternatives - many seeds") {
// Assert that we can generate a balanced learning set for all generated models

Expand Down

0 comments on commit 0834268

Please sign in to comment.