diff --git a/server/src/main/java/org/opensearch/gateway/AsyncShardsFetchPerNode.java b/server/src/main/java/org/opensearch/gateway/AsyncShardsFetchPerNode.java new file mode 100644 index 0000000000000..5391a0b8d7c25 --- /dev/null +++ b/server/src/main/java/org/opensearch/gateway/AsyncShardsFetchPerNode.java @@ -0,0 +1,432 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.gateway; + +import com.carrotsearch.hppc.cursors.ObjectObjectCursor; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchTimeoutException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.common.Nullable; +import org.opensearch.common.lease.Releasable; +import org.opensearch.index.shard.ShardId; +import org.opensearch.transport.ReceiveTimeoutTransportException; +import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicLong; + +/** + * This class is responsible for fetching shard data from nodes. It is analogous to AsyncShardFetch class since it fetches + * the data in asynchronous manner too. + * @param + */ +public abstract class AsyncShardsFetchPerNode implements Releasable { + + /** + * An action that lists the relevant shard data that needs to be fetched. + */ + public interface Lister, NodeResponse extends BaseNodeResponse> { + void list(DiscoveryNode[] nodes, Map shardsIdMap, ActionListener listener); + } + + protected final Logger logger; + protected final String type; + protected Map shardsToCustomDataPathMap; + private final AsyncShardsFetchPerNode.Lister, T> action; + protected final Map> cache = new HashMap<>(); + + private final Set nodesToIgnore = new HashSet<>(); + private final AtomicLong round = new AtomicLong(); + private boolean closed; + + @SuppressWarnings("unchecked") + protected AsyncShardsFetchPerNode( + Logger logger, + String type, + Map shardsToCustomDataPathMap, + AsyncShardsFetchPerNode.Lister, T> action + ) { + this.logger = logger; + this.type = type; + this.action = (AsyncShardsFetchPerNode.Lister, T>) action; + this.shardsToCustomDataPathMap = shardsToCustomDataPathMap; + } + + @Override + public synchronized void close() { + this.closed = true; + } + + protected abstract void reroute(String reason); + + /** + * Clear cache for node, ensuring next fetch will fetch a fresh copy. + */ + synchronized void clearCacheForNode(String nodeId) { + cache.remove(nodeId); + } + + /** This function is copy-pasted from AsyncShardFetch.java + * Fills the shard fetched data with new (data) nodes and a fresh NodeEntry, and removes from + * it nodes that are no longer part of the state. + */ + private void fillShardCacheWithDataNodes(Map> shardCache, DiscoveryNodes nodes) { + // verify that all current data nodes are there + for (ObjectObjectCursor cursor : nodes.getDataNodes()) { + DiscoveryNode node = cursor.value; + if (shardCache.containsKey(node.getId()) == false) { + shardCache.put(node.getId(), new AsyncShardsFetchPerNode.NodeEntry(node.getId())); + } + } + // remove nodes that are not longer part of the data nodes set + shardCache.keySet().removeIf(nodeId -> !nodes.nodeExists(nodeId)); + } + + /** + * This function is copy-pasted from AsyncShardFetch.java + * Finds all the nodes that need to be fetched. Those are nodes that have no + * data, and are not in fetch mode. + */ + private List> findNodesToFetch(Map> shardCache) { + List> nodesToFetch = new ArrayList<>(); + for (AsyncShardsFetchPerNode.NodeEntry nodeEntry : shardCache.values()) { + if (nodeEntry.hasData() == false && nodeEntry.isFetching() == false) { + nodesToFetch.add(nodeEntry); + } + } + return nodesToFetch; + } + + /** + * This function is copy-pasted from AsyncShardFetch.java + * Are there any nodes that are fetching data? + */ + private boolean hasAnyNodeFetching(Map> shardCache) { + for (AsyncShardsFetchPerNode.NodeEntry nodeEntry : shardCache.values()) { + if (nodeEntry.isFetching()) { + return true; + } + } + return false; + } + + /** + * This function is copy-pasted from AsyncShardFetch.java, fetchData(). Here we have modified the + * logging part for better debuggability and testing purpose + * @param nodes + * @return + */ + public synchronized AsyncShardsFetchPerNode.TestFetchResult testFetchData(DiscoveryNodes nodes){ + if (closed) { + throw new IllegalStateException("TEST: can't fetch data from nodes on closed async fetch"); + } + + logger.info("TEST- Fetching Unassigned Shards per node"); + fillShardCacheWithDataNodes(cache, nodes); + List> nodesToFetch = findNodesToFetch(cache); + if (nodesToFetch.isEmpty() == false) { + // mark all node as fetching and go ahead and async fetch them + // use a unique round id to detect stale responses in processAsyncFetch + final long fetchingRound = round.incrementAndGet(); + for (AsyncShardsFetchPerNode.NodeEntry nodeEntry : nodesToFetch) { + nodeEntry.markAsFetching(fetchingRound); + } + DiscoveryNode[] discoNodesToFetch = nodesToFetch.stream() + .map(AsyncShardsFetchPerNode.NodeEntry::getNodeId) + .map(nodes::get) + .toArray(DiscoveryNode[]::new); + asyncFetchShardPerNode(discoNodesToFetch, fetchingRound); + } + + if (hasAnyNodeFetching(cache)) { + return new AsyncShardsFetchPerNode.TestFetchResult<>( null); + } else { + // nothing to fetch, yay, build the return value + Map fetchData = new HashMap<>(); + Set failedNodes = new HashSet<>(); + for (Iterator>> it = cache.entrySet().iterator(); it.hasNext();) { + Map.Entry> entry = it.next(); + String nodeId = entry.getKey(); + AsyncShardsFetchPerNode.NodeEntry nodeEntry = entry.getValue(); + + DiscoveryNode node = nodes.get(nodeId); + if (node != null) { + if (nodeEntry.isFailed()) { + // if its failed, remove it from the list of nodes, so if this run doesn't work + // we try again next round to fetch it again + it.remove(); + failedNodes.add(nodeEntry.getNodeId()); + } else { + if (nodeEntry.getValue() != null) { + fetchData.put(node, nodeEntry.getValue()); + } + } + } + } + + // if at least one node failed, make sure to have a protective reroute + // here, just case this round won't find anything, and we need to retry fetching data + if (failedNodes.isEmpty() == false ) { + reroute("TEST--> nodes failed [" + failedNodes.size() ); + } + + return new AsyncShardsFetchPerNode.TestFetchResult<>(fetchData); + } + } + + /** This function is copy-pasted from AsyncShardFetch.java (asyncFetch()), with more verbose logging + * Async fetches data for the provided shard with the set of nodes that need to be fetched from. + */ + void asyncFetchShardPerNode(final DiscoveryNode[] nodes, long fetchingRound) { + logger.info("Fetching Unassigned Shards per node"); + action.list(nodes, shardsToCustomDataPathMap, new ActionListener>() { + @Override + public void onResponse(BaseNodesResponse tBaseNodesResponse) { + processTestAsyncFetch(tBaseNodesResponse.getNodes(),tBaseNodesResponse.failures(), fetchingRound); + } + + @Override + public void onFailure(Exception e) { + + List failures = new ArrayList<>(nodes.length); + for (final DiscoveryNode node : nodes) { + failures.add(new FailedNodeException(node.getId(), "Total failure in fetching", e)); + } + processTestAsyncFetch(null, failures, fetchingRound); + } + }); + } + + + /** This function is copy-pasted from AsyncShardFetch.java (processAsyncFetch()), with more verbose logging. + * + * Called by the response handler of the async action to fetch data. Verifies that its still working + * on the same cache generation, otherwise the results are discarded. It then goes and fills the relevant data for + * the shard (response + failures), issuing a reroute at the end of it to make sure there will be another round + * of allocations taking this new data into account. + */ + protected synchronized void processTestAsyncFetch(List responses, List failures, long fetchingRound){ + if (closed) { + // we are closed, no need to process this async fetch at all + logger.trace("TEST-Ignoring fetched [{}] results, already closed", type); + return; + } + + logger.trace("TEST-processing fetched results"); + + if (responses != null) { + for (T response : responses) { + AsyncShardsFetchPerNode.NodeEntry nodeEntry = cache.get(response.getNode().getId()); + if (nodeEntry != null) { + if (nodeEntry.getFetchingRound() != fetchingRound) { + assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds"; + logger.info( + "TEST--> received response for [{}] from node {} for an older fetching round (expected: {} but was: {})", + nodeEntry.getNodeId(), + type, + nodeEntry.getFetchingRound(), + fetchingRound + ); + } else if (nodeEntry.isFailed()) { + logger.info( + "node {} has failed for [{}] (failure [{}])", + nodeEntry.getNodeId(), + type, + nodeEntry.getFailure() + ); + } else { + // if the entry is there, for the right fetching round and not marked as failed already, process it + logger.info("TEST--> marking {} as done for [{}], result is [{}]", nodeEntry.getNodeId(), type, response); + nodeEntry.doneFetching(response); + } + } + } + } + if (failures != null) { + for (FailedNodeException failure : failures) { + logger.trace("processing failure {} for [{}]", failure, type); + AsyncShardsFetchPerNode.NodeEntry nodeEntry = cache.get(failure.nodeId()); + if (nodeEntry != null) { + if (nodeEntry.getFetchingRound() != fetchingRound) { + assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds"; + logger.trace( + "received failure for [{}] from node {} for an older fetching round (expected: {} but was: {})", + nodeEntry.getNodeId(), + type, + nodeEntry.getFetchingRound(), + fetchingRound + ); + } else if (nodeEntry.isFailed() == false) { + // if the entry is there, for the right fetching round and not marked as failed already, process it + Throwable unwrappedCause = ExceptionsHelper.unwrapCause(failure.getCause()); + // if the request got rejected or timed out, we need to try it again next time... + if (unwrappedCause instanceof OpenSearchRejectedExecutionException + || unwrappedCause instanceof ReceiveTimeoutTransportException + || unwrappedCause instanceof OpenSearchTimeoutException) { + nodeEntry.restartFetching(); + } else { + logger.warn( + () -> new ParameterizedMessage( + "failed to list shard for {} on node [{}]", + type, + failure.nodeId() + ), + failure + ); + nodeEntry.doneFetching(failure.getCause()); + } + } + } + } + } + + reroute("TEST_post_response"); + } + + protected synchronized void updateBatchOfShards(Map shardsToCustomDataPathMap){ + + // update only when current batch is completed + if(hasAnyNodeFetching(cache)==false && shardsToCustomDataPathMap.isEmpty()==false){ + this.shardsToCustomDataPathMap= shardsToCustomDataPathMap; + + // not intelligent enough right now to invalidate the diff. + // When batching the diff we can make it more intelligent + cache.values().forEach(NodeEntry::invalidateCurrentData); + } + } + + /** + * Analogous to FetchResult in AsyncShardFetch.java, but currently we dont accommodate ignoreNodes + * @param + */ + public static class TestFetchResult { + + private final Map nodesToShards; + + public TestFetchResult(Map nodesToShards) { + this.nodesToShards = nodesToShards; + } + + public Map getNodesToShards() { + return nodesToShards; + } + + public boolean hasData() { + return nodesToShards != null; + } + + } + + + /** + * A node entry, holding the state of the fetched data for a batch of shards + * for a giving node. + * + * It is analogous to NodeEntry in AsyncShardFetch.java + */ + static class NodeEntry { + + /* Copied and derived from AsyncShardFetch.java. Starts*/ + private final String nodeId; + private boolean fetching; + @Nullable + private T value; + private boolean valueSet; + private Throwable failure; + private long fetchingRound; + + NodeEntry(String nodeId) { + this.nodeId = nodeId; + } + + String getNodeId() { + return this.nodeId; + } + + boolean isFetching() { + return fetching; + } + + void markAsFetching(long fetchingRound) { + assert fetching == false : "double marking a node as fetching"; + this.fetching = true; + this.fetchingRound = fetchingRound; + } + + void doneFetching(T value) { + assert fetching : "setting value but not in fetching mode"; + assert failure == null : "setting value when failure already set"; + this.valueSet = true; + this.value = value; + this.fetching = false; + } + + void doneFetching(Throwable failure) { + assert fetching : "setting value but not in fetching mode"; + assert valueSet == false : "setting failure when already set value"; + assert failure != null : "setting failure can't be null"; + this.failure = failure; + this.fetching = false; + } + + void restartFetching() { + assert fetching : "restarting fetching, but not in fetching mode"; + assert valueSet == false : "value can't be set when restarting fetching"; + assert failure == null : "failure can't be set when restarting fetching"; + this.fetching = false; + } + + boolean isFailed() { + return failure != null; + } + + boolean hasData() { + return valueSet || failure != null; + } + + Throwable getFailure() { + assert hasData() : "getting failure when data has not been fetched"; + return failure; + } + + @Nullable + T getValue() { + assert failure == null : "trying to fetch value, but its marked as failed, check isFailed"; + assert valueSet : "value is not set, hasn't been fetched yet"; + return value; + } + + long getFetchingRound() { + return fetchingRound; + } + /* Copied and derived from AsyncShardFetch.java. Ends*/ + + void invalidateCurrentData() { + this.value=null; + valueSet=false; + fetchingRound=0; + failure=null; + } + } + + +} diff --git a/server/src/main/java/org/opensearch/gateway/GatewayAllocator.java b/server/src/main/java/org/opensearch/gateway/GatewayAllocator.java index cdcf813d9ede0..7c336bafc0699 100644 --- a/server/src/main/java/org/opensearch/gateway/GatewayAllocator.java +++ b/server/src/main/java/org/opensearch/gateway/GatewayAllocator.java @@ -43,6 +43,7 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.routing.RerouteService; +import org.opensearch.cluster.routing.RoutingNodes; import org.opensearch.cluster.routing.ShardRouting; import org.opensearch.cluster.routing.allocation.AllocateUnassignedDecision; import org.opensearch.cluster.routing.allocation.ExistingShardsAllocator; @@ -57,7 +58,9 @@ import org.opensearch.indices.store.TransportNodesListShardStoreMetadata; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentMap; import java.util.stream.Collectors; @@ -83,19 +86,28 @@ public class GatewayAllocator implements ExistingShardsAllocator { ShardId, AsyncShardFetch> asyncFetchStarted = ConcurrentCollections .newConcurrentMap(); - private final ConcurrentMap> asyncFetchStore = + private final ConcurrentMap> asyncFetchStore = ConcurrentCollections.newConcurrentMap(); + + + private Map shardsPerNode= ConcurrentCollections.newConcurrentMap(); + + private AsyncShardsFetchPerNode fetchShardsFromNodes=null; + private Set lastSeenEphemeralIds = Collections.emptySet(); + TransportNodesCollectGatewayStartedShard testAction; @Inject public GatewayAllocator( RerouteService rerouteService, TransportNodesListGatewayStartedShards startedAction, - TransportNodesListShardStoreMetadata storeAction + TransportNodesListShardStoreMetadata storeAction, + TransportNodesCollectGatewayStartedShard testAction ) { this.rerouteService = rerouteService; - this.primaryShardAllocator = new InternalPrimaryShardAllocator(startedAction); + this.primaryShardAllocator = new TestInternalPrimaryShardAllocator(testAction); this.replicaShardAllocator = new InternalReplicaShardAllocator(storeAction); + this.testAction=testAction; } @Override @@ -104,6 +116,8 @@ public void cleanCaches() { asyncFetchStarted.clear(); Releasables.close(asyncFetchStore.values()); asyncFetchStore.clear(); + Releasables.close(fetchShardsFromNodes); + shardsPerNode.clear(); } // for tests @@ -131,6 +145,13 @@ public void applyStartedShards(final List startedShards, final Rou Releasables.close(asyncFetchStarted.remove(startedShard.shardId())); Releasables.close(asyncFetchStore.remove(startedShard.shardId())); } + + // clean async object and cache for per DiscoverNode if all shards are assigned and none are ignore list + if (allocation.routingNodes().unassigned().isEmpty() && allocation.routingNodes().unassigned().isIgnoredEmpty()){ + Releasables.close(fetchShardsFromNodes); + shardsPerNode.clear(); + fetchShardsFromNodes =null; + } } @Override @@ -139,6 +160,12 @@ public void applyFailedShards(final List failedShards, final Routin Releasables.close(asyncFetchStarted.remove(failedShard.getRoutingEntry().shardId())); Releasables.close(asyncFetchStore.remove(failedShard.getRoutingEntry().shardId())); } + + // clean async object and cache for per DiscoverNode if all shards are assigned and none are ignore list + if (allocation.routingNodes().unassigned().isEmpty() && allocation.routingNodes().unassigned().isIgnoredEmpty()){ + Releasables.close(fetchShardsFromNodes); + shardsPerNode.clear(); + } } @Override @@ -146,6 +173,10 @@ public void beforeAllocation(final RoutingAllocation allocation) { assert primaryShardAllocator != null; assert replicaShardAllocator != null; ensureAsyncFetchStorePrimaryRecency(allocation); + + //build the view of shards per node here by doing transport calls on nodes and populate shardsPerNode + collectShardsPerNode(allocation); + } @Override @@ -168,6 +199,53 @@ public void allocateUnassigned( innerAllocatedUnassigned(allocation, primaryShardAllocator, replicaShardAllocator, shardRouting, unassignedAllocationHandler); } + private synchronized Map collectShardsPerNode(RoutingAllocation allocation) { + + Map batchOfUnassignedShardsWithCustomDataPath = getBatchOfUnassignedShardsWithCustomDataPath(allocation); + if (fetchShardsFromNodes == null) { + if (batchOfUnassignedShardsWithCustomDataPath.isEmpty()){ + return null; + } + fetchShardsFromNodes = new TestAsyncShardFetch<>(logger, "collect_shards", batchOfUnassignedShardsWithCustomDataPath, testAction); + } else { + //verify if any new shards need to be batched? + + // even if one shard is not in the map, we now update the batch with all unassigned shards + if (batchOfUnassignedShardsWithCustomDataPath.keySet().stream().allMatch(shard -> fetchShardsFromNodes.shardsToCustomDataPathMap.containsKey(shard)) == false) { + // right now update the complete map, but this can be optimized with only the diff + logger.info("Shards Batch not equal, updating it"); + if (fetchShardsFromNodes.shardsToCustomDataPathMap.keySet().equals(batchOfUnassignedShardsWithCustomDataPath.keySet()) == false) { + fetchShardsFromNodes.updateBatchOfShards(batchOfUnassignedShardsWithCustomDataPath); + } + } + } + + AsyncShardsFetchPerNode.TestFetchResult listOfNodeGatewayStartedShardsTestFetchResult = fetchShardsFromNodes.testFetchData(allocation.nodes()); + + if (listOfNodeGatewayStartedShardsTestFetchResult.getNodesToShards()==null) + { + logger.info("Fetching probably still going on some nodes for number of shards={}, current fetch = {}",fetchShardsFromNodes.shardsToCustomDataPathMap.size(),fetchShardsFromNodes.cache.size()); + return null; + } + else { + logger.info("Collecting of total shards ={}, over transport done", fetchShardsFromNodes.shardsToCustomDataPathMap.size()); + logger.info("Fetching from nodes done with size of nodes fetched= {}", listOfNodeGatewayStartedShardsTestFetchResult.getNodesToShards().size()); + // update the view for GatewayAllocator + shardsPerNode = listOfNodeGatewayStartedShardsTestFetchResult.getNodesToShards(); + return shardsPerNode; + } + } + + private Map getBatchOfUnassignedShardsWithCustomDataPath(RoutingAllocation allocation){ + Map map = new HashMap<>(); + RoutingNodes.UnassignedShards allUnassignedShards = allocation.routingNodes().unassigned(); + for (ShardRouting shardIterator : allUnassignedShards) { + if (shardIterator.primary()) + map.put(shardIterator.shardId(), IndexMetadata.INDEX_DATA_PATH_SETTING.get(allocation.metadata().index(shardIterator.index()).getSettings())); + } + return map; + } + // allow for testing infra to change shard allocators implementation protected static void innerAllocatedUnassigned( RoutingAllocation allocation, @@ -177,6 +255,7 @@ protected static void innerAllocatedUnassigned( ExistingShardsAllocator.UnassignedAllocationHandler unassignedAllocationHandler ) { assert shardRouting.unassigned(); + if (shardRouting.primary()) { primaryShardAllocator.allocateUnassigned(shardRouting, allocation, unassignedAllocationHandler); } else { @@ -268,6 +347,35 @@ protected void reroute(ShardId shardId, String reason) { } } + /** + * Analogous to InternalAsyncFetch. + * @param + */ + class TestAsyncShardFetch extends AsyncShardsFetchPerNode + { + TestAsyncShardFetch( + Logger logger, + String type, + Map map, + AsyncShardsFetchPerNode.Lister, T> action + ) { + super(logger, type, map, action); + } + + @Override + protected void reroute( String reason) { + logger.trace("TEST--scheduling reroute for {}", reason); + assert rerouteService != null; + rerouteService.reroute( + "TEST_async_shard_fetch", + Priority.HIGH, + ActionListener.wrap( + r -> logger.trace("TEST-scheduled reroute completed for {}", reason), + e -> logger.debug(new ParameterizedMessage("TEST- scheduled reroute failed for {}", reason), e) + ) + ); + } + } class InternalPrimaryShardAllocator extends PrimaryShardAllocator { private final TransportNodesListGatewayStartedShards startedAction; @@ -303,6 +411,40 @@ protected AsyncShardFetch.FetchResult fetchData(ShardRouting shard, RoutingAllocation allocation) { + ShardId shardId = shard.shardId(); + Map discoveryNodeListOfNodeGatewayStartedShardsMap = shardsPerNode; + + if (shardsPerNode.isEmpty()) { + return new AsyncShardFetch.FetchResult<>(shardId, null, Collections.emptySet()); + } + + HashMap dataToAdapt = new HashMap<>(); + for (DiscoveryNode node : discoveryNodeListOfNodeGatewayStartedShardsMap.keySet()) { + + TransportNodesCollectGatewayStartedShard.ListOfNodeGatewayStartedShards shardsOnThatNode = discoveryNodeListOfNodeGatewayStartedShardsMap.get(node); + if (shardsOnThatNode.getListOfNodeGatewayStartedShards().containsKey(shardId)) { + TransportNodesCollectGatewayStartedShard.NodeGatewayStartedShards nodeGatewayStartedShardsFromAdapt = shardsOnThatNode.getListOfNodeGatewayStartedShards().get(shardId); + // construct a object to adapt + TransportNodesListGatewayStartedShards.NodeGatewayStartedShards nodeGatewayStartedShardsToAdapt = new TransportNodesListGatewayStartedShards.NodeGatewayStartedShards(node, nodeGatewayStartedShardsFromAdapt.allocationId(), + nodeGatewayStartedShardsFromAdapt.primary(), nodeGatewayStartedShardsFromAdapt.replicationCheckpoint(), nodeGatewayStartedShardsFromAdapt.storeException()); + dataToAdapt.put(node, nodeGatewayStartedShardsToAdapt); + } + } + return new AsyncShardFetch.FetchResult<>(shardId, dataToAdapt, Collections.emptySet()); + } + } + + class InternalReplicaShardAllocator extends ReplicaShardAllocator { private final TransportNodesListShardStoreMetadata storeAction; diff --git a/server/src/main/java/org/opensearch/gateway/GatewayModule.java b/server/src/main/java/org/opensearch/gateway/GatewayModule.java index 59ec0243c88c9..3790cd5376a9d 100644 --- a/server/src/main/java/org/opensearch/gateway/GatewayModule.java +++ b/server/src/main/java/org/opensearch/gateway/GatewayModule.java @@ -47,6 +47,7 @@ protected void configure() { bind(GatewayService.class).asEagerSingleton(); bind(TransportNodesListGatewayMetaState.class).asEagerSingleton(); bind(TransportNodesListGatewayStartedShards.class).asEagerSingleton(); + bind(TransportNodesCollectGatewayStartedShard.class).asEagerSingleton(); bind(LocalAllocateDangledIndices.class).asEagerSingleton(); } } diff --git a/server/src/main/java/org/opensearch/gateway/TransportNodesCollectGatewayStartedShard.java b/server/src/main/java/org/opensearch/gateway/TransportNodesCollectGatewayStartedShard.java new file mode 100644 index 0000000000000..1f40a8811c642 --- /dev/null +++ b/server/src/main/java/org/opensearch/gateway/TransportNodesCollectGatewayStartedShard.java @@ -0,0 +1,471 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.gateway; + +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.LegacyESVersion; +import org.opensearch.OpenSearchException; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionType; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.BaseNodeRequest; +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Nullable; +import org.opensearch.common.collect.HppcMaps; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.env.NodeEnvironment; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.index.shard.ShardId; +import org.opensearch.index.shard.ShardPath; +import org.opensearch.index.shard.ShardStateMetadata; +import org.opensearch.index.store.Store; +import org.opensearch.indices.IndicesService; +import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentMap; + +/** + * This transport action is used to fetch the all unassigned shard version from each node during primary allocation in {@link GatewayAllocator}. + * We use this to find out which node holds the latest shard version and which of them used to be a primary in order to allocate + * shards after node or cluster restarts. + * + * @opensearch.internal + */ +public class TransportNodesCollectGatewayStartedShard extends TransportNodesAction< + TransportNodesCollectGatewayStartedShard.Request, + TransportNodesCollectGatewayStartedShard.NodesGatewayStartedShards, + TransportNodesCollectGatewayStartedShard.NodeRequest, + TransportNodesCollectGatewayStartedShard.ListOfNodeGatewayStartedShards> + implements + AsyncShardsFetchPerNode.Lister< + TransportNodesCollectGatewayStartedShard.NodesGatewayStartedShards, + TransportNodesCollectGatewayStartedShard.ListOfNodeGatewayStartedShards> { + + public static final String ACTION_NAME = "internal:gateway/local/collect_shards"; + public static final ActionType TYPE = new ActionType<>(ACTION_NAME, NodesGatewayStartedShards::new); + + private final Settings settings; + private final NodeEnvironment nodeEnv; + private final IndicesService indicesService; + private final NamedXContentRegistry namedXContentRegistry; + + @Inject + public TransportNodesCollectGatewayStartedShard( + Settings settings, + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + NodeEnvironment env, + IndicesService indicesService, + NamedXContentRegistry namedXContentRegistry + ) { + super( + ACTION_NAME, + threadPool, + clusterService, + transportService, + actionFilters, + Request::new, + NodeRequest::new, + ThreadPool.Names.FETCH_SHARD_STARTED, + ListOfNodeGatewayStartedShards.class + ); + this.settings = settings; + this.nodeEnv = env; + this.indicesService = indicesService; + this.namedXContentRegistry = namedXContentRegistry; + } + + @Override + public void list(DiscoveryNode[] nodes, MapshardsIdMap,ActionListener listener) { + execute(new Request(nodes, shardsIdMap), listener); + } + + @Override + protected NodeRequest newNodeRequest(Request request) { + return new NodeRequest(request); + } + + @Override + protected ListOfNodeGatewayStartedShards newNodeResponse(StreamInput in) throws IOException { + return new ListOfNodeGatewayStartedShards(in); + } + + @Override + protected NodesGatewayStartedShards newResponse( + Request request, + List responses, + List failures + ) { + return new NodesGatewayStartedShards(clusterService.getClusterName(), responses, failures); + } + + @Override + protected ListOfNodeGatewayStartedShards nodeOperation(NodeRequest request) { + logger.info("TEST->Transport call- +TC"); + Map shardsOnNode = new HashMap<>(); + + /* This is node Operation is same as nodeOperation on TransportNodesListGatewayStartedShards, but it does over a loop + for all Unassigned shards + */ + for (Map.Entry unsassignedShardsMap : request.shardIdsWithCustomDataPath.entrySet()) { + try { + final ShardId shardId = unsassignedShardsMap.getKey(); + logger.trace("{} loading local shard state info", shardId); + ShardStateMetadata shardStateMetadata = ShardStateMetadata.FORMAT.loadLatestState( + logger, + namedXContentRegistry, + nodeEnv.availableShardPaths(shardId) + ); + if (shardStateMetadata != null) { + if (indicesService.getShardOrNull(shardId) == null) { + final String customDataPath; + if (unsassignedShardsMap.getValue() != null) { + customDataPath = unsassignedShardsMap.getValue(); + } else { + // TODO: Fallback for BWC with older OpenSearch versions. + // Remove once request.getCustomDataPath() always returns non-null + final IndexMetadata metadata = clusterService.state().metadata().index(shardId.getIndex()); + if (metadata != null) { + customDataPath = new IndexSettings(metadata, settings).customDataPath(); + } else { + logger.trace("{} node doesn't have meta data for the requests index", shardId); + throw new OpenSearchException("node doesn't have meta data for index " + shardId.getIndex()); + } + } + // we don't have an open shard on the store, validate the files on disk are openable + ShardPath shardPath = null; + try { + shardPath = ShardPath.loadShardPath(logger, nodeEnv, shardId, customDataPath); + if (shardPath == null) { + throw new IllegalStateException(shardId + " no shard path found"); + } + Store.tryOpenIndex(shardPath.resolveIndex(), shardId, nodeEnv::shardLock, logger); + } catch (Exception exception) { + final ShardPath finalShardPath = shardPath; + logger.trace( + () -> new ParameterizedMessage( + "{} can't open index for shard [{}] in path [{}]", + shardId, + shardStateMetadata, + (finalShardPath != null) ? finalShardPath.resolveIndex() : "" + ), + exception + ); + String allocationId = shardStateMetadata.allocationId != null ? shardStateMetadata.allocationId.getId() : null; + shardsOnNode.put(shardId, new NodeGatewayStartedShards( + allocationId, + shardStateMetadata.primary, + null, + exception + )); + } + } + + logger.info("TEST---> {} shard state info found: [{}]", shardId, shardStateMetadata); + String allocationId = shardStateMetadata.allocationId != null ? shardStateMetadata.allocationId.getId() : null; + final IndexShard shard = indicesService.getShardOrNull(shardId); + shardsOnNode.put(shardId, new NodeGatewayStartedShards( + allocationId, + shardStateMetadata.primary, + shard != null ? shard.getLatestReplicationCheckpoint() : null + )); + } + else { + logger.info("TEST--> {} no local shard info found", shardId); + shardsOnNode.put(shardId, new NodeGatewayStartedShards(null, false, null)); + } + } catch (Exception e) { + throw new OpenSearchException("failed to load started shards", e); + } + } + return new ListOfNodeGatewayStartedShards(clusterService.localNode(), shardsOnNode); + } + + /** + * The nodes request. + * + * @opensearch.internal + */ + public static class Request extends BaseNodesRequest { + + + private final Map shardIdStringMap; + + public Request(StreamInput in) throws IOException { + super(in); + shardIdStringMap=in.readMap(ShardId::new, StreamInput::readString); + } + + public Request(DiscoveryNode[] nodes, Map shardIdStringMap) { + super(nodes); + this.shardIdStringMap= Objects.requireNonNull(shardIdStringMap); + } + + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeMap(shardIdStringMap, (o, k) -> k.writeTo(o), StreamOutput::writeString); + } + + public Map getShardIdsMap() { + return shardIdStringMap; + } + } + + /** + * The nodes response. + * + * @opensearch.internal + */ + public static class NodesGatewayStartedShards extends BaseNodesResponse { + + public NodesGatewayStartedShards(StreamInput in) throws IOException { + super(in); + } + + public NodesGatewayStartedShards( + ClusterName clusterName, + List nodes, + List failures + ) { + super(clusterName, nodes, failures); + } + + @Override + protected List readNodesFrom(StreamInput in) throws IOException { + return in.readList(ListOfNodeGatewayStartedShards::new); + } + + @Override + protected void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + } + + /** + * The request. + * + * @opensearch.internal + */ + public static class NodeRequest extends BaseNodeRequest { + + + private final Map shardIdsWithCustomDataPath; + + public NodeRequest(StreamInput in) throws IOException { + super(in); + shardIdsWithCustomDataPath=in.readMap(ShardId::new, StreamInput::readString); + } + + public NodeRequest(Request request) { + + this.shardIdsWithCustomDataPath=Objects.requireNonNull(request.getShardIdsMap()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeMap(shardIdsWithCustomDataPath, (o, k) -> k.writeTo(o), StreamOutput::writeString); + } + + } + + /** + * The response as stored by TransportNodesListGatewayStartedShards(to maintain backward compatibility). + * + * @opensearch.internal + */ + public static class NodeGatewayStartedShards { + private final String allocationId; + private final boolean primary; + private final Exception storeException; + private final ReplicationCheckpoint replicationCheckpoint; + + public NodeGatewayStartedShards(StreamInput in) throws IOException { + allocationId = in.readOptionalString(); + primary = in.readBoolean(); + if (in.readBoolean()) { + storeException = in.readException(); + } else { + storeException = null; + } + if (in.getVersion().onOrAfter(Version.V_2_3_0) && in.readBoolean()) { + replicationCheckpoint = new ReplicationCheckpoint(in); + } else { + replicationCheckpoint = null; + } + } + + public NodeGatewayStartedShards( + String allocationId, + boolean primary, + ReplicationCheckpoint replicationCheckpoint + ) { + this( allocationId, primary, replicationCheckpoint, null); + } + + public NodeGatewayStartedShards( + String allocationId, + boolean primary, + ReplicationCheckpoint replicationCheckpoint, + Exception storeException + ) { + this.allocationId = allocationId; + this.primary = primary; + this.replicationCheckpoint = replicationCheckpoint; + this.storeException = storeException; + } + + public String allocationId() { + return this.allocationId; + } + + public boolean primary() { + return this.primary; + } + + public ReplicationCheckpoint replicationCheckpoint() { + return this.replicationCheckpoint; + } + + public Exception storeException() { + return this.storeException; + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(allocationId); + out.writeBoolean(primary); + if (storeException != null) { + out.writeBoolean(true); + out.writeException(storeException); + } else { + out.writeBoolean(false); + } + if (out.getVersion().onOrAfter(Version.V_2_3_0)) { + if (replicationCheckpoint != null) { + out.writeBoolean(true); + replicationCheckpoint.writeTo(out); + } else { + out.writeBoolean(false); + } + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + NodeGatewayStartedShards that = (NodeGatewayStartedShards) o; + + return primary == that.primary + && Objects.equals(allocationId, that.allocationId) + && Objects.equals(storeException, that.storeException) + && Objects.equals(replicationCheckpoint, that.replicationCheckpoint); + } + + @Override + public int hashCode() { + int result = (allocationId != null ? allocationId.hashCode() : 0); + result = 31 * result + (primary ? 1 : 0); + result = 31 * result + (storeException != null ? storeException.hashCode() : 0); + result = 31 * result + (replicationCheckpoint != null ? replicationCheckpoint.hashCode() : 0); + return result; + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder(); + buf.append("NodeGatewayStartedShards[").append("allocationId=").append(allocationId).append(",primary=").append(primary); + if (storeException != null) { + buf.append(",storeException=").append(storeException); + } + if (replicationCheckpoint != null) { + buf.append(",ReplicationCheckpoint=").append(replicationCheckpoint.toString()); + } + buf.append("]"); + return buf.toString(); + } + } + + public static class ListOfNodeGatewayStartedShards extends BaseNodeResponse { + public Map getListOfNodeGatewayStartedShards() { + return listOfNodeGatewayStartedShards; + } + + private final Map listOfNodeGatewayStartedShards; + public ListOfNodeGatewayStartedShards(StreamInput in) throws IOException { + super(in); + this.listOfNodeGatewayStartedShards = in.readMap(ShardId::new, NodeGatewayStartedShards::new); + } + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeMap(listOfNodeGatewayStartedShards, (o, k) -> k.writeTo(o),(o,v)->v.writeTo(o)); + } + + public ListOfNodeGatewayStartedShards(DiscoveryNode node, Map listOfNodeGatewayStartedShards) { + super(node); + this.listOfNodeGatewayStartedShards=listOfNodeGatewayStartedShards; + } + + } +}