Skip to content

Commit

Permalink
perf: revert most of 2c4715f (#1319)
Browse files Browse the repository at this point in the history
The commit in question caused a 10 % regression in some benchmarks.
We only keep changes that are guaranteed not to regress.
The reverted changes will be re-evaluated and, if beneficial,
re-submitted later.
  • Loading branch information
triceo authored Jan 13, 2025
1 parent c41ba83 commit df5f405
Show file tree
Hide file tree
Showing 32 changed files with 427 additions and 444 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,35 +125,33 @@ public BavetConstraintSession<Score_> buildSession(Solution_ workingSolution, Co
buildNodeNetwork(workingSolution, constraintStreamSet, scoreInliner, nodeNetworkVisualizationConsumer));
}

@SuppressWarnings("unchecked")
private static <Solution_, Score_ extends Score<Score_>> NodeNetwork buildNodeNetwork(Solution_ workingSolution,
Set<BavetAbstractConstraintStream<Solution_>> constraintStreamSet, AbstractScoreInliner<Score_> scoreInliner,
Consumer<String> nodeNetworkVisualizationConsumer) {
/*
* Build constraintStreamSet in reverse order to create downstream nodes first
* so every node only has final variables (some of which have downstream node method references).
*/
var buildHelper = new NodeBuildHelper<>(constraintStreamSet, scoreInliner);
var nodeList = buildNodeList(constraintStreamSet, buildHelper);
var declaredClassToNodeMap = new LinkedHashMap<Class<?>, List<AbstractForEachUniNode<?>>>();
var nodeList = buildNodeList(constraintStreamSet, buildHelper, node -> {
if (!(node instanceof AbstractForEachUniNode<?> forEachUniNode)) {
return;
}
var forEachClass = forEachUniNode.getForEachClass();
var forEachUniNodeList =
declaredClassToNodeMap.computeIfAbsent(forEachClass, k -> new ArrayList<>(2));
if (forEachUniNodeList.size() == 2) {
// Each class can have at most two forEach nodes: one including null vars, the other excluding them.
throw new IllegalStateException(
"Impossible state: For class (%s) there are already 2 nodes (%s), not adding another (%s)."
.formatted(forEachClass, forEachUniNodeList, forEachUniNode));
}
forEachUniNodeList.add(forEachUniNode);
});
if (nodeNetworkVisualizationConsumer != null) {
var visualisation = visualizeNodeNetwork(workingSolution, buildHelper, scoreInliner, nodeList);
var constraintSet = scoreInliner.getConstraints();
var visualisation = NodeGraph.of(workingSolution, nodeList, constraintSet, buildHelper::getNodeCreatingStream,
buildHelper::findParentNode)
.buildGraphvizDOT();
nodeNetworkVisualizationConsumer.accept(visualisation);
}
var declaredClassToNodeMap = new LinkedHashMap<Class<?>, List<AbstractForEachUniNode<Object>>>();
for (var node : nodeList) {
if (node instanceof AbstractForEachUniNode<?> forEachUniNode) {
var forEachClass = forEachUniNode.getForEachClass();
var forEachUniNodeList =
declaredClassToNodeMap.computeIfAbsent(forEachClass, k -> new ArrayList<>());
if (forEachUniNodeList.size() == 2) {
// Each class can have at most two forEach nodes: one including null vars, the other excluding them.
throw new IllegalStateException("Impossible state: For class (" + forEachClass
+ ") there are already 2 nodes (" + forEachUniNodeList + "), not adding another ("
+ forEachUniNode + ").");
}
forEachUniNodeList.add((AbstractForEachUniNode<Object>) forEachUniNode);
}
}
var layerMap = new TreeMap<Long, List<Propagator>>();
for (var node : nodeList) {
layerMap.computeIfAbsent(node.getLayerIndex(), k -> new ArrayList<>())
Expand All @@ -169,7 +167,8 @@ private static <Solution_, Score_ extends Score<Score_>> NodeNetwork buildNodeNe
}

private static <Solution_, Score_ extends Score<Score_>> List<AbstractNode> buildNodeList(
Set<BavetAbstractConstraintStream<Solution_>> constraintStreamSet, NodeBuildHelper<Score_> buildHelper) {
Set<BavetAbstractConstraintStream<Solution_>> constraintStreamSet, NodeBuildHelper<Score_> buildHelper,
Consumer<AbstractNode> nodeProcessor) {
/*
* Build constraintStreamSet in reverse order to create downstream nodes first
* so every node only has final variables (some of which have downstream node method references).
Expand All @@ -188,16 +187,11 @@ private static <Solution_, Score_ extends Score<Score_>> List<AbstractNode> buil
*/
node.setId(nextNodeId++);
node.setLayerIndex(determineLayerIndex(node, buildHelper));
nodeProcessor.accept(node);
}
return nodeList;
}

public static <Solution_, Score_ extends Score<Score_>> String visualizeNodeNetwork(Solution_ solution,
NodeBuildHelper<Score_> buildHelper, AbstractScoreInliner<Score_> scoreInliner, List<AbstractNode> nodeList) {
return NodeGraph.of(solution, buildHelper, nodeList, scoreInliner)
.buildGraphvizDOT();
}

/**
* Nodes are propagated in layers.
* See {@link PropagationQueue} and {@link AbstractNode} for details.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* @param layeredNodes nodes grouped first by their layer, then by their index within the layer;
* propagation needs to happen in this order.
*/
record NodeNetwork(Map<Class<?>, List<AbstractForEachUniNode<Object>>> declaredClassToNodeMap, Propagator[][] layeredNodes) {
record NodeNetwork(Map<Class<?>, List<AbstractForEachUniNode<?>>> declaredClassToNodeMap, Propagator[][] layeredNodes) {

public static final NodeNetwork EMPTY = new NodeNetwork(Map.of(), new Propagator[0][0]);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
package ai.timefold.solver.core.impl.score.stream.bavet.bi;

import java.util.function.BiFunction;
import java.util.function.Function;

import ai.timefold.solver.core.api.function.TriPredicate;
import ai.timefold.solver.core.impl.score.stream.bavet.common.AbstractIndexedIfExistsNode;
import ai.timefold.solver.core.impl.score.stream.bavet.common.ExistsCounter;
import ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexProperties;
import ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer;
import ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexerFactory.BiMapping;
import ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexerFactory.UniMapping;
import ai.timefold.solver.core.impl.score.stream.bavet.common.tuple.BiTuple;
import ai.timefold.solver.core.impl.score.stream.bavet.common.tuple.TupleLifecycle;
import ai.timefold.solver.core.impl.score.stream.bavet.common.tuple.UniTuple;

final class IndexedIfExistsBiNode<A, B, C> extends AbstractIndexedIfExistsNode<BiTuple<A, B>, C> {

private final BiMapping<A, B> mappingAB;
private final BiFunction<A, B, IndexProperties> mappingAB;
private final TriPredicate<A, B, C> filtering;

public IndexedIfExistsBiNode(boolean shouldExist,
BiMapping<A, B> mappingAB, UniMapping<C> mappingC,
BiFunction<A, B, IndexProperties> mappingAB, Function<C, IndexProperties> mappingC,
int inputStoreIndexLeftProperties, int inputStoreIndexLeftCounterEntry, int inputStoreIndexRightProperties,
int inputStoreIndexRightEntry,
TupleLifecycle<BiTuple<A, B>> nextNodesTupleLifecycle,
Expand All @@ -29,7 +31,7 @@ public IndexedIfExistsBiNode(boolean shouldExist,
}

public IndexedIfExistsBiNode(boolean shouldExist,
BiMapping<A, B> mappingAB, UniMapping<C> mappingC,
BiFunction<A, B, IndexProperties> mappingAB, Function<C, IndexProperties> mappingC,
int inputStoreIndexLeftProperties, int inputStoreIndexLeftCounterEntry, int inputStoreIndexLeftTrackerList,
int inputStoreIndexRightProperties, int inputStoreIndexRightEntry, int inputStoreIndexRightTrackerList,
TupleLifecycle<BiTuple<A, B>> nextNodesTupleLifecycle,
Expand All @@ -45,7 +47,7 @@ public IndexedIfExistsBiNode(boolean shouldExist,
}

@Override
protected Object createIndexProperties(BiTuple<A, B> leftTuple) {
protected IndexProperties createIndexProperties(BiTuple<A, B> leftTuple) {
return mappingAB.apply(leftTuple.factA, leftTuple.factB);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
package ai.timefold.solver.core.impl.score.stream.bavet.bi;

import java.util.function.BiPredicate;
import java.util.function.Function;

import ai.timefold.solver.core.impl.score.stream.bavet.common.AbstractIndexedJoinNode;
import ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexProperties;
import ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer;
import ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexerFactory.UniMapping;
import ai.timefold.solver.core.impl.score.stream.bavet.common.tuple.BiTuple;
import ai.timefold.solver.core.impl.score.stream.bavet.common.tuple.TupleLifecycle;
import ai.timefold.solver.core.impl.score.stream.bavet.common.tuple.UniTuple;

final class IndexedJoinBiNode<A, B> extends AbstractIndexedJoinNode<UniTuple<A>, B, BiTuple<A, B>> {

private final UniMapping<A> mappingA;
private final Function<A, IndexProperties> mappingA;
private final BiPredicate<A, B> filtering;
private final int outputStoreSize;

public IndexedJoinBiNode(UniMapping<A> mappingA, UniMapping<B> mappingB,
public IndexedJoinBiNode(Function<A, IndexProperties> mappingA, Function<B, IndexProperties> mappingB,
int inputStoreIndexA, int inputStoreIndexEntryA, int inputStoreIndexOutTupleListA,
int inputStoreIndexB, int inputStoreIndexEntryB, int inputStoreIndexOutTupleListB,
TupleLifecycle<BiTuple<A, B>> nextNodesTupleLifecycle, BiPredicate<A, B> filtering,
int outputStoreSize, int outputStoreIndexOutEntryA, int outputStoreIndexOutEntryB,
Indexer<UniTuple<A>> indexerA, Indexer<UniTuple<B>> indexerB) {
int outputStoreSize,
int outputStoreIndexOutEntryA, int outputStoreIndexOutEntryB,
Indexer<UniTuple<A>> indexerA,
Indexer<UniTuple<B>> indexerB) {
super(mappingB,
inputStoreIndexA, inputStoreIndexEntryA, inputStoreIndexOutTupleListA,
inputStoreIndexB, inputStoreIndexEntryB, inputStoreIndexOutTupleListB,
Expand All @@ -33,7 +36,7 @@ public IndexedJoinBiNode(UniMapping<A> mappingA, UniMapping<B> mappingB,
}

@Override
protected Object createIndexPropertiesLeft(UniTuple<A> leftTuple) {
protected IndexProperties createIndexPropertiesLeft(UniTuple<A> leftTuple) {
return mappingA.apply(leftTuple.factA);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ protected final void updateUnchangedCounterLeft(ExistsCounter<LeftTuple_> counte
}

protected void updateCounterLeft(ExistsCounter<LeftTuple_> counter) {
var state = counter.state;
TupleState state = counter.state;
if (shouldExist ? counter.countRight > 0 : counter.countRight == 0) {
// Insert or update
switch (state) {
Expand Down Expand Up @@ -120,14 +120,14 @@ protected void decrementCounterRight(ExistsCounter<LeftTuple_> counter) {

protected ElementAwareList<FilteringTracker<LeftTuple_>> updateRightTrackerList(UniTuple<Right_> rightTuple) {
ElementAwareList<FilteringTracker<LeftTuple_>> rightTrackerList = rightTuple.getStore(inputStoreIndexRightTrackerList);
rightTrackerList.forEach(tracker -> {
decrementCounterRight(tracker.counter);
tracker.remove();
});
for (FilteringTracker<LeftTuple_> tuple : rightTrackerList) {
decrementCounterRight(tuple.counter);
tuple.remove();
}
return rightTrackerList;
}

protected void updateCounterFromLeft(UniTuple<Right_> rightTuple, LeftTuple_ leftTuple, ExistsCounter<LeftTuple_> counter,
protected void updateCounterFromLeft(LeftTuple_ leftTuple, UniTuple<Right_> rightTuple, ExistsCounter<LeftTuple_> counter,
ElementAwareList<FilteringTracker<LeftTuple_>> leftTrackerList) {
if (testFiltering(leftTuple, rightTuple)) {
counter.countRight++;
Expand All @@ -137,12 +137,12 @@ protected void updateCounterFromLeft(UniTuple<Right_> rightTuple, LeftTuple_ lef
}
}

protected void updateCounterFromRight(ExistsCounter<LeftTuple_> counter, UniTuple<Right_> rightTuple,
protected void updateCounterFromRight(UniTuple<Right_> rightTuple, ExistsCounter<LeftTuple_> counter,
ElementAwareList<FilteringTracker<LeftTuple_>> rightTrackerList) {
var leftTuple = counter.leftTuple;
if (testFiltering(leftTuple, rightTuple)) {
if (testFiltering(counter.leftTuple, rightTuple)) {
incrementCounterRight(counter);
ElementAwareList<FilteringTracker<LeftTuple_>> leftTrackerList = leftTuple.getStore(inputStoreIndexLeftTrackerList);
ElementAwareList<FilteringTracker<LeftTuple_>> leftTrackerList =
counter.leftTuple.getStore(inputStoreIndexLeftTrackerList);
new FilteringTracker<>(counter, leftTrackerList, rightTrackerList);
}
}
Expand Down Expand Up @@ -173,7 +173,6 @@ public Propagator getPropagator() {
}

protected static final class FilteringTracker<LeftTuple_ extends AbstractTuple> {

final ExistsCounter<LeftTuple_> counter;
private final ElementAwareListEntry<FilteringTracker<LeftTuple_>> leftTrackerEntry;
private final ElementAwareListEntry<FilteringTracker<LeftTuple_>> rightTrackerEntry;
Expand Down
Loading

0 comments on commit df5f405

Please sign in to comment.