Skip to content

Commit

Permalink
feat: add constraint weight to ScoreAnalysis (#416)
Browse files Browse the repository at this point in the history
Introduces constraint weight to ScoreAnalysis.
Sorts constraint matches by constraint weight first.
  • Loading branch information
triceo authored Nov 15, 2023
1 parent 4e7d7b3 commit 4415f4d
Show file tree
Hide file tree
Showing 13 changed files with 109 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import java.util.stream.Stream;

import ai.timefold.solver.core.api.score.Score;
import ai.timefold.solver.core.api.score.calculator.ConstraintMatchAwareIncrementalScoreCalculator;
import ai.timefold.solver.core.api.score.constraint.ConstraintRef;
import ai.timefold.solver.core.api.score.stream.ConstraintJustification;
import ai.timefold.solver.core.api.solver.SolutionManager;
import ai.timefold.solver.core.impl.score.constraint.DefaultConstraintMatchTotal;
import ai.timefold.solver.core.impl.util.CollectionUtils;

/**
Expand All @@ -19,32 +21,49 @@
*
* @param <Score_>
* @param constraintRef never null
* @param weight never null
* @param score never null
* @param matches null if analysis not available;
* empty if constraint has no matches, but still non-zero constraint weight;
* non-empty if constraint has matches.
* This is a {@link List} to simplify access to individual elements,
* but it contains no duplicates just like {@link HashSet} wouldn't.
*/
public record ConstraintAnalysis<Score_ extends Score<Score_>>(ConstraintRef constraintRef, Score_ score,
List<MatchAnalysis<Score_>> matches) {
public record ConstraintAnalysis<Score_ extends Score<Score_>>(ConstraintRef constraintRef, Score_ weight,
Score_ score, List<MatchAnalysis<Score_>> matches) {

static <Score_ extends Score<Score_>> ConstraintAnalysis<Score_> of(ConstraintRef constraintRef, Score_ score) {
return new ConstraintAnalysis<>(constraintRef, score, null);
static <Score_ extends Score<Score_>> ConstraintAnalysis<Score_> of(ConstraintRef constraintRef, Score_ constraintWeight,
Score_ score) {
return new ConstraintAnalysis<>(constraintRef, constraintWeight, score, null);
}

public ConstraintAnalysis {
Objects.requireNonNull(constraintRef);
if (weight == null) {
/*
* Only possible in ConstraintMatchAwareIncrementalScoreCalculator and/or tests.
* Easy doesn't support constraint analysis at all.
* CS always provides constraint weights.
*/
throw new IllegalArgumentException("""
The constraint weight must be non-null.
Maybe use a non-deprecated %s constructor in your %s implementation?
"""
.stripTrailing()
.formatted(DefaultConstraintMatchTotal.class.getSimpleName(),
ConstraintMatchAwareIncrementalScoreCalculator.class.getSimpleName()));
}
Objects.requireNonNull(score);
}

ConstraintAnalysis<Score_> negate() {
if (matches == null) {
return ConstraintAnalysis.of(constraintRef, score.negate());
return ConstraintAnalysis.of(constraintRef, weight.negate(), score.negate());
} else {
var negatedMatchAnalyses = matches.stream()
.map(MatchAnalysis::negate)
.toList();
return new ConstraintAnalysis<>(constraintRef, score.negate(), negatedMatchAnalyses);
return new ConstraintAnalysis<>(constraintRef, weight.negate(), score.negate(), negatedMatchAnalyses);
}
}

Expand All @@ -71,9 +90,10 @@ static <Score_ extends Score<Score_>> ConstraintAnalysis<Score_> diff(
.formatted(constraintAnalysis, otherConstraintAnalysis, constraintRef));
}
// Compute the diff.
var constraintWeightDifference = constraintAnalysis.weight().subtract(otherConstraintAnalysis.weight());
var scoreDifference = constraintAnalysis.score().subtract(otherConstraintAnalysis.score());
if (matchAnalyses == null) {
return ConstraintAnalysis.of(constraintRef, scoreDifference);
return ConstraintAnalysis.of(constraintRef, constraintWeightDifference, scoreDifference);
}
var matchAnalysisMap = mapMatchesToJustifications(matchAnalyses);
var otherMatchAnalysisMap = mapMatchesToJustifications(otherMatchAnalyses);
Expand All @@ -99,7 +119,7 @@ static <Score_ extends Score<Score_>> ConstraintAnalysis<Score_> diff(
}
})
.collect(Collectors.toList());
return new ConstraintAnalysis<>(constraintRef, scoreDifference, result);
return new ConstraintAnalysis<>(constraintRef, constraintWeightDifference, scoreDifference, result);
}

private static <Score_ extends Score<Score_>> Map<ConstraintJustification, MatchAnalysis<Score_>>
Expand All @@ -121,9 +141,11 @@ static <Score_ extends Score<Score_>> ConstraintAnalysis<Score_> diff(
@Override
public String toString() {
if (matches == null) {
return "(" + score + ", no match analysis)";
return "(%s at %s, no matches)"
.formatted(score, weight);
} else {
return "(" + score + ", " + matches.size() + " matches)";
return "(%s at %s, %s matches)"
.formatted(score, weight, matches.size());
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package ai.timefold.solver.core.api.score.analysis;

import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
Expand Down Expand Up @@ -38,8 +40,13 @@
*
* @param score never null
* @param constraintMap never null;
* constraints will be present even if they have no matches, unless their weight is zero;
* for each constraint identified by its {@link Constraint#getConstraintRef()},
* the {@link ConstraintAnalysis} that describes the impact of that constraint on the overall score.
* Constraints are present even if they have no matches, unless their weight is zero;
* zero-weight constraints are not present.
* Entries in the map have a stable iteration order; items are ordered first by {@link ConstraintAnalysis#weight()},
* then by {@link ConstraintAnalysis#constraintRef()}.
*
* @param <Score_>
*/
public record ScoreAnalysis<Score_ extends Score<Score_>>(Score_ score,
Expand All @@ -52,17 +59,17 @@ public record ScoreAnalysis<Score_ extends Score<Score_>>(Score_ score,
throw new IllegalArgumentException("The constraintMap must not be empty.");
}
// Ensure consistent order and no external interference.
constraintMap = Collections.unmodifiableMap(new TreeMap<>(constraintMap));
}

/**
* For each constraint identified by its {@link Constraint#getConstraintRef()} id},
* the {@link ConstraintAnalysis} that describes the impact of that constraint on the overall score.
*
* @return never null, unmodifiable
*/
public Map<ConstraintRef, ConstraintAnalysis<Score_>> constraintMap() {
return constraintMap;
var comparator = Comparator.<ConstraintAnalysis<Score_>, Score_> comparing(ConstraintAnalysis::weight)
.reversed()
.thenComparing(ConstraintAnalysis::constraintRef);
constraintMap = Collections.unmodifiableMap(constraintMap.values()
.stream()
.sorted(comparator)
.collect(Collectors.toMap(
ConstraintAnalysis::constraintRef,
Function.identity(),
(constraintAnalysis, otherConstraintAnalysis) -> constraintAnalysis,
LinkedHashMap::new)));
}

/**
Expand Down Expand Up @@ -127,8 +134,4 @@ public ScoreAnalysis<Score_> diff(ScoreAnalysis<Score_> other) {
return new ScoreAnalysis<>(score.subtract(other.score()), result);
}

@Override
public String toString() {
return "(" + score + ", " + constraintMap + ")";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ public DefaultConstraintMatchTotal(String constraintPackage, String constraintNa
this(ConstraintRef.of(constraintPackage, constraintName));
}

/**
*
* @deprecated Prefer {@link #DefaultConstraintMatchTotal(ConstraintRef, Score_)}.
*/
@Deprecated(forRemoval = true, since = "1.5.0")
public DefaultConstraintMatchTotal(ConstraintRef constraintRef) {
this.constraintRef = requireNonNull(constraintRef);
this.constraintWeight = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ static <Score_ extends Score<Score_>> ConstraintAnalysis<Score_> getConstraintAn
return new MatchAnalysis<>(constraintMatchTotal.getConstraintRef(), score, entry.getKey());
})
.toList();
return new ConstraintAnalysis<>(constraintMatchTotal.getConstraintRef(), constraintMatchTotal.getScore(),
return new ConstraintAnalysis<>(constraintMatchTotal.getConstraintRef(), constraintMatchTotal.getConstraintWeight(),
constraintMatchTotal.getScore(),
matchAnalyses);
} else {
return new ConstraintAnalysis<>(constraintMatchTotal.getConstraintRef(), constraintMatchTotal.getScore(), null);
return new ConstraintAnalysis<>(constraintMatchTotal.getConstraintRef(), constraintMatchTotal.getConstraintWeight(),
constraintMatchTotal.getScore(), null);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class TestdataIncrementalScoreCalculator
public void resetWorkingSolution(TestdataSolution workingSolution) {
score = 0;
constraintMatchTotal = new DefaultConstraintMatchTotal<>(
ConstraintRef.of("ai.timefold.solver.core.impl.testdata.domain", "testConstraint"));
ConstraintRef.of("ai.timefold.solver.core.impl.testdata.domain", "testConstraint"), SimpleScore.ONE);
indictmentMap = new HashMap<>();
for (TestdataEntity left : workingSolution.getEntityList()) {
TestdataValue value = left.getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ public class TestdataShadowingChainedIncrementalScoreCalculator
public void resetWorkingSolution(TestdataShadowingChainedSolution workingSolution) {
score = 0;
constraintMatchTotal = new DefaultConstraintMatchTotal<>(
ConstraintRef.of("ai.timefold.solver.core.impl.testdata.domain.chained.shadow", "testConstraint"));
ConstraintRef.of("ai.timefold.solver.core.impl.testdata.domain.chained.shadow", "testConstraint"),
SimpleScore.ONE);
indictmentMap = new HashMap<>();
for (TestdataShadowingChainedEntity left : workingSolution.getChainedEntityList()) {
String code = left.getCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ public class TestdataListWithShadowHistoryIncrementalScoreCalculator
public void resetWorkingSolution(TestdataListSolutionWithShadowHistory workingSolution) {
score = 0;
constraintMatchTotal = new DefaultConstraintMatchTotal<>(
ConstraintRef.of("ai.timefold.solver.core.impl.testdata.domain.chained.shadow", "testConstraint"));
ConstraintRef.of("ai.timefold.solver.core.impl.testdata.domain.chained.shadow", "testConstraint"),
SimpleScore.ONE);
indictmentMap = new HashMap<>();
for (TestdataListEntityWithShadowHistory left : workingSolution.getEntityList()) {
String code = left.getCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public final class TestdataNullableIncrementalScoreCalculator
public void resetWorkingSolution(TestdataNullableSolution workingSolution) {
score = 0;
constraintMatchTotal = new DefaultConstraintMatchTotal<>(
ConstraintRef.of("ai.timefold.solver.core.impl.testdata.domain.shadow", "testConstraint"));
ConstraintRef.of("ai.timefold.solver.core.impl.testdata.domain.shadow", "testConstraint"), SimpleScore.ONE);
indictmentMap = new HashMap<>();
for (TestdataNullableEntity left : workingSolution.getEntityList()) {
TestdataValue value = left.getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public class TestdataShadowedIncrementalScoreCalculator
public void resetWorkingSolution(TestdataShadowedSolution workingSolution) {
score = 0;
constraintMatchTotal = new DefaultConstraintMatchTotal<>(
ConstraintRef.of("ai.timefold.solver.core.impl.testdata.domain.shadow", "testConstraint"));
ConstraintRef.of("ai.timefold.solver.core.impl.testdata.domain.shadow", "testConstraint"),
SimpleScore.ONE);
indictmentMap = new HashMap<>();
for (TestdataShadowedEntity left : workingSolution.getEntityList()) {
TestdataValue value = left.getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,23 +440,26 @@ public void resetWorkingSolution(MachineReassignment workingSolution, boolean co
@Override
public Collection<ConstraintMatchTotal<HardSoftLongScore>> getConstraintMatchTotals() {
DefaultConstraintMatchTotal<HardSoftLongScore> maximumCapacityMatchTotal =
new DefaultConstraintMatchTotal<>(ConstraintRef.of(CONSTRAINT_PACKAGE, MrConstraints.MAXIMUM_CAPACITY));
getConstraintMatchTotal(MrConstraints.MAXIMUM_CAPACITY, HardSoftLongScore.ONE_HARD);
DefaultConstraintMatchTotal<HardSoftLongScore> serviceConflictMatchTotal =
new DefaultConstraintMatchTotal<>(ConstraintRef.of(CONSTRAINT_PACKAGE, MrConstraints.SERVICE_CONFLICT));
getConstraintMatchTotal(MrConstraints.SERVICE_CONFLICT, HardSoftLongScore.ONE_HARD);
DefaultConstraintMatchTotal<HardSoftLongScore> serviceLocationSpreadMatchTotal =
new DefaultConstraintMatchTotal<>(ConstraintRef.of(CONSTRAINT_PACKAGE, MrConstraints.SERVICE_LOCATION_SPREAD));
getConstraintMatchTotal(MrConstraints.SERVICE_LOCATION_SPREAD, HardSoftLongScore.ONE_HARD);
DefaultConstraintMatchTotal<HardSoftLongScore> serviceDependencyMatchTotal =
new DefaultConstraintMatchTotal<>(ConstraintRef.of(CONSTRAINT_PACKAGE, MrConstraints.SERVICE_DEPENDENCY));
getConstraintMatchTotal(MrConstraints.SERVICE_DEPENDENCY, HardSoftLongScore.ONE_HARD);
DefaultConstraintMatchTotal<HardSoftLongScore> loadCostMatchTotal =
new DefaultConstraintMatchTotal<>(ConstraintRef.of(CONSTRAINT_PACKAGE, MrConstraints.LOAD_COST));
getConstraintMatchTotal(MrConstraints.LOAD_COST, HardSoftLongScore.ONE_SOFT);
DefaultConstraintMatchTotal<HardSoftLongScore> balanceCostMatchTotal =
new DefaultConstraintMatchTotal<>(ConstraintRef.of(CONSTRAINT_PACKAGE, MrConstraints.BALANCE_COST));
getConstraintMatchTotal(MrConstraints.BALANCE_COST, HardSoftLongScore.ONE_SOFT);
DefaultConstraintMatchTotal<HardSoftLongScore> processMoveCostMatchTotal =
new DefaultConstraintMatchTotal<>(ConstraintRef.of(CONSTRAINT_PACKAGE, MrConstraints.PROCESS_MOVE_COST));
getConstraintMatchTotal(MrConstraints.PROCESS_MOVE_COST,
HardSoftLongScore.ofSoft(globalPenaltyInfo.getProcessMoveCostWeight()));
DefaultConstraintMatchTotal<HardSoftLongScore> serviceMoveCostMatchTotal =
new DefaultConstraintMatchTotal<>(ConstraintRef.of(CONSTRAINT_PACKAGE, MrConstraints.SERVICE_MOVE_COST));
getConstraintMatchTotal(MrConstraints.SERVICE_MOVE_COST,
HardSoftLongScore.ofSoft(globalPenaltyInfo.getServiceMoveCostWeight()));
DefaultConstraintMatchTotal<HardSoftLongScore> machineMoveCostMatchTotal =
new DefaultConstraintMatchTotal<>(ConstraintRef.of(CONSTRAINT_PACKAGE, MrConstraints.MACHINE_MOVE_COST));
getConstraintMatchTotal(MrConstraints.MACHINE_MOVE_COST,
HardSoftLongScore.ofSoft(globalPenaltyInfo.getMachineMoveCostWeight()));

for (MrServiceScorePart serviceScorePart : serviceScorePartMap.values()) {
MrService service = serviceScorePart.service;
Expand Down Expand Up @@ -540,6 +543,11 @@ public Collection<ConstraintMatchTotal<HardSoftLongScore>> getConstraintMatchTot
return constraintMatchTotalList;
}

private static DefaultConstraintMatchTotal<HardSoftLongScore> getConstraintMatchTotal(String constraintName,
HardSoftLongScore constraintWeight) {
return new DefaultConstraintMatchTotal<>(ConstraintRef.of(CONSTRAINT_PACKAGE, constraintName), constraintWeight);
}

@Override
public Map<Object, Indictment<HardSoftLongScore>> getIndictmentMap() {
return null; // Calculate it non-incrementally from getConstraintMatchTotals()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ public final ScoreAnalysis<Score_> deserialize(JsonParser p, DeserializationCont
var constraintPackage = constraintNode.get("package").asText();
var constraintName = constraintNode.get("name").asText();
var constraintRef = ConstraintRef.of(constraintPackage, constraintName);
var constraintWeight = parseScore(constraintNode.get("weight").asText());
var constraintScore = parseScore(constraintNode.get("score").asText());
var matchScoreList = new ArrayList<MatchAnalysis<Score_>>();
JsonNode matchesNode = constraintNode.get("matches");
if (matchesNode == null) {
constraintAnalysisList.put(constraintRef, new ConstraintAnalysis<>(constraintRef, constraintScore, null));
constraintAnalysisList.put(constraintRef,
new ConstraintAnalysis<>(constraintRef, constraintWeight, constraintScore, null));
} else {
constraintNode.get("matches").forEach(matchNode -> {
var matchScore = parseScore(matchNode.get("score").asText());
Expand All @@ -48,7 +50,7 @@ public final ScoreAnalysis<Score_> deserialize(JsonParser p, DeserializationCont
matchScoreList.add(new MatchAnalysis<>(constraintRef, matchScore, justification));
});
constraintAnalysisList.put(constraintRef,
new ConstraintAnalysis<>(constraintRef, constraintScore, matchScoreList));
new ConstraintAnalysis<>(constraintRef, constraintWeight, constraintScore, matchScoreList));
}
});
return new ScoreAnalysis<>(score, constraintAnalysisList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public void serialize(ScoreAnalysis<Score_> value, JsonGenerator gen, Serializer
Map<String, Object> constraintAnalysisMap = new LinkedHashMap<>();
constraintAnalysisMap.put("package", constraintRef.packageName());
constraintAnalysisMap.put("name", constraintRef.constraintName());
constraintAnalysisMap.put("weight", constraintAnalysis.weight().toString());
constraintAnalysisMap.put("score", constraintAnalysis.score().toString());
if (constraintAnalysis.matches() != null) {
List<Map<String, Object>> matchAnalysis = new ArrayList<>(constraintAnalysis.matches().size());
Expand Down
Loading

0 comments on commit 4415f4d

Please sign in to comment.