Skip to content

Commit

Permalink
feat: Add method to SolverConfig for setting unimproved termination…
Browse files Browse the repository at this point in the history
… limit
  • Loading branch information
triceo committed Dec 4, 2024
1 parent 9d7a502 commit 9c6f7b4
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,17 @@ public void setMonitoringConfig(@Nullable MonitoringConfig monitoringConfig) {
return this;
}

/**
* As defined by {@link TerminationConfig#withUnimprovedSpentLimit(Duration)}, but returns this.
*/
public @NonNull SolverConfig withTerminationUnimprovedSpentLimit(@NonNull Duration unimprovedSpentLimit) {
if (terminationConfig == null) {
terminationConfig = new TerminationConfig();
}
terminationConfig.setUnimprovedSpentLimit(unimprovedSpentLimit);
return this;
}

public @NonNull SolverConfig
withNearbyDistanceMeterClass(@NonNull Class<? extends NearbyDistanceMeter<?, ?>> distanceMeterClass) {
this.nearbyDistanceMeterClass = distanceMeterClass;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.io.StringReader;
import java.io.StringWriter;
import java.io.UncheckedIOException;
import java.io.Writer;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
Expand Down Expand Up @@ -55,7 +55,6 @@
import ai.timefold.solver.core.impl.testdata.domain.record.TestdataRecordSolution;

import org.apache.commons.io.IOUtils;
import org.assertj.core.api.Assertions;
import org.jspecify.annotations.NonNull;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
Expand All @@ -71,26 +70,29 @@ class SolverConfigTest {
@ParameterizedTest
@ValueSource(strings = { TEST_SOLVER_CONFIG_WITHOUT_NAMESPACE, TEST_SOLVER_CONFIG_WITH_NAMESPACE })
void xmlConfigRemainsSameAfterReadWrite(String solverConfigResource) throws IOException {
SolverConfig jaxbSolverConfig = readSolverConfig(solverConfigResource);
var jaxbSolverConfig = readSolverConfig(solverConfigResource);

Writer stringWriter = new StringWriter();
var stringWriter = new StringWriter();
solverConfigIO.write(jaxbSolverConfig, stringWriter);
String jaxbString = stringWriter.toString();
var jaxbString = stringWriter.toString();

String originalXml = IOUtils.toString(
SolverConfigTest.class.getResourceAsStream(solverConfigResource), StandardCharsets.UTF_8);
var originalXml =
IOUtils.toString(SolverConfigTest.class.getResourceAsStream(solverConfigResource), StandardCharsets.UTF_8);

// During writing the solver config, the solver element's namespace is removed.
String solverElementWithNamespace = SolverConfig.XML_ELEMENT_NAME + " xmlns=\"" + SolverConfig.XML_NAMESPACE + "\"";
var solverElementWithNamespace = """
%s xmlns="%s"
""".formatted(SolverConfig.XML_ELEMENT_NAME, SolverConfig.XML_NAMESPACE)
.trim();
if (originalXml.contains(solverElementWithNamespace)) {
originalXml = originalXml.replace(solverElementWithNamespace, SolverConfig.XML_ELEMENT_NAME);
}
assertThat(jaxbString).isXmlEqualTo(originalXml);
assertThat(jaxbString).isEqualToIgnoringWhitespace(originalXml);
}

@Test
void readXmlConfigWithNamespace() {
SolverConfig solverConfig = readSolverConfig(TEST_SOLVER_CONFIG_WITH_NAMESPACE);
var solverConfig = readSolverConfig(TEST_SOLVER_CONFIG_WITH_NAMESPACE);

assertThat(solverConfig).isNotNull();
assertThat(solverConfig.getPhaseConfigList())
Expand All @@ -103,7 +105,7 @@ void readXmlConfigWithNamespace() {
}

private SolverConfig readSolverConfig(String solverConfigResource) {
try (Reader reader = new InputStreamReader(SolverConfigTest.class.getResourceAsStream(solverConfigResource))) {
try (var reader = new InputStreamReader(SolverConfigTest.class.getResourceAsStream(solverConfigResource))) {
return solverConfigIO.read(reader);
} catch (IOException ioException) {
throw new UncheckedIOException(ioException);
Expand All @@ -112,29 +114,32 @@ private SolverConfig readSolverConfig(String solverConfigResource) {

@Test
void whiteCharsInClassName() {
String solutionClassName = "ai.timefold.solver.core.impl.testdata.domain.TestdataSolution";
String xmlFragment = String.format("<solver xmlns=\"https://timefold.ai/xsd/solver\">%n"
+ " <solutionClass> %s %n" // Intentionally included white chars around the class name.
+ " </solutionClass>%n"
+ "</solver>", solutionClassName);
SolverConfig solverConfig = solverConfigIO.read(new StringReader(xmlFragment));
var solutionClassName = "ai.timefold.solver.core.impl.testdata.domain.TestdataSolution";
// Intentionally included white chars around the class name.
var xmlFragment = """
<solver xmlns="https://timefold.ai/xsd/solver">
<solutionClass> %s\s
</solutionClass>
</solver>""".formatted(solutionClassName);
var solverConfig = solverConfigIO.read(new StringReader(xmlFragment));
assertThat(solverConfig.getSolutionClass().getName()).isEqualTo(solutionClassName);
}

@Test
void readAndValidateInvalidSolverConfig_failsIndicatingTheIssue() {
String solverConfigXml = "<solver xmlns=\"https://timefold.ai/xsd/solver\">\n"
+ " <constructionHeuristic>\n"
+ " <changeMoveSelector>\n"
+ " <valueSelector>\n"
// Intentionally wrong: variableName should be an attribute of the <valueSelector/>
+ " <variableName>subValue</variableName>\n"
+ " </valueSelector>\n"
+ " </changeMoveSelector>\n"
+ " </constructionHeuristic>\n"
+ "</solver>";

StringReader stringReader = new StringReader(solverConfigXml);
// Intentionally wrong: variableName should be an attribute of the <valueSelector/>
var solverConfigXml = """
<solver xmlns="https://timefold.ai/xsd/solver">
<constructionHeuristic>
<changeMoveSelector>
<valueSelector>
<variableName>subValue</variableName>
</valueSelector>
</changeMoveSelector>
</constructionHeuristic>
</solver>""";

var stringReader = new StringReader(solverConfigXml);
assertThatExceptionOfType(TimefoldXmlSerializationException.class)
.isThrownBy(() -> solverConfigIO.read(stringReader))
.withRootCauseExactlyInstanceOf(SAXParseException.class)
Expand All @@ -143,7 +148,7 @@ void readAndValidateInvalidSolverConfig_failsIndicatingTheIssue() {

@Test
void withEasyScoreCalculatorClass() {
SolverConfig solverConfig = new SolverConfig();
var solverConfig = new SolverConfig();
assertThat(solverConfig.getScoreDirectorFactoryConfig()).isNull();
solverConfig.withEasyScoreCalculatorClass(DummyEasyScoreCalculator.class);
assertThat(solverConfig.getScoreDirectorFactoryConfig().getEasyScoreCalculatorClass())
Expand All @@ -152,7 +157,7 @@ void withEasyScoreCalculatorClass() {

@Test
void withConstraintProviderClass() {
SolverConfig solverConfig = new SolverConfig();
var solverConfig = new SolverConfig();
assertThat(solverConfig.getScoreDirectorFactoryConfig()).isNull();
solverConfig.withConstraintProviderClass(DummyConstraintProvider.class);
assertThat(solverConfig.getScoreDirectorFactoryConfig().getConstraintProviderClass())
Expand All @@ -161,25 +166,36 @@ void withConstraintProviderClass() {

@Test
void withTerminationSpentLimit() {
SolverConfig solverConfig = new SolverConfig();
var solverConfig = new SolverConfig();
var duration = Duration.ofMinutes(2);
assertThat(solverConfig.getTerminationConfig()).isNull();
solverConfig.withTerminationSpentLimit(Duration.ofMinutes(2));
solverConfig.withTerminationSpentLimit(duration);
assertThat(solverConfig.getTerminationConfig().getSpentLimit())
.isEqualTo(Duration.ofMinutes(2));
.isEqualTo(duration);
}

@Test
void withTerminationUnimprovedSpentLimit() {
var solverConfig = new SolverConfig();
var duration = Duration.ofMinutes(2);
assertThat(solverConfig.getTerminationConfig()).isNull();
solverConfig.withTerminationUnimprovedSpentLimit(duration);
assertThat(solverConfig.getTerminationConfig().getUnimprovedSpentLimit())
.isEqualTo(duration);
}

@Test
void inherit() {
SolverConfig originalSolverConfig = readSolverConfig(TEST_SOLVER_CONFIG_WITHOUT_NAMESPACE);
SolverConfig inheritedSolverConfig =
var originalSolverConfig = readSolverConfig(TEST_SOLVER_CONFIG_WITHOUT_NAMESPACE);
var inheritedSolverConfig =
new SolverConfig().inherit(originalSolverConfig);
assertThat(inheritedSolverConfig).usingRecursiveComparison().isEqualTo(originalSolverConfig);
}

@Test
void visitReferencedClasses() {
SolverConfig solverConfig = readSolverConfig(TEST_SOLVER_CONFIG_WITHOUT_NAMESPACE);
Consumer<Class<?>> classVisitor = mock(Consumer.class);
var solverConfig = readSolverConfig(TEST_SOLVER_CONFIG_WITHOUT_NAMESPACE);
var classVisitor = (Consumer<Class<?>>) mock(Consumer.class);
solverConfig.visitReferencedClasses(classVisitor);
verify(classVisitor, atLeastOnce()).accept(TestdataAnnotatedExtendedSolution.class);
verify(classVisitor, atLeastOnce()).accept(TestdataEntity.class);
Expand All @@ -201,7 +217,7 @@ void solutionIsARecord() {
var solverConfig = new SolverConfig()
.withSolutionClass(DummyRecordSolution.class)
.withEntityClasses(TestdataEntity.class);
Assertions.assertThatThrownBy(() -> SolverFactory.create(solverConfig))
assertThatThrownBy(() -> SolverFactory.create(solverConfig))
.hasMessageContaining(DummyRecordSolution.class.getSimpleName())
.hasMessageContaining("record");
}
Expand All @@ -211,7 +227,7 @@ void entityIsARecord() {
var solverConfig = new SolverConfig()
.withSolutionClass(DummySolutionWithRecordEntity.class)
.withEntityClasses(DummyRecordEntity.class);
Assertions.assertThatThrownBy(() -> SolverFactory.create(solverConfig))
assertThatThrownBy(() -> SolverFactory.create(solverConfig))
.hasMessageContaining(DummyRecordEntity.class.getSimpleName())
.hasMessageContaining("record");
}
Expand All @@ -227,7 +243,7 @@ void variableWithPlanningIdIsARecord() {
.buildSolver();

var solution = TestdataRecordSolution.generateSolution();
Assertions.assertThatNoException().isThrownBy(() -> solver.solve(solution));
assertThatNoException().isThrownBy(() -> solver.solve(solution));
}

@Test
Expand All @@ -241,7 +257,7 @@ void domainClassesAreInterfaces() {
.buildSolver();

var solution = TestdataInterfaceSolution.generateSolution();
Assertions.assertThatNoException().isThrownBy(() -> solver.solve(solution));
assertThatNoException().isThrownBy(() -> solver.solve(solution));
}

@Test
Expand All @@ -250,7 +266,7 @@ void entityWithTwoPlanningListVariables() {
.withSolutionClass(DummySolutionWithTwoListVariablesEntity.class)
.withEntityClasses(DummyEntityWithTwoListVariables.class)
.withEasyScoreCalculatorClass(DummyRecordEasyScoreCalculator.class);
Assertions.assertThatThrownBy(() -> SolverFactory.create(solverConfig))
assertThatThrownBy(() -> SolverFactory.create(solverConfig))
.isExactlyInstanceOf(UnsupportedOperationException.class)
.hasMessageContaining(DummyEntityWithTwoListVariables.class.getSimpleName())
.hasMessageContaining("firstListVariable")
Expand All @@ -263,7 +279,7 @@ void entityWithMixedBasicAndPlanningListVariables() {
.withSolutionClass(DummySolutionWithMixedSimpleAndListVariableEntity.class)
.withEntityClasses(DummyEntityWithMixedSimpleAndListVariable.class)
.withEasyScoreCalculatorClass(DummyRecordEasyScoreCalculator.class);
Assertions.assertThatThrownBy(() -> SolverFactory.create(solverConfig))
assertThatThrownBy(() -> SolverFactory.create(solverConfig))
.isExactlyInstanceOf(UnsupportedOperationException.class)
.hasMessageContaining(DummyEntityWithMixedSimpleAndListVariable.class.getSimpleName())
.hasMessageContaining("listVariable")
Expand All @@ -278,7 +294,7 @@ private record DummyRecordSolution(
}

@PlanningSolution
private class DummySolutionWithRecordEntity {
private static class DummySolutionWithRecordEntity {

@PlanningEntityCollectionProperty
List<DummyRecordEntity> entities;
Expand All @@ -294,7 +310,7 @@ private record DummyRecordEntity(
}

@PlanningSolution
private class DummySolutionWithMixedSimpleAndListVariableEntity {
private static class DummySolutionWithMixedSimpleAndListVariableEntity {

@PlanningEntityCollectionProperty
List<DummyEntityWithMixedSimpleAndListVariable> entities;
Expand All @@ -311,7 +327,7 @@ private class DummySolutionWithMixedSimpleAndListVariableEntity {
}

@PlanningEntity
private class DummyEntityWithMixedSimpleAndListVariable {
private static class DummyEntityWithMixedSimpleAndListVariable {

@PlanningListVariable(valueRangeProviderRefs = "listValueRange")
private List<DummyEntityForListVariable> listVariable;
Expand All @@ -322,7 +338,7 @@ private class DummyEntityWithMixedSimpleAndListVariable {
}

@PlanningSolution
private class DummySolutionWithTwoListVariablesEntity {
private static class DummySolutionWithTwoListVariablesEntity {

@PlanningEntityCollectionProperty
List<DummyEntityWithTwoListVariables> entities;
Expand All @@ -339,7 +355,7 @@ private class DummySolutionWithTwoListVariablesEntity {
}

@PlanningEntity
private class DummyEntityWithTwoListVariables {
private static class DummyEntityWithTwoListVariables {

@PlanningListVariable(valueRangeProviderRefs = "firstListValueRange")
private List<DummyEntityForListVariable> firstListVariable;
Expand All @@ -350,7 +366,7 @@ private class DummyEntityWithTwoListVariables {
}

@PlanningEntity
private class DummyEntityForListVariable {
private static class DummyEntityForListVariable {

}

Expand Down Expand Up @@ -394,7 +410,7 @@ public abstract static class DummyMoveIteratorFactory implements MoveIteratorFac
public abstract static class DummyMoveListFactory implements MoveListFactory<TestdataSolution> {
}

public class DummyNearbyDistanceClass implements NearbyDistanceMeter<String, String> {
public static class DummyNearbyDistanceClass implements NearbyDistanceMeter<String, String> {

@Override
public double getNearbyDistance(String origin, String destination) {
Expand Down

0 comments on commit 9c6f7b4

Please sign in to comment.