Skip to content

Commit

Permalink
perf: reduce the LocationInList allocation rate (#1277)
Browse files Browse the repository at this point in the history
Although the actual move evaluation speed increase is small (~ 3 %), the
allocation rate drops substantially.
  • Loading branch information
triceo authored Dec 17, 2024
1 parent 26d68e6 commit 415de75
Show file tree
Hide file tree
Showing 13 changed files with 150 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import java.util.List;
import java.util.Map;
import java.util.Objects;

import ai.timefold.solver.core.impl.domain.variable.descriptor.ListVariableDescriptor;
import ai.timefold.solver.core.impl.domain.variable.index.IndexShadowVariableDescriptor;
Expand All @@ -26,7 +25,7 @@ final class ListVariableState<Solution_> {
private boolean requiresLocationMap = true;
private InnerScoreDirector<Solution_, ?> scoreDirector;
private int unassignedCount = 0;
private Map<Object, LocationInList> elementLocationMap;
private Map<Object, MutableLocationInList> elementLocationMap;

public ListVariableState(ListVariableDescriptor<Solution_> sourceVariableDescriptor) {
this.sourceVariableDescriptor = sourceVariableDescriptor;
Expand Down Expand Up @@ -69,8 +68,7 @@ public void initialize(InnerScoreDirector<Solution_, ?> scoreDirector, int initi

public void addElement(Object entity, List<Object> elements, Object element, int index) {
if (requiresLocationMap) {
var location = ElementLocation.of(entity, index);
var oldLocation = elementLocationMap.put(element, location);
var oldLocation = elementLocationMap.put(element, new MutableLocationInList(entity, index));
if (oldLocation != null) {
throw new IllegalStateException(
"The supply for list variable (%s) is corrupted, because the element (%s) at index (%d) already exists (%s)."
Expand Down Expand Up @@ -100,7 +98,7 @@ public void removeElement(Object entity, Object element, int index) {
"The supply for list variable (%s) is corrupted, because the element (%s) at index (%d) was already unassigned (%s)."
.formatted(sourceVariableDescriptor, element, index, oldElementLocation));
}
var oldIndex = oldElementLocation.index();
var oldIndex = oldElementLocation.getIndex();
if (oldIndex != index) {
throw new IllegalStateException(
"The supply for list variable (%s) is corrupted, because the element (%s) at index (%d) had an old index (%d) which is not the current index (%d)."
Expand Down Expand Up @@ -168,13 +166,22 @@ public boolean changeElement(Object entity, List<Object> elements, int index) {

private ChangeType processElementLocation(Object entity, Object element, int index) {
if (requiresLocationMap) { // Update the location and figure out if it is different from previous.
var newLocation = ElementLocation.of(entity, index);
var oldLocation = elementLocationMap.put(element, newLocation);
var oldLocation = elementLocationMap.get(element);
if (oldLocation == null) {
elementLocationMap.put(element, new MutableLocationInList(entity, index));
unassignedCount--;
return ChangeType.BOTH;
}
return compareLocations(entity, oldLocation.entity(), index, oldLocation.index());
var changeType = compareLocations(entity, oldLocation.getEntity(), index, oldLocation.getIndex());
if (changeType.anythingChanged) { // Replace the map value in-place, to avoid a put() on the hot path.
if (changeType.entityChanged) {
oldLocation.setEntity(entity);
}
if (changeType.indexChanged) {
oldLocation.setIndex(index);
}
}
return changeType;
} else { // Read the location and figure out if it is different from previous.
var oldEntity = getInverseSingleton(element);
if (oldEntity == null) {
Expand All @@ -199,27 +206,13 @@ private static ChangeType compareLocations(Object entity, Object otherEntity, in
}
}

private enum ChangeType {

BOTH(true, true),
INDEX(false, true),
NEITHER(false, false);

final boolean anythingChanged;
final boolean entityChanged;
final boolean indexChanged;

ChangeType(boolean entityChanged, boolean indexChanged) {
this.anythingChanged = entityChanged || indexChanged;
this.entityChanged = entityChanged;
this.indexChanged = indexChanged;
}

}

public ElementLocation getLocationInList(Object planningValue) {
if (requiresLocationMap) {
return Objects.requireNonNullElse(elementLocationMap.get(planningValue), ElementLocation.unassigned());
var mutableLocationInList = elementLocationMap.get(planningValue);
if (mutableLocationInList == null) {
return ElementLocation.unassigned();
}
return mutableLocationInList.getLocationInList();
} else { // At this point, both inverse and index are externalized.
var inverse = externalizedInverseProcessor.getInverseSingleton(planningValue);
if (inverse == null) {
Expand All @@ -235,7 +228,7 @@ public Integer getIndex(Object planningValue) {
if (elementLocation == null) {
return null;
}
return elementLocation.index();
return elementLocation.getIndex();
}
return externalizedIndexProcessor.getIndex(planningValue);
}
Expand All @@ -246,34 +239,35 @@ public Object getInverseSingleton(Object planningValue) {
if (elementLocation == null) {
return null;
}
return elementLocation.entity();
return elementLocation.getEntity();
}
return externalizedInverseProcessor.getInverseSingleton(planningValue);
}

public Object getPreviousElement(Object element) {
if (externalizedPreviousElementProcessor == null) {
var elementLocation = getLocationInList(element);
if (!(elementLocation instanceof LocationInList locationInList)) {
var mutableLocationInList = elementLocationMap.get(element);
if (mutableLocationInList == null) {
return null;
}
var index = locationInList.index();
var index = mutableLocationInList.getIndex();
if (index == 0) {
return null;
}
return sourceVariableDescriptor.getValue(locationInList.entity()).get(index - 1);
return sourceVariableDescriptor.getValue(mutableLocationInList.getEntity())
.get(index - 1);
}
return externalizedPreviousElementProcessor.getElement(element);
}

public Object getNextElement(Object element) {
if (externalizedNextElementProcessor == null) {
var elementLocation = getLocationInList(element);
if (!(elementLocation instanceof LocationInList locationInList)) {
var mutableLocationInList = elementLocationMap.get(element);
if (mutableLocationInList == null) {
return null;
}
var list = sourceVariableDescriptor.getValue(locationInList.entity());
var index = locationInList.index();
var list = sourceVariableDescriptor.getValue(mutableLocationInList.getEntity());
var index = mutableLocationInList.getIndex();
if (index == list.size() - 1) {
return null;
}
Expand All @@ -286,4 +280,66 @@ public int getUnassignedCount() {
return unassignedCount;
}

private enum ChangeType {

BOTH(true, true),
INDEX(false, true),
NEITHER(false, false);

final boolean anythingChanged;
final boolean entityChanged;
final boolean indexChanged;

ChangeType(boolean entityChanged, boolean indexChanged) {
this.anythingChanged = entityChanged || indexChanged;
this.entityChanged = entityChanged;
this.indexChanged = indexChanged;
}

}

/**
* This class is used to avoid creating a new {@link LocationInList} object every time we need to return a location.
* The actual value is held in a map and can be updated without doing a put() operation, which is more efficient.
* The {@link LocationInList} object is only created when it is actually requested,
* and stored until the next time the mutable state is updated and therefore the cache invalidated.
*/
private static final class MutableLocationInList {

private Object entity;
private int index;
private LocationInList locationInList;

public MutableLocationInList(Object entity, int index) {
this.entity = entity;
this.index = index;
}

public Object getEntity() {
return entity;
}

public void setEntity(Object entity) {
this.entity = entity;
this.locationInList = null;
}

public int getIndex() {
return index;
}

public void setIndex(int index) {
this.index = index;
this.locationInList = null;
}

public LocationInList getLocationInList() {
if (locationInList == null) {
locationInList = ElementLocation.of(entity, index);
}
return locationInList;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class PooledEntityPlacerTest {

@Test
void oneMoveSelector() {
MoveSelector<TestdataSolution> moveSelector = SelectorTestUtils.mockMoveSelector(DummyMove.class,
MoveSelector<TestdataSolution> moveSelector = SelectorTestUtils.mockMoveSelector(
new DummyMove("a1"), new DummyMove("a2"), new DummyMove("b1"));

PooledEntityPlacer<TestdataSolution> placer = new PooledEntityPlacer<>(moveSelector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ public static <Solution_> MimicReplayingSubListSelector<Solution_> mockReplaying
}

@SafeVarargs
public static <Solution_> MoveSelector<Solution_> mockMoveSelector(Class<?> moveClass, Move<Solution_>... moves) {
public static <Solution_> MoveSelector<Solution_> mockMoveSelector(Move<Solution_>... moves) {
MoveSelector<Solution_> moveSelector = mock(MoveSelector.class);
final List<Move<Solution_>> moveList = Arrays.asList(moves);
when(moveSelector.iterator()).thenAnswer(invocation -> moveList.iterator());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package ai.timefold.solver.core.impl.heuristic.selector.move;

import static ai.timefold.solver.core.impl.heuristic.HeuristicConfigPolicyTestUtils.buildHeuristicConfigPolicy;
import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;

Expand All @@ -14,7 +13,6 @@
import ai.timefold.solver.core.config.heuristic.selector.common.decorator.SelectionSorterOrder;
import ai.timefold.solver.core.config.heuristic.selector.move.MoveSelectorConfig;
import ai.timefold.solver.core.impl.heuristic.HeuristicConfigPolicy;
import ai.timefold.solver.core.impl.heuristic.move.DummyMove;
import ai.timefold.solver.core.impl.heuristic.move.Move;
import ai.timefold.solver.core.impl.heuristic.selector.SelectorTestUtils;
import ai.timefold.solver.core.impl.heuristic.selector.common.decorator.SelectionProbabilityWeightFactory;
Expand All @@ -32,7 +30,7 @@ class MoveSelectorFactoryTest {

@Test
void phaseOriginal() {
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector(DummyMove.class);
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector();
DummyMoveSelectorConfig moveSelectorConfig = new DummyMoveSelectorConfig();
moveSelectorConfig.setCacheType(SelectionCacheType.PHASE);
moveSelectorConfig.setSelectionOrder(SelectionOrder.ORIGINAL);
Expand All @@ -50,7 +48,7 @@ void phaseOriginal() {

@Test
void stepOriginal() {
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector(DummyMove.class);
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector();
DummyMoveSelectorConfig moveSelectorConfig = new DummyMoveSelectorConfig();
moveSelectorConfig.setCacheType(SelectionCacheType.STEP);
moveSelectorConfig.setSelectionOrder(SelectionOrder.ORIGINAL);
Expand All @@ -68,7 +66,7 @@ void stepOriginal() {

@Test
void justInTimeOriginal() {
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector(DummyMove.class);
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector();
DummyMoveSelectorConfig moveSelectorConfig = new DummyMoveSelectorConfig();
MoveSelectorFactory<TestdataSolution> moveSelectorFactory =
new AssertingMoveSelectorFactory(moveSelectorConfig, baseMoveSelector, SelectionCacheType.JUST_IN_TIME, false);
Expand All @@ -82,7 +80,7 @@ void justInTimeOriginal() {

@Test
void phaseRandom() {
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector(DummyMove.class);
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector();
DummyMoveSelectorConfig moveSelectorConfig = new DummyMoveSelectorConfig();
MoveSelectorFactory<TestdataSolution> moveSelectorFactory =
new AssertingMoveSelectorFactory(moveSelectorConfig, baseMoveSelector, SelectionCacheType.PHASE, false);
Expand All @@ -100,7 +98,7 @@ void phaseRandom() {

@Test
void stepRandom() {
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector(DummyMove.class);
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector();
DummyMoveSelectorConfig moveSelectorConfig = new DummyMoveSelectorConfig();
MoveSelectorFactory<TestdataSolution> moveSelectorFactory =
new AssertingMoveSelectorFactory(moveSelectorConfig, baseMoveSelector, SelectionCacheType.STEP, false);
Expand All @@ -118,7 +116,7 @@ void stepRandom() {

@Test
void justInTimeRandom() {
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector(DummyMove.class);
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector();
DummyMoveSelectorConfig moveSelectorConfig = new DummyMoveSelectorConfig();
MoveSelectorFactory<TestdataSolution> moveSelectorFactory =
new AssertingMoveSelectorFactory(moveSelectorConfig, baseMoveSelector, SelectionCacheType.JUST_IN_TIME, true);
Expand All @@ -132,7 +130,7 @@ void justInTimeRandom() {

@Test
void phaseShuffled() {
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector(DummyMove.class);
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector();
DummyMoveSelectorConfig moveSelectorConfig = new DummyMoveSelectorConfig();
MoveSelectorFactory<TestdataSolution> moveSelectorFactory =
new AssertingMoveSelectorFactory(moveSelectorConfig, baseMoveSelector, SelectionCacheType.PHASE, false);
Expand All @@ -149,7 +147,7 @@ void phaseShuffled() {

@Test
void stepShuffled() {
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector(DummyMove.class);
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector();
DummyMoveSelectorConfig moveSelectorConfig = new DummyMoveSelectorConfig();
MoveSelectorFactory<TestdataSolution> moveSelectorFactory =
new AssertingMoveSelectorFactory(moveSelectorConfig, baseMoveSelector, SelectionCacheType.STEP, false);
Expand All @@ -165,7 +163,7 @@ void stepShuffled() {

@Test
void justInTimeShuffled() {
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector(DummyMove.class);
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector();
DummyMoveSelectorConfig moveSelectorConfig = new DummyMoveSelectorConfig();
moveSelectorConfig.setCacheType(SelectionCacheType.JUST_IN_TIME);
moveSelectorConfig.setSelectionOrder(SelectionOrder.SHUFFLED);
Expand All @@ -179,7 +177,7 @@ void justInTimeShuffled() {

@Test
void validateSorting_incompatibleSelectionOrder() {
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector(DummyMove.class);
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector();
DummyMoveSelectorConfig moveSelectorConfig = new DummyMoveSelectorConfig();
moveSelectorConfig.setSorterOrder(SelectionSorterOrder.ASCENDING);

Expand All @@ -190,7 +188,7 @@ void validateSorting_incompatibleSelectionOrder() {

@Test
void applySorting_withoutAnySortingClass() {
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector(DummyMove.class);
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector();
DummyMoveSelectorConfig moveSelectorConfig = new DummyMoveSelectorConfig();
moveSelectorConfig.setSorterOrder(SelectionSorterOrder.ASCENDING);

Expand All @@ -203,7 +201,7 @@ void applySorting_withoutAnySortingClass() {

@Test
void applySorting_withSorterComparatorClass() {
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector(DummyMove.class);
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector();
DummyMoveSelectorConfig moveSelectorConfig = new DummyMoveSelectorConfig();
moveSelectorConfig.setSorterOrder(SelectionSorterOrder.ASCENDING);
moveSelectorConfig.setSorterComparatorClass(DummyComparator.class);
Expand All @@ -216,7 +214,7 @@ void applySorting_withSorterComparatorClass() {

@Test
void applyProbability_withProbabilityWeightFactoryClass() {
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector(DummyMove.class);
final MoveSelector<TestdataSolution> baseMoveSelector = SelectorTestUtils.mockMoveSelector();
DummyMoveSelectorConfig moveSelectorConfig = new DummyMoveSelectorConfig();
moveSelectorConfig.setCacheType(SelectionCacheType.PHASE);
moveSelectorConfig.setProbabilityWeightFactoryClass(DummySelectionProbabilityWeightFactory.class);
Expand All @@ -241,7 +239,7 @@ public Move<TestdataSolution> doMove(ScoreDirector<TestdataSolution> scoreDirect
}
};
final MoveSelector<TestdataSolution> baseMoveSelector =
SelectorTestUtils.mockMoveSelector(DummyMove.class, notDoableMove);
SelectorTestUtils.mockMoveSelector(notDoableMove);
DummyMoveSelectorConfig moveSelectorConfig = new DummyMoveSelectorConfig();
MoveSelectorFactory<TestdataSolution> moveSelectorFactory =
new DummyMoveSelectorFactory(moveSelectorConfig, baseMoveSelector);
Expand Down
Loading

0 comments on commit 415de75

Please sign in to comment.