Skip to content

Commit

Permalink
Removed AsyncShardBatchFetch class
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav614 committed Aug 23, 2023
1 parent 850286c commit 490d77f
Show file tree
Hide file tree
Showing 16 changed files with 197 additions and 604 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ void start() {
} else {
for (Tuple<ShardId, String> shard : shards) {
InternalAsyncFetch fetch = new InternalAsyncFetch(logger, "shard_stores", shard.v1(), shard.v2(), listShardStoresInfo);
fetch.fetchData(nodes, Collections.<String>emptySet());
fetch.fetchData(nodes, Collections.emptyMap());
}
}
}
Expand Down Expand Up @@ -223,7 +223,7 @@ protected synchronized void processAsyncFetch(
List<FailedNodeException> failures,
long fetchingRound
) {
fetchResponses.add(new Response(shardId, responses, failures));
fetchResponses.add(new Response(shardToCustomDataPath.keySet().iterator().next(), responses, failures));
if (expectedOps.countDown()) {
finish();
}
Expand Down Expand Up @@ -312,7 +312,7 @@ private boolean shardExistsInNode(final NodeGatewayStartedShards response) {
}

@Override
protected void reroute(ShardId shardId, String reason) {
protected void reroute(String shardId, String reason) {
// no-op
}

Expand Down
131 changes: 87 additions & 44 deletions server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,11 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;

import static java.util.Collections.emptySet;
import static java.util.Collections.unmodifiableSet;
import static java.util.Collections.emptyMap;
import static java.util.Collections.unmodifiableMap;

/**
* Allows to asynchronously fetch shard related data from other nodes for allocation, without blocking
Expand All @@ -77,18 +76,22 @@ public abstract class AsyncShardFetch<T extends BaseNodeResponse> implements Rel
* An action that lists the relevant shard data that needs to be fetched.
*/
public interface Lister<NodesResponse extends BaseNodesResponse<NodeResponse>, NodeResponse extends BaseNodeResponse> {
void list(ShardId shardId, @Nullable String customDataPath, DiscoveryNode[] nodes, ActionListener<NodesResponse> listener);
void list(Map<ShardId, String> shardIdsWithCustomDataPath, DiscoveryNode[] nodes, ActionListener<NodesResponse> listener);

}

protected final Logger logger;
protected final String type;
protected final ShardId shardId;
protected final String customDataPath;

protected final Map<ShardId,String> shardToCustomDataPath;
private final Lister<BaseNodesResponse<T>, T> action;
private final Map<String, NodeEntry<T>> cache = new HashMap<>();
private final Set<String> nodesToIgnore = new HashSet<>();
private final AtomicLong round = new AtomicLong();
private boolean closed;
private final String logKey;
private final Map<ShardId, Set<String>> shardToIgnoreNodes = new HashMap<>();

private final boolean enableBatchMode;

@SuppressWarnings("unchecked")
protected AsyncShardFetch(
Expand All @@ -100,11 +103,30 @@ protected AsyncShardFetch(
) {
this.logger = logger;
this.type = type;
this.shardId = Objects.requireNonNull(shardId);
this.customDataPath = Objects.requireNonNull(customDataPath);
shardToCustomDataPath =new HashMap<>();
shardToCustomDataPath.put(shardId, customDataPath);
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.logKey = "ShardId=[" + shardId.toString() + "]";
enableBatchMode = false;
}

@SuppressWarnings("unchecked")
protected AsyncShardFetch(
Logger logger,
String type,
Map<ShardId, String> shardToCustomDataPath,
Lister<? extends BaseNodesResponse<T>, T> action,
String batchId
) {
this.logger = logger;
this.type = type;
this.shardToCustomDataPath = shardToCustomDataPath;
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.logKey = "BatchID=[" + batchId+ "]";
enableBatchMode = true;
}


@Override
public synchronized void close() {
this.closed = true;
Expand All @@ -130,11 +152,26 @@ public synchronized int getNumberOfInFlightFetches() {
* The ignoreNodes are nodes that are supposed to be ignored for this round, since fetching is async, we need
* to keep them around and make sure we add them back when all the responses are fetched and returned.
*/
public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> ignoreNodes) {
public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId, Set<String>> ignoreNodes) {
if (closed) {
throw new IllegalStateException(shardId + ": can't fetch data on closed async fetch");
throw new IllegalStateException(logKey + ": can't fetch data on closed async fetch");
}

if(enableBatchMode == false){
// we will do assertions here on ignoreNodes
assert ignoreNodes.size() <=1 : "Can only have at-most one shard";
if(ignoreNodes.size() == 1) {
assert shardToCustomDataPath.containsKey(ignoreNodes.keySet().iterator().next()) : "ShardId should be same as initialised in fetcher";
}
}

// add the nodes to ignore to the list of nodes to ignore for each shard
for (Map.Entry<ShardId, Set<String>> ignoreNodesEntry : ignoreNodes.entrySet()) {
Set<String> ignoreNodesSet = shardToIgnoreNodes.getOrDefault(ignoreNodesEntry.getKey(), new HashSet<>());
ignoreNodesSet.addAll(ignoreNodesEntry.getValue());
shardToIgnoreNodes.put(ignoreNodesEntry.getKey(), ignoreNodesSet);
}
nodesToIgnore.addAll(ignoreNodes);

fillShardCacheWithDataNodes(cache, nodes);
List<NodeEntry<T>> nodesToFetch = findNodesToFetch(cache);
if (nodesToFetch.isEmpty() == false) {
Expand All @@ -153,7 +190,7 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> i

// if we are still fetching, return null to indicate it
if (hasAnyNodeFetching(cache)) {
return new FetchResult<>(shardId, null, emptySet());
return new FetchResult<>(null, emptyMap());
} else {
// nothing to fetch, yay, build the return value
Map<DiscoveryNode, T> fetchData = new HashMap<>();
Expand All @@ -177,16 +214,19 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> i
}
}
}
Set<String> allIgnoreNodes = unmodifiableSet(new HashSet<>(nodesToIgnore));

Map<ShardId, Set<String>> allIgnoreNodesMap = unmodifiableMap(new HashMap<>(shardToIgnoreNodes));
// clear the nodes to ignore, we had a successful run in fetching everything we can
// we need to try them if another full run is needed
nodesToIgnore.clear();
shardToIgnoreNodes.clear();
// if at least one node failed, make sure to have a protective reroute
// here, just case this round won't find anything, and we need to retry fetching data
if (failedNodes.isEmpty() == false || allIgnoreNodes.isEmpty() == false) {
reroute(shardId, "nodes failed [" + failedNodes.size() + "], ignored [" + allIgnoreNodes.size() + "]");
if (failedNodes.isEmpty() == false || allIgnoreNodesMap.values().stream().anyMatch(ignoreNodeSet -> ignoreNodeSet.isEmpty() == false)) {
reroute(logKey, "nodes failed [" + failedNodes.size() + "], ignored ["
+ allIgnoreNodesMap.values().stream().mapToInt(Set::size).sum() + "]");
}
return new FetchResult<>(shardId, fetchData, allIgnoreNodes);

return new FetchResult<>(fetchData, allIgnoreNodesMap);
}
}

Expand All @@ -199,10 +239,10 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> i
protected synchronized void processAsyncFetch(List<T> responses, List<FailedNodeException> failures, long fetchingRound) {
if (closed) {
// we are closed, no need to process this async fetch at all
logger.trace("{} ignoring fetched [{}] results, already closed", shardId, type);
logger.trace("{} ignoring fetched [{}] results, already closed", logKey, type);
return;
}
logger.trace("{} processing fetched [{}] results", shardId, type);
logger.trace("{} processing fetched [{}] results", logKey, type);

if (responses != null) {
for (T response : responses) {
Expand All @@ -212,7 +252,7 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds";
logger.trace(
"{} received response for [{}] from node {} for an older fetching round (expected: {} but was: {})",
shardId,
logKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFetchingRound(),
Expand All @@ -221,29 +261,29 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
} else if (nodeEntry.isFailed()) {
logger.trace(
"{} node {} has failed for [{}] (failure [{}])",
shardId,
logKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFailure()
);
} else {
// if the entry is there, for the right fetching round and not marked as failed already, process it
logger.trace("{} marking {} as done for [{}], result is [{}]", shardId, nodeEntry.getNodeId(), type, response);
logger.trace("{} marking {} as done for [{}], result is [{}]", logKey, nodeEntry.getNodeId(), type, response);
nodeEntry.doneFetching(response);
}
}
}
}
if (failures != null) {
for (FailedNodeException failure : failures) {
logger.trace("{} processing failure {} for [{}]", shardId, failure, type);
logger.trace("{} processing failure {} for [{}]", logKey, failure, type);
NodeEntry<T> nodeEntry = cache.get(failure.nodeId());
if (nodeEntry != null) {
if (nodeEntry.getFetchingRound() != fetchingRound) {
assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds";
logger.trace(
"{} received failure for [{}] from node {} for an older fetching round (expected: {} but was: {})",
shardId,
logKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFetchingRound(),
Expand All @@ -261,7 +301,7 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
logger.warn(
() -> new ParameterizedMessage(
"{}: failed to list shard for {} on node [{}]",
shardId,
logKey,
type,
failure.nodeId()
),
Expand All @@ -273,13 +313,13 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
}
}
}
reroute(shardId, "post_response");
reroute(logKey, "post_response");
}

/**
* Implement this in order to scheduled another round that causes a call to fetch data.
*/
protected abstract void reroute(ShardId shardId, String reason);
protected abstract void reroute(String logKey, String reason);

/**
* Clear cache for node, ensuring next fetch will fetch a fresh copy.
Expand Down Expand Up @@ -334,8 +374,8 @@ private boolean hasAnyNodeFetching(Map<String, NodeEntry<T>> shardCache) {
*/
// visible for testing
void asyncFetch(final DiscoveryNode[] nodes, long fetchingRound) {
logger.trace("{} fetching [{}] from {}", shardId, type, nodes);
action.list(shardId, customDataPath, nodes, new ActionListener<BaseNodesResponse<T>>() {
logger.trace("{} fetching [{}] from {}", logKey, type, nodes);
action.list(shardToCustomDataPath, nodes, new ActionListener<BaseNodesResponse<T>>() {
@Override
public void onResponse(BaseNodesResponse<T> response) {
processAsyncFetch(response.getNodes(), response.failures(), fetchingRound);
Expand All @@ -358,15 +398,13 @@ public void onFailure(Exception e) {
*/
public static class FetchResult<T extends BaseNodeResponse> {

private final ShardId shardId;
private final Map<DiscoveryNode, T> data;
private final Set<String> ignoreNodes;
private final Map<DiscoveryNode, T> data;
private final Map<ShardId, Set<String>> ignoredShardToNodes;

public FetchResult(ShardId shardId, Map<DiscoveryNode, T> data, Set<String> ignoreNodes) {
this.shardId = shardId;
this.data = data;
this.ignoreNodes = ignoreNodes;
}
public FetchResult(Map<DiscoveryNode, T> data, Map<ShardId, Set<String>> ignoreNodes) {
this.data = data;
this.ignoredShardToNodes = ignoreNodes;
}

/**
* Does the result actually contain data? If not, then there are on going fetch
Expand All @@ -385,15 +423,20 @@ public Map<DiscoveryNode, T> getData() {
return this.data;
}

/**
* Process any changes needed to the allocation based on this fetch result.
*/
public void processAllocation(RoutingAllocation allocation) {
for (String ignoreNode : ignoreNodes) {
allocation.addIgnoreShardForNode(shardId, ignoreNode);
/**
* Process any changes needed to the allocation based on this fetch result.
*/
public void processAllocation(RoutingAllocation allocation) {
for(Map.Entry<ShardId, Set<String>> entry : ignoredShardToNodes.entrySet()) {
ShardId shardId = entry.getKey();
Set<String> ignoreNodes = entry.getValue();
if (ignoreNodes.isEmpty() == false) {
ignoreNodes.forEach(nodeId -> allocation.addIgnoreShardForNode(shardId, nodeId));
}
}

}
}

/**
* A node entry, holding the state of the fetched data for a specific shard
Expand Down
Loading

0 comments on commit 490d77f

Please sign in to comment.