Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made the BulkRetryStrategyTests less reliant on implementation specifics from OpenSearch #1346

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,35 @@
import com.amazon.dataprepper.metrics.PluginMetrics;
import com.amazon.dataprepper.model.configuration.PluginSetting;
import io.micrometer.core.instrument.Measurement;
import org.hamcrest.MatcherAssert;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.opensearch.OpenSearchException;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException;
import org.opensearch.index.Index;
import org.opensearch.index.shard.ShardId;
import org.junit.Before;
import org.junit.Test;
import org.opensearch.rest.RestStatus;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.StringJoiner;
import java.util.UUID;
import java.util.function.BiConsumer;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class BulkRetryStrategyTests {
private static final String PLUGIN_NAME = "opensearch";
Expand All @@ -40,10 +47,12 @@ public class BulkRetryStrategyTests {
setPipelineName(PIPELINE_NAME);
}};
private static final PluginMetrics PLUGIN_METRICS = PluginMetrics.fromPluginSetting(PLUGIN_SETTING);
private BiConsumer<DocWriteRequest<?>, Throwable> logFailureConsumer;

@Before
public void metricsInit() {
@BeforeEach
public void setUp() {
MetricsTestUtil.initMetrics();
logFailureConsumer = mock(BiConsumer.class);
}

@Test
Expand Down Expand Up @@ -74,10 +83,9 @@ public void testExecuteSuccessOnFirstAttempt() throws Exception {
final String testIndex = "bar";
final FakeClient client = new FakeClient(testIndex);
client.successOnFirstAttempt = true;
final FakeLogger logger = new FakeLogger();

final BulkRetryStrategy bulkRetryStrategy = new BulkRetryStrategy(
client::bulk, logger::logFailure, PLUGIN_METRICS, BulkRequest::new);
client::bulk, logFailureConsumer, PLUGIN_METRICS, BulkRequest::new);
final BulkRequest testBulkRequest = new BulkRequest();
testBulkRequest.add(new IndexRequest(testIndex).id("1"));
testBulkRequest.add(new IndexRequest(testIndex).id("2"));
Expand Down Expand Up @@ -105,10 +113,9 @@ public void testExecuteSuccessOnFirstAttempt() throws Exception {
public void testExecuteRetryable() throws Exception {
final String testIndex = "bar";
final FakeClient client = new FakeClient(testIndex);
final FakeLogger logger = new FakeLogger();

final BulkRetryStrategy bulkRetryStrategy = new BulkRetryStrategy(
client::bulk, logger::logFailure, PLUGIN_METRICS, BulkRequest::new);
client::bulk, logFailureConsumer, PLUGIN_METRICS, BulkRequest::new);
final BulkRequest testBulkRequest = new BulkRequest();
testBulkRequest.add(new IndexRequest(testIndex).id("1"));
testBulkRequest.add(new IndexRequest(testIndex).id("2"));
Expand All @@ -122,9 +129,14 @@ public void testExecuteRetryable() throws Exception {
assertFalse(client.finalResponse.hasFailures());
assertEquals("3", client.finalRequest.requests().get(0).id());
assertEquals("4", client.finalRequest.requests().get(1).id());
final String logging = logger.msg.toString();
assertTrue(logging.contains("[bar][_doc][2]"));
assertFalse(logging.contains("[bar][_doc][1]"));

ArgumentCaptor<DocWriteRequest> loggerWriteRequestArgCaptor = ArgumentCaptor.forClass(DocWriteRequest.class);
ArgumentCaptor<Throwable> loggerThrowableArgCaptor = ArgumentCaptor.forClass(Throwable.class);
verify(logFailureConsumer).accept(loggerWriteRequestArgCaptor.capture(), loggerThrowableArgCaptor.capture());
MatcherAssert.assertThat(loggerWriteRequestArgCaptor.getValue(), notNullValue());
MatcherAssert.assertThat(loggerWriteRequestArgCaptor.getValue().index(), equalTo(testIndex));
MatcherAssert.assertThat(loggerWriteRequestArgCaptor.getValue().id(), equalTo("2"));
MatcherAssert.assertThat(loggerThrowableArgCaptor.getValue(), notNullValue());

// verify metrics
final List<Measurement> documentsSuccessFirstAttemptMeasurements = MetricsTestUtil.getMeasurementList(
Expand All @@ -149,10 +161,9 @@ public void testExecuteNonRetryableException() throws Exception {
final String testIndex = "bar";
final FakeClient client = new FakeClient(testIndex);
client.retryable = false;
final FakeLogger logger = new FakeLogger();

final BulkRetryStrategy bulkRetryStrategy = new BulkRetryStrategy(
client::bulk, logger::logFailure, PLUGIN_METRICS, BulkRequest::new);
client::bulk, logFailureConsumer, PLUGIN_METRICS, BulkRequest::new);
final BulkRequest testBulkRequest = new BulkRequest();
testBulkRequest.add(new IndexRequest(testIndex).id("1"));
testBulkRequest.add(new IndexRequest(testIndex).id("2"));
Expand All @@ -162,9 +173,17 @@ public void testExecuteNonRetryableException() throws Exception {
bulkRetryStrategy.execute(testBulkRequest);

assertEquals(1, client.attempt);
final String logging = logger.msg.toString();
for (int i = 1; i <= 4; i++) {
assertTrue(logging.contains(String.format("[bar][_doc][%d]", i)));

ArgumentCaptor<DocWriteRequest> loggerWriteRequestArgCaptor = ArgumentCaptor.forClass(DocWriteRequest.class);
ArgumentCaptor<Throwable> loggerExceptionArgCaptor = ArgumentCaptor.forClass(Throwable.class);
verify(logFailureConsumer, times(4))
.accept(loggerWriteRequestArgCaptor.capture(), isA(IllegalArgumentException.class));
final List<DocWriteRequest> allLoggerWriteRequests = loggerWriteRequestArgCaptor.getAllValues();
for (int i = 0; i < allLoggerWriteRequests.size(); i++) {
final DocWriteRequest actualFailedWrite = allLoggerWriteRequests.get(i);
MatcherAssert.assertThat(actualFailedWrite.index(), equalTo(testIndex));
String expectedIndexName = Integer.toString(i+1);
MatcherAssert.assertThat(actualFailedWrite.id(), equalTo(expectedIndexName));
}

// verify metrics
Expand All @@ -186,10 +205,9 @@ public void testExecuteNonRetryableResponse() throws Exception {
final FakeClient client = new FakeClient(testIndex);
client.retryable = false;
client.nonRetryableException = false;
final FakeLogger logger = new FakeLogger();

final BulkRetryStrategy bulkRetryStrategy = new BulkRetryStrategy(
client::bulk, logger::logFailure, PLUGIN_METRICS, BulkRequest::new);
client::bulk, logFailureConsumer, PLUGIN_METRICS, BulkRequest::new);
final BulkRequest testBulkRequest = new BulkRequest();
testBulkRequest.add(new IndexRequest(testIndex).id("1"));
testBulkRequest.add(new IndexRequest(testIndex).id("2"));
Expand All @@ -199,9 +217,17 @@ public void testExecuteNonRetryableResponse() throws Exception {
bulkRetryStrategy.execute(testBulkRequest);

assertEquals(1, client.attempt);
final String logging = logger.msg.toString();
for (int i = 2; i <= 4; i++) {
assertTrue(logging.contains(String.format("[bar][_doc][%d]", i)));

ArgumentCaptor<DocWriteRequest> loggerWriteRequestArgCaptor = ArgumentCaptor.forClass(DocWriteRequest.class);
ArgumentCaptor<Throwable> loggerExceptionArgCaptor = ArgumentCaptor.forClass(Throwable.class);
verify(logFailureConsumer, times(3))
.accept(loggerWriteRequestArgCaptor.capture(), isA(IllegalArgumentException.class));
final List<DocWriteRequest> allLoggerWriteRequests = loggerWriteRequestArgCaptor.getAllValues();
for (int i = 0; i < allLoggerWriteRequests.size(); i++) {
final DocWriteRequest actualFailedWrite = allLoggerWriteRequests.get(i);
MatcherAssert.assertThat(actualFailedWrite.index(), equalTo(testIndex));
String expectedIndexName = Integer.toString(i+2);
MatcherAssert.assertThat(actualFailedWrite.id(), equalTo(expectedIndexName));
}

// verify metrics
Expand All @@ -218,31 +244,30 @@ public void testExecuteNonRetryableResponse() throws Exception {
}

private static BulkItemResponse successItemResponse(final String index) {
final String docId = UUID.randomUUID().toString();
return new BulkItemResponse(1, DocWriteRequest.OpType.INDEX,
new IndexResponse(new ShardId(new Index(index, "fakeUUID"), 1),
"_doc", docId, 1, 1, 1, true));
return mock(BulkItemResponse.class);
}

private static BulkItemResponse badRequestItemResponse(final String index) {
final String docId = UUID.randomUUID().toString();
return new BulkItemResponse(1, DocWriteRequest.OpType.INDEX,
new BulkItemResponse.Failure(index, "_doc", docId,
new IllegalArgumentException()));
return customBulkFailureResponse(index, RestStatus.BAD_REQUEST, new IllegalArgumentException());
}

private static BulkItemResponse tooManyRequestItemResponse(final String index) {
final String docId = UUID.randomUUID().toString();
return new BulkItemResponse(1, DocWriteRequest.OpType.INDEX,
new BulkItemResponse.Failure(index, "_doc", docId,
new OpenSearchRejectedExecutionException()));
return customBulkFailureResponse(index, RestStatus.TOO_MANY_REQUESTS, new OpenSearchRejectedExecutionException());
}

private static BulkItemResponse internalServerErrorItemResponse(final String index) {
final String docId = UUID.randomUUID().toString();
return new BulkItemResponse(1, DocWriteRequest.OpType.INDEX,
new BulkItemResponse.Failure(index, "_doc", docId,
new IllegalAccessException()));
return customBulkFailureResponse(index, RestStatus.INTERNAL_SERVER_ERROR, new IllegalAccessException());
}

private static BulkItemResponse customBulkFailureResponse(final String index, final RestStatus restStatus, final Exception cause) {
final BulkItemResponse.Failure failure = mock(BulkItemResponse.Failure.class);
when(failure.getStatus()).thenReturn(restStatus);
when(failure.getCause()).thenReturn(cause);
final BulkItemResponse badResponse = mock(BulkItemResponse.class);
when(badResponse.isFailed()).thenReturn(true);
when(badResponse.status()).thenReturn(restStatus);
when(badResponse.getFailure()).thenReturn(failure);
return badResponse;
}

private static class FakeClient {
Expand Down Expand Up @@ -327,12 +352,4 @@ private BulkResponse bulkSuccessResponse(final BulkRequest bulkRequest) {
return new BulkResponse(bulkItemResponses, 10);
}
}

private static class FakeLogger {
StringBuilder msg = new StringBuilder();

public void logFailure(final DocWriteRequest<?> docWriteRequest, final Throwable t) {
msg.append(String.format("Document [%s] has failure: %s", docWriteRequest.toString(), t));
}
}
}