Skip to content

Commit

Permalink
Removed AsyncShardBatchFetch class
Browse files Browse the repository at this point in the history
Signed-off-by: Gaurav Chandani <[email protected]>
  • Loading branch information
Gaurav614 committed Aug 30, 2023
1 parent a32b0ee commit 63b0ed4
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void start() {
} else {
for (Tuple<ShardId, String> shard : shards) {
InternalAsyncFetch fetch = new InternalAsyncFetch(logger, "shard_stores", shard.v1(), shard.v2(), listShardStoresInfo);
fetch.fetchData(nodes, Collections.<String>emptySet());
fetch.fetchData(nodes, Collections.emptyMap());
}
}
}
Expand Down Expand Up @@ -224,7 +224,7 @@ protected synchronized void processAsyncFetch(
List<FailedNodeException> failures,
long fetchingRound
) {
fetchResponses.add(new Response(shardId, responses, failures));
fetchResponses.add(new Response(shardToCustomDataPath.keySet().iterator().next(), responses, failures));
if (expectedOps.countDown()) {
finish();
}
Expand Down Expand Up @@ -314,7 +314,7 @@ private boolean shardExistsInNode(final NodeGatewayStartedShards response) {
}

@Override
protected void reroute(ShardId shardId, String reason) {
protected void reroute(String shardId, String reason) {
// no-op
}

Expand Down
131 changes: 87 additions & 44 deletions server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,11 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;

import static java.util.Collections.emptySet;
import static java.util.Collections.unmodifiableSet;
import static java.util.Collections.emptyMap;
import static java.util.Collections.unmodifiableMap;

/**
* Allows to asynchronously fetch shard related data from other nodes for allocation, without blocking
Expand All @@ -77,18 +76,22 @@ public abstract class AsyncShardFetch<T extends BaseNodeResponse> implements Rel
* An action that lists the relevant shard data that needs to be fetched.
*/
public interface Lister<NodesResponse extends BaseNodesResponse<NodeResponse>, NodeResponse extends BaseNodeResponse> {
void list(ShardId shardId, @Nullable String customDataPath, DiscoveryNode[] nodes, ActionListener<NodesResponse> listener);
void list(Map<ShardId, String> shardIdsWithCustomDataPath, DiscoveryNode[] nodes, ActionListener<NodesResponse> listener);

}

protected final Logger logger;
protected final String type;
protected final ShardId shardId;
protected final String customDataPath;

protected final Map<ShardId,String> shardToCustomDataPath;
private final Lister<BaseNodesResponse<T>, T> action;
private final Map<String, NodeEntry<T>> cache = new HashMap<>();
private final Set<String> nodesToIgnore = new HashSet<>();
private final AtomicLong round = new AtomicLong();
private boolean closed;
private final String logKey;
private final Map<ShardId, Set<String>> shardToIgnoreNodes = new HashMap<>();

private final boolean enableBatchMode;

@SuppressWarnings("unchecked")
protected AsyncShardFetch(
Expand All @@ -100,11 +103,30 @@ protected AsyncShardFetch(
) {
this.logger = logger;
this.type = type;
this.shardId = Objects.requireNonNull(shardId);
this.customDataPath = Objects.requireNonNull(customDataPath);
shardToCustomDataPath =new HashMap<>();
shardToCustomDataPath.put(shardId, customDataPath);
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.logKey = "ShardId=[" + shardId.toString() + "]";
enableBatchMode = false;
}

@SuppressWarnings("unchecked")
protected AsyncShardFetch(
Logger logger,
String type,
Map<ShardId, String> shardToCustomDataPath,
Lister<? extends BaseNodesResponse<T>, T> action,
String batchId
) {
this.logger = logger;
this.type = type;
this.shardToCustomDataPath = shardToCustomDataPath;
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.logKey = "BatchID=[" + batchId+ "]";
enableBatchMode = true;
}


@Override
public synchronized void close() {
this.closed = true;
Expand All @@ -130,11 +152,26 @@ public synchronized int getNumberOfInFlightFetches() {
* The ignoreNodes are nodes that are supposed to be ignored for this round, since fetching is async, we need
* to keep them around and make sure we add them back when all the responses are fetched and returned.
*/
public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> ignoreNodes) {
public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId, Set<String>> ignoreNodes) {
if (closed) {
throw new IllegalStateException(shardId + ": can't fetch data on closed async fetch");
throw new IllegalStateException(logKey + ": can't fetch data on closed async fetch");
}

if(enableBatchMode == false){
// we will do assertions here on ignoreNodes
assert ignoreNodes.size() <=1 : "Can only have at-most one shard";
if(ignoreNodes.size() == 1) {
assert shardToCustomDataPath.containsKey(ignoreNodes.keySet().iterator().next()) : "ShardId should be same as initialised in fetcher";
}
}

// add the nodes to ignore to the list of nodes to ignore for each shard
for (Map.Entry<ShardId, Set<String>> ignoreNodesEntry : ignoreNodes.entrySet()) {
Set<String> ignoreNodesSet = shardToIgnoreNodes.getOrDefault(ignoreNodesEntry.getKey(), new HashSet<>());
ignoreNodesSet.addAll(ignoreNodesEntry.getValue());
shardToIgnoreNodes.put(ignoreNodesEntry.getKey(), ignoreNodesSet);
}
nodesToIgnore.addAll(ignoreNodes);

fillShardCacheWithDataNodes(cache, nodes);
List<NodeEntry<T>> nodesToFetch = findNodesToFetch(cache);
if (nodesToFetch.isEmpty() == false) {
Expand All @@ -153,7 +190,7 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> i

// if we are still fetching, return null to indicate it
if (hasAnyNodeFetching(cache)) {
return new FetchResult<>(shardId, null, emptySet());
return new FetchResult<>(null, emptyMap());
} else {
// nothing to fetch, yay, build the return value
Map<DiscoveryNode, T> fetchData = new HashMap<>();
Expand All @@ -177,16 +214,19 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> i
}
}
}
Set<String> allIgnoreNodes = unmodifiableSet(new HashSet<>(nodesToIgnore));

Map<ShardId, Set<String>> allIgnoreNodesMap = unmodifiableMap(new HashMap<>(shardToIgnoreNodes));
// clear the nodes to ignore, we had a successful run in fetching everything we can
// we need to try them if another full run is needed
nodesToIgnore.clear();
shardToIgnoreNodes.clear();
// if at least one node failed, make sure to have a protective reroute
// here, just case this round won't find anything, and we need to retry fetching data
if (failedNodes.isEmpty() == false || allIgnoreNodes.isEmpty() == false) {
reroute(shardId, "nodes failed [" + failedNodes.size() + "], ignored [" + allIgnoreNodes.size() + "]");
if (failedNodes.isEmpty() == false || allIgnoreNodesMap.values().stream().anyMatch(ignoreNodeSet -> ignoreNodeSet.isEmpty() == false)) {
reroute(logKey, "nodes failed [" + failedNodes.size() + "], ignored ["
+ allIgnoreNodesMap.values().stream().mapToInt(Set::size).sum() + "]");
}
return new FetchResult<>(shardId, fetchData, allIgnoreNodes);

return new FetchResult<>(fetchData, allIgnoreNodesMap);
}
}

Expand All @@ -199,10 +239,10 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> i
protected synchronized void processAsyncFetch(List<T> responses, List<FailedNodeException> failures, long fetchingRound) {
if (closed) {
// we are closed, no need to process this async fetch at all
logger.trace("{} ignoring fetched [{}] results, already closed", shardId, type);
logger.trace("{} ignoring fetched [{}] results, already closed", logKey, type);
return;
}
logger.trace("{} processing fetched [{}] results", shardId, type);
logger.trace("{} processing fetched [{}] results", logKey, type);

if (responses != null) {
for (T response : responses) {
Expand All @@ -212,7 +252,7 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds";
logger.trace(
"{} received response for [{}] from node {} for an older fetching round (expected: {} but was: {})",
shardId,
logKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFetchingRound(),
Expand All @@ -221,29 +261,29 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
} else if (nodeEntry.isFailed()) {
logger.trace(
"{} node {} has failed for [{}] (failure [{}])",
shardId,
logKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFailure()
);
} else {
// if the entry is there, for the right fetching round and not marked as failed already, process it
logger.trace("{} marking {} as done for [{}], result is [{}]", shardId, nodeEntry.getNodeId(), type, response);
logger.trace("{} marking {} as done for [{}], result is [{}]", logKey, nodeEntry.getNodeId(), type, response);
nodeEntry.doneFetching(response);
}
}
}
}
if (failures != null) {
for (FailedNodeException failure : failures) {
logger.trace("{} processing failure {} for [{}]", shardId, failure, type);
logger.trace("{} processing failure {} for [{}]", logKey, failure, type);
NodeEntry<T> nodeEntry = cache.get(failure.nodeId());
if (nodeEntry != null) {
if (nodeEntry.getFetchingRound() != fetchingRound) {
assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds";
logger.trace(
"{} received failure for [{}] from node {} for an older fetching round (expected: {} but was: {})",
shardId,
logKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFetchingRound(),
Expand All @@ -261,7 +301,7 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
logger.warn(
() -> new ParameterizedMessage(
"{}: failed to list shard for {} on node [{}]",
shardId,
logKey,
type,
failure.nodeId()
),
Expand All @@ -273,13 +313,13 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
}
}
}
reroute(shardId, "post_response");
reroute(logKey, "post_response");
}

/**
* Implement this in order to scheduled another round that causes a call to fetch data.
*/
protected abstract void reroute(ShardId shardId, String reason);
protected abstract void reroute(String logKey, String reason);

/**
* Clear cache for node, ensuring next fetch will fetch a fresh copy.
Expand Down Expand Up @@ -334,8 +374,8 @@ private boolean hasAnyNodeFetching(Map<String, NodeEntry<T>> shardCache) {
*/
// visible for testing
void asyncFetch(final DiscoveryNode[] nodes, long fetchingRound) {
logger.trace("{} fetching [{}] from {}", shardId, type, nodes);
action.list(shardId, customDataPath, nodes, new ActionListener<BaseNodesResponse<T>>() {
logger.trace("{} fetching [{}] from {}", logKey, type, nodes);
action.list(shardToCustomDataPath, nodes, new ActionListener<BaseNodesResponse<T>>() {
@Override
public void onResponse(BaseNodesResponse<T> response) {
processAsyncFetch(response.getNodes(), response.failures(), fetchingRound);
Expand All @@ -358,15 +398,13 @@ public void onFailure(Exception e) {
*/
public static class FetchResult<T extends BaseNodeResponse> {

private final ShardId shardId;
private final Map<DiscoveryNode, T> data;
private final Set<String> ignoreNodes;
private final Map<DiscoveryNode, T> data;
private final Map<ShardId, Set<String>> ignoredShardToNodes;

public FetchResult(ShardId shardId, Map<DiscoveryNode, T> data, Set<String> ignoreNodes) {
this.shardId = shardId;
this.data = data;
this.ignoreNodes = ignoreNodes;
}
public FetchResult(Map<DiscoveryNode, T> data, Map<ShardId, Set<String>> ignoreNodes) {
this.data = data;
this.ignoredShardToNodes = ignoreNodes;
}

/**
* Does the result actually contain data? If not, then there are on going fetch
Expand All @@ -385,15 +423,20 @@ public Map<DiscoveryNode, T> getData() {
return this.data;
}

/**
* Process any changes needed to the allocation based on this fetch result.
*/
public void processAllocation(RoutingAllocation allocation) {
for (String ignoreNode : ignoreNodes) {
allocation.addIgnoreShardForNode(shardId, ignoreNode);
/**
* Process any changes needed to the allocation based on this fetch result.
*/
public void processAllocation(RoutingAllocation allocation) {
for(Map.Entry<ShardId, Set<String>> entry : ignoredShardToNodes.entrySet()) {
ShardId shardId = entry.getKey();
Set<String> ignoreNodes = entry.getValue();
if (ignoreNodes.isEmpty() == false) {
ignoreNodes.forEach(nodeId -> allocation.addIgnoreShardForNode(shardId, nodeId));
}
}

}
}

/**
* A node entry, holding the state of the fetched data for a specific shard
Expand Down
22 changes: 14 additions & 8 deletions server/src/main/java/org/opensearch/gateway/GatewayAllocator.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.opensearch.indices.store.TransportNodesListShardStoreMetadata;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.Spliterators;
Expand Down Expand Up @@ -226,7 +227,9 @@ private static void clearCacheForPrimary(
AsyncShardFetch<TransportNodesListShardStoreMetadata.NodeStoreFilesMetadata> fetch,
RoutingAllocation allocation
) {
ShardRouting primary = allocation.routingNodes().activePrimary(fetch.shardId);
assert fetch.shardToCustomDataPath.size() == 1 : "expected only one shard";
ShardId shardId = fetch.shardToCustomDataPath.keySet().iterator().next();
ShardRouting primary = allocation.routingNodes().activePrimary(shardId);
if (primary != null) {
fetch.clearCacheForNode(primary.currentNodeId());
}
Expand Down Expand Up @@ -254,20 +257,19 @@ class InternalAsyncFetch<T extends BaseNodeResponse> extends AsyncShardFetch<T>
}

@Override
protected void reroute(ShardId shardId, String reason) {
logger.trace("{} scheduling reroute for {}", shardId, reason);
protected void reroute(String logKey, String reason) {
logger.trace("{} scheduling reroute for {}", logKey, reason);
assert rerouteService != null;
rerouteService.reroute(
"async_shard_fetch",
Priority.HIGH,
ActionListener.wrap(
r -> logger.trace("{} scheduled reroute completed for {}", shardId, reason),
e -> logger.debug(new ParameterizedMessage("{} scheduled reroute failed for {}", shardId, reason), e)
r -> logger.trace("{} scheduled reroute completed for {}", logKey, reason),
e -> logger.debug(new ParameterizedMessage("{} scheduled reroute failed for {}", logKey, reason), e)
)
);
}
}

class InternalPrimaryShardAllocator extends PrimaryShardAllocator {

private final TransportNodesListGatewayStartedShards startedAction;
Expand All @@ -293,7 +295,9 @@ protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShards.Nod
);
AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShards.NodeGatewayStartedShards> shardState = fetch.fetchData(
allocation.nodes(),
allocation.getIgnoreNodes(shard.shardId())
new HashMap<>() {{
put(shard.shardId(), allocation.getIgnoreNodes(shard.shardId()));
}}
);

if (shardState.hasData()) {
Expand Down Expand Up @@ -328,7 +332,9 @@ protected AsyncShardFetch.FetchResult<TransportNodesListShardStoreMetadata.NodeS
);
AsyncShardFetch.FetchResult<TransportNodesListShardStoreMetadata.NodeStoreFilesMetadata> shardStores = fetch.fetchData(
allocation.nodes(),
allocation.getIgnoreNodes(shard.shardId())
new HashMap<>() {{
put(shard.shardId(), allocation.getIgnoreNodes(shard.shardId()));
}}
);
if (shardStores.hasData()) {
shardStores.processAllocation(allocation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
Expand Down Expand Up @@ -124,7 +125,10 @@ public TransportNodesListGatewayStartedShards(
}

@Override
public void list(ShardId shardId, String customDataPath, DiscoveryNode[] nodes, ActionListener<NodesGatewayStartedShards> listener) {
public void list(Map<ShardId, String> shardIdsWithCustomDataPath, DiscoveryNode[] nodes, ActionListener<NodesGatewayStartedShards> listener) {
assert shardIdsWithCustomDataPath.size() == 1 : "only one shard should be specified";
final ShardId shardId = shardIdsWithCustomDataPath.keySet().iterator().next();
final String customDataPath = shardIdsWithCustomDataPath.get(shardId);
execute(new Request(shardId, customDataPath, nodes), listener);
}

Expand Down
Loading

0 comments on commit 63b0ed4

Please sign in to comment.