Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add new termination config based on the best solution improvement #1234

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions benchmark/src/main/resources/benchmark.xsd
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,15 @@
<xs:element minOccurs="0" name="moveCountLimit" type="xs:long"/>


<xs:element minOccurs="0" name="stopFlatLineDetectionRatio" type="xs:double"/>


<xs:element minOccurs="0" name="noStopFlatLineDetectionRatio" type="xs:double"/>


<xs:element minOccurs="0" name="delayFlatLineSecondsSpentLimit" type="xs:long"/>


<xs:element maxOccurs="unbounded" minOccurs="0" name="termination" type="tns:terminationConfig"/>


Expand Down
11 changes: 11 additions & 0 deletions core/src/build/revapi-differences.json
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@
"old": "method Score_ ai.timefold.solver.core.api.score.constraint.ConstraintMatch<Score_ extends ai.timefold.solver.core.api.score.Score<Score_>>::getScore()",
"new": "method Score_ ai.timefold.solver.core.api.score.constraint.ConstraintMatch<Score_ extends ai.timefold.solver.core.api.score.Score<Score_>>::getScore()",
"justification": "False positive after addition of @NonNull annotation"
},
{
"ignore": true,
"code": "java.annotation.attributeValueChanged",
"old": "class ai.timefold.solver.core.config.solver.termination.TerminationConfig",
"new": "class ai.timefold.solver.core.config.solver.termination.TerminationConfig",
"annotationType": "jakarta.xml.bind.annotation.XmlType",
"attribute": "propOrder",
"oldValue": "{\"terminationClass\", \"terminationCompositionStyle\", \"spentLimit\", \"millisecondsSpentLimit\", \"secondsSpentLimit\", \"minutesSpentLimit\", \"hoursSpentLimit\", \"daysSpentLimit\", \"unimprovedSpentLimit\", \"unimprovedMillisecondsSpentLimit\", \"unimprovedSecondsSpentLimit\", \"unimprovedMinutesSpentLimit\", \"unimprovedHoursSpentLimit\", \"unimprovedDaysSpentLimit\", \"unimprovedScoreDifferenceThreshold\", \"bestScoreLimit\", \"bestScoreFeasible\", \"stepCountLimit\", \"unimprovedStepCountLimit\", \"scoreCalculationCountLimit\", \"terminationConfigList\"}",
"newValue": "{\"terminationClass\", \"terminationCompositionStyle\", \"spentLimit\", \"millisecondsSpentLimit\", \"secondsSpentLimit\", \"minutesSpentLimit\", \"hoursSpentLimit\", \"daysSpentLimit\", \"unimprovedSpentLimit\", \"unimprovedMillisecondsSpentLimit\", \"unimprovedSecondsSpentLimit\", \"unimprovedMinutesSpentLimit\", \"unimprovedHoursSpentLimit\", \"unimprovedDaysSpentLimit\", \"unimprovedScoreDifferenceThreshold\", \"bestScoreLimit\", \"bestScoreFeasible\", \"stepCountLimit\", \"unimprovedStepCountLimit\", \"scoreCalculationCountLimit\", \"moveCountLimit\", \"stopFlatLineDetectionRatio\", \"noStopFlatLineDetectionRatio\", \"delayFlatLineSecondsSpentLimit\", \"terminationConfigList\"}",
"justification": "Add new termination config"
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
"unimprovedStepCountLimit",
"scoreCalculationCountLimit",
"moveCountLimit",
"stopFlatLineDetectionRatio",
"noStopFlatLineDetectionRatio",
"delayFlatLineSecondsSpentLimit",
"terminationConfigList"
})
public class TerminationConfig extends AbstractConfig<TerminationConfig> {
Expand Down Expand Up @@ -78,6 +81,10 @@ public class TerminationConfig extends AbstractConfig<TerminationConfig> {

private Long moveCountLimit = null;

private Double stopFlatLineDetectionRatio = null;
private Double noStopFlatLineDetectionRatio = null;
private Long delayFlatLineSecondsSpentLimit = null;

@XmlElement(name = "termination")
private List<TerminationConfig> terminationConfigList = null;

Expand Down Expand Up @@ -257,6 +264,30 @@ public void setMoveCountLimit(@Nullable Long moveCountLimit) {
this.moveCountLimit = moveCountLimit;
}

public @Nullable Double getStopFlatLineDetectionRatio() {
return stopFlatLineDetectionRatio;
}

public void setStopFlatLineDetectionRatio(@Nullable Double stopFlatLineDetectionRatio) {
this.stopFlatLineDetectionRatio = stopFlatLineDetectionRatio;
}

public @Nullable Double getNoStopFlatLineDetectionRatio() {
return noStopFlatLineDetectionRatio;
}

public void setNoStopFlatLineDetectionRatio(@Nullable Double noStopFlatLineDetectionRatio) {
this.noStopFlatLineDetectionRatio = noStopFlatLineDetectionRatio;
}

public @Nullable Long getDelayFlatLineSecondsSpentLimit() {
return delayFlatLineSecondsSpentLimit;
}

public void setDelayFlatLineSecondsSpentLimit(@Nullable Long delayFlatLineSecondsSpentLimit) {
this.delayFlatLineSecondsSpentLimit = delayFlatLineSecondsSpentLimit;
}

public @Nullable List<@NonNull TerminationConfig> getTerminationConfigList() {
return terminationConfigList;
}
Expand Down Expand Up @@ -380,6 +411,21 @@ public TerminationConfig withTerminationClass(Class<? extends Termination> termi
return this;
}

public @NonNull TerminationConfig withStopFlatLineDetectionRatio(@NonNull Double stopFlatLineDetectionRatio) {
this.stopFlatLineDetectionRatio = stopFlatLineDetectionRatio;
return this;
}

public @NonNull TerminationConfig withNoStopFlatLineDetectionRatio(@NonNull Double noStopFlatLineDetectionRatio) {
this.noStopFlatLineDetectionRatio = noStopFlatLineDetectionRatio;
return this;
}

public @NonNull TerminationConfig withDelayFlatLineSecondsSpentLimit(@NonNull Long delayFlatLineSecondsSpentLimit) {
this.delayFlatLineSecondsSpentLimit = delayFlatLineSecondsSpentLimit;
return this;
}

public @NonNull TerminationConfig
withTerminationConfigList(@NonNull List<@NonNull TerminationConfig> terminationConfigList) {
this.terminationConfigList = terminationConfigList;
Expand Down Expand Up @@ -489,6 +535,9 @@ public boolean isConfigured() {
unimprovedStepCountLimit != null ||
scoreCalculationCountLimit != null ||
moveCountLimit != null ||
stopFlatLineDetectionRatio != null ||
noStopFlatLineDetectionRatio != null ||
delayFlatLineSecondsSpentLimit != null ||
isTerminationListConfigured();
}

Expand Down Expand Up @@ -529,6 +578,12 @@ private boolean isTerminationListConfigured() {
inheritedConfig.getScoreCalculationCountLimit());
moveCountLimit = ConfigUtils.inheritOverwritableProperty(moveCountLimit,
inheritedConfig.getMoveCountLimit());
stopFlatLineDetectionRatio = ConfigUtils.inheritOverwritableProperty(stopFlatLineDetectionRatio,
inheritedConfig.getStopFlatLineDetectionRatio());
noStopFlatLineDetectionRatio = ConfigUtils.inheritOverwritableProperty(noStopFlatLineDetectionRatio,
inheritedConfig.getNoStopFlatLineDetectionRatio());
delayFlatLineSecondsSpentLimit = ConfigUtils.inheritOverwritableProperty(delayFlatLineSecondsSpentLimit,
inheritedConfig.getDelayFlatLineSecondsSpentLimit());
terminationConfigList = ConfigUtils.inheritMergeableListConfig(
terminationConfigList, inheritedConfig.getTerminationConfigList());
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public abstract sealed class AbstractTermination<Solution_>
implements Termination<Solution_>
permits AbstractCompositeTermination, BasicPlumbingTermination, BestScoreFeasibleTermination, BestScoreTermination,
ChildThreadPlumbingTermination, MoveCountTermination, PhaseToSolverTerminationBridge, ScoreCalculationCountTermination,
StepCountTermination, TimeMillisSpentTermination, UnimprovedStepCountTermination,
StepCountTermination, TimeMillisSpentTermination, UnimprovedBestSolutionTermination, UnimprovedStepCountTermination,
UnimprovedTimeMillisSpentScoreDifferenceThresholdTermination, UnimprovedTimeMillisSpentTermination {

protected final transient Logger logger = LoggerFactory.getLogger(getClass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ The termination with bestScoreFeasible (%s) can only be used with a score type \
if (terminationConfig.getMoveCountLimit() != null) {
terminationList.add(new MoveCountTermination<>(terminationConfig.getMoveCountLimit()));
}
if (terminationConfig.getStopFlatLineDetectionRatio() != null
|| terminationConfig.getNoStopFlatLineDetectionRatio() != null
|| terminationConfig.getDelayFlatLineSecondsSpentLimit() != null) {
terminationList.add(new UnimprovedBestSolutionTermination<>(terminationConfig.getStopFlatLineDetectionRatio(),
terminationConfig.getNoStopFlatLineDetectionRatio(),
terminationConfig.getDelayFlatLineSecondsSpentLimit()));
}
terminationList.addAll(buildInnerTermination(configPolicy));
return buildTerminationFromList(terminationList);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
package ai.timefold.solver.core.impl.solver.termination;

import java.time.Clock;
import java.util.Objects;

import ai.timefold.solver.core.api.score.Score;
import ai.timefold.solver.core.impl.phase.scope.AbstractPhaseScope;
import ai.timefold.solver.core.impl.phase.scope.AbstractStepScope;
import ai.timefold.solver.core.impl.solver.scope.SolverScope;
import ai.timefold.solver.core.impl.solver.thread.ChildThreadType;

public final class UnimprovedBestSolutionTermination<Solution_> extends AbstractTermination<Solution_> {

// Evaluation delay to avoid early conclusions
private final long delayExecutionTimeMillis;
// This setting determines the amount of time
// that is allowed without any improvements since the last best solution was identified.
// For example, if the last solution was found at 10 seconds and the setting is configured to 0.5,
// the solver will stop if no improvement is made within 5 seconds.
private final double stopFlatLineDetectionRatio;
// This criterion functions similarly to the stopFlatLineDetectionRatio,
// as it is also used to identify periods without improvement.
// However, the key difference is that it focuses on detecting "flat lines" between solution improvements.
// When a flat line is detected after the solution has improved,
// it indicates that the previous duration was not enough to terminate the process,
// but it indicates that the solver will begin
// re-evaluating the termination criterion from the last improvement before the recent improvement.
private final double noStopFlatLineDetectionRatio;
private final Clock clock;
// The field stores the time of the first best solution of the current curve.
// If a solving process involves multiple curves,
// the value is tied to the growth of the last curve analyzed.
protected long initialCurvePointMillis;
protected long lastImprovementMillis;
private Score<?> previousBest;
protected Score<?> currentBest;
protected boolean waitForFirstBestScore;
protected Boolean terminate;

public UnimprovedBestSolutionTermination(Double stopFlatLineDetectionRatio,
Double noStopFlatLineDetectionRatio, Long delayFlatLineSecondsSpentLimit) {
this(stopFlatLineDetectionRatio, noStopFlatLineDetectionRatio, delayFlatLineSecondsSpentLimit, Clock.systemUTC());
}

public UnimprovedBestSolutionTermination(Double stopFlatLineDetectionRatio, Double noStopFlatLineDetectionRatio,
Long delayFlatLineSecondsSpentLimit, Clock clock) {
this.stopFlatLineDetectionRatio = Objects.requireNonNull(stopFlatLineDetectionRatio,
"The field stopFlatLineDetectionRatio is required for the termination UnimprovedBestSolutionTermination");
this.noStopFlatLineDetectionRatio = Objects.requireNonNull(noStopFlatLineDetectionRatio,
"The field noStopFlatLineDetectionRatio is required for the termination UnimprovedBestSolutionTermination");
this.delayExecutionTimeMillis =
(Objects.requireNonNull(delayFlatLineSecondsSpentLimit,
"The field delayFlatLineSecondsSpentLimit is required for the termination UnimprovedBestSolutionTermination")
* 1000L);
this.clock = Objects.requireNonNull(clock);
if (stopFlatLineDetectionRatio < 0) {
throw new IllegalArgumentException(
"The stopFlatLineDetectionRatio (%.2f) cannot be negative.".formatted(stopFlatLineDetectionRatio));
}
if (noStopFlatLineDetectionRatio < 0) {
throw new IllegalArgumentException(
"The noStopFlatLineDetectionRatio (%.2f) cannot be negative.".formatted(noStopFlatLineDetectionRatio));
}
if (noStopFlatLineDetectionRatio > stopFlatLineDetectionRatio) {
throw new IllegalArgumentException(
"The noStopFlatLineDetectionRatio (%.2f) cannot be greater than stopFlatLineDetectionRatio (%.2f)."
.formatted(noStopFlatLineDetectionRatio, stopFlatLineDetectionRatio));
}
if (delayFlatLineSecondsSpentLimit < 0) {
throw new IllegalArgumentException(
"The delayFlatLineSecondsSpentLimit (%d) cannot be negative.".formatted(delayFlatLineSecondsSpentLimit));
}
}

public long getDelayExecutionTimeMillis() {
return delayExecutionTimeMillis;
}

public double getStopFlatLineDetectionRatio() {
return stopFlatLineDetectionRatio;
}

public double getNoStopFlatLineDetectionRatio() {
return noStopFlatLineDetectionRatio;
}

// ************************************************************************
// Lifecycle methods
// ************************************************************************

@Override
@SuppressWarnings("unchecked")
public void phaseStarted(AbstractPhaseScope<Solution_> phaseScope) {
super.phaseStarted(phaseScope);
initialCurvePointMillis = clock.millis();
lastImprovementMillis = 0L;
currentBest = phaseScope.getBestScore();
previousBest = currentBest;
waitForFirstBestScore = true;
terminate = null;
}

@Override
public void stepStarted(AbstractStepScope<Solution_> stepScope) {
super.stepStarted(stepScope);
terminate = null;
}

@Override
@SuppressWarnings({ "unchecked", "rawtypes" })
public void stepEnded(AbstractStepScope<Solution_> stepScope) {
super.stepEnded(stepScope);
if (waitForFirstBestScore) {
waitForFirstBestScore = ((Score) currentBest).compareTo(stepScope.getScore()) >= 0;
}
}

// ************************************************************************
// Terminated methods
// ************************************************************************

@Override
public boolean isSolverTerminated(SolverScope<Solution_> solverScope) {
throw new UnsupportedOperationException(
"%s can only be used for phase termination.".formatted(getClass().getSimpleName()));
}

@Override
public boolean isPhaseTerminated(AbstractPhaseScope<Solution_> phaseScope) {
if (terminate != null) {
return terminate;
}
// Validate if there is a first best solution
if (waitForFirstBestScore) {
return false;
}
var currentTimeMillis = clock.millis();
var improved = currentBest.compareTo(phaseScope.getBestScore()) < 0;
var lastImprovementInterval = lastImprovementMillis - initialCurvePointMillis;
var completeInterval = currentTimeMillis - initialCurvePointMillis;
var newInterval = currentTimeMillis - lastImprovementMillis;
if (improved) {
// If there is a flat line between the last and new best solutions,
// the initial value becomes the most recent best score,
// as it would be the starting point for the new curve.
var minInterval = Math.floor(lastImprovementInterval * noStopFlatLineDetectionRatio);
var maxInterval = Math.floor(lastImprovementInterval * stopFlatLineDetectionRatio);
if (lastImprovementMillis > 0 && completeInterval >= delayExecutionTimeMillis && newInterval >= minInterval
&& newInterval < maxInterval) {
initialCurvePointMillis = lastImprovementMillis;
previousBest = currentBest;
if (logger.isInfoEnabled()) {
logger.debug("Starting a new curve with ({}), time interval ({}s)",
previousBest,
String.format("%.2f", completeInterval / 1000.0));
}
}
lastImprovementMillis = currentTimeMillis;
currentBest = phaseScope.getBestScore();
terminate = null;
return false;
} else {
if (completeInterval < delayExecutionTimeMillis) {
return false;
}
var maxInterval = Math.floor(lastImprovementInterval * stopFlatLineDetectionRatio);
if (newInterval > maxInterval) {
terminate = true;
return true;
} else {
terminate = null;
return false;
}
}
}

// ************************************************************************
// Time gradient methods
// ************************************************************************

@Override
public double calculateSolverTimeGradient(SolverScope<Solution_> solverScope) {
throw new UnsupportedOperationException(
"%s can only be used for phase termination.".formatted(getClass().getSimpleName()));
}

@Override
public double calculatePhaseTimeGradient(AbstractPhaseScope<Solution_> phaseScope) {
// The value will change during the solving process.
// Therefore, it is not possible to provide a number asymptotically incrementally
return -1.0;
}

// ************************************************************************
// Other methods
// ************************************************************************

@Override
public UnimprovedBestSolutionTermination<Solution_> createChildThreadTermination(SolverScope<Solution_> solverScope,
ChildThreadType childThreadType) {
return new UnimprovedBestSolutionTermination<>(stopFlatLineDetectionRatio, noStopFlatLineDetectionRatio,
delayExecutionTimeMillis / 1000, clock);
}

@Override
public String toString() {
return "UnimprovedBestSolutionTermination(%.2f, %.2f, %d)".formatted(stopFlatLineDetectionRatio,
noStopFlatLineDetectionRatio, delayExecutionTimeMillis / 1000);
}
}
6 changes: 6 additions & 0 deletions core/src/main/resources/solver.xsd
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@

<xs:element minOccurs="0" name="moveCountLimit" type="xs:long"/>

<xs:element minOccurs="0" name="stopFlatLineDetectionRatio" type="xs:double"/>

<xs:element minOccurs="0" name="noStopFlatLineDetectionRatio" type="xs:double"/>

<xs:element minOccurs="0" name="delayFlatLineSecondsSpentLimit" type="xs:long"/>

<xs:element maxOccurs="unbounded" minOccurs="0" name="termination" type="tns:terminationConfig"/>

</xs:sequence>
Expand Down
Loading
Loading