Skip to content

Commit

Permalink
Add PrimaryShardBatchAllocator to take allocation decisions for a bat…
Browse files Browse the repository at this point in the history
…ch of shards (#8916)

* Add PrimaryShardBatchAllocator to take allocation decisions for a batch of shards

Signed-off-by: Aman Khare <[email protected]>
  • Loading branch information
amkhar authored Mar 19, 2024
1 parent 5e2034c commit a499d1e
Show file tree
Hide file tree
Showing 10 changed files with 754 additions and 268 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper.GatewayStartedShard;
import org.opensearch.index.IndexService;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.MergePolicyProvider;
Expand Down Expand Up @@ -720,11 +721,11 @@ public Settings onNodeStopped(String nodeName) throws Exception {
);

assertThat(response.getNodes(), hasSize(1));
assertThat(response.getNodes().get(0).allocationId(), notNullValue());
assertThat(response.getNodes().get(0).getGatewayShardStarted().allocationId(), notNullValue());
if (corrupt) {
assertThat(response.getNodes().get(0).storeException(), notNullValue());
assertThat(response.getNodes().get(0).getGatewayShardStarted().storeException(), notNullValue());
} else {
assertThat(response.getNodes().get(0).storeException(), nullValue());
assertThat(response.getNodes().get(0).getGatewayShardStarted().storeException(), nullValue());
}

// start another node so cluster consistency checks won't time out due to the lack of state
Expand Down Expand Up @@ -764,11 +765,11 @@ public void testSingleShardFetchUsingBatchAction() {
);
final Index index = resolveIndex(indexName);
final ShardId shardId = new ShardId(index, 0);
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard nodeGatewayStartedShards = response.getNodesMap()
GatewayStartedShard gatewayStartedShard = response.getNodesMap()
.get(searchShardsResponse.getNodes()[0].getId())
.getNodeGatewayStartedShardsBatch()
.get(shardId);
assertNodeGatewayStartedShardsHappyCase(nodeGatewayStartedShards);
assertNodeGatewayStartedShardsHappyCase(gatewayStartedShard);
}

public void testShardFetchMultiNodeMultiIndexesUsingBatchAction() {
Expand All @@ -792,11 +793,8 @@ public void testShardFetchMultiNodeMultiIndexesUsingBatchAction() {
ShardId shardId = clusterSearchShardsGroup.getShardId();
assertEquals(1, clusterSearchShardsGroup.getShards().length);
String nodeId = clusterSearchShardsGroup.getShards()[0].currentNodeId();
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard nodeGatewayStartedShards = response.getNodesMap()
.get(nodeId)
.getNodeGatewayStartedShardsBatch()
.get(shardId);
assertNodeGatewayStartedShardsHappyCase(nodeGatewayStartedShards);
GatewayStartedShard gatewayStartedShard = response.getNodesMap().get(nodeId).getNodeGatewayStartedShardsBatch().get(shardId);
assertNodeGatewayStartedShardsHappyCase(gatewayStartedShard);
}
}

Expand All @@ -816,13 +814,13 @@ public void testShardFetchCorruptedShardsUsingBatchAction() throws Exception {
new TransportNodesListGatewayStartedShardsBatch.Request(getDiscoveryNodes(), shardIdShardAttributesMap)
);
DiscoveryNode[] discoveryNodes = getDiscoveryNodes();
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard nodeGatewayStartedShards = response.getNodesMap()
GatewayStartedShard gatewayStartedShard = response.getNodesMap()
.get(discoveryNodes[0].getId())
.getNodeGatewayStartedShardsBatch()
.get(shardId);
assertNotNull(nodeGatewayStartedShards.storeException());
assertNotNull(nodeGatewayStartedShards.allocationId());
assertTrue(nodeGatewayStartedShards.primary());
assertNotNull(gatewayStartedShard.storeException());
assertNotNull(gatewayStartedShard.allocationId());
assertTrue(gatewayStartedShard.primary());
}

public void testSingleShardStoreFetchUsingBatchAction() throws ExecutionException, InterruptedException {
Expand Down Expand Up @@ -950,12 +948,10 @@ private void assertNodeStoreFilesMetadataSuccessCase(
assertNotNull(storeFileMetadata.peerRecoveryRetentionLeases());
}

private void assertNodeGatewayStartedShardsHappyCase(
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard nodeGatewayStartedShards
) {
assertNull(nodeGatewayStartedShards.storeException());
assertNotNull(nodeGatewayStartedShards.allocationId());
assertTrue(nodeGatewayStartedShards.primary());
private void assertNodeGatewayStartedShardsHappyCase(GatewayStartedShard gatewayStartedShard) {
assertNull(gatewayStartedShard.storeException());
assertNotNull(gatewayStartedShard.allocationId());
assertTrue(gatewayStartedShard.primary());
}

private void prepareIndex(String indexName, int numberOfPrimaryShards) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ void finish() {
storeStatuses.add(
new IndicesShardStoresResponse.StoreStatus(
response.getNode(),
response.allocationId(),
response.getGatewayShardStarted().allocationId(),
allocationStatus,
response.storeException()
response.getGatewayShardStarted().storeException()
)
);
}
Expand Down Expand Up @@ -308,7 +308,8 @@ private IndicesShardStoresResponse.StoreStatus.AllocationStatus getAllocationSta
* A shard exists/existed in a node only if shard state file exists in the node
*/
private boolean shardExistsInNode(final NodeGatewayStartedShards response) {
return response.storeException() != null || response.allocationId() != null;
return response.getGatewayShardStarted().storeException() != null
|| response.getGatewayShardStarted().allocationId() != null;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.opensearch.cluster.routing.allocation.decider.Decision.Type;
import org.opensearch.env.ShardLockObtainFailedException;
import org.opensearch.gateway.AsyncShardFetch.FetchResult;
import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper.NodeGatewayStartedShard;
import org.opensearch.gateway.TransportNodesListGatewayStartedShards.NodeGatewayStartedShards;

import java.util.ArrayList;
Expand Down Expand Up @@ -125,27 +126,37 @@ public AllocateUnassignedDecision makeAllocationDecision(
return decision;
}
final FetchResult<NodeGatewayStartedShards> shardState = fetchData(unassignedShard, allocation);
List<NodeGatewayStartedShards> nodeShardStates = adaptToNodeStartedShardList(shardState);
List<NodeGatewayStartedShard> nodeShardStates = adaptToNodeStartedShardList(shardState);
return getAllocationDecision(unassignedShard, allocation, nodeShardStates, logger);
}

/**
* Transforms {@link FetchResult} of {@link NodeGatewayStartedShards} to {@link List} of {@link NodeGatewayStartedShards}
* Transforms {@link FetchResult} of {@link NodeGatewayStartedShards} to {@link List} of {@link NodeGatewayStartedShard}
* Returns null if {@link FetchResult} does not have any data.
*/
private static List<NodeGatewayStartedShards> adaptToNodeStartedShardList(FetchResult<NodeGatewayStartedShards> shardsState) {
private static List<NodeGatewayStartedShard> adaptToNodeStartedShardList(FetchResult<NodeGatewayStartedShards> shardsState) {
if (!shardsState.hasData()) {
return null;
}
List<NodeGatewayStartedShards> nodeShardStates = new ArrayList<>();
shardsState.getData().forEach((node, nodeGatewayStartedShard) -> { nodeShardStates.add(nodeGatewayStartedShard); });
List<NodeGatewayStartedShard> nodeShardStates = new ArrayList<>();
shardsState.getData().forEach((node, nodeGatewayStartedShard) -> {
nodeShardStates.add(
new NodeGatewayStartedShard(
nodeGatewayStartedShard.getGatewayShardStarted().allocationId(),
nodeGatewayStartedShard.getGatewayShardStarted().primary(),
nodeGatewayStartedShard.getGatewayShardStarted().replicationCheckpoint(),
nodeGatewayStartedShard.getGatewayShardStarted().storeException(),
node
)
);
});
return nodeShardStates;
}

protected AllocateUnassignedDecision getAllocationDecision(
ShardRouting unassignedShard,
RoutingAllocation allocation,
List<NodeGatewayStartedShards> shardState,
List<NodeGatewayStartedShard> shardState,
Logger logger
) {
final boolean explain = allocation.debugDecision();
Expand Down Expand Up @@ -236,7 +247,7 @@ protected AllocateUnassignedDecision getAllocationDecision(
nodesToAllocate = buildNodesToAllocate(allocation, nodeShardsResult.orderedAllocationCandidates, unassignedShard, true);
if (nodesToAllocate.yesNodeShards.isEmpty() == false) {
final DecidedNode decidedNode = nodesToAllocate.yesNodeShards.get(0);
final NodeGatewayStartedShards nodeShardState = decidedNode.nodeShardState;
final NodeGatewayStartedShard nodeShardState = decidedNode.nodeShardState;
logger.debug(
"[{}][{}]: allocating [{}] to [{}] on forced primary allocation",
unassignedShard.index(),
Expand Down Expand Up @@ -296,11 +307,11 @@ protected AllocateUnassignedDecision getAllocationDecision(
*/
private static List<NodeAllocationResult> buildNodeDecisions(
NodesToAllocate nodesToAllocate,
List<NodeGatewayStartedShards> fetchedShardData,
List<NodeGatewayStartedShard> fetchedShardData,
Set<String> inSyncAllocationIds
) {
List<NodeAllocationResult> nodeResults = new ArrayList<>();
Collection<NodeGatewayStartedShards> ineligibleShards = new ArrayList<>();
Collection<NodeGatewayStartedShard> ineligibleShards = new ArrayList<>();
if (nodesToAllocate != null) {
final Set<DiscoveryNode> discoNodes = new HashSet<>();
nodeResults.addAll(
Expand Down Expand Up @@ -334,21 +345,21 @@ private static List<NodeAllocationResult> buildNodeDecisions(
return nodeResults;
}

private static ShardStoreInfo shardStoreInfo(NodeGatewayStartedShards nodeShardState, Set<String> inSyncAllocationIds) {
private static ShardStoreInfo shardStoreInfo(NodeGatewayStartedShard nodeShardState, Set<String> inSyncAllocationIds) {
final Exception storeErr = nodeShardState.storeException();
final boolean inSync = nodeShardState.allocationId() != null && inSyncAllocationIds.contains(nodeShardState.allocationId());
return new ShardStoreInfo(nodeShardState.allocationId(), inSync, storeErr);
}

private static final Comparator<NodeGatewayStartedShards> NO_STORE_EXCEPTION_FIRST_COMPARATOR = Comparator.comparing(
(NodeGatewayStartedShards state) -> state.storeException() == null
private static final Comparator<NodeGatewayStartedShard> NO_STORE_EXCEPTION_FIRST_COMPARATOR = Comparator.comparing(
(NodeGatewayStartedShard state) -> state.storeException() == null
).reversed();
private static final Comparator<NodeGatewayStartedShards> PRIMARY_FIRST_COMPARATOR = Comparator.comparing(
NodeGatewayStartedShards::primary
private static final Comparator<NodeGatewayStartedShard> PRIMARY_FIRST_COMPARATOR = Comparator.comparing(
NodeGatewayStartedShard::primary
).reversed();

private static final Comparator<NodeGatewayStartedShards> HIGHEST_REPLICATION_CHECKPOINT_FIRST_COMPARATOR = Comparator.comparing(
NodeGatewayStartedShards::replicationCheckpoint,
private static final Comparator<NodeGatewayStartedShard> HIGHEST_REPLICATION_CHECKPOINT_FIRST_COMPARATOR = Comparator.comparing(
NodeGatewayStartedShard::replicationCheckpoint,
Comparator.nullsLast(Comparator.naturalOrder())
);

Expand All @@ -362,12 +373,12 @@ protected static NodeShardsResult buildNodeShardsResult(
boolean matchAnyShard,
Set<String> ignoreNodes,
Set<String> inSyncAllocationIds,
List<NodeGatewayStartedShards> shardState,
List<NodeGatewayStartedShard> shardState,
Logger logger
) {
List<NodeGatewayStartedShards> nodeShardStates = new ArrayList<>();
List<NodeGatewayStartedShard> nodeShardStates = new ArrayList<>();
int numberOfAllocationsFound = 0;
for (NodeGatewayStartedShards nodeShardState : shardState) {
for (NodeGatewayStartedShard nodeShardState : shardState) {
DiscoveryNode node = nodeShardState.getNode();
String allocationId = nodeShardState.allocationId();

Expand Down Expand Up @@ -432,21 +443,18 @@ protected static NodeShardsResult buildNodeShardsResult(
return new NodeShardsResult(nodeShardStates, numberOfAllocationsFound);
}

private static Comparator<NodeGatewayStartedShards> createActiveShardComparator(
boolean matchAnyShard,
Set<String> inSyncAllocationIds
) {
private static Comparator<NodeGatewayStartedShard> createActiveShardComparator(boolean matchAnyShard, Set<String> inSyncAllocationIds) {
/**
* Orders the active shards copies based on below comparators
* 1. No store exception i.e. shard copy is readable
* 2. Prefer previous primary shard
* 3. Prefer shard copy with the highest replication checkpoint. It is NO-OP for doc rep enabled indices.
*/
final Comparator<NodeGatewayStartedShards> comparator; // allocation preference
final Comparator<NodeGatewayStartedShard> comparator; // allocation preference
if (matchAnyShard) {
// prefer shards with matching allocation ids
Comparator<NodeGatewayStartedShards> matchingAllocationsFirst = Comparator.comparing(
(NodeGatewayStartedShards state) -> inSyncAllocationIds.contains(state.allocationId())
Comparator<NodeGatewayStartedShard> matchingAllocationsFirst = Comparator.comparing(
(NodeGatewayStartedShard state) -> inSyncAllocationIds.contains(state.allocationId())
).reversed();
comparator = matchingAllocationsFirst.thenComparing(NO_STORE_EXCEPTION_FIRST_COMPARATOR)
.thenComparing(PRIMARY_FIRST_COMPARATOR)
Expand All @@ -464,14 +472,14 @@ private static Comparator<NodeGatewayStartedShards> createActiveShardComparator(
*/
private static NodesToAllocate buildNodesToAllocate(
RoutingAllocation allocation,
List<NodeGatewayStartedShards> nodeShardStates,
List<NodeGatewayStartedShard> nodeShardStates,
ShardRouting shardRouting,
boolean forceAllocate
) {
List<DecidedNode> yesNodeShards = new ArrayList<>();
List<DecidedNode> throttledNodeShards = new ArrayList<>();
List<DecidedNode> noNodeShards = new ArrayList<>();
for (NodeGatewayStartedShards nodeShardState : nodeShardStates) {
for (NodeGatewayStartedShard nodeShardState : nodeShardStates) {
RoutingNode node = allocation.routingNodes().node(nodeShardState.getNode().getId());
if (node == null) {
continue;
Expand Down Expand Up @@ -502,10 +510,10 @@ private static NodesToAllocate buildNodesToAllocate(
* This class encapsulates the result of a call to {@link #buildNodeShardsResult}
*/
static class NodeShardsResult {
final List<NodeGatewayStartedShards> orderedAllocationCandidates;
final List<NodeGatewayStartedShard> orderedAllocationCandidates;
final int allocationsFound;

NodeShardsResult(List<NodeGatewayStartedShards> orderedAllocationCandidates, int allocationsFound) {
NodeShardsResult(List<NodeGatewayStartedShard> orderedAllocationCandidates, int allocationsFound) {
this.orderedAllocationCandidates = orderedAllocationCandidates;
this.allocationsFound = allocationsFound;
}
Expand All @@ -531,10 +539,10 @@ protected static class NodesToAllocate {
* by the allocator for allocating to the node that holds the shard copy.
*/
private static class DecidedNode {
final NodeGatewayStartedShards nodeShardState;
final NodeGatewayStartedShard nodeShardState;
final Decision decision;

private DecidedNode(NodeGatewayStartedShards nodeShardState, Decision decision) {
private DecidedNode(NodeGatewayStartedShard nodeShardState, Decision decision) {
this.nodeShardState = nodeShardState;
this.decision = decision;
}
Expand Down
Loading

0 comments on commit a499d1e

Please sign in to comment.