Skip to content

Commit

Permalink
Remove compounding retries within PrimaryShardReplicationSource (#12043)
Browse files Browse the repository at this point in the history
This change removes retries within PrimaryShardReplicationSource and relies on retries in one place at the start of replication.
This is done within SegmentReplicationTargetService's processLatestReceivedCheckpoint after a failure/success occurs.
The timeout on these retries is the cause of flaky failures from SegmentReplication's bwc test within IndexingIT, that can occur
on node disconnect.  The retries will persist for over ~1m to the same primary node that has been relocated/shut down and cause the test to timeout.

This change also includes simplifications to the cancellation flow on the target service before the shard is closed.
Previously we "request" a cancel that does not remove the target from the ongoing replications collection until a cancellation failure is thrown.
The transport calls from PrimaryShardReplicationSource are no longer wrapped in CancellableThreads by the client so a call to "cancel" will not throw.
Instead we now immediately remove the target and decref/close it.

Signed-off-by: Marc Handalian <[email protected]>
  • Loading branch information
mch2 authored Jan 30, 2024
1 parent fb2c5f2 commit 11644d5
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ public void testIndexing() throws Exception {
* @throws Exception if index creation fail
* @throws UnsupportedOperationException if cluster type is unknown
*/
@AwaitsFix(bugUrl = "https://github.com/opensearch-project/OpenSearch/issues/7679")
public void testIndexingWithSegRep() throws Exception {
if (UPGRADE_FROM_VERSION.before(Version.V_2_4_0)) {
logger.info("--> Skip test for version {} where segment replication feature is not available", UPGRADE_FROM_VERSION);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,67 @@ public void testCancellation() throws Exception {
assertDocCounts(docCount, primaryNode);
}

public void testCancellationDuringGetCheckpointInfo() throws Exception {
cancelDuringReplicaAction(SegmentReplicationSourceService.Actions.GET_CHECKPOINT_INFO);
}

public void testCancellationDuringGetSegments() throws Exception {
cancelDuringReplicaAction(SegmentReplicationSourceService.Actions.GET_SEGMENT_FILES);
}

private void cancelDuringReplicaAction(String actionToblock) throws Exception {
// this test stubs transport calls specific to node-node replication.
assumeFalse(
"Skipping the test as its not compatible with segment replication with remote store.",
segmentReplicationWithRemoteEnabled()
);
final String primaryNode = internalCluster().startDataOnlyNode();
createIndex(INDEX_NAME, Settings.builder().put(indexSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1).build());
ensureYellow(INDEX_NAME);

final String replicaNode = internalCluster().startDataOnlyNode();
ensureGreen(INDEX_NAME);
final SegmentReplicationTargetService targetService = internalCluster().getInstance(
SegmentReplicationTargetService.class,
replicaNode
);
final IndexShard replicaShard = getIndexShard(replicaNode, INDEX_NAME);
CountDownLatch startCancellationLatch = new CountDownLatch(1);
CountDownLatch latch = new CountDownLatch(1);

MockTransportService primaryTransportService = (MockTransportService) internalCluster().getInstance(
TransportService.class,
primaryNode
);
primaryTransportService.addRequestHandlingBehavior(actionToblock, (handler, request, channel, task) -> {
logger.info("action {}", actionToblock);
try {
startCancellationLatch.countDown();
latch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
});

// index a doc and trigger replication
client().prepareIndex(INDEX_NAME).setId("1").setSource("foo", "bar").setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).get();

// remove the replica and ensure it is cleaned up.
startCancellationLatch.await();
SegmentReplicationTarget target = targetService.get(replicaShard.shardId());
assertAcked(
client().admin()
.indices()
.prepareUpdateSettings(INDEX_NAME)
.setSettings(Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0))
);
assertEquals("Replication not closed: " + target.getId(), 0, target.refCount());
assertEquals("Store has a positive refCount", 0, replicaShard.store().refCount());
// stop the replica, this will do additional checks on shutDown to ensure the replica and its store are closed properly
internalCluster().stopRandomNode(InternalTestCluster.nameFilter(replicaNode));
latch.countDown();
}

public void testStartReplicaAfterPrimaryIndexesDocs() throws Exception {
final String primaryNode = internalCluster().startDataOnlyNode();
createIndex(INDEX_NAME, Settings.builder().put(indexSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import org.opensearch.indices.replication.SegmentReplicationState;
import org.opensearch.indices.replication.SegmentReplicationTarget;
import org.opensearch.indices.replication.SegmentReplicationTargetService;
import org.opensearch.indices.replication.common.ReplicationCollection;
import org.opensearch.test.InternalTestCluster;
import org.opensearch.test.OpenSearchIntegTestCase;
import org.opensearch.test.disruption.SlowClusterStateProcessing;

Expand All @@ -33,6 +31,8 @@
import java.util.Set;
import java.util.concurrent.TimeUnit;

import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;

/**
* This class runs tests with remote store + segRep while blocking file downloads
*/
Expand All @@ -59,22 +59,18 @@ public void testCancelReplicationWhileSyncingSegments() throws Exception {
indexSingleDoc();
refresh(INDEX_NAME);
waitForBlock(replicaNode, REPOSITORY_NAME, TimeValue.timeValueSeconds(10));
final SegmentReplicationState state = targetService.getOngoingEventSegmentReplicationState(indexShard.shardId());
assertEquals(SegmentReplicationState.Stage.GET_FILES, state.getStage());
ReplicationCollection.ReplicationRef<SegmentReplicationTarget> segmentReplicationTargetReplicationRef = targetService.get(
state.getReplicationId()
);
final SegmentReplicationTarget segmentReplicationTarget = segmentReplicationTargetReplicationRef.get();
// close the target ref here otherwise it will hold a refcount
segmentReplicationTargetReplicationRef.close();
SegmentReplicationTarget segmentReplicationTarget = targetService.get(indexShard.shardId());
assertNotNull(segmentReplicationTarget);
assertEquals(SegmentReplicationState.Stage.GET_FILES, segmentReplicationTarget.state().getStage());
assertTrue(segmentReplicationTarget.refCount() > 0);
internalCluster().stopRandomNode(InternalTestCluster.nameFilter(primaryNode));
assertBusy(() -> {
assertTrue(indexShard.routingEntry().primary());
assertNull(targetService.getOngoingEventSegmentReplicationState(indexShard.shardId()));
assertEquals("Target should be closed", 0, segmentReplicationTarget.refCount());
});
assertAcked(
client().admin()
.indices()
.prepareUpdateSettings(INDEX_NAME)
.setSettings(Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0))
);
assertNull(targetService.getOngoingEventSegmentReplicationState(indexShard.shardId()));
assertEquals("Target should be closed", 0, segmentReplicationTarget.refCount());
unblockNode(REPOSITORY_NAME, replicaNode);
cleanupRepo();
}
Expand All @@ -85,7 +81,6 @@ public void testCancelReplicationWhileFetchingMetadata() throws Exception {

final Set<String> dataNodeNames = internalCluster().getDataNodeNames();
final String replicaNode = getNode(dataNodeNames, false);
final String primaryNode = getNode(dataNodeNames, true);

SegmentReplicationTargetService targetService = internalCluster().getInstance(SegmentReplicationTargetService.class, replicaNode);
ensureGreen(INDEX_NAME);
Expand All @@ -94,22 +89,18 @@ public void testCancelReplicationWhileFetchingMetadata() throws Exception {
indexSingleDoc();
refresh(INDEX_NAME);
waitForBlock(replicaNode, REPOSITORY_NAME, TimeValue.timeValueSeconds(10));
final SegmentReplicationState state = targetService.getOngoingEventSegmentReplicationState(indexShard.shardId());
assertEquals(SegmentReplicationState.Stage.GET_CHECKPOINT_INFO, state.getStage());
ReplicationCollection.ReplicationRef<SegmentReplicationTarget> segmentReplicationTargetReplicationRef = targetService.get(
state.getReplicationId()
);
final SegmentReplicationTarget segmentReplicationTarget = segmentReplicationTargetReplicationRef.get();
// close the target ref here otherwise it will hold a refcount
segmentReplicationTargetReplicationRef.close();
SegmentReplicationTarget segmentReplicationTarget = targetService.get(indexShard.shardId());
assertNotNull(segmentReplicationTarget);
assertEquals(SegmentReplicationState.Stage.GET_CHECKPOINT_INFO, segmentReplicationTarget.state().getStage());
assertTrue(segmentReplicationTarget.refCount() > 0);
internalCluster().stopRandomNode(InternalTestCluster.nameFilter(primaryNode));
assertBusy(() -> {
assertTrue(indexShard.routingEntry().primary());
assertNull(targetService.getOngoingEventSegmentReplicationState(indexShard.shardId()));
assertEquals("Target should be closed", 0, segmentReplicationTarget.refCount());
});
assertAcked(
client().admin()
.indices()
.prepareUpdateSettings(INDEX_NAME)
.setSettings(Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0))
);
assertNull(targetService.get(indexShard.shardId()));
assertEquals("Target should be closed", 0, segmentReplicationTarget.refCount());
unblockNode(REPOSITORY_NAME, replicaNode);
cleanupRepo();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@

package org.opensearch.indices.replication;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.store.StoreFileMetadata;
import org.opensearch.indices.recovery.RecoverySettings;
import org.opensearch.indices.recovery.RetryableTransportClient;
import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportRequestOptions;
import org.opensearch.transport.TransportService;

Expand All @@ -35,9 +33,7 @@
*/
public class PrimaryShardReplicationSource implements SegmentReplicationSource {

private static final Logger logger = LogManager.getLogger(PrimaryShardReplicationSource.class);

private final RetryableTransportClient transportClient;
private final TransportService transportService;

private final DiscoveryNode sourceNode;
private final DiscoveryNode targetNode;
Expand All @@ -52,12 +48,7 @@ public PrimaryShardReplicationSource(
DiscoveryNode sourceNode
) {
this.targetAllocationId = targetAllocationId;
this.transportClient = new RetryableTransportClient(
transportService,
sourceNode,
recoverySettings.internalActionRetryTimeout(),
logger
);
this.transportService = transportService;
this.sourceNode = sourceNode;
this.targetNode = targetNode;
this.recoverySettings = recoverySettings;
Expand All @@ -69,10 +60,14 @@ public void getCheckpointMetadata(
ReplicationCheckpoint checkpoint,
ActionListener<CheckpointInfoResponse> listener
) {
final Writeable.Reader<CheckpointInfoResponse> reader = CheckpointInfoResponse::new;
final ActionListener<CheckpointInfoResponse> responseListener = ActionListener.map(listener, r -> r);
final CheckpointInfoRequest request = new CheckpointInfoRequest(replicationId, targetAllocationId, targetNode, checkpoint);
transportClient.executeRetryableAction(GET_CHECKPOINT_INFO, request, responseListener, reader);
transportService.sendRequest(
sourceNode,
GET_CHECKPOINT_INFO,
request,
TransportRequestOptions.builder().withTimeout(recoverySettings.internalActionRetryTimeout()).build(),
new ActionListenerResponseHandler<>(listener, CheckpointInfoResponse::new, ThreadPool.Names.GENERIC)
);
}

@Override
Expand All @@ -88,29 +83,24 @@ public void getSegmentFiles(
// MultiFileWriter takes care of progress tracking for downloads in this scenario
// TODO: Move state management and tracking into replication methods and use chunking and data
// copy mechanisms only from MultiFileWriter
final Writeable.Reader<GetSegmentFilesResponse> reader = GetSegmentFilesResponse::new;
final ActionListener<GetSegmentFilesResponse> responseListener = ActionListener.map(listener, r -> r);
final GetSegmentFilesRequest request = new GetSegmentFilesRequest(
replicationId,
targetAllocationId,
targetNode,
filesToFetch,
checkpoint
);
final TransportRequestOptions options = TransportRequestOptions.builder()
.withTimeout(recoverySettings.internalActionLongTimeout())
.build();
transportClient.executeRetryableAction(GET_SEGMENT_FILES, request, options, responseListener, reader);
transportService.sendRequest(
sourceNode,
GET_SEGMENT_FILES,
request,
TransportRequestOptions.builder().withTimeout(recoverySettings.internalActionLongTimeout()).build(),
new ActionListenerResponseHandler<>(listener, GetSegmentFilesResponse::new, ThreadPool.Names.GENERIC)
);
}

@Override
public String getDescription() {
return sourceNode.getName();
}

@Override
public void cancel() {
transportClient.cancel();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ protected void closeInternal() {
}
}

@Override
protected void onCancel(String reason) {
try {
notifyListener(new ReplicationFailedException(reason), false);
} finally {
source.cancel();
cancellableThreads.cancel(reason);
}
}

@Override
protected String getPrefix() {
return REPLICATION_PREFIX + UUIDs.randomBase64UUID() + ".";
Expand Down Expand Up @@ -320,16 +330,4 @@ private void finalizeReplication(CheckpointInfoResponse checkpointInfoResponse)
}
}
}

/**
* Trigger a cancellation, this method will not close the target a subsequent call to #fail is required from target service.
*/
@Override
public void cancel(String reason) {
if (finished.get() == false) {
logger.trace(new ParameterizedMessage("Cancelling replication for target {}", description()));
cancellableThreads.cancel(reason);
source.cancel();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ public class SegmentReplicationTargetService extends AbstractLifecycleComponent
private final ClusterService clusterService;
private final TransportService transportService;

public ReplicationRef<SegmentReplicationTarget> get(long replicationId) {
return onGoingReplications.get(replicationId);
}

/**
* The internal actions
*
Expand Down Expand Up @@ -158,6 +154,7 @@ protected void doStart() {
@Override
protected void doStop() {
if (DiscoveryNode.isDataNode(clusterService.getSettings())) {
assert onGoingReplications.size() == 0 : "Replication collection should be empty on shutdown";
clusterService.removeListener(this);
}
}
Expand Down Expand Up @@ -201,7 +198,7 @@ public void clusterChanged(ClusterChangedEvent event) {
@Override
public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexShard, Settings indexSettings) {
if (indexShard != null && indexShard.indexSettings().isSegRepEnabled()) {
onGoingReplications.requestCancel(indexShard.shardId(), "Shard closing");
onGoingReplications.cancelForShard(indexShard.shardId(), "Shard closing");
latestReceivedCheckpoint.remove(shardId);
}
}
Expand All @@ -223,7 +220,7 @@ public void afterIndexShardStarted(IndexShard indexShard) {
@Override
public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting oldRouting, ShardRouting newRouting) {
if (oldRouting != null && indexShard.indexSettings().isSegRepEnabled() && oldRouting.primary() == false && newRouting.primary()) {
onGoingReplications.requestCancel(indexShard.shardId(), "Shard has been promoted to primary");
onGoingReplications.cancelForShard(indexShard.shardId(), "Shard has been promoted to primary");
latestReceivedCheckpoint.remove(indexShard.shardId());
}
}
Expand Down Expand Up @@ -255,6 +252,14 @@ public SegmentReplicationState getSegmentReplicationState(ShardId shardId) {
.orElseGet(() -> getlatestCompletedEventSegmentReplicationState(shardId));
}

public ReplicationRef<SegmentReplicationTarget> get(long replicationId) {
return onGoingReplications.get(replicationId);
}

public SegmentReplicationTarget get(ShardId shardId) {
return onGoingReplications.getOngoingReplicationTarget(shardId);
}

/**
* Invoked when a new checkpoint is received from a primary shard.
* It checks if a new checkpoint should be processed or not and starts replication if needed.
Expand Down Expand Up @@ -454,7 +459,13 @@ protected boolean processLatestReceivedCheckpoint(IndexShard replicaShard, Threa
latestPublishedCheckpoint
)
);
Runnable runnable = () -> onNewCheckpoint(latestReceivedCheckpoint.get(replicaShard.shardId()), replicaShard);
Runnable runnable = () -> {
// if we retry ensure the shard is not in the process of being closed.
// it will be removed from indexService's collection before the shard is actually marked as closed.
if (indicesService.getShardOrNull(replicaShard.shardId()) != null) {
onNewCheckpoint(latestReceivedCheckpoint.get(replicaShard.shardId()), replicaShard);
}
};
// Checks if we are using same thread and forks if necessary.
if (thread == Thread.currentThread()) {
threadPool.generic().execute(runnable);
Expand Down Expand Up @@ -548,9 +559,6 @@ public ReplicationRunner(long replicationId) {

@Override
public void onFailure(Exception e) {
try (final ReplicationRef<SegmentReplicationTarget> ref = onGoingReplications.get(replicationId)) {
logger.error(() -> new ParameterizedMessage("Error during segment replication, {}", ref.get().description()), e);
}
onGoingReplications.fail(replicationId, new ReplicationFailedException("Unexpected Error during replication", e), false);
}

Expand Down
Loading

0 comments on commit 11644d5

Please sign in to comment.