From 3a1ef628b07aac12232a961abf8a0f8af53c8068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Frederico=20Gon=C3=A7alves?= Date: Fri, 6 Dec 2024 15:44:14 -0300 Subject: [PATCH] fix: ensure the sort manner is applied to the default CH config (#1258) This pull request corrects a bug that occurs when an entity selector configuration is created without using the sorting settings, particularly when the related entity has a difficulty weight factory defined `@PlanningEntity(difficultyWeightFactoryClass...`. The bug only affects solver configurations that do not define a phase list, as the solver sets the entity selector configuration before inner operations update the heuristic policy configuration. The issue does not affect the value selector config because no config is created by default (`DefaultSolverFactory#buildPhaseList`). --- .../core/impl/AbstractFromConfigFactory.java | 8 +- .../placer/QueuedEntityPlacerFactory.java | 16 ++-- .../placer/QueuedValuePlacerFactory.java | 3 +- .../entity/QueuedEntityPlacerFactoryTest.java | 19 +++++ .../HeuristicConfigPolicyTestUtils.java | 8 ++ .../TestdataDifficultyWeightEntity.java | 40 +++++++++ .../TestdataDifficultyWeightFactory.java | 21 +++++ .../TestdataDifficultyWeightSolution.java | 82 +++++++++++++++++++ .../TestdataDifficultyWeightValue.java | 11 +++ 9 files changed, 200 insertions(+), 8 deletions(-) create mode 100644 core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightEntity.java create mode 100644 core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightFactory.java create mode 100644 core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightSolution.java create mode 100644 core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightValue.java diff --git a/core/src/main/java/ai/timefold/solver/core/impl/AbstractFromConfigFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/AbstractFromConfigFactory.java index 071b09a6a6..73c5e5a9ac 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/AbstractFromConfigFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/AbstractFromConfigFactory.java @@ -29,7 +29,13 @@ public static EntitySelectorConfig getDefaultEntitySelectorConfigFor EntitySelectorConfig entitySelectorConfig = new EntitySelectorConfig() .withId(entityClass.getName()) .withEntityClass(entityClass); - if (EntitySelectorConfig.hasSorter(configPolicy.getEntitySorterManner(), entityDescriptor)) { + return deduceEntitySortManner(configPolicy, entityDescriptor, entitySelectorConfig); + } + + public static EntitySelectorConfig deduceEntitySortManner(HeuristicConfigPolicy configPolicy, + EntityDescriptor entityDescriptor, EntitySelectorConfig entitySelectorConfig) { + if (configPolicy.getEntitySorterManner() != null + && EntitySelectorConfig.hasSorter(configPolicy.getEntitySorterManner(), entityDescriptor)) { entitySelectorConfig = entitySelectorConfig.withCacheType(SelectionCacheType.PHASE) .withSelectionOrder(SelectionOrder.SORTED) .withSorterManner(configPolicy.getEntitySorterManner()); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/constructionheuristic/placer/QueuedEntityPlacerFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/constructionheuristic/placer/QueuedEntityPlacerFactory.java index e8cec325dc..a33271f874 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/constructionheuristic/placer/QueuedEntityPlacerFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/constructionheuristic/placer/QueuedEntityPlacerFactory.java @@ -3,7 +3,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Objects; import ai.timefold.solver.core.config.constructionheuristic.placer.QueuedEntityPlacerConfig; import ai.timefold.solver.core.config.heuristic.selector.common.SelectionCacheType; @@ -103,11 +102,16 @@ public QueuedEntityPlacer buildEntityPlacer(HeuristicConfigPolicy configPolicy) { - var entitySelectorConfig = - Objects.requireNonNullElseGet(config.getEntitySelectorConfig(), () -> { - var entityDescriptor = getTheOnlyEntityDescriptor(configPolicy.getSolutionDescriptor()); - return getDefaultEntitySelectorConfigForEntity(configPolicy, entityDescriptor); - }); + var entitySelectorConfig = config.getEntitySelectorConfig(); + if (entitySelectorConfig == null) { + var entityDescriptor = getTheOnlyEntityDescriptor(configPolicy.getSolutionDescriptor()); + entitySelectorConfig = getDefaultEntitySelectorConfigForEntity(configPolicy, entityDescriptor); + } else { + // The default phase configuration generates the entity selector config without an updated version of the configuration policy. + // We need to ensure that there are no missing sorting settings. + var entityDescriptor = deduceEntityDescriptor(configPolicy, entitySelectorConfig.getEntityClass()); + entitySelectorConfig = deduceEntitySortManner(configPolicy, entityDescriptor, entitySelectorConfig); + } var cacheType = entitySelectorConfig.getCacheType(); if (cacheType != null && cacheType.compareTo(SelectionCacheType.PHASE) < 0) { throw new IllegalArgumentException( diff --git a/core/src/main/java/ai/timefold/solver/core/impl/constructionheuristic/placer/QueuedValuePlacerFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/constructionheuristic/placer/QueuedValuePlacerFactory.java index be3ddcc051..64777cf8eb 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/constructionheuristic/placer/QueuedValuePlacerFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/constructionheuristic/placer/QueuedValuePlacerFactory.java @@ -91,7 +91,8 @@ protected ChangeMoveSelectorConfig buildChangeMoveSelectorConfig( EntityDescriptor entityDescriptor = variableDescriptor.getEntityDescriptor(); EntitySelectorConfig changeEntitySelectorConfig = new EntitySelectorConfig() .withEntityClass(entityDescriptor.getEntityClass()); - if (EntitySelectorConfig.hasSorter(configPolicy.getEntitySorterManner(), entityDescriptor)) { + if (configPolicy.getEntitySorterManner() != null + && EntitySelectorConfig.hasSorter(configPolicy.getEntitySorterManner(), entityDescriptor)) { changeEntitySelectorConfig = changeEntitySelectorConfig.withCacheType(SelectionCacheType.PHASE) .withSelectionOrder(SelectionOrder.SORTED) .withSorterManner(configPolicy.getEntitySorterManner()); diff --git a/core/src/test/java/ai/timefold/solver/core/impl/constructionheuristic/placer/entity/QueuedEntityPlacerFactoryTest.java b/core/src/test/java/ai/timefold/solver/core/impl/constructionheuristic/placer/entity/QueuedEntityPlacerFactoryTest.java index a37f85efa4..918f29af5d 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/constructionheuristic/placer/entity/QueuedEntityPlacerFactoryTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/constructionheuristic/placer/entity/QueuedEntityPlacerFactoryTest.java @@ -8,9 +8,12 @@ import java.util.Arrays; import java.util.Iterator; +import java.util.List; import ai.timefold.solver.core.api.score.buildin.simple.SimpleScore; import ai.timefold.solver.core.config.constructionheuristic.placer.QueuedEntityPlacerConfig; +import ai.timefold.solver.core.config.heuristic.selector.common.SelectionOrder; +import ai.timefold.solver.core.config.heuristic.selector.entity.EntitySorterManner; import ai.timefold.solver.core.config.heuristic.selector.move.generic.ChangeMoveSelectorConfig; import ai.timefold.solver.core.config.heuristic.selector.value.ValueSelectorConfig; import ai.timefold.solver.core.impl.constructionheuristic.placer.Placement; @@ -23,6 +26,7 @@ import ai.timefold.solver.core.impl.score.director.InnerScoreDirector; import ai.timefold.solver.core.impl.solver.scope.SolverScope; import ai.timefold.solver.core.impl.testdata.domain.TestdataValue; +import ai.timefold.solver.core.impl.testdata.domain.difficultyweight.TestdataDifficultyWeightSolution; import ai.timefold.solver.core.impl.testdata.domain.multivar.TestdataMultiVarEntity; import ai.timefold.solver.core.impl.testdata.domain.multivar.TestdataMultiVarSolution; @@ -72,6 +76,21 @@ void buildFromUnfoldNew() { assertEntityPlacement(placement, "e1", "e1v1", "e1v2", "e2v1", "e2v2"); } + @Test + void buildWithEntitySortManner() { + ChangeMoveSelectorConfig primaryMoveSelectorConfig = new ChangeMoveSelectorConfig() + .withValueSelectorConfig(new ValueSelectorConfig("primaryValue")); + var configPolicy = buildHeuristicConfigPolicy(TestdataDifficultyWeightSolution.buildSolutionDescriptor(), + EntitySorterManner.DECREASING_DIFFICULTY_IF_AVAILABLE); + QueuedEntityPlacerConfig placerConfig = + QueuedEntityPlacerFactory.unfoldNew(configPolicy, List.of(primaryMoveSelectorConfig)); + var entityPlacer = + new QueuedEntityPlacerFactory(placerConfig); + var entitySelectorConfig = entityPlacer.buildEntitySelectorConfig(configPolicy); + assertThat(entitySelectorConfig.getSelectionOrder()).isEqualTo(SelectionOrder.SORTED); + assertThat(entitySelectorConfig.getSorterManner()).isEqualTo(EntitySorterManner.DECREASING_DIFFICULTY_IF_AVAILABLE); + } + private TestdataMultiVarSolution generateTestdataSolution() { TestdataMultiVarEntity entity1 = new TestdataMultiVarEntity("e1"); entity1.setPrimaryValue(new TestdataValue("e1v1")); diff --git a/core/src/test/java/ai/timefold/solver/core/impl/heuristic/HeuristicConfigPolicyTestUtils.java b/core/src/test/java/ai/timefold/solver/core/impl/heuristic/HeuristicConfigPolicyTestUtils.java index e43e367d5b..8f2d0b162a 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/heuristic/HeuristicConfigPolicyTestUtils.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/heuristic/HeuristicConfigPolicyTestUtils.java @@ -2,6 +2,7 @@ import java.util.Random; +import ai.timefold.solver.core.config.heuristic.selector.entity.EntitySorterManner; import ai.timefold.solver.core.config.solver.EnvironmentMode; import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor; import ai.timefold.solver.core.impl.solver.ClassInstanceCache; @@ -15,11 +16,18 @@ public static HeuristicConfigPolicy buildHeuristicConfigPolicy public static HeuristicConfigPolicy buildHeuristicConfigPolicy(SolutionDescriptor solutionDescriptor) { + return buildHeuristicConfigPolicy(solutionDescriptor, null); + } + + public static HeuristicConfigPolicy + buildHeuristicConfigPolicy(SolutionDescriptor solutionDescriptor, + EntitySorterManner entitySorterManner) { return new HeuristicConfigPolicy.Builder() .withEnvironmentMode(EnvironmentMode.REPRODUCIBLE) .withRandom(new Random()) .withSolutionDescriptor(solutionDescriptor) .withClassInstanceCache(ClassInstanceCache.create()) + .withEntitySorterManner(entitySorterManner) .build(); } diff --git a/core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightEntity.java b/core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightEntity.java new file mode 100644 index 0000000000..6bbd7f913c --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightEntity.java @@ -0,0 +1,40 @@ +package ai.timefold.solver.core.impl.testdata.domain.difficultyweight; + +import ai.timefold.solver.core.api.domain.entity.PlanningEntity; +import ai.timefold.solver.core.api.domain.variable.PlanningVariable; +import ai.timefold.solver.core.impl.domain.entity.descriptor.EntityDescriptor; +import ai.timefold.solver.core.impl.domain.variable.descriptor.GenuineVariableDescriptor; +import ai.timefold.solver.core.impl.testdata.domain.TestdataObject; + +@PlanningEntity(difficultyWeightFactoryClass = TestdataDifficultyWeightFactory.class) +public class TestdataDifficultyWeightEntity extends TestdataObject { + + public static EntityDescriptor buildEntityDescriptor() { + return TestdataDifficultyWeightSolution.buildSolutionDescriptor() + .findEntityDescriptorOrFail(TestdataDifficultyWeightEntity.class); + } + + public static GenuineVariableDescriptor buildVariableDescriptorForValue() { + return buildEntityDescriptor().getGenuineVariableDescriptor("value"); + } + + private TestdataDifficultyWeightValue value; + + public TestdataDifficultyWeightEntity(String code) { + super(code); + } + + public TestdataDifficultyWeightEntity(String code, TestdataDifficultyWeightValue value) { + this(code); + this.value = value; + } + + @PlanningVariable(valueRangeProviderRefs = "valueRange") + public TestdataDifficultyWeightValue getValue() { + return value; + } + + public void setValue(TestdataDifficultyWeightValue value) { + this.value = value; + } +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightFactory.java b/core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightFactory.java new file mode 100644 index 0000000000..435e85a4c9 --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightFactory.java @@ -0,0 +1,21 @@ +package ai.timefold.solver.core.impl.testdata.domain.difficultyweight; + +import ai.timefold.solver.core.impl.heuristic.selector.common.decorator.SelectionSorterWeightFactory; + +public class TestdataDifficultyWeightFactory implements + SelectionSorterWeightFactory { + + @Override + public TestdataDifficultyWeightComparable createSorterWeight(TestdataDifficultyWeightSolution solution, + TestdataDifficultyWeightEntity entity) { + return new TestdataDifficultyWeightComparable(); + } + + public static class TestdataDifficultyWeightComparable implements Comparable { + + @Override + public int compareTo(TestdataDifficultyWeightComparable other) { + return 0; + } + } +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightSolution.java b/core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightSolution.java new file mode 100644 index 0000000000..5e6e3a5f3f --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightSolution.java @@ -0,0 +1,82 @@ +package ai.timefold.solver.core.impl.testdata.domain.difficultyweight; + +import java.util.ArrayList; +import java.util.List; + +import ai.timefold.solver.core.api.domain.solution.PlanningEntityCollectionProperty; +import ai.timefold.solver.core.api.domain.solution.PlanningScore; +import ai.timefold.solver.core.api.domain.solution.PlanningSolution; +import ai.timefold.solver.core.api.domain.solution.ProblemFactCollectionProperty; +import ai.timefold.solver.core.api.domain.valuerange.ValueRangeProvider; +import ai.timefold.solver.core.api.score.buildin.simple.SimpleScore; +import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor; +import ai.timefold.solver.core.impl.testdata.domain.TestdataObject; + +@PlanningSolution +public class TestdataDifficultyWeightSolution extends TestdataObject { + + public static SolutionDescriptor buildSolutionDescriptor() { + return SolutionDescriptor.buildSolutionDescriptor(TestdataDifficultyWeightSolution.class, + TestdataDifficultyWeightEntity.class); + } + + public static TestdataDifficultyWeightSolution generateSolution() { + return generateSolution(5, 7); + } + + public static TestdataDifficultyWeightSolution generateSolution(int valueListSize, int entityListSize) { + TestdataDifficultyWeightSolution solution = new TestdataDifficultyWeightSolution("Generated Solution 0"); + List valueList = new ArrayList<>(valueListSize); + for (int i = 0; i < valueListSize; i++) { + TestdataDifficultyWeightValue value = new TestdataDifficultyWeightValue("Generated Value " + i); + valueList.add(value); + } + solution.setValueList(valueList); + List entityList = new ArrayList<>(entityListSize); + for (int i = 0; i < entityListSize; i++) { + TestdataDifficultyWeightValue value = valueList.get(i % valueListSize); + TestdataDifficultyWeightEntity entity = new TestdataDifficultyWeightEntity("Generated Entity " + i, value); + entityList.add(entity); + } + solution.setEntityList(entityList); + return solution; + } + + private List valueList; + private List entityList; + + private SimpleScore score; + + public TestdataDifficultyWeightSolution(String code) { + super(code); + } + + @ValueRangeProvider(id = "valueRange") + @ProblemFactCollectionProperty + public List getValueList() { + return valueList; + } + + public void setValueList(List valueList) { + this.valueList = valueList; + } + + @PlanningEntityCollectionProperty + public List getEntityList() { + return entityList; + } + + public void setEntityList(List entityList) { + this.entityList = entityList; + } + + @PlanningScore + public SimpleScore getScore() { + return score; + } + + public void setScore(SimpleScore score) { + this.score = score; + } + +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightValue.java b/core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightValue.java new file mode 100644 index 0000000000..de8556a5bf --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/difficultyweight/TestdataDifficultyWeightValue.java @@ -0,0 +1,11 @@ +package ai.timefold.solver.core.impl.testdata.domain.difficultyweight; + +import ai.timefold.solver.core.impl.testdata.domain.TestdataObject; + +public class TestdataDifficultyWeightValue extends TestdataObject { + + public TestdataDifficultyWeightValue(String code) { + super(code); + } + +}