Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch Async Fetcher class changes #8742

Merged
merged 12 commits into from
Jan 2, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@
} else {
for (Tuple<ShardId, String> shard : shards) {
InternalAsyncFetch fetch = new InternalAsyncFetch(logger, "shard_stores", shard.v1(), shard.v2(), listShardStoresInfo);
fetch.fetchData(nodes, Collections.<String>emptySet());
fetch.fetchData(nodes, Collections.emptyMap());

Check warning on line 198 in server/src/main/java/org/opensearch/action/admin/indices/shards/TransportIndicesShardStoresAction.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/action/admin/indices/shards/TransportIndicesShardStoresAction.java#L198

Added line #L198 was not covered by tests
}
}
}
Expand Down Expand Up @@ -223,7 +223,7 @@
List<FailedNodeException> failures,
long fetchingRound
) {
fetchResponses.add(new Response(shardId, responses, failures));
fetchResponses.add(new Response(shardAttributesMap.keySet().iterator().next(), responses, failures));

Check warning on line 226 in server/src/main/java/org/opensearch/action/admin/indices/shards/TransportIndicesShardStoresAction.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/action/admin/indices/shards/TransportIndicesShardStoresAction.java#L226

Added line #L226 was not covered by tests
if (expectedOps.countDown()) {
finish();
}
Expand Down Expand Up @@ -312,7 +312,7 @@
}

@Override
protected void reroute(ShardId shardId, String reason) {
protected void reroute(String shardId, String reason) {
// no-op
}

Expand Down
140 changes: 103 additions & 37 deletions server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.indices.store.ShardAttributes;
import org.opensearch.transport.ReceiveTimeoutTransportException;

import java.util.ArrayList;
Expand All @@ -54,12 +55,11 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;

import static java.util.Collections.emptySet;
import static java.util.Collections.unmodifiableSet;
import static java.util.Collections.emptyMap;
import static java.util.Collections.unmodifiableMap;

/**
* Allows to asynchronously fetch shard related data from other nodes for allocation, without blocking
Expand All @@ -69,6 +69,7 @@
* and once the results are back, it makes sure to schedule a reroute to make sure those results will
* be taken into account.
*
* It comes in two modes, to single fetch a shard or fetch a batch of shards.
* @opensearch.internal
*/
public abstract class AsyncShardFetch<T extends BaseNodeResponse> implements Releasable {
Expand All @@ -77,18 +78,21 @@
* An action that lists the relevant shard data that needs to be fetched.
*/
public interface Lister<NodesResponse extends BaseNodesResponse<NodeResponse>, NodeResponse extends BaseNodeResponse> {
void list(ShardId shardId, @Nullable String customDataPath, DiscoveryNode[] nodes, ActionListener<NodesResponse> listener);
void list(Map<ShardId, ShardAttributes> shardAttributesMap, DiscoveryNode[] nodes, ActionListener<NodesResponse> listener);

}

protected final Logger logger;
protected final String type;
protected final ShardId shardId;
protected final String customDataPath;
protected final Map<ShardId, ShardAttributes> shardAttributesMap;
private final Lister<BaseNodesResponse<T>, T> action;
private final Map<String, NodeEntry<T>> cache = new HashMap<>();
private final Set<String> nodesToIgnore = new HashSet<>();
private final AtomicLong round = new AtomicLong();
private boolean closed;
private final String reroutingKey;
private final Map<ShardId, Set<String>> shardToIgnoreNodes = new HashMap<>();

private final boolean enableBatchMode;

@SuppressWarnings("unchecked")
protected AsyncShardFetch(
Expand All @@ -100,9 +104,36 @@
) {
this.logger = logger;
this.type = type;
this.shardId = Objects.requireNonNull(shardId);
this.customDataPath = Objects.requireNonNull(customDataPath);
shardAttributesMap = new HashMap<>();
shardAttributesMap.put(shardId, new ShardAttributes(shardId, customDataPath));
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.reroutingKey = "ShardId=[" + shardId.toString() + "]";
enableBatchMode = false;
}

/**
* Added to fetch a batch of shards from nodes
*
* @param logger Logger
* @param type type of action
* @param shardAttributesMap Map of {@link ShardId} to {@link ShardAttributes} to perform fetching on them a
* @param action Transport Action
* @param batchId For the given ShardAttributesMap, we expect them to tie with a single batch id for logging and later identification
*/
@SuppressWarnings("unchecked")
protected AsyncShardFetch(
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
Logger logger,
String type,
Map<ShardId, ShardAttributes> shardAttributesMap,
Lister<? extends BaseNodesResponse<T>, T> action,
String batchId
) {
this.logger = logger;
this.type = type;
this.shardAttributesMap = shardAttributesMap;
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.reroutingKey = "BatchID=[" + batchId + "]";
enableBatchMode = true;
}

@Override
Expand Down Expand Up @@ -130,11 +161,32 @@
* The ignoreNodes are nodes that are supposed to be ignored for this round, since fetching is async, we need
* to keep them around and make sure we add them back when all the responses are fetched and returned.
*/
public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> ignoreNodes) {
public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId, Set<String>> ignoreNodes) {
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
if (closed) {
throw new IllegalStateException(shardId + ": can't fetch data on closed async fetch");
throw new IllegalStateException(reroutingKey + ": can't fetch data on closed async fetch");
}
nodesToIgnore.addAll(ignoreNodes);

if (enableBatchMode == false) {
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
// we will do assertions here on ignoreNodes
if (ignoreNodes.size() > 1) {
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
throw new IllegalStateException(

Check warning on line 172 in server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java#L172

Added line #L172 was not covered by tests
"Fetching Shard Data, " + reroutingKey + "Can only have atmost one shard" + "for non-batch mode"
);
}
if (ignoreNodes.size() == 1) {
if (shardAttributesMap.containsKey(ignoreNodes.keySet().iterator().next()) == false) {
throw new IllegalStateException("Shard Id must be same as initialized in AsyncShardFetch. Expecting = " + reroutingKey);

Check warning on line 178 in server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java#L178

Added line #L178 was not covered by tests
}
}
}

// add the nodes to ignore to the list of nodes to ignore for each shard
for (Map.Entry<ShardId, Set<String>> ignoreNodesEntry : ignoreNodes.entrySet()) {
Set<String> ignoreNodesSet = shardToIgnoreNodes.getOrDefault(ignoreNodesEntry.getKey(), new HashSet<>());
ignoreNodesSet.addAll(ignoreNodesEntry.getValue());
shardToIgnoreNodes.put(ignoreNodesEntry.getKey(), ignoreNodesSet);
}

fillShardCacheWithDataNodes(cache, nodes);
List<NodeEntry<T>> nodesToFetch = findNodesToFetch(cache);
if (nodesToFetch.isEmpty() == false) {
Expand All @@ -153,7 +205,7 @@

// if we are still fetching, return null to indicate it
if (hasAnyNodeFetching(cache)) {
return new FetchResult<>(shardId, null, emptySet());
return new FetchResult<>(null, emptyMap());
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
} else {
// nothing to fetch, yay, build the return value
Map<DiscoveryNode, T> fetchData = new HashMap<>();
Expand All @@ -177,16 +229,27 @@
}
}
}
Set<String> allIgnoreNodes = unmodifiableSet(new HashSet<>(nodesToIgnore));

Map<ShardId, Set<String>> allIgnoreNodesMap = unmodifiableMap(new HashMap<>(shardToIgnoreNodes));
// clear the nodes to ignore, we had a successful run in fetching everything we can
// we need to try them if another full run is needed
nodesToIgnore.clear();
shardToIgnoreNodes.clear();
// if at least one node failed, make sure to have a protective reroute
// here, just case this round won't find anything, and we need to retry fetching data
if (failedNodes.isEmpty() == false || allIgnoreNodes.isEmpty() == false) {
reroute(shardId, "nodes failed [" + failedNodes.size() + "], ignored [" + allIgnoreNodes.size() + "]");

if (failedNodes.isEmpty() == false
|| allIgnoreNodesMap.values().stream().anyMatch(ignoreNodeSet -> ignoreNodeSet.isEmpty() == false)) {
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
reroute(
reroutingKey,
"nodes failed ["
+ failedNodes.size()
+ "], ignored ["
+ allIgnoreNodesMap.values().stream().mapToInt(Set::size).sum()
+ "]"
);
}
return new FetchResult<>(shardId, fetchData, allIgnoreNodes);

return new FetchResult<>(fetchData, allIgnoreNodesMap);
}
}

Expand All @@ -199,10 +262,10 @@
protected synchronized void processAsyncFetch(List<T> responses, List<FailedNodeException> failures, long fetchingRound) {
if (closed) {
// we are closed, no need to process this async fetch at all
logger.trace("{} ignoring fetched [{}] results, already closed", shardId, type);
logger.trace("{} ignoring fetched [{}] results, already closed", reroutingKey, type);

Check warning on line 265 in server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java#L265

Added line #L265 was not covered by tests
return;
}
logger.trace("{} processing fetched [{}] results", shardId, type);
logger.trace("{} processing fetched [{}] results", reroutingKey, type);

if (responses != null) {
for (T response : responses) {
Expand All @@ -212,7 +275,7 @@
assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds";
logger.trace(
"{} received response for [{}] from node {} for an older fetching round (expected: {} but was: {})",
shardId,
reroutingKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFetchingRound(),
Expand All @@ -221,29 +284,29 @@
} else if (nodeEntry.isFailed()) {
logger.trace(
"{} node {} has failed for [{}] (failure [{}])",
shardId,
reroutingKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFailure()
);
} else {
// if the entry is there, for the right fetching round and not marked as failed already, process it
logger.trace("{} marking {} as done for [{}], result is [{}]", shardId, nodeEntry.getNodeId(), type, response);
logger.trace("{} marking {} as done for [{}], result is [{}]", reroutingKey, nodeEntry.getNodeId(), type, response);
nodeEntry.doneFetching(response);
}
}
}
}
if (failures != null) {
for (FailedNodeException failure : failures) {
logger.trace("{} processing failure {} for [{}]", shardId, failure, type);
logger.trace("{} processing failure {} for [{}]", reroutingKey, failure, type);
NodeEntry<T> nodeEntry = cache.get(failure.nodeId());
if (nodeEntry != null) {
if (nodeEntry.getFetchingRound() != fetchingRound) {
assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds";
logger.trace(
"{} received failure for [{}] from node {} for an older fetching round (expected: {} but was: {})",
shardId,
reroutingKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFetchingRound(),
Expand All @@ -261,7 +324,7 @@
logger.warn(
() -> new ParameterizedMessage(
"{}: failed to list shard for {} on node [{}]",
shardId,
reroutingKey,
type,
failure.nodeId()
),
Expand All @@ -273,13 +336,13 @@
}
}
}
reroute(shardId, "post_response");
reroute(reroutingKey, "post_response");
}

/**
* Implement this in order to scheduled another round that causes a call to fetch data.
*/
protected abstract void reroute(ShardId shardId, String reason);
protected abstract void reroute(String reroutingKey, String reason);

/**
* Clear cache for node, ensuring next fetch will fetch a fresh copy.
Expand Down Expand Up @@ -334,8 +397,8 @@
*/
// visible for testing
void asyncFetch(final DiscoveryNode[] nodes, long fetchingRound) {
logger.trace("{} fetching [{}] from {}", shardId, type, nodes);
action.list(shardId, customDataPath, nodes, new ActionListener<BaseNodesResponse<T>>() {
logger.trace("{} fetching [{}] from {}", reroutingKey, type, nodes);
action.list(shardAttributesMap, nodes, new ActionListener<BaseNodesResponse<T>>() {
@Override
public void onResponse(BaseNodesResponse<T> response) {
processAsyncFetch(response.getNodes(), response.failures(), fetchingRound);
Expand All @@ -358,14 +421,12 @@
*/
public static class FetchResult<T extends BaseNodeResponse> {

private final ShardId shardId;
private final Map<DiscoveryNode, T> data;
private final Set<String> ignoreNodes;
private final Map<ShardId, Set<String>> ignoredShardToNodes;

public FetchResult(ShardId shardId, Map<DiscoveryNode, T> data, Set<String> ignoreNodes) {
this.shardId = shardId;
public FetchResult(Map<DiscoveryNode, T> data, Map<ShardId, Set<String>> ignoreNodes) {
this.data = data;
this.ignoreNodes = ignoreNodes;
this.ignoredShardToNodes = ignoreNodes;
}

/**
Expand All @@ -389,9 +450,14 @@
* Process any changes needed to the allocation based on this fetch result.
*/
public void processAllocation(RoutingAllocation allocation) {
for (String ignoreNode : ignoreNodes) {
allocation.addIgnoreShardForNode(shardId, ignoreNode);
for (Map.Entry<ShardId, Set<String>> entry : ignoredShardToNodes.entrySet()) {
ShardId shardId = entry.getKey();
Set<String> ignoreNodes = entry.getValue();
if (ignoreNodes.isEmpty() == false) {
ignoreNodes.forEach(nodeId -> allocation.addIgnoreShardForNode(shardId, nodeId));
}
}

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.opensearch.indices.store.TransportNodesListShardStoreMetadata;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.Spliterators;
Expand Down Expand Up @@ -226,7 +227,9 @@
AsyncShardFetch<TransportNodesListShardStoreMetadata.NodeStoreFilesMetadata> fetch,
RoutingAllocation allocation
) {
ShardRouting primary = allocation.routingNodes().activePrimary(fetch.shardId);
assert fetch.shardAttributesMap.size() == 1 : "expected only one shard";
ShardId shardId = fetch.shardAttributesMap.keySet().iterator().next();
ShardRouting primary = allocation.routingNodes().activePrimary(shardId);

Check warning on line 232 in server/src/main/java/org/opensearch/gateway/GatewayAllocator.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/gateway/GatewayAllocator.java#L231-L232

Added lines #L231 - L232 were not covered by tests
if (primary != null) {
fetch.clearCacheForNode(primary.currentNodeId());
}
Expand Down Expand Up @@ -254,15 +257,15 @@
}

@Override
protected void reroute(ShardId shardId, String reason) {
logger.trace("{} scheduling reroute for {}", shardId, reason);
protected void reroute(String reroutingKey, String reason) {
logger.trace("{} scheduling reroute for {}", reroutingKey, reason);
assert rerouteService != null;
rerouteService.reroute(
"async_shard_fetch",
Priority.HIGH,
ActionListener.wrap(
r -> logger.trace("{} scheduled reroute completed for {}", shardId, reason),
e -> logger.debug(new ParameterizedMessage("{} scheduled reroute failed for {}", shardId, reason), e)
r -> logger.trace("{} scheduled reroute completed for {}", reroutingKey, reason),
e -> logger.debug(new ParameterizedMessage("{} scheduled reroute failed for {}", reroutingKey, reason), e)

Check warning on line 268 in server/src/main/java/org/opensearch/gateway/GatewayAllocator.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/gateway/GatewayAllocator.java#L268

Added line #L268 was not covered by tests
)
);
}
Expand Down Expand Up @@ -293,7 +296,11 @@
);
AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShards.NodeGatewayStartedShards> shardState = fetch.fetchData(
allocation.nodes(),
allocation.getIgnoreNodes(shard.shardId())
new HashMap<>() {
{
put(shard.shardId(), allocation.getIgnoreNodes(shard.shardId()));
}
}
);

if (shardState.hasData()) {
Expand Down Expand Up @@ -328,7 +335,11 @@
);
AsyncShardFetch.FetchResult<TransportNodesListShardStoreMetadata.NodeStoreFilesMetadata> shardStores = fetch.fetchData(
allocation.nodes(),
allocation.getIgnoreNodes(shard.shardId())
new HashMap<>() {

Check warning on line 338 in server/src/main/java/org/opensearch/gateway/GatewayAllocator.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/gateway/GatewayAllocator.java#L338

Added line #L338 was not covered by tests
{
put(shard.shardId(), allocation.getIgnoreNodes(shard.shardId()));
}

Check warning on line 341 in server/src/main/java/org/opensearch/gateway/GatewayAllocator.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/gateway/GatewayAllocator.java#L340-L341

Added lines #L340 - L341 were not covered by tests
}
);
if (shardStores.hasData()) {
shardStores.processAllocation(allocation);
Expand Down
Loading
Loading