Skip to content

Commit

Permalink
fix: ensure the sort manner is applied to the default CH config (#1258)
Browse files Browse the repository at this point in the history
This pull request corrects a bug that occurs when an entity selector
configuration is created without using the sorting settings,
particularly when the related entity has a difficulty weight factory
defined `@PlanningEntity(difficultyWeightFactoryClass...`.

The bug only affects solver configurations that do not define a phase
list, as the solver sets the entity selector configuration before inner
operations update the heuristic policy configuration. The issue does not
affect the value selector config because no config is created by default
(`DefaultSolverFactory#buildPhaseList`).
  • Loading branch information
zepfred authored Dec 6, 2024
1 parent c6c30e9 commit 3a1ef62
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ public static <Solution_> EntitySelectorConfig getDefaultEntitySelectorConfigFor
EntitySelectorConfig entitySelectorConfig = new EntitySelectorConfig()
.withId(entityClass.getName())
.withEntityClass(entityClass);
if (EntitySelectorConfig.hasSorter(configPolicy.getEntitySorterManner(), entityDescriptor)) {
return deduceEntitySortManner(configPolicy, entityDescriptor, entitySelectorConfig);
}

public static <Solution_> EntitySelectorConfig deduceEntitySortManner(HeuristicConfigPolicy<Solution_> configPolicy,
EntityDescriptor<Solution_> entityDescriptor, EntitySelectorConfig entitySelectorConfig) {
if (configPolicy.getEntitySorterManner() != null
&& EntitySelectorConfig.hasSorter(configPolicy.getEntitySorterManner(), entityDescriptor)) {
entitySelectorConfig = entitySelectorConfig.withCacheType(SelectionCacheType.PHASE)
.withSelectionOrder(SelectionOrder.SORTED)
.withSorterManner(configPolicy.getEntitySorterManner());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

import ai.timefold.solver.core.config.constructionheuristic.placer.QueuedEntityPlacerConfig;
import ai.timefold.solver.core.config.heuristic.selector.common.SelectionCacheType;
Expand Down Expand Up @@ -103,11 +102,16 @@ public QueuedEntityPlacer<Solution_> buildEntityPlacer(HeuristicConfigPolicy<Sol
}

public EntitySelectorConfig buildEntitySelectorConfig(HeuristicConfigPolicy<Solution_> configPolicy) {
var entitySelectorConfig =
Objects.requireNonNullElseGet(config.getEntitySelectorConfig(), () -> {
var entityDescriptor = getTheOnlyEntityDescriptor(configPolicy.getSolutionDescriptor());
return getDefaultEntitySelectorConfigForEntity(configPolicy, entityDescriptor);
});
var entitySelectorConfig = config.getEntitySelectorConfig();
if (entitySelectorConfig == null) {
var entityDescriptor = getTheOnlyEntityDescriptor(configPolicy.getSolutionDescriptor());
entitySelectorConfig = getDefaultEntitySelectorConfigForEntity(configPolicy, entityDescriptor);
} else {
// The default phase configuration generates the entity selector config without an updated version of the configuration policy.
// We need to ensure that there are no missing sorting settings.
var entityDescriptor = deduceEntityDescriptor(configPolicy, entitySelectorConfig.getEntityClass());
entitySelectorConfig = deduceEntitySortManner(configPolicy, entityDescriptor, entitySelectorConfig);
}
var cacheType = entitySelectorConfig.getCacheType();
if (cacheType != null && cacheType.compareTo(SelectionCacheType.PHASE) < 0) {
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ protected ChangeMoveSelectorConfig buildChangeMoveSelectorConfig(
EntityDescriptor<Solution_> entityDescriptor = variableDescriptor.getEntityDescriptor();
EntitySelectorConfig changeEntitySelectorConfig = new EntitySelectorConfig()
.withEntityClass(entityDescriptor.getEntityClass());
if (EntitySelectorConfig.hasSorter(configPolicy.getEntitySorterManner(), entityDescriptor)) {
if (configPolicy.getEntitySorterManner() != null
&& EntitySelectorConfig.hasSorter(configPolicy.getEntitySorterManner(), entityDescriptor)) {
changeEntitySelectorConfig = changeEntitySelectorConfig.withCacheType(SelectionCacheType.PHASE)
.withSelectionOrder(SelectionOrder.SORTED)
.withSorterManner(configPolicy.getEntitySorterManner());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

import ai.timefold.solver.core.api.score.buildin.simple.SimpleScore;
import ai.timefold.solver.core.config.constructionheuristic.placer.QueuedEntityPlacerConfig;
import ai.timefold.solver.core.config.heuristic.selector.common.SelectionOrder;
import ai.timefold.solver.core.config.heuristic.selector.entity.EntitySorterManner;
import ai.timefold.solver.core.config.heuristic.selector.move.generic.ChangeMoveSelectorConfig;
import ai.timefold.solver.core.config.heuristic.selector.value.ValueSelectorConfig;
import ai.timefold.solver.core.impl.constructionheuristic.placer.Placement;
Expand All @@ -23,6 +26,7 @@
import ai.timefold.solver.core.impl.score.director.InnerScoreDirector;
import ai.timefold.solver.core.impl.solver.scope.SolverScope;
import ai.timefold.solver.core.impl.testdata.domain.TestdataValue;
import ai.timefold.solver.core.impl.testdata.domain.difficultyweight.TestdataDifficultyWeightSolution;
import ai.timefold.solver.core.impl.testdata.domain.multivar.TestdataMultiVarEntity;
import ai.timefold.solver.core.impl.testdata.domain.multivar.TestdataMultiVarSolution;

Expand Down Expand Up @@ -72,6 +76,21 @@ void buildFromUnfoldNew() {
assertEntityPlacement(placement, "e1", "e1v1", "e1v2", "e2v1", "e2v2");
}

@Test
void buildWithEntitySortManner() {
ChangeMoveSelectorConfig primaryMoveSelectorConfig = new ChangeMoveSelectorConfig()
.withValueSelectorConfig(new ValueSelectorConfig("primaryValue"));
var configPolicy = buildHeuristicConfigPolicy(TestdataDifficultyWeightSolution.buildSolutionDescriptor(),
EntitySorterManner.DECREASING_DIFFICULTY_IF_AVAILABLE);
QueuedEntityPlacerConfig placerConfig =
QueuedEntityPlacerFactory.unfoldNew(configPolicy, List.of(primaryMoveSelectorConfig));
var entityPlacer =
new QueuedEntityPlacerFactory<TestdataDifficultyWeightSolution>(placerConfig);
var entitySelectorConfig = entityPlacer.buildEntitySelectorConfig(configPolicy);
assertThat(entitySelectorConfig.getSelectionOrder()).isEqualTo(SelectionOrder.SORTED);
assertThat(entitySelectorConfig.getSorterManner()).isEqualTo(EntitySorterManner.DECREASING_DIFFICULTY_IF_AVAILABLE);
}

private TestdataMultiVarSolution generateTestdataSolution() {
TestdataMultiVarEntity entity1 = new TestdataMultiVarEntity("e1");
entity1.setPrimaryValue(new TestdataValue("e1v1"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.Random;

import ai.timefold.solver.core.config.heuristic.selector.entity.EntitySorterManner;
import ai.timefold.solver.core.config.solver.EnvironmentMode;
import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor;
import ai.timefold.solver.core.impl.solver.ClassInstanceCache;
Expand All @@ -15,11 +16,18 @@ public static HeuristicConfigPolicy<TestdataSolution> buildHeuristicConfigPolicy

public static <Solution_> HeuristicConfigPolicy<Solution_>
buildHeuristicConfigPolicy(SolutionDescriptor<Solution_> solutionDescriptor) {
return buildHeuristicConfigPolicy(solutionDescriptor, null);
}

public static <Solution_> HeuristicConfigPolicy<Solution_>
buildHeuristicConfigPolicy(SolutionDescriptor<Solution_> solutionDescriptor,
EntitySorterManner entitySorterManner) {
return new HeuristicConfigPolicy.Builder<Solution_>()
.withEnvironmentMode(EnvironmentMode.REPRODUCIBLE)
.withRandom(new Random())
.withSolutionDescriptor(solutionDescriptor)
.withClassInstanceCache(ClassInstanceCache.create())
.withEntitySorterManner(entitySorterManner)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package ai.timefold.solver.core.impl.testdata.domain.difficultyweight;

import ai.timefold.solver.core.api.domain.entity.PlanningEntity;
import ai.timefold.solver.core.api.domain.variable.PlanningVariable;
import ai.timefold.solver.core.impl.domain.entity.descriptor.EntityDescriptor;
import ai.timefold.solver.core.impl.domain.variable.descriptor.GenuineVariableDescriptor;
import ai.timefold.solver.core.impl.testdata.domain.TestdataObject;

@PlanningEntity(difficultyWeightFactoryClass = TestdataDifficultyWeightFactory.class)
public class TestdataDifficultyWeightEntity extends TestdataObject {

public static EntityDescriptor<TestdataDifficultyWeightSolution> buildEntityDescriptor() {
return TestdataDifficultyWeightSolution.buildSolutionDescriptor()
.findEntityDescriptorOrFail(TestdataDifficultyWeightEntity.class);
}

public static GenuineVariableDescriptor<TestdataDifficultyWeightSolution> buildVariableDescriptorForValue() {
return buildEntityDescriptor().getGenuineVariableDescriptor("value");
}

private TestdataDifficultyWeightValue value;

public TestdataDifficultyWeightEntity(String code) {
super(code);
}

public TestdataDifficultyWeightEntity(String code, TestdataDifficultyWeightValue value) {
this(code);
this.value = value;
}

@PlanningVariable(valueRangeProviderRefs = "valueRange")
public TestdataDifficultyWeightValue getValue() {
return value;
}

public void setValue(TestdataDifficultyWeightValue value) {
this.value = value;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package ai.timefold.solver.core.impl.testdata.domain.difficultyweight;

import ai.timefold.solver.core.impl.heuristic.selector.common.decorator.SelectionSorterWeightFactory;

public class TestdataDifficultyWeightFactory implements
SelectionSorterWeightFactory<TestdataDifficultyWeightSolution, TestdataDifficultyWeightEntity> {

@Override
public TestdataDifficultyWeightComparable createSorterWeight(TestdataDifficultyWeightSolution solution,
TestdataDifficultyWeightEntity entity) {
return new TestdataDifficultyWeightComparable();
}

public static class TestdataDifficultyWeightComparable implements Comparable<TestdataDifficultyWeightComparable> {

@Override
public int compareTo(TestdataDifficultyWeightComparable other) {
return 0;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package ai.timefold.solver.core.impl.testdata.domain.difficultyweight;

import java.util.ArrayList;
import java.util.List;

import ai.timefold.solver.core.api.domain.solution.PlanningEntityCollectionProperty;
import ai.timefold.solver.core.api.domain.solution.PlanningScore;
import ai.timefold.solver.core.api.domain.solution.PlanningSolution;
import ai.timefold.solver.core.api.domain.solution.ProblemFactCollectionProperty;
import ai.timefold.solver.core.api.domain.valuerange.ValueRangeProvider;
import ai.timefold.solver.core.api.score.buildin.simple.SimpleScore;
import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor;
import ai.timefold.solver.core.impl.testdata.domain.TestdataObject;

@PlanningSolution
public class TestdataDifficultyWeightSolution extends TestdataObject {

public static SolutionDescriptor<TestdataDifficultyWeightSolution> buildSolutionDescriptor() {
return SolutionDescriptor.buildSolutionDescriptor(TestdataDifficultyWeightSolution.class,
TestdataDifficultyWeightEntity.class);
}

public static TestdataDifficultyWeightSolution generateSolution() {
return generateSolution(5, 7);
}

public static TestdataDifficultyWeightSolution generateSolution(int valueListSize, int entityListSize) {
TestdataDifficultyWeightSolution solution = new TestdataDifficultyWeightSolution("Generated Solution 0");
List<TestdataDifficultyWeightValue> valueList = new ArrayList<>(valueListSize);
for (int i = 0; i < valueListSize; i++) {
TestdataDifficultyWeightValue value = new TestdataDifficultyWeightValue("Generated Value " + i);
valueList.add(value);
}
solution.setValueList(valueList);
List<TestdataDifficultyWeightEntity> entityList = new ArrayList<>(entityListSize);
for (int i = 0; i < entityListSize; i++) {
TestdataDifficultyWeightValue value = valueList.get(i % valueListSize);
TestdataDifficultyWeightEntity entity = new TestdataDifficultyWeightEntity("Generated Entity " + i, value);
entityList.add(entity);
}
solution.setEntityList(entityList);
return solution;
}

private List<TestdataDifficultyWeightValue> valueList;
private List<TestdataDifficultyWeightEntity> entityList;

private SimpleScore score;

public TestdataDifficultyWeightSolution(String code) {
super(code);
}

@ValueRangeProvider(id = "valueRange")
@ProblemFactCollectionProperty
public List<TestdataDifficultyWeightValue> getValueList() {
return valueList;
}

public void setValueList(List<TestdataDifficultyWeightValue> valueList) {
this.valueList = valueList;
}

@PlanningEntityCollectionProperty
public List<TestdataDifficultyWeightEntity> getEntityList() {
return entityList;
}

public void setEntityList(List<TestdataDifficultyWeightEntity> entityList) {
this.entityList = entityList;
}

@PlanningScore
public SimpleScore getScore() {
return score;
}

public void setScore(SimpleScore score) {
this.score = score;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package ai.timefold.solver.core.impl.testdata.domain.difficultyweight;

import ai.timefold.solver.core.impl.testdata.domain.TestdataObject;

public class TestdataDifficultyWeightValue extends TestdataObject {

public TestdataDifficultyWeightValue(String code) {
super(code);
}

}

0 comments on commit 3a1ef62

Please sign in to comment.