Skip to content

Commit

Permalink
Implement forEach
Browse files Browse the repository at this point in the history
  • Loading branch information
triceo committed Nov 2, 2024
1 parent 12014e8 commit d67aae2
Show file tree
Hide file tree
Showing 24 changed files with 346 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ public interface DatasetFactory<Solution_> {
* Create a cached stream of arbitrary values.
* Read from the planning solution.
*
* @param clz
* @param extractor
* @param <A>
* @return
*/
<A> @NonNull UniDataStream<Solution_, A> forEach(@NonNull SolutionExtractor<Solution_, A> extractor);
<A> @NonNull UniDataStream<Solution_, A> forEach(@NonNull Class<A> clz, @NonNull SolutionExtractor<Solution_, A> extractor);

/**
* Create a cached stream of arbitrary values.
*
* @param clz
* @param collection
* @param <A>
* @return
*/
<A> @NonNull UniDataStream<Solution_, A> forEach(@NonNull Collection<A> collection);
<A> @NonNull UniDataStream<Solution_, A> forEach(@NonNull Class<A> clz, @NonNull Collection<A> collection);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package ai.timefold.solver.core.impl.bavet;

import java.util.IdentityHashMap;
import java.util.Map;

import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode;

public abstract class AbstractSession {

protected final NodeNetwork nodeNetwork;
private final Map<Class<?>, AbstractForEachUniNode<Object>[]> effectiveClassToNodeArrayMap;

protected AbstractSession(NodeNetwork nodeNetwork) {
this.nodeNetwork = nodeNetwork;
this.effectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount());
}

public final void insert(Object fact) {
var factClass = fact.getClass();
for (var node : findNodes(factClass)) {
node.insert(fact);
}
}

@SuppressWarnings("unchecked")
private AbstractForEachUniNode<Object>[] findNodes(Class<?> factClass) {
// Map.computeIfAbsent() would have created lambdas on the hot path, this will not.
var nodeArray = effectiveClassToNodeArrayMap.get(factClass);
if (nodeArray == null) {
nodeArray = nodeNetwork.getForEachNodes(factClass)
.filter(AbstractForEachUniNode::supportsIndividualUpdates)
.toArray(AbstractForEachUniNode[]::new);
effectiveClassToNodeArrayMap.put(factClass, nodeArray);
}
return nodeArray;
}

public final void update(Object fact) {
var factClass = fact.getClass();
for (var node : findNodes(factClass)) {
node.update(fact);
}
}

public final void retract(Object fact) {
var factClass = fact.getClass();
for (var node : findNodes(factClass)) {
node.retract(fact);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;

import ai.timefold.solver.core.impl.bavet.common.Propagator;
import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode;
Expand All @@ -17,10 +18,9 @@
* @param layeredNodes nodes grouped first by their layer, then by their index within the layer;
* propagation needs to happen in this order.
*/
public record NodeNetwork<Solution_>(Map<Class<?>, List<AbstractForEachUniNode<Solution_, ?>>> declaredClassToNodeMap,
Propagator[][] layeredNodes) {
public record NodeNetwork(Map<Class<?>, List<AbstractForEachUniNode<?>>> declaredClassToNodeMap, Propagator[][] layeredNodes) {

public static final NodeNetwork<?> EMPTY = new NodeNetwork<>(Map.of(), new Propagator[0][0]);
public static final NodeNetwork EMPTY = new NodeNetwork(Map.of(), new Propagator[0][0]);

public int forEachNodeCount() {
return declaredClassToNodeMap.size();
Expand All @@ -30,14 +30,18 @@ public int layerCount() {
return layeredNodes.length;
}

@SuppressWarnings("unchecked")
public AbstractForEachUniNode<Solution_, Object>[] getApplicableForEachNodes(Class<?> factClass) {
public Stream<AbstractForEachUniNode<?>> getForEachNodes() {
return declaredClassToNodeMap.values()
.stream()
.flatMap(List::stream);
}

public Stream<AbstractForEachUniNode<?>> getForEachNodes(Class<?> factClass) {
return declaredClassToNodeMap.entrySet()
.stream()
.filter(entry -> entry.getKey().isAssignableFrom(factClass))
.map(Map.Entry::getValue)
.flatMap(List::stream)
.toArray(AbstractForEachUniNode[]::new);
.flatMap(List::stream);
}

public void propagate() {
Expand Down Expand Up @@ -67,7 +71,7 @@ private static void propagateInLayer(Propagator[] nodesInLayer) {
public boolean equals(Object o) {
if (this == o)
return true;
if (!(o instanceof NodeNetwork<?> that))
if (!(o instanceof NodeNetwork that))
return false;
return Objects.equals(declaredClassToNodeMap, that.declaredClassToNodeMap)
&& Objects.deepEquals(layeredNodes, that.layeredNodes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle;
import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode;

public abstract class AbstractNodeBuildHelper<Stream_ extends Stream> {
public abstract class AbstractNodeBuildHelper<Stream_ extends BavetStream> {

private final Set<Stream_> activeStreamSet;
private final Map<AbstractNode, Stream_> nodeCreatorMap;
Expand Down Expand Up @@ -44,7 +44,7 @@ public void addNode(AbstractNode node, Stream_ creator) {
public void addNode(AbstractNode node, Stream_ creator, Stream_ parent) {
reversedNodeList.add(node);
nodeCreatorMap.put(node, creator);
if (!(node instanceof AbstractForEachUniNode<?, ?>)) {
if (!(node instanceof AbstractForEachUniNode<?>)) {
if (parent == null) {
throw new IllegalStateException("Impossible state: The node (" + node + ") has no parent (" + parent + ").");
}
Expand Down Expand Up @@ -87,7 +87,7 @@ public <Tuple_ extends AbstractTuple> TupleLifecycle<Tuple_> getAggregatedTupleL

@SuppressWarnings("unchecked")
private static <Stream_, Tuple_ extends AbstractTuple> TupleLifecycle<Tuple_> getTupleLifecycle(Stream_ stream,
Map<Stream_, TupleLifecycle<? extends AbstractTuple>> tupleLifecycleMap) {
Map<Stream_, TupleLifecycle<? extends AbstractTuple>> tupleLifecycleMap) {
var tupleLifecycle = (TupleLifecycle<Tuple_>) tupleLifecycleMap.get(stream);
if (tupleLifecycle == null) {
throw new IllegalStateException("Impossible state: the stream (" + stream + ") hasn't built a node yet.");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package ai.timefold.solver.core.impl.bavet.common;

public interface Stream {
public interface BavetStream {

<Stream_> Stream_ getParent();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package ai.timefold.solver.core.impl.bavet.common;

public interface BavetStreamBinaryOperation<Stream_ extends Stream> {
public interface BavetStreamBinaryOperation<Stream_ extends BavetStream> {

/**
* @return An instance of fore bridge.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ AbstractNode build(KeyA_ keyMappingA,
TupleLifecycle<Tuple_> nextNodesTupleLifecycle, int outputStoreSize, EnvironmentMode environmentMode);
}

<Stream_ extends Stream> void build(AbstractNodeBuildHelper<Stream_> buildHelper, Stream_ parentTupleSource,
<Stream_ extends BavetStream> void build(AbstractNodeBuildHelper<Stream_> buildHelper, Stream_ parentTupleSource,
Stream_ aftStream, List<Stream_> aftStreamChildList, Stream_ thisStream, List<Stream_> thisStreamChildList,
EnvironmentMode environmentMode);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public GroupNodeConstructorWithAccumulate(Object equalityKey,
}

@Override
public <Stream_ extends Stream> void build(AbstractNodeBuildHelper<Stream_> buildHelper, Stream_ parentTupleSource,
public <Stream_ extends BavetStream> void build(AbstractNodeBuildHelper<Stream_> buildHelper, Stream_ parentTupleSource,
Stream_ aftStream, List<Stream_> aftStreamChildList, Stream_ bridgeStream, List<Stream_> bridgeStreamChildList,
EnvironmentMode environmentMode) {
if (!bridgeStreamChildList.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public GroupNodeConstructorWithoutAccumulate(Object equalityKey,
}

@Override
public <Stream_ extends Stream> void build(AbstractNodeBuildHelper<Stream_> buildHelper, Stream_ parentTupleSource,
public <Stream_ extends BavetStream> void build(AbstractNodeBuildHelper<Stream_> buildHelper, Stream_ parentTupleSource,
Stream_ aftStream, List<Stream_> aftStreamChildList, Stream_ bridgeStream, List<Stream_> bridgeStreamChildList,
EnvironmentMode environmentMode) {
if (!bridgeStreamChildList.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,29 @@
* Filtering nodes are expensive.
* Considering that most streams start with a nullity check on genuine planning variables,
* it makes sense to create a specialized version of the node for this case ({@link ForEachExcludingUnassignedUniNode}),
* as opposed to forcing an extra filter node on the generic case ({@link ForEachIncludingUnassignedUniNode}).
* as opposed to forcing an extra filter node on the generic case ({@link ForEachUniNode}).
*
* @param <A>
*/
public abstract sealed class AbstractForEachUniNode<Solution_, A>
public abstract sealed class AbstractForEachUniNode<A>
extends AbstractNode
permits ForEachExcludingUnassignedUniNode, ForEachIncludingUnassignedUniNode {
permits ForEachExcludingUnassignedUniNode, ForEachUniNode, ForEachStaticUniNode {

private final Class<A> forEachClass;
private final int outputStoreSize;
private final StaticPropagationQueue<UniTuple<A>> propagationQueue;
protected final Map<A, UniTuple<A>> tupleMap = new IdentityHashMap<>(1000);

public AbstractForEachUniNode(Class<A> forEachClass, TupleLifecycle<UniTuple<A>> nextNodesTupleLifecycle,
protected AbstractForEachUniNode(Class<A> forEachClass, TupleLifecycle<UniTuple<A>> nextNodesTupleLifecycle,
int outputStoreSize) {
this.forEachClass = forEachClass;
this.outputStoreSize = outputStoreSize;
this.propagationQueue = new StaticPropagationQueue<>(nextNodesTupleLifecycle);
}

public void insert(A a) {
UniTuple<A> tuple = new UniTuple<>(a, outputStoreSize);
UniTuple<A> old = tupleMap.put(a, tuple);
var tuple = new UniTuple<>(a, outputStoreSize);
var old = tupleMap.put(a, tuple);
if (old != null) {
throw new IllegalStateException("The fact (" + a + ") was already inserted, so it cannot insert again.");
}
Expand All @@ -45,8 +45,8 @@ public void insert(A a) {

public abstract void update(A a);

protected final void innerUpdate(A a, UniTuple<A> tuple) {
TupleState state = tuple.state;
protected final void updateExisting(A a, UniTuple<A> tuple) {
var state = tuple.state;
if (state.isDirty()) {
if (state == TupleState.DYING || state == TupleState.ABORTING) {
throw new IllegalStateException("The fact (" + a + ") was retracted, so it cannot update.");
Expand All @@ -58,11 +58,15 @@ protected final void innerUpdate(A a, UniTuple<A> tuple) {
}

public void retract(A a) {
UniTuple<A> tuple = tupleMap.remove(a);
var tuple = tupleMap.remove(a);
if (tuple == null) {
throw new IllegalStateException("The fact (" + a + ") was never inserted, so it cannot retract.");
}
TupleState state = tuple.state;
retractExisting(a, tuple);
}

protected void retractExisting(A a, UniTuple<A> tuple) {
var state = tuple.state;
if (state.isDirty()) {
if (state == TupleState.DYING || state == TupleState.ABORTING) {
throw new IllegalStateException("The fact (" + a + ") was already retracted, so it cannot retract.");
Expand All @@ -82,6 +86,10 @@ public final Class<A> getForEachClass() {
return forEachClass;
}

public boolean supportsIndividualUpdates() {
return true;
}

@Override
public final String toString() {
return super.toString() + "(" + forEachClass.getSimpleName() + ")";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle;
import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple;

public final class ForEachExcludingUnassignedUniNode<Solution_, A> extends AbstractForEachUniNode<Solution_, A> {
public final class ForEachExcludingUnassignedUniNode<A> extends AbstractForEachUniNode<A> {

private final Predicate<A> filter;

Expand All @@ -26,22 +26,23 @@ public void insert(A a) {

@Override
public void update(A a) {
UniTuple<A> tuple = tupleMap.get(a);
var tuple = tupleMap.get(a);
if (tuple == null) { // The tuple was never inserted because it did not pass the filter.
insert(a);
} else if (filter.test(a)) {
innerUpdate(a, tuple);
} else {
super.retract(a); // Call super.retract() to avoid testing the filter again.
updateExisting(a, tuple);
} else { // Tuple no longer passes the filter.
retract(a);
}
}

@Override
public void retract(A a) {
if (!filter.test(a)) { // The tuple was never inserted because it did not pass the filter.
var tuple = tupleMap.remove(a);
if (tuple == null) { // The tuple was never inserted because it did not pass the filter.
return;
}
super.retract(a);
super.retractExisting(a, tuple);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package ai.timefold.solver.core.impl.bavet.uni;

import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Objects;

import ai.timefold.solver.core.api.move.SolutionExtractor;
import ai.timefold.solver.core.api.move.SolutionView;
import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle;
import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple;

public final class ForEachFromSolutionUniNode<Solution_, A>
extends ForEachUniNode<A> {

private final SolutionExtractor<Solution_, A> solutionExtractor;

public ForEachFromSolutionUniNode(Class<A> forEachClass, SolutionExtractor<Solution_, A> solutionExtractor,
TupleLifecycle<UniTuple<A>> nextNodesTupleLifecycle, int outputStoreSize) {
super(forEachClass, nextNodesTupleLifecycle, outputStoreSize);
this.solutionExtractor = Objects.requireNonNull(solutionExtractor);
}

public void read(SolutionView<Solution_> solutionView, Solution_ solution) {
var seenFactSet = Collections.newSetFromMap(new IdentityHashMap<A, Boolean>());
solutionExtractor.apply(solutionView, solution).forEach(a -> {
if (seenFactSet.contains(a)) { // Eliminate duplicates in the source data.
return;
}
seenFactSet.add(a);
var tuple = tupleMap.get(a);
if (tuple == null) {
super.insert(a);
} else {
updateExisting(a, tuple);
}
});
// Retract all tuples that were not seen in the source data.
var iterator = tupleMap.entrySet().iterator();
while (iterator.hasNext()) {
var entry = iterator.next();
var fact = entry.getKey();
if (!seenFactSet.contains(fact)) {
iterator.remove();
retractExisting(fact, entry.getValue());
}
}
}

@Override
public void insert(A a) {
throw new IllegalStateException("Impossible state: solution-based node cannot insert.");
}

@Override
public void update(A a) {
throw new IllegalStateException("Impossible state: solution-based node cannot update.");
}

@Override
public void retract(A a) {
throw new IllegalStateException("Impossible state: solution-based node cannot retract.");
}

@Override
public boolean supportsIndividualUpdates() {
return false;
}
}
Loading

0 comments on commit d67aae2

Please sign in to comment.