Skip to content

Commit

Permalink
UHeapBasedRateTracker uses time provider to allow simluating of time …
Browse files Browse the repository at this point in the history
…in unit tests

Signed-off-by: Peter Nied <[email protected]>
  • Loading branch information
peternied committed Jan 9, 2024
1 parent 649a2fc commit 0b4294d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.opensearch.security.util.ratetracking;

import java.util.Arrays;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.LongSupplier;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
Expand All @@ -33,16 +35,22 @@ public class HeapBasedRateTracker<ClientIdType> implements RateTracker<ClientIdT
private final Logger log = LogManager.getLogger(this.getClass());

private final Cache<ClientIdType, ClientRecord> cache;
private final LongSupplier timeProvider;
private final long timeWindowMs;
private final int maxTimeOffsets;

public HeapBasedRateTracker(long timeWindowMs, int allowedTries, int maxEntries) {
this(timeWindowMs, allowedTries, maxEntries, null);
}

public HeapBasedRateTracker(long timeWindowMs, int allowedTries, int maxEntries, LongSupplier timeProvider) {
if (allowedTries < 2) {
throw new IllegalArgumentException("allowedTries must be >= 2");
}

this.timeWindowMs = timeWindowMs;
this.maxTimeOffsets = allowedTries > 2 ? allowedTries - 2 : 0;
this.timeProvider = Optional.ofNullable(timeProvider).orElse(System::currentTimeMillis);
this.cache = CacheBuilder.newBuilder()
.expireAfterAccess(this.timeWindowMs, TimeUnit.MILLISECONDS)
.maximumSize(maxEntries)
Expand Down Expand Up @@ -89,7 +97,7 @@ private class ClientRecord {
private short timeOffsetEnd = -1;

synchronized boolean track() {
long timestamp = System.currentTimeMillis();
long timestamp = timeProvider.getAsLong();

if (this.startTime == -1 || timestamp - getMostRecent() >= timeWindowMs) {
this.startTime = timestamp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.opensearch.security.auth.limiting;

import org.junit.Ignore;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.LongSupplier;

import org.junit.Test;

import org.opensearch.security.util.ratetracking.HeapBasedRateTracker;
Expand All @@ -27,9 +29,12 @@

public class HeapBasedRateTrackerTest {

private final AtomicLong currentTime = new AtomicLong(1);
private LongSupplier timeProvider = () -> currentTime.getAndAdd(1);

@Test
public void simpleTest() throws Exception {
HeapBasedRateTracker<String> tracker = new HeapBasedRateTracker<>(100, 5, 100_000);
HeapBasedRateTracker<String> tracker = new HeapBasedRateTracker<>(100, 5, 100_000, timeProvider);

assertFalse(tracker.track("a"));
assertFalse(tracker.track("a"));
Expand All @@ -40,9 +45,8 @@ public void simpleTest() throws Exception {
}

@Test
@Ignore // https://github.com/opensearch-project/security/issues/2193
public void expiryTest() throws Exception {
HeapBasedRateTracker<String> tracker = new HeapBasedRateTracker<>(100, 5, 100_000);
HeapBasedRateTracker<String> tracker = new HeapBasedRateTracker<>(100, 5, 100_000, timeProvider);

assertFalse(tracker.track("a"));
assertFalse(tracker.track("a"));
Expand All @@ -58,42 +62,41 @@ public void expiryTest() throws Exception {

assertFalse(tracker.track("c"));

Thread.sleep(50);
currentTime.addAndGet(50);

assertFalse(tracker.track("c"));
assertFalse(tracker.track("c"));
assertFalse(tracker.track("c"));

Thread.sleep(55);
currentTime.addAndGet(55);

assertFalse(tracker.track("c"));
assertTrue(tracker.track("c"));

assertFalse(tracker.track("a"));

Thread.sleep(55);
currentTime.addAndGet(55);
assertFalse(tracker.track("c"));
assertFalse(tracker.track("c"));
assertTrue(tracker.track("c"));

}

@Test
@Ignore // https://github.com/opensearch-project/security/issues/2193
public void maxTwoTriesTest() throws Exception {
HeapBasedRateTracker<String> tracker = new HeapBasedRateTracker<>(100, 2, 100_000);
HeapBasedRateTracker<String> tracker = new HeapBasedRateTracker<>(100, 2, 100_000, timeProvider);

assertFalse(tracker.track("a"));
assertTrue(tracker.track("a"));

assertFalse(tracker.track("b"));
Thread.sleep(50);
currentTime.addAndGet(50);
assertTrue(tracker.track("b"));

Thread.sleep(55);
currentTime.addAndGet(55);
assertTrue(tracker.track("b"));

Thread.sleep(105);
currentTime.addAndGet(105);
assertFalse(tracker.track("b"));
assertTrue(tracker.track("b"));

Expand Down

0 comments on commit 0b4294d

Please sign in to comment.