Skip to content

Commit

Permalink
[Backport 2.x][Refactor] Improve asynchronous test cases format (open…
Browse files Browse the repository at this point in the history
…search-project#3601) (opensearch-project#3630)

### Description
Signed-off-by: Peter Nied <[email protected]>
(cherry picked from commit 1ad8cf4)
  • Loading branch information
peternied authored Oct 31, 2023
1 parent 60e2ecf commit 515e559
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,27 @@

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.management.GarbageCollectorMXBean;
import java.lang.management.ManagementFactory;
import java.lang.management.MemoryPoolMXBean;
import java.lang.management.MemoryUsage;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.zip.GZIPOutputStream;

import org.apache.http.HttpStatus;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.entity.ContentType;
import org.apache.http.message.BasicHeader;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.client.Client;
import org.opensearch.test.framework.AsyncActions;
import org.opensearch.test.framework.TestSecurityConfig;
import org.opensearch.test.framework.TestSecurityConfig.User;
import org.opensearch.test.framework.cluster.ClusterManager;
Expand All @@ -52,6 +48,7 @@
@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class)
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class ResourceFocusedTests {
private final static Logger LOG = LogManager.getLogger(AsyncActions.class);
private static final User ADMIN_USER = new User("admin").roles(ALL_ACCESS);
private static final User LIMITED_USER = new User("limited_user").roles(
new TestSecurityConfig.Role("limited-role").clusterPermissions(
Expand Down Expand Up @@ -93,9 +90,8 @@ public void testUnauthenticatedFewBig() {
final String requestPath = "/*/_search";
final int parrallelism = 5;
final int totalNumberOfRequests = 100;
final boolean statsPrinter = false;

runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter);
runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests);
}

@Test
Expand All @@ -105,9 +101,8 @@ public void testUnauthenticatedManyMedium() {
final String requestPath = "/*/_search";
final int parrallelism = 20;
final int totalNumberOfRequests = 10_000;
final boolean statsPrinter = false;

runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter);
runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests);
}

@Test
Expand All @@ -116,62 +111,27 @@ public void testUnauthenticatedTonsSmall() {
final RequestBodySize size = RequestBodySize.Small;
final String requestPath = "/*/_search";
final int parrallelism = 100;
final int totalNumberOfRequests = 1_000_000;
final boolean statsPrinter = false;
final int totalNumberOfRequests = 15_000;

runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter);
runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests);
}

private Long runResourceTest(
private void runResourceTest(
final RequestBodySize size,
final String requestPath,
final int parrallelism,
final int totalNumberOfRequests,
final boolean statsPrinter
final int totalNumberOfRequests
) {
final byte[] compressedRequestBody = createCompressedRequestBody(size);
try (final TestRestClient client = cluster.getRestClient(new BasicHeader("Content-Encoding", "gzip"))) {

if (statsPrinter) {
printStats();
}
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));

final ForkJoinPool forkJoinPool = new ForkJoinPool(parrallelism);

final List<CompletableFuture<Void>> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests)
.boxed()
.map(i -> CompletableFuture.runAsync(() -> client.executeRequest(post), forkJoinPool))
.collect(Collectors.toList());
Supplier<Long> getCount = () -> waitingOn.stream().filter(cf -> cf.isDone() && !cf.isCompletedExceptionally()).count();

CompletableFuture<Void> statPrinter = statsPrinter ? CompletableFuture.runAsync(() -> {
while (true) {
printStats();
System.err.println(" & Succesful completions: " + getCount.get());
try {
Thread.sleep(500);
} catch (Exception e) {
break;
}
}
}, forkJoinPool) : CompletableFuture.completedFuture(null);

final CompletableFuture<Void> allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0]));

try {
allOfThem.get(30, TimeUnit.SECONDS);
statPrinter.cancel(true);
} catch (final Exception e) {
// Ignored
}

if (statsPrinter) {
printStats();
System.err.println(" & Succesful completions: " + getCount.get());
}
return getCount.get();
final var requests = AsyncActions.generate(() -> {
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));
return client.executeRequest(post);
}, parrallelism, totalNumberOfRequests);

AsyncActions.getAll(requests, 2, TimeUnit.MINUTES)
.forEach((response) -> { response.assertStatusCode(HttpStatus.SC_UNAUTHORIZED); });
}
}

Expand Down Expand Up @@ -217,51 +177,17 @@ private byte[] createCompressedRequestBody(final RequestBodySize size) {
gzipOutputStream.finish();

final byte[] compressedRequestBody = byteArrayOutputStream.toByteArray();
System.err.println(
"^^^"
+ String.format(
"Original size was %,d bytes, compressed to %,d bytes, ratio %,.2f",
uncompressedBytesSize,
compressedRequestBody.length,
((double) uncompressedBytesSize / compressedRequestBody.length)
)
LOG.info(
String.format(
"Original size was %,d bytes, compressed to %,d bytes, ratio %,.2f",
uncompressedBytesSize,
compressedRequestBody.length,
((double) uncompressedBytesSize / compressedRequestBody.length)
)
);
return compressedRequestBody;
} catch (final IOException ioe) {
throw new RuntimeException(ioe);
}
}

private void printStats() {
System.err.println("** Stats ");
printMemory();
printMemoryPools();
printGCPools();
}

private void printMemory() {
final Runtime runtime = Runtime.getRuntime();

final long totalMemory = runtime.totalMemory(); // Total allocated memory
final long freeMemory = runtime.freeMemory(); // Amount of free memory
final long usedMemory = totalMemory - freeMemory; // Amount of used memory

System.err.println(" Memory Total: " + totalMemory + " Free:" + freeMemory + " Used:" + usedMemory);
}

private void printMemoryPools() {
List<MemoryPoolMXBean> memoryPools = ManagementFactory.getMemoryPoolMXBeans();
for (MemoryPoolMXBean memoryPool : memoryPools) {
MemoryUsage usage = memoryPool.getUsage();
System.err.println(" " + memoryPool.getName() + " USED: " + usage.getUsed() + " MAX: " + usage.getMax());
}
}

private void printGCPools() {
List<GarbageCollectorMXBean> garbageCollectors = ManagementFactory.getGarbageCollectorMXBeans();
for (GarbageCollectorMXBean garbageCollector : garbageCollectors) {
System.err.println(" " + garbageCollector.getName() + " COLLECTION TIME: " + garbageCollector.getCollectionTime());
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
package org.opensearch.security.rest;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import org.apache.http.Header;
import org.apache.http.HttpStatus;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
Expand All @@ -21,11 +20,7 @@
import org.junit.Test;
import org.junit.runner.RunWith;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.anyOf;
import static org.hamcrest.MatcherAssert.assertThat;
import org.opensearch.test.framework.AsyncActions;
import org.opensearch.test.framework.TestSecurityConfig;
import org.opensearch.test.framework.cluster.ClusterManager;
import org.opensearch.test.framework.cluster.LocalCluster;
Expand All @@ -34,15 +29,14 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.zip.GZIPOutputStream;
import org.opensearch.test.framework.cluster.TestRestClient.HttpResponse;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL;
import static org.opensearch.test.framework.TestSecurityConfig.Role.ALL_ACCESS;
import static org.opensearch.test.framework.cluster.TestRestClientConfiguration.getBasicAuthHeader;
Expand All @@ -60,7 +54,7 @@ public class CompressionTests {
.build();

@Test
public void testAuthenticatedGzippedRequests() throws Exception {
public void testAuthenticatedGzippedRequests() {
final String requestPath = "/*/_search";
final int parallelism = 10;
final int totalNumberOfRequests = 100;
Expand All @@ -69,72 +63,54 @@ public void testAuthenticatedGzippedRequests() throws Exception {

final byte[] compressedRequestBody = createCompressedRequestBody(rawBody);
try (final TestRestClient client = cluster.getRestClient(ADMIN_USER, new BasicHeader("Content-Encoding", "gzip"))) {
final var requests = AsyncActions.generate(() -> {
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));
return client.executeRequest(post);
}, parallelism, totalNumberOfRequests);

final ForkJoinPool forkJoinPool = new ForkJoinPool(parallelism);

final List<CompletableFuture<HttpResponse>> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests)
.boxed()
.map(i -> CompletableFuture.supplyAsync(() -> {
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));
return client.executeRequest(post);
}, forkJoinPool))
.collect(Collectors.toList());

final CompletableFuture<Void> allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0]));

allOfThem.get(30, TimeUnit.SECONDS);

waitingOn.stream().forEach(future -> {
try {
final HttpResponse response = future.get();
response.assertStatusCode(HttpStatus.SC_OK);
} catch (final Exception ex) {
throw new RuntimeException(ex);
}
});
;
AsyncActions.getAll(requests, 30, TimeUnit.SECONDS).forEach((response) -> { response.assertStatusCode(HttpStatus.SC_OK); });
}
}

@Test
public void testMixOfAuthenticatedAndUnauthenticatedGzippedRequests() throws Exception {
final String requestPath = "/*/_search";
final int parallelism = 10;
final int totalNumberOfRequests = 100;
final int totalNumberOfRequests = 50;

final String rawBody = "{ \"query\": { \"match\": { \"foo\": \"bar\" }}}";

final byte[] compressedRequestBody = createCompressedRequestBody(rawBody);
try (final TestRestClient client = cluster.getRestClient(new BasicHeader("Content-Encoding", "gzip"))) {

final ForkJoinPool forkJoinPool = new ForkJoinPool(parallelism);

final Header basicAuthHeader = getBasicAuthHeader(ADMIN_USER.getName(), ADMIN_USER.getPassword());

final List<CompletableFuture<HttpResponse>> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests)
.boxed()
.map(i -> CompletableFuture.supplyAsync(() -> {
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));
return i % 2 == 0 ? client.executeRequest(post) : client.executeRequest(post, basicAuthHeader);
}, forkJoinPool))
.collect(Collectors.toList());

final CompletableFuture<Void> allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0]));

allOfThem.get(30, TimeUnit.SECONDS);

waitingOn.stream().forEach(future -> {
try {
final HttpResponse response = future.get();
assertThat(response.getBody(), not(containsString("json_parse_exception")));
assertThat(response.getStatusCode(), anyOf(equalTo(HttpStatus.SC_UNAUTHORIZED), equalTo(HttpStatus.SC_OK)));
} catch (final Exception ex) {
throw new RuntimeException(ex);
}
final CountDownLatch countDownLatch = new CountDownLatch(1);

final var authorizedRequests = AsyncActions.generate(() -> {
countDownLatch.await();
System.err.println("Generation triggered authorizedRequests");
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));
return client.executeRequest(post, getBasicAuthHeader(ADMIN_USER.getName(), ADMIN_USER.getPassword()));
}, parallelism, totalNumberOfRequests);

final var unauthorizedRequests = AsyncActions.generate(() -> {
countDownLatch.await();
System.err.println("Generation triggered unauthorizedRequests");
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));
return client.executeRequest(post);
}, parallelism, totalNumberOfRequests);

// Make sure all requests start at the same time
countDownLatch.countDown();

AsyncActions.getAll(authorizedRequests, 30, TimeUnit.SECONDS).forEach((response) -> {
assertThat(response.getStatusCode(), equalTo(HttpStatus.SC_OK));
});
AsyncActions.getAll(unauthorizedRequests, 30, TimeUnit.SECONDS).forEach((response) -> {
assertThat(response.getBody(), not(containsString("json_parse_exception")));
assertThat(response.getStatusCode(), equalTo(HttpStatus.SC_UNAUTHORIZED));
});
;
}
}

Expand Down
Loading

0 comments on commit 515e559

Please sign in to comment.