Skip to content

Commit

Permalink
fix: Use inserted range when retracting range in toConnectedRanges
Browse files Browse the repository at this point in the history
The `ConnectedRangesCalculator` incorrectly assumes the range for each
value is constant. The fix is to save the range on insert, and then
pass the saved range on retract. The other calculators do not have
this problem, because either:

- Their input should be immutable (ex: number)
- They used a map in their internal data structure that
  saved the original value
  • Loading branch information
Christopher-Chianelli committed Aug 27, 2024
1 parent 85e6812 commit 0d0ad32
Show file tree
Hide file tree
Showing 36 changed files with 224 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import ai.timefold.solver.core.api.score.stream.common.ConnectedRangeChain;
import ai.timefold.solver.core.impl.score.stream.collector.connected_ranges.ConnectedRangeTracker;
import ai.timefold.solver.core.impl.score.stream.collector.connected_ranges.Range;

public final class ConnectedRangesCalculator<Interval_, Point_ extends Comparable<Point_>, Difference_ extends Comparable<Difference_>>
implements ObjectCalculator<Interval_, ConnectedRangeChain<Interval_, Point_, Difference_>> {
implements
ObjectCalculator<Interval_, ConnectedRangeChain<Interval_, Point_, Difference_>, Range<Interval_, Point_>> {

private final ConnectedRangeTracker<Interval_, Point_, Difference_> context;

Expand All @@ -21,13 +23,15 @@ public ConnectedRangesCalculator(Function<? super Interval_, ? extends Point_> s
}

@Override
public void insert(Interval_ result) {
context.add(context.getRange(result));
public Range<Interval_, Point_> insert(Interval_ result) {
final var saved = context.getRange(result);
context.add(saved);
return saved;
}

@Override
public void retract(Interval_ result) {
context.remove(context.getRange(result));
public void retract(Range<Interval_, Point_> range) {
context.remove(range);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@

import ai.timefold.solver.core.impl.util.MutableInt;

public final class IntDistinctCountCalculator<Input_> implements ObjectCalculator<Input_, Integer> {
public final class IntDistinctCountCalculator<Input_> implements ObjectCalculator<Input_, Integer, Input_> {
private final Map<Input_, MutableInt> countMap = new HashMap<>();

@Override
public void insert(Input_ input) {
public Input_ insert(Input_ input) {
countMap.computeIfAbsent(input, ignored -> new MutableInt()).increment();
return input;
}

@Override
public void retract(Input_ input) {
if (countMap.get(input).decrement() == 0) {
countMap.remove(input);
public void retract(Input_ mapped) {
if (countMap.get(mapped).decrement() == 0) {
countMap.remove(mapped);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@

import ai.timefold.solver.core.impl.util.MutableInt;

public final class LongDistinctCountCalculator<Input_> implements ObjectCalculator<Input_, Long> {
public final class LongDistinctCountCalculator<Input_> implements ObjectCalculator<Input_, Long, Input_> {
private final Map<Input_, MutableInt> countMap = new HashMap<>();

@Override
public void insert(Input_ input) {
public Input_ insert(Input_ input) {
countMap.computeIfAbsent(input, ignored -> new MutableInt()).increment();
return input;
}

@Override
public void retract(Input_ input) {
if (countMap.get(input).decrement() == 0) {
countMap.remove(input);
public void retract(Input_ mapped) {
if (countMap.get(mapped).decrement() == 0) {
countMap.remove(mapped);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package ai.timefold.solver.core.impl.score.stream.collector;

public sealed interface ObjectCalculator<Input_, Output_>
public sealed interface ObjectCalculator<Input_, Output_, Mapped_>
permits ConnectedRangesCalculator, IntDistinctCountCalculator, LongDistinctCountCalculator, ReferenceAverageCalculator,
ReferenceSumCalculator, SequenceCalculator {
void insert(Input_ input);
Mapped_ insert(Input_ input);

void retract(Input_ input);
void retract(Mapped_ mapped);

Output_ result();
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import java.util.function.BinaryOperator;
import java.util.function.Supplier;

public final class ReferenceAverageCalculator<Input_, Output_> implements ObjectCalculator<Input_, Output_> {
public final class ReferenceAverageCalculator<Input_, Output_> implements ObjectCalculator<Input_, Output_, Input_> {
int count = 0;
Input_ sum;
final BinaryOperator<Input_> adder;
Expand Down Expand Up @@ -51,15 +51,16 @@ public static Supplier<ReferenceAverageCalculator<Duration, Duration>> duration(
}

@Override
public void insert(Input_ input) {
public Input_ insert(Input_ input) {
count++;
sum = adder.apply(sum, input);
return input;
}

@Override
public void retract(Input_ input) {
public void retract(Input_ mapped) {
count--;
sum = subtractor.apply(sum, input);
sum = subtractor.apply(sum, mapped);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import java.util.function.BinaryOperator;

public final class ReferenceSumCalculator<Result_> implements ObjectCalculator<Result_, Result_> {
public final class ReferenceSumCalculator<Result_> implements ObjectCalculator<Result_, Result_, Result_> {
private Result_ current;
private final BinaryOperator<Result_> adder;
private final BinaryOperator<Result_> subtractor;
Expand All @@ -14,13 +14,14 @@ public ReferenceSumCalculator(Result_ current, BinaryOperator<Result_> adder, Bi
}

@Override
public void insert(Result_ input) {
public Result_ insert(Result_ input) {
current = adder.apply(current, input);
return input;
}

@Override
public void retract(Result_ input) {
current = subtractor.apply(current, input);
public void retract(Result_ mapped) {
current = subtractor.apply(current, mapped);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import ai.timefold.solver.core.impl.score.stream.collector.consecutive.ConsecutiveSetTree;

public final class SequenceCalculator<Result_>
implements ObjectCalculator<Result_, SequenceChain<Result_, Integer>> {
implements ObjectCalculator<Result_, SequenceChain<Result_, Integer>, Result_> {

private final ConsecutiveSetTree<Result_, Integer, Integer> context = new ConsecutiveSetTree<>(
(Integer a, Integer b) -> b - a,
Expand All @@ -20,9 +20,10 @@ public SequenceCalculator(ToIntFunction<Result_> indexMap) {
}

@Override
public void insert(Result_ result) {
public Result_ insert(Result_ result) {
var value = indexMap.applyAsInt(result);
context.add(result, value);
return result;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import ai.timefold.solver.core.impl.score.stream.collector.ReferenceAverageCalculator;

final class AverageReferenceBiCollector<A, B, Mapped_, Average_>
extends ObjectCalculatorBiCollector<A, B, Mapped_, Average_, ReferenceAverageCalculator<Mapped_, Average_>> {
extends ObjectCalculatorBiCollector<A, B, Mapped_, Average_, Mapped_, ReferenceAverageCalculator<Mapped_, Average_>> {
private final Supplier<ReferenceAverageCalculator<Mapped_, Average_>> calculatorSupplier;

AverageReferenceBiCollector(BiFunction<? super A, ? super B, ? extends Mapped_> mapper,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import ai.timefold.solver.core.api.score.stream.common.ConnectedRangeChain;
import ai.timefold.solver.core.impl.score.stream.collector.ConnectedRangesCalculator;
import ai.timefold.solver.core.impl.score.stream.collector.connected_ranges.Range;

final class ConnectedRangesBiConstraintCollector<A, B, Interval_, Point_ extends Comparable<Point_>, Difference_ extends Comparable<Difference_>>
extends
ObjectCalculatorBiCollector<A, B, Interval_, ConnectedRangeChain<Interval_, Point_, Difference_>, ConnectedRangesCalculator<Interval_, Point_, Difference_>> {
ObjectCalculatorBiCollector<A, B, Interval_, ConnectedRangeChain<Interval_, Point_, Difference_>, Range<Interval_, Point_>, ConnectedRangesCalculator<Interval_, Point_, Difference_>> {

private final Function<? super Interval_, ? extends Point_> startMap;
private final Function<? super Interval_, ? extends Point_> endMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

final class ConsecutiveSequencesBiConstraintCollector<A, B, Result_>
extends
ObjectCalculatorBiCollector<A, B, Result_, SequenceChain<Result_, Integer>, SequenceCalculator<Result_>> {
ObjectCalculatorBiCollector<A, B, Result_, SequenceChain<Result_, Integer>, Result_, SequenceCalculator<Result_>> {

private final ToIntFunction<Result_> indexMap;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ai.timefold.solver.core.impl.score.stream.collector.IntDistinctCountCalculator;

final class CountDistinctIntBiCollector<A, B, Mapped_>
extends ObjectCalculatorBiCollector<A, B, Mapped_, Integer, IntDistinctCountCalculator<Mapped_>> {
extends ObjectCalculatorBiCollector<A, B, Mapped_, Integer, Mapped_, IntDistinctCountCalculator<Mapped_>> {
CountDistinctIntBiCollector(BiFunction<? super A, ? super B, ? extends Mapped_> mapper) {
super(mapper);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ai.timefold.solver.core.impl.score.stream.collector.LongDistinctCountCalculator;

final class CountDistinctLongBiCollector<A, B, Mapped_>
extends ObjectCalculatorBiCollector<A, B, Mapped_, Long, LongDistinctCountCalculator<Mapped_>> {
extends ObjectCalculatorBiCollector<A, B, Mapped_, Long, Mapped_, LongDistinctCountCalculator<Mapped_>> {
CountDistinctLongBiCollector(BiFunction<? super A, ? super B, ? extends Mapped_> mapper) {
super(mapper);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import ai.timefold.solver.core.api.score.stream.bi.BiConstraintCollector;
import ai.timefold.solver.core.impl.score.stream.collector.ObjectCalculator;

abstract sealed class ObjectCalculatorBiCollector<A, B, Input_, Output_, Calculator_ extends ObjectCalculator<Input_, Output_>>
abstract sealed class ObjectCalculatorBiCollector<A, B, Input_, Output_, Mapped_, Calculator_ extends ObjectCalculator<Input_, Output_, Mapped_>>
implements BiConstraintCollector<A, B, Calculator_, Output_>
permits AverageReferenceBiCollector, ConnectedRangesBiConstraintCollector, ConsecutiveSequencesBiConstraintCollector,
CountDistinctIntBiCollector, CountDistinctLongBiCollector, SumReferenceBiCollector {
Expand All @@ -21,9 +21,9 @@ public ObjectCalculatorBiCollector(BiFunction<? super A, ? super B, ? extends In
@Override
public TriFunction<Calculator_, A, B, Runnable> accumulator() {
return (calculator, a, b) -> {
final Input_ mapped = mapper.apply(a, b);
calculator.insert(mapped);
return () -> calculator.retract(mapped);
final var mapped = mapper.apply(a, b);
final var saved = calculator.insert(mapped);
return () -> calculator.retract(saved);
};
}

Expand All @@ -38,7 +38,7 @@ public boolean equals(Object object) {
return true;
if (object == null || getClass() != object.getClass())
return false;
var that = (ObjectCalculatorBiCollector<?, ?, ?, ?, ?>) object;
var that = (ObjectCalculatorBiCollector<?, ?, ?, ?, ?, ?>) object;
return Objects.equals(mapper, that.mapper);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import ai.timefold.solver.core.impl.score.stream.collector.ReferenceSumCalculator;

final class SumReferenceBiCollector<A, B, Result_>
extends ObjectCalculatorBiCollector<A, B, Result_, Result_, ReferenceSumCalculator<Result_>> {
extends ObjectCalculatorBiCollector<A, B, Result_, Result_, Result_, ReferenceSumCalculator<Result_>> {
private final Result_ zero;
private final BinaryOperator<Result_> adder;
private final BinaryOperator<Result_> subtractor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import ai.timefold.solver.core.impl.score.stream.collector.ReferenceAverageCalculator;

final class AverageReferenceQuadCollector<A, B, C, D, Mapped_, Average_>
extends ObjectCalculatorQuadCollector<A, B, C, D, Mapped_, Average_, ReferenceAverageCalculator<Mapped_, Average_>> {
extends
ObjectCalculatorQuadCollector<A, B, C, D, Mapped_, Average_, Mapped_, ReferenceAverageCalculator<Mapped_, Average_>> {
private final Supplier<ReferenceAverageCalculator<Mapped_, Average_>> calculatorSupplier;

AverageReferenceQuadCollector(QuadFunction<? super A, ? super B, ? super C, ? super D, ? extends Mapped_> mapper,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import ai.timefold.solver.core.api.function.QuadFunction;
import ai.timefold.solver.core.api.score.stream.common.ConnectedRangeChain;
import ai.timefold.solver.core.impl.score.stream.collector.ConnectedRangesCalculator;
import ai.timefold.solver.core.impl.score.stream.collector.connected_ranges.Range;

final class ConnectedRangesQuadConstraintCollector<A, B, C, D, Interval_, Point_ extends Comparable<Point_>, Difference_ extends Comparable<Difference_>>
extends
ObjectCalculatorQuadCollector<A, B, C, D, Interval_, ConnectedRangeChain<Interval_, Point_, Difference_>, ConnectedRangesCalculator<Interval_, Point_, Difference_>> {
ObjectCalculatorQuadCollector<A, B, C, D, Interval_, ConnectedRangeChain<Interval_, Point_, Difference_>, Range<Interval_, Point_>, ConnectedRangesCalculator<Interval_, Point_, Difference_>> {

private final Function<? super Interval_, ? extends Point_> startMap;
private final Function<? super Interval_, ? extends Point_> endMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

final class ConsecutiveSequencesQuadConstraintCollector<A, B, C, D, Result_>
extends
ObjectCalculatorQuadCollector<A, B, C, D, Result_, SequenceChain<Result_, Integer>, SequenceCalculator<Result_>> {
ObjectCalculatorQuadCollector<A, B, C, D, Result_, SequenceChain<Result_, Integer>, Result_, SequenceCalculator<Result_>> {

private final ToIntFunction<Result_> indexMap;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ai.timefold.solver.core.impl.score.stream.collector.IntDistinctCountCalculator;

final class CountDistinctIntQuadCollector<A, B, C, D, Mapped_>
extends ObjectCalculatorQuadCollector<A, B, C, D, Mapped_, Integer, IntDistinctCountCalculator<Mapped_>> {
extends ObjectCalculatorQuadCollector<A, B, C, D, Mapped_, Integer, Mapped_, IntDistinctCountCalculator<Mapped_>> {
CountDistinctIntQuadCollector(QuadFunction<? super A, ? super B, ? super C, ? super D, ? extends Mapped_> mapper) {
super(mapper);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ai.timefold.solver.core.impl.score.stream.collector.LongDistinctCountCalculator;

final class CountDistinctLongQuadCollector<A, B, C, D, Mapped_>
extends ObjectCalculatorQuadCollector<A, B, C, D, Mapped_, Long, LongDistinctCountCalculator<Mapped_>> {
extends ObjectCalculatorQuadCollector<A, B, C, D, Mapped_, Long, Mapped_, LongDistinctCountCalculator<Mapped_>> {
CountDistinctLongQuadCollector(QuadFunction<? super A, ? super B, ? super C, ? super D, ? extends Mapped_> mapper) {
super(mapper);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import ai.timefold.solver.core.api.score.stream.quad.QuadConstraintCollector;
import ai.timefold.solver.core.impl.score.stream.collector.ObjectCalculator;

abstract sealed class ObjectCalculatorQuadCollector<A, B, C, D, Input_, Output_, Calculator_ extends ObjectCalculator<Input_, Output_>>
abstract sealed class ObjectCalculatorQuadCollector<A, B, C, D, Input_, Output_, Mapped_, Calculator_ extends ObjectCalculator<Input_, Output_, Mapped_>>
implements QuadConstraintCollector<A, B, C, D, Calculator_, Output_>
permits AverageReferenceQuadCollector, ConnectedRangesQuadConstraintCollector,
ConsecutiveSequencesQuadConstraintCollector, CountDistinctIntQuadCollector, CountDistinctLongQuadCollector,
Expand All @@ -23,9 +23,9 @@ public ObjectCalculatorQuadCollector(QuadFunction<? super A, ? super B, ? super
@Override
public PentaFunction<Calculator_, A, B, C, D, Runnable> accumulator() {
return (calculator, a, b, c, d) -> {
final Input_ mapped = mapper.apply(a, b, c, d);
calculator.insert(mapped);
return () -> calculator.retract(mapped);
final var mapped = mapper.apply(a, b, c, d);
final var saved = calculator.insert(mapped);
return () -> calculator.retract(saved);
};
}

Expand All @@ -40,7 +40,7 @@ public boolean equals(Object object) {
return true;
if (object == null || getClass() != object.getClass())
return false;
var that = (ObjectCalculatorQuadCollector<?, ?, ?, ?, ?, ?, ?>) object;
var that = (ObjectCalculatorQuadCollector<?, ?, ?, ?, ?, ?, ?, ?>) object;
return Objects.equals(mapper, that.mapper);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import ai.timefold.solver.core.impl.score.stream.collector.ReferenceSumCalculator;

final class SumReferenceQuadCollector<A, B, C, D, Result_>
extends ObjectCalculatorQuadCollector<A, B, C, D, Result_, Result_, ReferenceSumCalculator<Result_>> {
extends ObjectCalculatorQuadCollector<A, B, C, D, Result_, Result_, Result_, ReferenceSumCalculator<Result_>> {
private final Result_ zero;
private final BinaryOperator<Result_> adder;
private final BinaryOperator<Result_> subtractor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import ai.timefold.solver.core.impl.score.stream.collector.ReferenceAverageCalculator;

final class AverageReferenceTriCollector<A, B, C, Mapped_, Average_>
extends ObjectCalculatorTriCollector<A, B, C, Mapped_, Average_, ReferenceAverageCalculator<Mapped_, Average_>> {
extends
ObjectCalculatorTriCollector<A, B, C, Mapped_, Average_, Mapped_, ReferenceAverageCalculator<Mapped_, Average_>> {
private final Supplier<ReferenceAverageCalculator<Mapped_, Average_>> calculatorSupplier;

AverageReferenceTriCollector(TriFunction<? super A, ? super B, ? super C, ? extends Mapped_> mapper,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import ai.timefold.solver.core.api.function.TriFunction;
import ai.timefold.solver.core.api.score.stream.common.ConnectedRangeChain;
import ai.timefold.solver.core.impl.score.stream.collector.ConnectedRangesCalculator;
import ai.timefold.solver.core.impl.score.stream.collector.connected_ranges.Range;

final class ConnectedRangesTriConstraintCollector<A, B, C, Interval_, Point_ extends Comparable<Point_>, Difference_ extends Comparable<Difference_>>
extends
ObjectCalculatorTriCollector<A, B, C, Interval_, ConnectedRangeChain<Interval_, Point_, Difference_>, ConnectedRangesCalculator<Interval_, Point_, Difference_>> {
ObjectCalculatorTriCollector<A, B, C, Interval_, ConnectedRangeChain<Interval_, Point_, Difference_>, Range<Interval_, Point_>, ConnectedRangesCalculator<Interval_, Point_, Difference_>> {

private final Function<? super Interval_, ? extends Point_> startMap;
private final Function<? super Interval_, ? extends Point_> endMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

final class ConsecutiveSequencesTriConstraintCollector<A, B, C, Result_>
extends
ObjectCalculatorTriCollector<A, B, C, Result_, SequenceChain<Result_, Integer>, SequenceCalculator<Result_>> {
ObjectCalculatorTriCollector<A, B, C, Result_, SequenceChain<Result_, Integer>, Result_, SequenceCalculator<Result_>> {

private final ToIntFunction<Result_> indexMap;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ai.timefold.solver.core.impl.score.stream.collector.IntDistinctCountCalculator;

final class CountDistinctIntTriCollector<A, B, C, Mapped_>
extends ObjectCalculatorTriCollector<A, B, C, Mapped_, Integer, IntDistinctCountCalculator<Mapped_>> {
extends ObjectCalculatorTriCollector<A, B, C, Mapped_, Integer, Mapped_, IntDistinctCountCalculator<Mapped_>> {
CountDistinctIntTriCollector(TriFunction<? super A, ? super B, ? super C, ? extends Mapped_> mapper) {
super(mapper);
}
Expand Down
Loading

0 comments on commit 0d0ad32

Please sign in to comment.