diff --git a/plugins/cache-ehcache/src/main/java/org/opensearch/cache/keystore/KeyStoreStats.java b/plugins/cache-ehcache/src/main/java/org/opensearch/cache/keystore/KeyStoreStats.java new file mode 100644 index 0000000000000..37fdf47163779 --- /dev/null +++ b/plugins/cache-ehcache/src/main/java/org/opensearch/cache/keystore/KeyStoreStats.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.cache.keystore; + +import org.opensearch.common.metrics.CounterMetric; + +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * A stats holder for use in KeyLookupStore implementations. + * Getters should be exposed by the KeyLookupStore which uses it. + */ +public class KeyStoreStats { + // Number of entries + protected CounterMetric size; + // Memory cap in bytes + protected long memSizeCapInBytes; + // Number of add attempts + protected CounterMetric numAddAttempts; + // Number of collisions + protected CounterMetric numCollisions; + // True if the store is at capacity + protected AtomicBoolean atCapacity; + // Number of removal attempts + protected CounterMetric numRemovalAttempts; + // Number of successful removal attempts + protected CounterMetric numSuccessfulRemovals; + + protected KeyStoreStats(long memSizeCapInBytes) { + this.size = new CounterMetric(); + this.numAddAttempts = new CounterMetric(); + this.numCollisions = new CounterMetric(); + this.memSizeCapInBytes = memSizeCapInBytes; + this.atCapacity = new AtomicBoolean(false); + this.numRemovalAttempts = new CounterMetric(); + this.numSuccessfulRemovals = new CounterMetric(); + } +} diff --git a/plugins/cache-ehcache/src/main/java/org/opensearch/cache/keystore/RBMIntKeyLookupStore.java b/plugins/cache-ehcache/src/main/java/org/opensearch/cache/keystore/RBMIntKeyLookupStore.java new file mode 100644 index 0000000000000..fe7f64712eda6 --- /dev/null +++ b/plugins/cache-ehcache/src/main/java/org/opensearch/cache/keystore/RBMIntKeyLookupStore.java @@ -0,0 +1,381 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.cache.keystore; + +import org.opensearch.common.cache.keystore.KeyLookupStore; +import org.opensearch.common.metrics.CounterMetric; +import org.opensearch.core.common.unit.ByteSizeValue; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +import org.roaringbitmap.RoaringBitmap; + +/** + * This class implements KeyLookupStore using a roaring bitmap with a modulo applied to values. + * The modulo increases the density of values, which makes RBMs more memory-efficient. The recommended modulo is ~2^28. + * It also maintains a hash set of values which have had collisions. Values which haven't had collisions can be + * safely removed from the store. The fraction of collided values should be low, + * about 0.5% for a store with 10^7 values and a modulo of 2^28. + * The store estimates its memory footprint and will stop adding more values once it reaches its memory cap. + */ +public class RBMIntKeyLookupStore implements KeyLookupStore { + /** + * An enum representing modulo values for use in the keystore + */ + public enum KeystoreModuloValue { + NONE(0), // No modulo applied + TWO_TO_THIRTY_ONE((int) Math.pow(2, 31)), + TWO_TO_TWENTY_NINE((int) Math.pow(2, 29)), // recommended value + TWO_TO_TWENTY_EIGHT((int) Math.pow(2, 28)), + TWO_TO_TWENTY_SIX((int) Math.pow(2, 26)); + + private final int value; + + private KeystoreModuloValue(int value) { + this.value = value; + } + + public int getValue() { + return this.value; + } + } + + // The modulo applied to values before adding into the RBM + protected final int modulo; + private final int modulo_bitmask; + // Since our modulo is always a power of two we can optimize it by ANDing with a particular bitmask + KeyStoreStats stats; + private RoaringBitmap rbm; + private HashMap collidedIntCounters; + private HashMap> removalSets; + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + private final Lock readLock = lock.readLock(); + private final Lock writeLock = lock.writeLock(); + private long mostRecentByteEstimate; + + // Refresh size estimate every X new elements. Refreshes use the RBM's internal size estimator, which takes ~0.01 ms, + // so we don't want to do it on every get(), and it doesn't matter much if there are +- 10000 keys in this store + // in terms of storage impact + static final int REFRESH_SIZE_EST_INTERVAL = 10_000; + + // Use this constructor to specify memory cap with default modulo = 2^28, which we found in experiments + // to be the best tradeoff between lower memory usage and risk of collisions + public RBMIntKeyLookupStore(long memSizeCapInBytes) { + this(KeystoreModuloValue.TWO_TO_TWENTY_EIGHT, memSizeCapInBytes); + } + + // Use this constructor to specify memory cap and modulo + public RBMIntKeyLookupStore(KeystoreModuloValue moduloValue, long memSizeCapInBytes) { + this.modulo = moduloValue.getValue(); + if (modulo > 0) { + this.modulo_bitmask = modulo - 1; // keep last log_2(modulo) bits + } else { + this.modulo_bitmask = -1; // -1 in twos complement is all ones -> includes all bits -> same as no modulo + } + this.stats = new KeyStoreStats(memSizeCapInBytes); + this.rbm = new RoaringBitmap(); + this.collidedIntCounters = new HashMap<>(); + this.removalSets = new HashMap<>(); + this.mostRecentByteEstimate = 0L; + } + + private int transform(int value) { + return value & modulo_bitmask; + } + + private void handleCollisions(int transformedValue) { + stats.numCollisions.inc(); + CounterMetric numCollisions = collidedIntCounters.get(transformedValue); + if (numCollisions == null) { // First time the transformedValue has had a collision + numCollisions = new CounterMetric(); + numCollisions.inc(2); // initialize to 2, since the first collision means 2 keys have collided + collidedIntCounters.put(transformedValue, numCollisions); + } else { + numCollisions.inc(); + } + } + + private boolean shouldUpdateByteEstimate() { + return getSize() % REFRESH_SIZE_EST_INTERVAL == 0; + } + + private boolean isAtCapacityLimit() { + return getMemorySizeCapInBytes() > 0 && mostRecentByteEstimate > getMemorySizeCapInBytes(); + } + + @Override + public boolean add(Integer value) { + if (value == null) { + return false; + } + stats.numAddAttempts.inc(); + + if (shouldUpdateByteEstimate()) { + mostRecentByteEstimate = computeMemorySizeInBytes(); + } + if (isAtCapacityLimit()) { + stats.atCapacity.set(true); + return false; + } + int transformedValue = transform(value); + + writeLock.lock(); + try { + if (!rbm.contains(transformedValue)) { + rbm.add(transformedValue); + stats.size.inc(); + return true; + } + // If the value is already pending removal, take it out of the removalList + HashSet removalSet = removalSets.get(transformedValue); + if (removalSet != null) { + removalSet.remove(value); + // Don't increment the counter - this is handled by handleCollisions() later + if (removalSet.isEmpty()) { + removalSets.remove(transformedValue); + } + } + + handleCollisions(transformedValue); + return false; + } finally { + writeLock.unlock(); + } + } + + @Override + public boolean contains(Integer value) { + if (value == null) { + return false; + } + int transformedValue = transform(value); + readLock.lock(); + try { + return rbm.contains(transformedValue); + } finally { + readLock.unlock(); + } + } + + public Integer getInternalRepresentation(Integer value) { + if (value == null) { + return 0; + } + return Integer.valueOf(transform(value)); + } + + /** + * Attempts to remove a value from the keystore. WARNING: Removing keys which have not been added to the keystore + * may cause undefined behavior, including future false negatives!! + * @param value The value to attempt to remove. + * @return true if the value was removed, false otherwise + */ + @Override + public boolean remove(Integer value) { + if (value == null) { + return false; + } + int transformedValue = transform(value); + readLock.lock(); + try { + if (!rbm.contains(transformedValue)) { // saves additional transform() call + return false; + } + stats.numRemovalAttempts.inc(); + } finally { + readLock.unlock(); + } + writeLock.lock(); + try { + CounterMetric numCollisions = collidedIntCounters.get(transformedValue); + if (numCollisions != null) { + // This transformed value has had a collision before + HashSet removalSet = removalSets.get(transformedValue); + if (removalSet == null) { + // First time a removal has been attempted for this transformed value + HashSet newRemovalSet = new HashSet<>(); + newRemovalSet.add(value); // Add the key value, not the transformed value, to the list of attempted removals for this + // transformedValue + removalSets.put(transformedValue, newRemovalSet); + numCollisions.dec(); + } else { + if (removalSet.contains(value)) { + return false; // We have already attempted to remove this value. Do nothing + } + removalSet.add(value); + numCollisions.dec(); + // If numCollisions has reached zero, we can safely remove all values in removalList + if (numCollisions.count() == 0) { + removeFromRBM(transformedValue); + collidedIntCounters.remove(transformedValue); + removalSets.remove(transformedValue); + return true; + } + } + return false; + } + // Otherwise, there's not been a collision for this transformedValue, so we can safely remove + removeFromRBM(transformedValue); + return true; + } finally { + writeLock.unlock(); + } + } + + // Helper fn for remove() + private void removeFromRBM(int transformedValue) { + if (!lock.isWriteLockedByCurrentThread()) { + throw new IllegalStateException("Write Lock must be held when calling this method"); + } + rbm.remove(transformedValue); + stats.size.dec(); + stats.numSuccessfulRemovals.inc(); + } + + @Override + public int getSize() { + readLock.lock(); + try { + return (int) stats.size.count(); + } finally { + readLock.unlock(); + } + } + + public int getAddAttempts() { + return (int) stats.numAddAttempts.count(); + } + + public int getCollisions() { + return (int) stats.numCollisions.count(); + } + + public boolean isCollision(Integer value1, Integer value2) { + if (value1 == null || value2 == null) { + return false; + } + return transform(value1) == transform(value2); + } + + /* + The built-in RBM size estimator is known to work very badly for randomly-distributed data, like the hashes we will be using. + See https://github.com/RoaringBitmap/RoaringBitmap/issues/257. + We ran tests to determine what multiplier you need to get true size from reported size, as a function of log10(# entries / modulo), + and found this piecewise linear function was a good approximation across different modulos. + */ + static double getRBMSizeMultiplier(int numEntries, int modulo) { + double effectiveModulo = (double) modulo / 2; + /* This model was created when we used % operator to calculate modulo. This has range (-modulo, modulo). + Now we have optimized to use a bitmask, which has range [0, modulo). So the number of possible values stored + is halved. */ + if (modulo == 0) { + effectiveModulo = Math.pow(2, 32); + } + double x = Math.log10((double) numEntries / effectiveModulo); + if (x < -5) { + return 7.0; + } + if (x < -2.75) { + return -2.5 * x - 5.5; + } + if (x <= 0) { + return -3.0 / 22.0 * x + 1; + } + return 1; + } + + /** + * Return the most recent memory size estimate, without updating it. + * @return the size estimate (bytes) + */ + @Override + public long getMemorySizeInBytes() { + return mostRecentByteEstimate; + } + + /** + * Calculate a new memory size estimate. This is somewhat expensive, so we don't call this every time we run get(). + * @return a new size estimate (bytes) + */ + private long computeMemorySizeInBytes() { + double multiplier = getRBMSizeMultiplier((int) stats.size.count(), modulo); + return (long) (rbm.getSizeInBytes() * multiplier); + } + + @Override + public long getMemorySizeCapInBytes() { + return stats.memSizeCapInBytes; + } + + @Override + public boolean isFull() { + return stats.atCapacity.get(); + } + + @Override + public void regenerateStore(Integer[] newValues) { + rbm.clear(); + collidedIntCounters = new HashMap<>(); + removalSets = new HashMap<>(); + stats.size = new CounterMetric(); + stats.numAddAttempts = new CounterMetric(); + stats.numCollisions = new CounterMetric(); + stats.numRemovalAttempts = new CounterMetric(); + stats.numSuccessfulRemovals = new CounterMetric(); + for (int i = 0; i < newValues.length; i++) { + if (newValues[i] != null) { + add(newValues[i]); + } + } + } + + @Override + public void clear() { + regenerateStore(new Integer[] {}); + } + + public int getNumRemovalAttempts() { + return (int) stats.numRemovalAttempts.count(); + } + + public int getNumSuccessfulRemovals() { + return (int) stats.numSuccessfulRemovals.count(); + } + + public boolean valueHasHadCollision(Integer value) { + if (value == null) { + return false; + } + return collidedIntCounters.containsKey(transform(value)); + } + + CounterMetric getNumCollisionsForValue(int value) { // package private for testing + return collidedIntCounters.get(transform(value)); + } + + HashSet getRemovalSetForValue(int value) { + return removalSets.get(transform(value)); + } + + /** + * Function to set a new memory size cap. + * TODO: Integrate this with the tiered caching cluster settings PR once this is raised. + * @param newMemSizeCap The new cap size. + */ + protected void setMemSizeCap(ByteSizeValue newMemSizeCap) { + stats.memSizeCapInBytes = newMemSizeCap.getBytes(); + mostRecentByteEstimate = getMemorySizeInBytes(); + if (mostRecentByteEstimate > getMemorySizeCapInBytes()) { + stats.atCapacity.set(true); + } + } +} diff --git a/plugins/cache-ehcache/src/test/java/org/opensearch/cache/keystore/RBMIntKeyLookupStoreTests.java b/plugins/cache-ehcache/src/test/java/org/opensearch/cache/keystore/RBMIntKeyLookupStoreTests.java new file mode 100644 index 0000000000000..f1a65cfafc3e7 --- /dev/null +++ b/plugins/cache-ehcache/src/test/java/org/opensearch/cache/keystore/RBMIntKeyLookupStoreTests.java @@ -0,0 +1,409 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.cache.keystore; + +import org.opensearch.common.Randomness; +import org.opensearch.common.metrics.CounterMetric; +import org.opensearch.core.common.unit.ByteSizeUnit; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Random; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; + +import org.roaringbitmap.RoaringBitmap; + +public class RBMIntKeyLookupStoreTests extends OpenSearchTestCase { + + final int BYTES_IN_MB = 1048576; + + public void testInit() { + long memCap = 100 * BYTES_IN_MB; + RBMIntKeyLookupStore kls = new RBMIntKeyLookupStore(memCap); + assertEquals(0, kls.getSize()); + assertEquals(RBMIntKeyLookupStore.KeystoreModuloValue.TWO_TO_TWENTY_EIGHT.getValue(), kls.modulo); + assertEquals(memCap, kls.getMemorySizeCapInBytes()); + } + + public void testTransformationLogic() throws Exception { + int modulo = (int) Math.pow(2, 29); + RBMIntKeyLookupStore kls = new RBMIntKeyLookupStore(RBMIntKeyLookupStore.KeystoreModuloValue.TWO_TO_TWENTY_NINE, 0L); + int offset = 3; + for (int i = 0; i < 4; i++) { // after this we run into max value, but thats not a flaw with the class design + int posValue = i * modulo + offset; + kls.add(posValue); + assertEquals(offset, (int) kls.getInternalRepresentation(posValue)); + int negValue = -(i * modulo + offset); + kls.add(negValue); + assertEquals(modulo - offset, (int) kls.getInternalRepresentation(negValue)); + } + assertEquals(2, kls.getSize()); + int[] testVals = new int[] { 0, 1, -1, -23495, 23058, modulo, -modulo, Integer.MAX_VALUE, Integer.MIN_VALUE }; + for (int value : testVals) { + assertTrue(kls.getInternalRepresentation(value) < modulo); + assertTrue(kls.getInternalRepresentation(value) >= 0); + } + RBMIntKeyLookupStore no_modulo_kls = new RBMIntKeyLookupStore(RBMIntKeyLookupStore.KeystoreModuloValue.NONE, 0L); + Random rand = Randomness.get(); + for (int i = 0; i < 100; i++) { + int val = rand.nextInt(); + assertEquals(val, (int) no_modulo_kls.getInternalRepresentation(val)); + } + } + + public void testContains() throws Exception { + RBMIntKeyLookupStore kls = new RBMIntKeyLookupStore(RBMIntKeyLookupStore.KeystoreModuloValue.TWO_TO_TWENTY_NINE, 0L); + RBMIntKeyLookupStore noModuloKls = new RBMIntKeyLookupStore(RBMIntKeyLookupStore.KeystoreModuloValue.NONE, 0L); + for (int i = 0; i < RBMIntKeyLookupStore.REFRESH_SIZE_EST_INTERVAL + 1000; i++) { + // set upper bound > number of elements to trigger a size check, ensuring we test that too + kls.add(i); + assertTrue(kls.contains(i)); + noModuloKls.add(i); + assertTrue(noModuloKls.contains(i)); + } + } + + public void testAddingStatsGetters() throws Exception { + RBMIntKeyLookupStore.KeystoreModuloValue moduloValue = RBMIntKeyLookupStore.KeystoreModuloValue.TWO_TO_TWENTY_SIX; + RBMIntKeyLookupStore kls = new RBMIntKeyLookupStore(moduloValue, 0L); + kls.add(15); + kls.add(-15); + assertEquals(2, kls.getAddAttempts()); + assertEquals(0, kls.getCollisions()); + + int offset = 1; + for (int i = 0; i < 10; i++) { + kls.add(i * moduloValue.getValue() + offset); + } + assertEquals(12, kls.getAddAttempts()); + assertEquals(9, kls.getCollisions()); + } + + public void testRegenerateStore() throws Exception { + int numToAdd = 10000000; + Random rand = Randomness.get(); + RBMIntKeyLookupStore kls = new RBMIntKeyLookupStore(RBMIntKeyLookupStore.KeystoreModuloValue.TWO_TO_TWENTY_NINE, 0L); + for (int i = 0; i < numToAdd; i++) { + kls.add(i); + } + assertEquals(numToAdd, kls.getSize()); + Integer[] newVals = new Integer[1000]; // margin accounts for collisions + for (int j = 0; j < newVals.length; j++) { + newVals[j] = rand.nextInt(); + } + kls.regenerateStore(newVals); + assertTrue(Math.abs(kls.getSize() - newVals.length) < 3); // inexact due to collisions + + // test clear() + kls.clear(); + assertEquals(0, kls.getSize()); + } + + public void testAddingDuplicates() throws Exception { + RBMIntKeyLookupStore kls = new RBMIntKeyLookupStore(0L); + int numToAdd = 4820411; + for (int i = 0; i < numToAdd; i++) { + kls.add(i); + kls.add(i); + } + for (int j = 0; j < 1000; j++) { + kls.add(577); + } + assertEquals(numToAdd, kls.getSize()); + } + + public void testMemoryCapBlocksAdd() throws Exception { + // Now that we're using a modified version of rbm.getSizeInBytes(), which doesn't provide an inverse function, + // we have to test filling just an RBM with random test values first so that we can get the resulting memory cap limit + // to use with our modified size estimate. + // This is much noisier so the precision is lower. + + // It is necessary to use randomly distributed integers for both parts of this test, as we would do with hashes in the cache, + // as that's what our size estimator is designed for. + // If we add a run of integers, our size estimator is not valid, especially for small RBMs. + + int[] maxEntriesArr = new int[] { 1342000, 100000, 3000000 }; + long[] rbmReportedSizes = new long[4]; + Random rand = Randomness.get(); + for (int j = 0; j < maxEntriesArr.length; j++) { + RoaringBitmap rbm = new RoaringBitmap(); + for (int i = 0; i < maxEntriesArr[j]; i++) { + rbm.add(rand.nextInt()); + } + rbmReportedSizes[j] = rbm.getSizeInBytes(); + } + RBMIntKeyLookupStore.KeystoreModuloValue moduloValue = RBMIntKeyLookupStore.KeystoreModuloValue.TWO_TO_TWENTY_NINE; + for (int i = 0; i < maxEntriesArr.length; i++) { + double multiplier = RBMIntKeyLookupStore.getRBMSizeMultiplier(maxEntriesArr[i], moduloValue.getValue()); + long memSizeCapInBytes = (long) (rbmReportedSizes[i] * multiplier); + RBMIntKeyLookupStore kls = new RBMIntKeyLookupStore(moduloValue, memSizeCapInBytes); + for (int j = 0; j < maxEntriesArr[i] + 5000; j++) { + kls.add(rand.nextInt()); + } + assertTrue(Math.abs(maxEntriesArr[i] - kls.getSize()) < (double) maxEntriesArr[i] / 10); + } + } + + public void testConcurrency() throws Exception { + Random rand = Randomness.get(); + for (int j = 0; j < 5; j++) { // test with different numbers of threads + RBMIntKeyLookupStore kls = new RBMIntKeyLookupStore(RBMIntKeyLookupStore.KeystoreModuloValue.TWO_TO_TWENTY_NINE, 0L); + int numThreads = rand.nextInt(50) + 1; + ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numThreads); + // In this test we want to add the first 200K numbers and check they're all correctly there. + // We do some duplicates too to ensure those aren't incorrectly added. + int amountToAdd = 200000; + ArrayList> wasAdded = new ArrayList<>(amountToAdd); + ArrayList> duplicatesWasAdded = new ArrayList<>(); + for (int i = 0; i < amountToAdd; i++) { + wasAdded.add(null); + } + for (int i = 0; i < amountToAdd; i++) { + final int val = i; + Future fut = executor.submit(() -> { + boolean didAdd; + try { + didAdd = kls.add(val); + } catch (Exception e) { + throw new RuntimeException(e); + } + return didAdd; + }); + wasAdded.set(val, fut); + if (val % 1000 == 0) { + // do a duplicate add + Future duplicateFut = executor.submit(() -> { + boolean didAdd; + try { + didAdd = kls.add(val); + } catch (Exception e) { + throw new RuntimeException(e); + } + return didAdd; + }); + duplicatesWasAdded.add(duplicateFut); + } + } + int originalAdds = 0; + int duplicateAdds = 0; + for (Future fut : wasAdded) { + if (fut.get()) { + originalAdds++; + } + } + for (Future duplicateFut : duplicatesWasAdded) { + if (duplicateFut.get()) { + duplicateAdds++; + } + } + for (int i = 0; i < amountToAdd; i++) { + assertTrue(kls.contains(i)); + } + assertEquals(amountToAdd, originalAdds + duplicateAdds); + assertEquals(amountToAdd, kls.getSize()); + assertEquals(amountToAdd / 1000, kls.getCollisions()); + executor.shutdown(); + } + } + + public void testRemoveNoCollisions() throws Exception { + long memCap = 100L * BYTES_IN_MB; + int numToAdd = 195000; + RBMIntKeyLookupStore kls = new RBMIntKeyLookupStore(RBMIntKeyLookupStore.KeystoreModuloValue.NONE, memCap); + // there should be no collisions for sequential positive numbers up to modulo + for (int i = 0; i < numToAdd; i++) { + kls.add(i); + } + for (int i = 0; i < 1000; i++) { + assertTrue(kls.remove(i)); + assertFalse(kls.contains(i)); + assertFalse(kls.valueHasHadCollision(i)); + } + assertEquals(numToAdd - 1000, kls.getSize()); + } + + public void testRemoveWithCollisions() throws Exception { + int modulo = (int) Math.pow(2, 26); + long memCap = 100L * BYTES_IN_MB; + RBMIntKeyLookupStore kls = new RBMIntKeyLookupStore(RBMIntKeyLookupStore.KeystoreModuloValue.TWO_TO_TWENTY_SIX, memCap); + for (int i = 0; i < 10; i++) { + kls.add(i); + if (i % 2 == 1) { + kls.add(-i); + assertFalse(kls.valueHasHadCollision(i)); + kls.add(i + modulo); + assertTrue(kls.valueHasHadCollision(i)); + } else { + assertFalse(kls.valueHasHadCollision(i)); + } + } + assertEquals(15, kls.getSize()); + for (int i = 0; i < 10; i++) { + boolean didRemove = kls.remove(i); + if (i % 2 == 1) { + // we expect a collision with i + modulo, so we can't remove + assertFalse(didRemove); + assertTrue(kls.contains(i)); + // but we should be able to remove -i + boolean didRemoveNegative = kls.remove(-i); + assertTrue(didRemoveNegative); + assertFalse(kls.contains(-i)); + } else { + // we expect no collision + assertTrue(didRemove); + assertFalse(kls.contains(i)); + assertFalse(kls.valueHasHadCollision(i)); + } + } + assertEquals(5, kls.getSize()); + int offset = 12; + kls.add(offset); + for (int j = 1; j < 5; j++) { + kls.add(offset + j * modulo); + } + assertEquals(6, kls.getSize()); + assertFalse(kls.remove(offset + modulo)); + assertTrue(kls.valueHasHadCollision(offset + 15 * modulo)); + assertTrue(kls.contains(offset + 17 * modulo)); + } + + public void testNullInputs() throws Exception { + RBMIntKeyLookupStore kls = new RBMIntKeyLookupStore(RBMIntKeyLookupStore.KeystoreModuloValue.TWO_TO_TWENTY_NINE, 0L); + assertFalse(kls.add(null)); + assertFalse(kls.contains(null)); + assertEquals(0, (int) kls.getInternalRepresentation(null)); + assertFalse(kls.remove(null)); + assertFalse(kls.isCollision(null, null)); + assertEquals(0, kls.getAddAttempts()); + Integer[] newVals = new Integer[] { 1, 17, -2, null, -4, null }; + kls.regenerateStore(newVals); + assertEquals(4, kls.getSize()); + } + + public void testRemovalLogic() throws Exception { + RBMIntKeyLookupStore.KeystoreModuloValue moduloValue = RBMIntKeyLookupStore.KeystoreModuloValue.TWO_TO_TWENTY_SIX; + int modulo = moduloValue.getValue(); + RBMIntKeyLookupStore kls = new RBMIntKeyLookupStore(moduloValue, 0L); + + // Test standard sequence: add K1, K2, K3 which all transform to C, then: + // Remove K3 + // Remove K2, re-add it, re-remove it twice (duplicate should do nothing) + // Remove K1, which should finally actually remove everything + int c = -42; + int k1 = c + modulo; + int k2 = c + 2 * modulo; + int k3 = c + 3 * modulo; + kls.add(k1); + assertTrue(kls.contains(k1)); + assertTrue(kls.contains(k3)); + kls.add(k2); + CounterMetric numCollisions = kls.getNumCollisionsForValue(k2); + assertNotNull(numCollisions); + assertEquals(2, numCollisions.count()); + kls.add(k3); + assertEquals(3, numCollisions.count()); + assertEquals(1, kls.getSize()); + + boolean removed = kls.remove(k3); + assertFalse(removed); + HashSet removalSet = kls.getRemovalSetForValue(k3); + assertEquals(1, removalSet.size()); + assertTrue(removalSet.contains(k3)); + assertEquals(2, numCollisions.count()); + assertEquals(1, kls.getSize()); + + removed = kls.remove(k2); + assertFalse(removed); + assertEquals(2, removalSet.size()); + assertTrue(removalSet.contains(k2)); + assertEquals(1, numCollisions.count()); + assertEquals(1, kls.getSize()); + + kls.add(k2); + assertEquals(1, removalSet.size()); + assertFalse(removalSet.contains(k2)); + assertEquals(2, numCollisions.count()); + assertEquals(1, kls.getSize()); + + removed = kls.remove(k2); + assertFalse(removed); + assertEquals(2, removalSet.size()); + assertTrue(removalSet.contains(k2)); + assertEquals(1, numCollisions.count()); + assertEquals(1, kls.getSize()); + + removed = kls.remove(k2); + assertFalse(removed); + assertEquals(2, removalSet.size()); + assertTrue(removalSet.contains(k2)); + assertEquals(1, numCollisions.count()); + assertEquals(1, kls.getSize()); + + removed = kls.remove(k1); + assertTrue(removed); + assertNull(kls.getRemovalSetForValue(k1)); + assertNull(kls.getNumCollisionsForValue(k1)); + assertEquals(0, kls.getSize()); + } + + public void testRemovalLogicWithHashCollision() throws Exception { + RBMIntKeyLookupStore.KeystoreModuloValue moduloValue = RBMIntKeyLookupStore.KeystoreModuloValue.TWO_TO_TWENTY_SIX; + int modulo = moduloValue.getValue(); + RBMIntKeyLookupStore kls = new RBMIntKeyLookupStore(moduloValue, 0L); + + // Test adding K1 twice (maybe two keys hash to K1), then removing it twice. + // We expect it to be unable to remove the last one, but there should be no false negatives. + int c = 77; + int k1 = c + modulo; + int k2 = c + 2 * modulo; + kls.add(k1); + kls.add(k2); + CounterMetric numCollisions = kls.getNumCollisionsForValue(k1); + assertEquals(2, numCollisions.count()); + kls.add(k1); + assertEquals(3, numCollisions.count()); + + boolean removed = kls.remove(k1); + assertFalse(removed); + HashSet removalSet = kls.getRemovalSetForValue(k1); + assertTrue(removalSet.contains(k1)); + assertEquals(2, numCollisions.count()); + + removed = kls.remove(k2); + assertFalse(removed); + assertTrue(removalSet.contains(k2)); + assertEquals(1, numCollisions.count()); + + removed = kls.remove(k1); + assertFalse(removed); + assertTrue(removalSet.contains(k1)); + assertEquals(1, numCollisions.count()); + assertTrue(kls.contains(k1)); + assertTrue(kls.contains(k2)); + } + + public void testSetMemSizeCap() throws Exception { + RBMIntKeyLookupStore kls = new RBMIntKeyLookupStore(0L); // no memory cap + Random rand = Randomness.get(); + for (int i = 0; i < RBMIntKeyLookupStore.REFRESH_SIZE_EST_INTERVAL * 3; i++) { + kls.add(rand.nextInt()); + } + long memSize = kls.getMemorySizeInBytes(); + assertEquals(0, kls.getMemorySizeCapInBytes()); + kls.setMemSizeCap(new ByteSizeValue(memSize / 2, ByteSizeUnit.BYTES)); + // check the keystore is now full and has its lower cap + assertTrue(kls.isFull()); + assertEquals(memSize / 2, kls.getMemorySizeCapInBytes()); + assertFalse(kls.add(rand.nextInt())); + } +}